{-# LANGUAGE TemplateHaskell #-}

module MonadLab.Declarations (
   createMonadDecls, createRunDecl,
   createResBaseDecls, createResDecl,
   createReactiveBaseDecls, createReactiveDecl,
   createStepTypeSig, createStepDecl,
   createSignalTypeSig, createSignalDecl,
   createSignullTypeSig, createSignullDecl
 ) where

import Prelude hiding (Monad)
import qualified Prelude (Monad)
import Language.Haskell.TH
import MonadLab.CommonTypes
import MonadLab.Reduction


-------------------------------
--General Declarations

mkValDecl :: String -> ExpQ -> DecQ
mkValDecl name e = valD (varP $ mkName name) (normalB e) []


createMonadDecls :: MonadName -> Monad -> Q [Dec]
createMonadDecls monadName m = do
			typeDecl	  <- createTypeDecl monadName m
			returnDecl	  <- createReturnDecl monadName m
			bindDecl	  <- createBindDecl monadName m
			monadInstanceDecl <- createMonadInstanceDecl monadName m
			npmDecls	  <- createNPMDecls monadName m
			return $ typeDecl : monadInstanceDecl : npmDecls

createTypeDecl :: MonadName -> Monad -> DecQ
createTypeDecl monadName m = let tc = getTypeCon m
				 a  = mkName "a"
				 deM  = mkName ("de" ++ monadName)
			     in  tc (varT a) >>= \t -> return $ NewtypeD [] (mkName monadName) [a] (RecC (mkName monadName) [(deM, NotStrict, t)]) []



createMonadInstanceDecl :: MonadName -> Monad -> DecQ
createMonadInstanceDecl monadName m = instanceD (return []) (appT (conT $ mkName "Prelude.Monad") (conT $ mkName monadName)) [createReturnDecl monadName m, createBindDecl monadName m] 

createReturnDecl :: MonadName -> Monad -> DecQ
createReturnDecl monadName m = let conM   = conE (mkName monadName)
				   return = getReturn m
			       in  reduceM $ mkValDecl "return" [| $composition $conM $return |]

createBindDecl :: MonadName -> Monad -> DecQ
createBindDecl monadName m = let deM  = dyn ("de" ++ monadName)
				 conM = conE (mkName monadName)
				 bind = getBind m
			     in  reduceM $ mkValDecl ">>=" ([| \x -> \f -> $conM $ $bind ($deM x) ($composition $deM f) |])

---------------------------------------------

createRunDecl :: [Layer] -> MonadName -> RunExpQ -> DecQ
createRunDecl ls monadName run = let deM = dyn ("de" ++ monadName)
				 in  reduceM $ mkValDecl ("run" ++ monadName) [| $composition $run $deM |]

---------------------------------------------

createNPMDecls :: MonadName -> Monad -> Q [Dec]
createNPMDecls monadName m = let lnpms = getLayerNPMs m
			     in  sequence $ concatMap (createLayerNPMDecls monadName) lnpms

		where	createLayerNPMDecls :: MonadName -> LayerNPM -> [DecQ]
			createLayerNPMDecls monadName c@(Io, _)	 	 = createIoDecls monadName c
			createLayerNPMDecls monadName c@(List, _)	 = createListMonadDecls monadName c
			createLayerNPMDecls monadName c@(ErrorT _ _, _)  = createErrorDecls monadName c
			createLayerNPMDecls monadName c@(StateT _ _, _)  = createStateDecls monadName c
			createLayerNPMDecls monadName c@(EnvT _ _, _) 	 = createEnvDecls monadName c
			createLayerNPMDecls monadName c@(WriterT _ _, _) = createWriterDecls monadName c
			createLayerNPMDecls monadName c@(ContT _, _)	 = createContDecls monadName c

---------------------------------------------
--Basic Resumption Declarations

createResBaseDecls :: Q [Dec]
createResBaseDecls = [d|
			data Prelude.Monad m => ResT m a = Done a | Pause (m (ResT m a))

			instance Prelude.Monad m => Prelude.Monad (ResT m) where
				return v       = Done v
				Done  v >>= f  = f v
				Pause m >>= f  = Pause (m >>= \r -> return (r >>= f))
		     |]

createResDecl :: MonadName -> MonadName -> DecQ
createResDecl resName innerName = return $ TySynD (mkName resName) [] (AppT (ConT (mkName "ResT")) (ConT (mkName innerName)))  


---------------------------------------------
--Reactive Resumption Declarations

createReactiveBaseDecls :: Q [Dec]
createReactiveBaseDecls = [d|
				data Prelude.Monad m => ReactT q r m a = D a
								 | P (q, r -> m (ReactT q r m a))
				
				instance Prelude.Monad m => Prelude.Monad (ReactT q r m) where
					return         = D
					D v      >>= f = f v
					P (r, s) >>= f = P (r, \rsp -> (s rsp) >>= \m -> return (m >>= f))
			  |]


createReactiveDecl :: MonadName -> TypeQ -> TypeQ -> MonadName -> DecQ
createReactiveDecl reactiveName reqT rspT innerName = tySynD (mkName reactiveName) [] (appsT [conT (mkName "ReactT"), reqT, rspT, conT (mkName innerName)])
					where appsT = foldl1 appT

createStepTypeSig :: MonadName -> MonadName -> DecQ
createStepTypeSig reactiveName innerName = do
					   a <- newName "a"
					   return $ SigD (mkName $ "step" ++ reactiveName) (ForallT [a] [] (AppT (AppT ArrowT (AppT (ConT (mkName innerName)) (VarT a))) (AppT (ConT (mkName reactiveName)) (VarT a))))

createStepDecl :: MonadName -> MonadName -> DecQ
createStepDecl reactiveName innerName = do
					x <- newName "x"
					mkValDecl ("step" ++ reactiveName) (return $ LamE [VarP x] (AppE (ConE (mkName "P")) (TupE [ConE (mkName "Cont"), LamE [ConP (mkName "Ack") []] (InfixE (Just (VarE x)) (VarE (mkName ">>=")) (Just (InfixE (Just (VarE (mkName "return"))) (VarE (mkName ".")) (Just (ConE (mkName "D"))))))])))

createSignalTypeSig :: MonadName -> TypeQ -> TypeQ -> DecQ
createSignalTypeSig reactiveName reqT rspT = do
					req <- reqT
					rsp <- rspT
					return $ SigD (mkName ("signal" ++ reactiveName)) (AppT (AppT ArrowT req) (AppT (ConT (mkName reactiveName)) rsp)) 
		
createSignalDecl :: MonadName -> DecQ
createSignalDecl reactiveName = mkValDecl ("signal" ++ reactiveName) [| \q -> $(conE (mkName "P")) (q, return . return) |]


createSignullTypeSig :: MonadName -> TypeQ -> DecQ
createSignullTypeSig reactiveName reqT = do
					req <- reqT
					return $SigD (mkName ("signull" ++ reactiveName)) (AppT (AppT ArrowT req) (AppT (ConT (mkName reactiveName)) (ConT (mkName "()")))) 
		
createSignullDecl :: MonadName -> DecQ
createSignullDecl reactiveName = mkValDecl ("signull" ++ reactiveName) [| \q -> $(conE (mkName "P")) (q, \_ -> return . return $ ()) |]


-------------------------------
--State Declarations

createStateDecls :: MonadName -> LayerNPM -> [DecQ]
createStateDecls monadName (StateT stateName _, [get, put]) = [createGetDecl monadName stateName get, createPutDecl monadName stateName put, createUpdateDecl monadName stateName]

createGetDecl :: MonadName -> StateName -> NonProperMorphismExpQ -> DecQ
createGetDecl monadName stateName get = let conM = conE (mkName monadName)
					in  reduceM $ mkValDecl ("get" ++ stateName ++ monadName) [| $conM $get |]

createPutDecl :: MonadName -> StateName -> NonProperMorphismExpQ -> DecQ
createPutDecl monadName stateName put = let conM = conE (mkName monadName)
					in  reduceM $ mkValDecl ("put" ++ stateName ++ monadName) [| $composition $conM $put |]

createUpdateDecl :: MonadName -> StateName -> DecQ
createUpdateDecl monadName stateName = let get = dyn ("get" ++ stateName ++ monadName)
				           put = dyn ("put" ++ stateName ++ monadName)
				       in  mkValDecl ("update" ++ stateName ++ monadName) [| \f -> $get >>= \s -> $put (f s) >> $get |]


-------------------------------
--Environment Declarations

createEnvDecls :: MonadName -> LayerNPM -> [DecQ]
createEnvDecls monadName (EnvT envName _, [rdEnv,inEnv]) = [createRdEnvDecl monadName envName rdEnv, createInEnvDecl monadName envName inEnv]

createRdEnvDecl :: MonadName -> EnvName -> NonProperMorphismExpQ -> DecQ
createRdEnvDecl monadName envName rdEnv = reduceM $ mkValDecl ("rd" ++ envName ++ monadName) [| $(conE (mkName monadName)) $ $rdEnv |]

createInEnvDecl :: MonadName -> EnvName -> NonProperMorphismExpQ -> DecQ
createInEnvDecl monadName envName inEnv = let deM  = dyn ("de" ++ monadName)
					      conM = conE (mkName monadName)
					  in  reduceM $ mkValDecl ("in" ++ envName ++ monadName)  [| \r -> \m -> $conM $ $inEnv r ($deM m) |]


-------------------------------
--Error Declarations

createErrorDecls :: MonadName -> LayerNPM -> [DecQ]
createErrorDecls monadName (ErrorT errorName _, [throw,catch]) = [createThrowDecl monadName errorName throw, createCatchDecl monadName errorName catch]

createThrowDecl :: MonadName -> ErrorName -> NonProperMorphismExpQ -> DecQ
createThrowDecl monadName errorName throw = let conM = conE (mkName monadName)
					    in  reduceM $ mkValDecl ("throw" ++ errorName ++ monadName) [| $composition $conM $throw |]

createCatchDecl :: MonadName -> ErrorName -> NonProperMorphismExpQ -> DecQ
createCatchDecl monadName errorName catch = let deM  = dyn ("de" ++ monadName)
						conM = conE (mkName monadName)
					    in  reduceM $ mkValDecl ("catch" ++ errorName ++ monadName)  [| \x -> \h -> $conM $ $catch ($deM x) ($composition $deM h) |]


-------------------------------
--List Declarations

createListMonadDecls :: MonadName -> LayerNPM -> [DecQ]
createListMonadDecls monadName (List, [merge, halt])           = [createMergeDecl monadName merge, createHaltDecl monadName halt]


createMergeDecl :: MonadName -> NonProperMorphismExpQ -> DecQ
createMergeDecl monadName merge = let deM  = dyn ("de" ++ monadName)
				      conM = conE (mkName monadName)
				  in  reduceM $ mkValDecl ("merge" ++ monadName) [| \xs -> $conM $ $merge (map $deM xs) |] 

createHaltDecl :: MonadName -> NonProperMorphismExpQ -> DecQ
createHaltDecl monadName halt = let conM = conE (mkName monadName)
				in  reduceM $ mkValDecl ("halt" ++ monadName) [| $conM $halt |] 


-------------------------------
--Io Declarations

createIoDecls :: MonadName -> LayerNPM -> [DecQ]
createIoDecls monadName (Io, [liftIO])           = [createLiftIODecl monadName liftIO]


createLiftIODecl :: MonadName -> NonProperMorphismExpQ -> DecQ
createLiftIODecl monadName liftIO = let conM = conE (mkName monadName)
				    in  reduceM $ mkValDecl ("liftIO" ++ monadName) [| $conM . $liftIO |] 


-------------------------------
--Writer Declarations

createWriterDecls :: MonadName -> LayerNPM -> [DecQ]
createWriterDecls monadName (WriterT writerName _, [tell, listen, pass]) = [ createTellDecl monadName writerName tell
									   , createListenDecl monadName writerName listen
									   , createPassDecl monadName writerName pass
									   , createCensorDecl monadName writerName
									   ]

createTellDecl :: MonadName -> WriterName -> NonProperMorphismExpQ -> DecQ
createTellDecl monadName writerName tell = let conM = conE (mkName monadName)
					   in  reduceM $ mkValDecl ("tell" ++ writerName ++ monadName) [| $composition $conM $tell |]

createListenDecl :: MonadName -> WriterName -> NonProperMorphismExpQ -> DecQ
createListenDecl monadName writerName listen = let deM  = dyn ("de" ++ monadName)
						   conM = conE (mkName monadName)
					       in  reduceM $ mkValDecl ("listen" ++ writerName ++ monadName)  [| \m -> $conM ( $listen ($deM m) ) |]

createPassDecl :: MonadName -> WriterName -> NonProperMorphismExpQ -> DecQ
createPassDecl monadName writerName pass = let deM  = dyn ("de" ++ monadName)
					       conM = conE (mkName monadName)
					   in  reduceM $ mkValDecl ("pass" ++ writerName ++ monadName)  [| \m -> $conM ( $pass ($deM m) ) |]

createCensorDecl :: MonadName -> WriterName -> DecQ
createCensorDecl monadName writerName = let pass = varE (mkName ("pass" ++ writerName ++ monadName))
					in  reduceM $ mkValDecl ("censor" ++ writerName ++ monadName) [| \f -> \m -> $pass (m >>= \a -> return (a, f)) |]




-------------------------------
--Continuation Declarations

createContDecls :: MonadName -> LayerNPM -> [DecQ]
createContDecls monadName (ContT _, [callcc]) = [createCallCCDecl monadName callcc]

createCallCCDecl :: MonadName -> NonProperMorphismExpQ -> DecQ
createCallCCDecl monadName callcc =let deM  = dyn ("de" ++ monadName)
				       conM = conE (mkName monadName)
				   in  reduceM $ mkValDecl ("callcc" ++ monadName) [| \h -> $conM $ $callcc ($deM . (\y -> h ($conM . y))) |]