{-# LANGUAGE CPP #-}
module Control.Arrow.Static where

import Control.Arrow
import Control.Applicative
import Control.Category
import Control.Comonad
import Control.Monad.Instances
import Control.Monad (ap)
import Data.Functor.Apply
import Data.Functor.Plus
import Data.Monoid
import Data.Semigroup
import Data.Semigroupoid
import Prelude hiding ((.), id)

#ifdef LANGUAGE_DeriveDataTypeable 
import Data.Typeable
#endif

newtype Static f a b = Static { runStatic :: f (a -> b) } 
#ifdef LANGUAGE_DeriveDataTypeable
  deriving (Typeable)
#endif

instance Functor f => Functor (Static f a) where
  fmap f = Static . fmap (f .) . runStatic

instance Apply f => Apply (Static f a) where
  Static f <.> Static g = Static (ap <$> f <.> g)

instance Alt f => Alt (Static f a) where
  Static f <!> Static g = Static (f <!> g)

instance Plus f => Plus (Static f a) where
  zero = Static zero

instance Applicative f => Applicative (Static f a) where
  pure = Static . pure . const 
  Static f <*> Static g = Static (ap <$> f <*> g)

instance (Extend f, Semigroup a) => Extend (Static f a) where
  extend f = Static . extend (\wf m -> f (Static (fmap (. (<>) m) wf))) . runStatic

instance (Comonad f, Semigroup a, Monoid a) => Comonad (Static f a) where
  extract (Static g) = extract g mempty

instance Apply f => Semigroupoid (Static f) where
  Static f `o` Static g = Static ((.) <$> f <.> g)

instance Applicative f => Category (Static f) where
  id = Static (pure id)
  Static f . Static g = Static ((.) <$> f <*> g)

instance Applicative f => Arrow (Static f) where
  arr = Static . pure 
  first (Static g) = Static (first <$> g) 
  second (Static g) = Static (second <$> g) 
  Static g *** Static h = Static ((***) <$> g <*> h)
  Static g &&& Static h = Static ((&&&) <$> g <*> h)

instance Alternative f => ArrowZero (Static f) where
  zeroArrow = Static empty
  
instance Alternative f => ArrowPlus (Static f) where
  Static f <+> Static g = Static (f <|> g)

instance Applicative f => ArrowChoice (Static f) where
  left (Static g) = Static (left <$> g)
  right (Static g) = Static (right <$> g)
  Static g +++ Static h = Static ((+++) <$> g <*> h)
  Static g ||| Static h = Static ((|||) <$> g <*> h)