{-# LANGUAGE TemplateHaskell #-}

module MonadLab.Env (
   envT
 ) where

import Prelude hiding (Monad)
import Language.Haskell.TH
import MonadLab.CommonTypes

envT :: Layer -> MonadTransformer
envT l@(EnvT name t0) = \m -> ( envTransTypeCon t0 m
			      , envTransReturn m
			      , envTransBind m
			      , envTransLayerNPM l m : envTransLiftLayerNPMs m
			      , [| $envTransLift . $(getBaseLift m) |]
			      )

--------------------------------

envTransTypeCon :: TypeQ -> Monad -> MonadTypeCon
envTransTypeCon r m = let tc = getTypeCon m
		      in  \t -> arrow r (tc t)

envTransReturn :: Monad -> ReturnExpQ
envTransReturn m = let return = getReturn m
		   in  [| \v -> \r -> $return v |]

envTransBind :: Monad -> BindExpQ
envTransBind m = let bind = getBind m
		 in  [| \m -> \f -> \r -> $bind (m r) (\v -> f v r) |]

envTransRdEnv :: Monad -> NonProperMorphismExpQ
envTransRdEnv m = let return = getReturn m
		  in  [| \r -> $return r |]

envTransInEnv :: NonProperMorphismExpQ
envTransInEnv = [| \r -> \m -> \_ -> m r |]

envTransLift :: LiftExpQ
envTransLift = [| \m -> \r -> m |]


--------------------------------

envTransLayerNPM :: Layer -> Monad -> LayerNPM
envTransLayerNPM l m = (l, [envTransRdEnv m, envTransInEnv])

envTransLiftLayerNPMs :: Monad -> [LayerNPM]
envTransLiftLayerNPMs m = map (envTransLiftLayerNPM m) (getLayerNPMs m) 

	where	envTransLiftLayerNPM :: Monad -> LayerNPM -> LayerNPM
		envTransLiftLayerNPM m l = case l of
			(Io, [liftIO]) 		   	    -> (Io, [envTransLiftLiftIO liftIO])
			(List, [merge, halt]) 		    -> (List, [envTransLiftMerge m, envTransLiftHalt m])
			(StateT n t, [get,put]) 	    -> (StateT n t, [envTransLiftGet get, envTransLiftPut put])
			(EnvT n t, [rdEnv,inEnv])   	    -> (EnvT n t, [envTransLiftRdEnv rdEnv, envTransLiftInEnv inEnv])
			(ErrorT n t, [throw,catch]) 	    -> (ErrorT n t, [envTransLiftThrow throw, envTransLiftCatch catch])
			(WriterT n t, [tell, listen, pass]) -> (WriterT n t, [envTransLiftTell tell, envTransLiftListen m listen, envTransLiftPass m pass])
			(ContT t, [callcc])		    -> (ContT t, [envTransLiftCallCC m callcc])


envTransLiftRdEnv :: NonProperMorphismExpQ -> NonProperMorphismExpQ 
envTransLiftRdEnv rdEnv = [| $envTransLift $rdEnv |]

envTransLiftInEnv :: NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftInEnv inEnv = [| \r -> \m -> \r' -> $inEnv r (m r') |]

envTransLiftGet :: NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftGet get = [| $envTransLift $get |]

envTransLiftPut :: NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftPut put = [| $composition $envTransLift $put |] 

envTransLiftThrow :: NonProperMorphismExpQ -> NonProperMorphismExpQ 
envTransLiftThrow throw = [| $composition $envTransLift $throw |]

envTransLiftCatch :: NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftCatch catch = [| \m -> \h -> \r -> $catch (m r) (\e -> h e r) |]

envTransLiftMerge :: Monad -> NonProperMorphismExpQ
envTransLiftMerge m = let newJoin     = [| \x -> $(envTransBind m) x (\a -> a) |]
			  newBaseLift = [| $envTransLift . $(getBaseLift m) |]
		      in  [| $composition $newJoin $newBaseLift |]

envTransLiftHalt :: Monad -> NonProperMorphismExpQ
envTransLiftHalt m = let newBaseLift = [| $envTransLift . $(getBaseLift m) |]
		     in  [| $newBaseLift [] |]

envTransLiftLiftIO :: NonProperMorphismExpQ -> NonProperMorphismExpQ 
envTransLiftLiftIO liftIO = [| $envTransLift . $liftIO |]

envTransLiftTell :: NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftTell tell = [| $composition $envTransLift $tell |]

envTransLiftListen :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftListen m listen = let return = getReturn m
				  bind   = getBind m
			      in  [| \m -> \r -> $listen (m r) |]


envTransLiftPass :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftPass m pass = let return = getReturn m
			      bind   = getBind m
			  in  [| \m -> \r -> $pass (m r) |]


envTransLiftCallCC :: Monad -> NonProperMorphismExpQ -> NonProperMorphismExpQ
envTransLiftCallCC m callcc = [| \f -> \r -> $callcc (\k -> f (\a -> \r' -> k a) r) |]