{-# LANGUAGE TemplateHaskell #-}

module MonadLab.State (
   stateT
 ) where

import Prelude hiding (Monad)
import Language.Haskell.TH
import MonadLab.CommonTypes


stateT :: Layer -> MonadTransformer
stateT l@(StateT name t) = \m -> ( stateTransTypeCon t m
				 , stateTransReturn m
				 , stateTransBind m
				 , stateTransLayerNPM l m : stateTransLiftLayerNPMs m
				 , [| $(stateTransLift m) . $(getBaseLift m) |]
				 )

--------------------------------

stateTransTypeCon :: TypeQ -> Monad -> MonadTypeCon
stateTransTypeCon s m = let tc = getTypeCon m
			in  \t -> appT (appT arrowT s) (tc (appT (appT (tupleT 2) t) s))

stateTransReturn :: Monad -> ReturnExpQ
stateTransReturn m = let return = getReturn m
		     in  [| \v -> \s -> $return (v, s) |]

stateTransBind :: Monad -> BindExpQ
stateTransBind m = let bind = getBind m
		   in  [| \m -> \f -> \s0 -> $bind (m s0) (\(v, s1) -> f v s1) |]

stateTransGet :: Monad -> NonProperMorphismExpQ
stateTransGet m = let return = getReturn m
		  in  [| \s -> $return (s,s) |]

stateTransPut :: Monad -> NonProperMorphismExpQ
stateTransPut m = let return = getReturn m
		  in  [| \s -> \_ -> $return ((),s) |]

stateTransLift :: Monad -> LiftExpQ
stateTransLift m = let return = getReturn m
		       bind   = getBind m
		   in  [| \m -> \s -> $bind m (\v -> $return (v, s)) |]

--------------------------------

stateTransLayerNPM :: Layer -> Monad -> LayerNPM
stateTransLayerNPM l m = (l, [stateTransGet m, stateTransPut m])

stateTransLiftLayerNPMs :: Monad -> [LayerNPM]
stateTransLiftLayerNPMs m = map (stateTransLiftLayerNPM m) (getLayerNPMs m) 

	where	stateTransLiftLayerNPM :: Monad -> LayerNPM -> LayerNPM
		stateTransLiftLayerNPM m lnpms = case lnpms of
			(Io, [liftIO]) 		  	    -> (Io, [stateTransLiftLiftIO m liftIO])
			(List, [merge, halt]) 		    -> (List, [stateTransLiftMerge m, stateTransLiftHalt m])
			(StateT n t, [get,put]) 	    -> (StateT n t, [stateTransLiftGet m get, stateTransLiftPut m put])
			(EnvT n t, [rdEnv,inEnv])   	    -> (EnvT n t, [stateTransLiftRdEnv m rdEnv, stateTransLiftInEnv inEnv])
			(ErrorT n t, [throw,catch]) 	    -> (ErrorT n t, [stateTransLiftThrow m throw, stateTransLiftCatch catch])
			(WriterT n t, [tell, listen, pass]) -> (WriterT n t, [stateTransLiftTell m tell, stateTransLiftListen m listen, stateTransLiftPass m pass])
			(ContT t, [callcc])		    -> (ContT t, [stateTransLiftCallCC callcc])

stateTransLiftGet :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftGet m get = [| $(stateTransLift m) $get |]

stateTransLiftPut :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftPut m put = [| $composition $(stateTransLift m) $put |]

stateTransLiftRdEnv :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ 
stateTransLiftRdEnv m rdEnv = [| $(stateTransLift m) $rdEnv |]

stateTransLiftInEnv :: NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftInEnv inEnv = [| \r -> \m -> \s -> $inEnv r (m s) |]

stateTransLiftThrow :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ 
stateTransLiftThrow m throw = [| $composition $(stateTransLift m) $throw |]

stateTransLiftCatch :: NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftCatch catch = [| \x -> \h -> \s -> $catch (x s) (\e -> h e s) |]

stateTransLiftMerge :: Monad -> NonProperMorphismExpQ
stateTransLiftMerge m = let newJoin     = [| \x -> $(stateTransBind m) x (\a -> a) |]
			    newBaseLift = [| $(stateTransLift m) . $(getBaseLift m) |]
			in  [| $composition $newJoin $newBaseLift |]

stateTransLiftHalt :: Monad -> NonProperMorphismExpQ
stateTransLiftHalt m = let newBaseLift = [| $(stateTransLift m) . $(getBaseLift m) |]
		       in  [| $newBaseLift [] |]


stateTransLiftLiftIO :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ 
stateTransLiftLiftIO m liftIO = [| $(stateTransLift m) . $liftIO |]

stateTransLiftTell :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftTell m tell = [| $composition $(stateTransLift m) $tell |]

stateTransLiftListen :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftListen m listen = let return = getReturn m
				    bind   = getBind m
			        in  [| \m -> \s -> $bind ($listen (m s)) (\((a, s'), w) -> $return ((a, w), s')) |]

stateTransLiftPass :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftPass m pass = let return = getReturn m
				bind   = getBind m
			    in  [| \m -> \s -> $pass $ $bind (m s) (\((v, f), s') -> $return ((v, s'), f)) |]

stateTransLiftCallCC :: NonProperMorphismExpQ -> NonProperMorphismExpQ
stateTransLiftCallCC callcc = [| \f -> \s0 -> $callcc (\k -> f (\a -> \s1 -> k (a, s1)) s0) |]