{-#LANGUAGE Arrows, FlexibleInstances, MultiParamTypeClasses, FlexibleContexts,
  UndecidableInstances, FunctionalDependencies, NoMonomorphismRestriction #-}

module Control.Arrow.Transformer.Automaton.Monad
    (monadToAuto, co, autoToMonad, readerArrow, swapsnd,
     pushError,popError,rstrength,
     ArrowAddAutomaton (..), dispatch) where

import Control.Monad
import Control.Monad.Cont
import Control.Monad.State (MonadState (..))

import Control.Arrow
import Control.Arrow.Operations
import qualified Control.Arrow.Transformer as AT
import Control.Arrow.Transformer.All

import Data.Maybe
import qualified Data.Map as M

unAM (ArrowMonad f) = f

monadToAuto
  :: (ArrowAddAutomaton a a', ArrowApply a') =>
     (i -> ContT (o, a i o) (ArrowMonad a') z) -> a i o
monadToAuto f = liftAutomaton (proc i -> 
     unAM ((f i) `runContT` (error "automaton ended")) -<< ())


co
  :: (ArrowApply a', ArrowAddAutomaton a a') =>
     o -> ContT (o, a i o) (ArrowMonad a') i
co o = ContT (\fi -> 
                  return (o, liftAutomaton (proc i -> unAM (fi i) -<< ())))

autoToMonad
  :: (ArrowApply a', ArrowAddAutomaton a a') =>
     a i (Either o z)
     -> i
     -> ContT (o, a i o) (ArrowMonad a') z
autoToMonad f i = do
  x <- lift $ ArrowMonad $ (proc () -> elimAutomaton f -< i)
  case x of
    (Right z,_) -> return z
    (Left o,f') -> autoToMonad f' =<< co o



class ArrowAddAutomaton a a' | a -> a' where
    elimAutomaton :: a i o -> a' i (o, a i o)
    liftAutomaton :: a' i (o, a i o) -> a i o
    constantAutomaton :: a' i o -> a i o

instance (Arrow a) => ArrowAddAutomaton (Automaton a) a where
    elimAutomaton (Automaton f) = f
    liftAutomaton = Automaton
    constantAutomaton f = Automaton (f >>> 
                                     arr (flip (,) (constantAutomaton f)))

instance (Arrow a, Arrow a', ArrowAddAutomaton a a') 
    => ArrowAddAutomaton (StateArrow s a) (StateArrow s a') where
   elimAutomaton = autoState . elimAutomaton . runState 
   liftAutomaton = stateArrow . liftAutomaton . stateAuto
   constantAutomaton = stateArrow . constantAutomaton . runState
    

instance (ArrowState s a, ArrowApply a) => (MonadState s (ArrowMonad a)) where
    put s = ArrowMonad (proc () -> store -< s)
    get = ArrowMonad fetch

instance (Arrow a, Arrow a', ArrowAddAutomaton a a') 
    => ArrowAddAutomaton (ReaderArrow r a) (ReaderArrow r a') where
   elimAutomaton = (>>> (second (arr readerArrow))) . 
                   readerArrow . elimAutomaton . runReader

   liftAutomaton = readerArrow . liftAutomaton . 
                   (>>> (second (arr runReader))) . runReader
    
   constantAutomaton = readerArrow . constantAutomaton . runReader

instance (ArrowChoice a, ArrowChoice a', ArrowAddAutomaton a a')
    => ArrowAddAutomaton (ErrorArrow ex a) (ErrorArrow ex a') where
        elimAutomaton = pushError . 
                (>>> second (arr pushError) >>> arr rstrength) 
                . elimAutomaton . popError

        liftAutomaton f = 
            pushError $ liftAutomaton $
            (>>> arr (revrstrength (liftAutomaton f)) 
             >>> second (arr popError)) 
            $ popError f

        constantAutomaton = pushError . constantAutomaton . popError


dispatch = dispatch0 M.empty

dispatch0
  :: (Ord k,
      ArrowAddAutomaton a a',
      ArrowApply a') =>
     M.Map k (a i o) -> (k -> a i o) -> a (i, k) o
dispatch0 mp def = liftAutomaton $ proc (i,k) -> do
                    let f = fromMaybe (def k) (M.lookup k mp)
                    (o,f') <- app -< (elimAutomaton f,i)
                    returnA -< (o, dispatch0 (M.insert k f' mp) def)
                    

--Utility functions

swapsnd :: ((a, b), c) -> ((a, c), b)
swapsnd ~(~(x, y), z) = ((x, z), y)

rstrength :: (Either ex a, b) -> Either ex (a, b)
rstrength (Left ex, _) = Left ex
rstrength (Right a, b) = Right (a, b)

revrstrength :: b -> Either ex (a,b) -> (Either ex a, b)
revrstrength def (Left ex) = (Left ex, def)
revrstrength _ (Right (a,b)) = (Right a, b)

autoState :: (Arrow a, Arrow a') => a' (i,s) ((o,s), a (i,s) (o,s)) -> 
             StateArrow s a' i (o,StateArrow s a i o)
autoState f = stateArrow $ f >>> second (arr stateArrow) >>> arr swapsnd

stateAuto :: (Arrow a, Arrow a') => StateArrow s a' i (o,StateArrow s a i o) ->
             a' (i,s) ((o,s), a (i,s) (o,s))
stateAuto f = runState (f >>> second (arr runState)) >>> arr swapsnd


--simulating the unexported data constructors for StateArrow,
--ReaderArrow, ErrorArrow

stateArrow :: (Arrow a) => a (t, s) (b, s) -> StateArrow s a t b
stateArrow f = proc i -> do
                 s <- fetch -< ()
                 (o,s') <- AT.lift f -< (i,s)
                 store -< s'
                 returnA -< o

readerArrow :: (Arrow a) => a (e,r) b -> ReaderArrow r a e b
readerArrow f = proc i -> do
                  r <- readState -< ()
                  AT.lift f -< (i,r)

popError :: (ArrowChoice a) => ErrorArrow ex a e b -> a e (Either ex b)
popError f = runError (f >>> arr Right) (arr snd >>> arr Left)

pushError :: (ArrowChoice a) => a e (Either ex b) -> ErrorArrow ex a e b
pushError f = (AT.lift f) >>> (raise ||| arr id)