{-# LANGUAGE FlexibleInstances, TemplateHaskell, CPP, PatternGuards #-} module Happstack.State.ComponentTH ( mkMethods ) where import Data.Char import Language.Haskell.TH import Happstack.State.Types import Happstack.Data.Serialize import Happstack.State.ComponentSystem import Control.Monad.State import Control.Monad.Reader import Control.Monad.Writer import Data.Maybe import Data.List import Data.Generics.Basics nubCxt :: [TypeQ] -> Q [Type] nubCxt tsQ = do ts <- cxt tsQ return $ nub ts {- Error cases: not all state keys are tyvars. method not using the keys in either the args or the result. Checked: component not being a data declaration with a single constructor. Checked: method using a tyvar that isn't a key. Checked: method is not a function. -} {- | This function will derive, for the provided type name and the provided list of methods, an instance of Methods for the given type and necessary instance of UpdateEvent and QueryEvent to make the provided list of functions into Methods. This TH splice needs to be executed for any MACID state type in order to have the necessary code for event handling generated. Ex: @$(mkMethods ''Foo ['fooUpdate,'fooQuery'])@ -} mkMethods :: Name -> [Name] -> Q [Dec] mkMethods componentName componentMethods = do keys <- liftM (requireSimpleCon componentName) $ reify componentName methodInfos <- getMethodInfos keys componentMethods ds1 <- genEventInstances methodInfos let handlers = genComponentHandlers methodInfos stType = mkType componentName keys context' = mkKeyConstraints keys ++ concatMap (mkMethodConstraints keys) methodInfos ds2 <- instanceD (nubCxt context') (appT (conT ''Methods) stType) [ funD 'methods [clause [wildP] (normalB handlers) []] ] ds4 <- genMethodStructs [''Typeable] methodInfos ds5 <- genSerializeInstances methodInfos return (ds1 ++ [ds2] ++ ds4 ++ ds5 ) mkKeyConstraints :: [Name] -> [TypeQ] mkKeyConstraints keys = [ appT (conT ''Typeable) (varT key) | key <- keys ] ++ [ appT (conT ''Serialize) (varT key) | key <- keys ] mkMethodConstraints :: [Name] -> MethodInfo -> [TypeQ] mkMethodConstraints keys method = map return (substMethodContext method keys) substMethodContext :: MethodInfo -> [Name] -> [Type] substMethodContext method keys = let relation = zip (methodKeys method) keys worker (VarT old) | Just new <- lookup old relation = VarT new worker (AppT l r) = AppT (worker l) (worker r) worker (ForallT c names t) = ForallT c names (worker t) worker t = t in map worker (methodContext method) mkType :: Name -> [Name] -> TypeQ mkType name args = foldl appT (conT name) (map varT args) genSerializeInstances :: [MethodInfo] -> Q [Dec] genSerializeInstances meths = liftM concat $ forM meths $ \method -> let constraints = nubCxt $ mkKeyConstraints (methodKeys method) ++ map return (methodContext method) upperMethod = upperName (methodName method) encode = do args <- replicateM (length (methodArgs method)) (newName "arg") lamE [conP upperMethod $ map varP args ] $ doE $ [ noBindS $ appE (varE 'safePut) (varE arg) | arg <- args] ++ [ noBindS [| return () |] ] decode = do args <- replicateM (length (methodArgs method)) (newName "arg") doE $ [ bindS (varP arg) (varE 'safeGet) | arg <- args] ++ [ noBindS $ appE (varE 'return) $ foldl appE (conE upperMethod) $ map varE args ] in do s <- instanceD constraints (appT (conT ''Serialize) (mkType (upperName (methodName method)) (methodKeys method))) [funD 'putCopy [clause [] (normalB [| contain . $(encode) |]) []] ,funD 'getCopy [clause [] (normalB [| contain $(decode) |]) []]] v <- instanceD constraints (appT (conT ''Version) (mkType (upperName (methodName method)) (methodKeys method))) [] return [s,v] {- [ Update $ \(SetComponent c) -> setComponent c , Query $ \GetComponent -> getComponent ] -} genComponentHandlers :: [MethodInfo] -> ExpQ genComponentHandlers meths = do let localHandlers = flip map meths $ \method -> let upName = upperName (methodName method) in do args <- replicateM (length (methodArgs method)) (newName "arg") appE (conE (methodEv method)) $ lamE [conP upName (map varP args)] $ foldl appE (varE (methodName method)) $ map varE args handlers = listE localHandlers handlers genEventInstances :: [MethodInfo] -> Q [Dec] genEventInstances methodsInfo = mapM genEventInstance methodsInfo genEventInstance :: MethodInfo -> Q Dec genEventInstance method = do let keys = methodKeys method eventType = foldl appT (conT (upperName (methodName method))) (map varT keys) instanceD (nubCxt $ [appT (conT ''Serialize) eventType ,appT (conT ''Serialize) (return (methodResult method))] ++ mkKeyConstraints keys ++ mkMethodConstraints keys method ) (appT (appT (conT (methodClass method)) eventType) (return (methodResult method))) [] genMethodStructs :: [Name] -> [MethodInfo] -> Q [Dec] genMethodStructs derv meths = liftM concat (mapM (genMethodStruct derv) meths) -- FIXME: allow class constraints on keys. genMethodStruct :: [Name] -> MethodInfo -> Q [Dec] genMethodStruct derv method = do let c = NormalC (upperName (methodName method)) (zip (repeat NotStrict ) (methodArgs method)) return [ DataD [] (upperName (methodName method)) (methodKeys method) [c] (derv) ] upperName :: Name -> Name upperName = mkName . upperFirst . nameBase upperFirst :: String -> String upperFirst (x:xs) = toUpper x : xs upperFirst "" = error "ComponentTH.UpperFirst []" data MethodInfo = Method { methodName :: Name , methodKeys :: [Name] , methodContext:: [Type] , methodArgs :: [Type] , methodClass :: Name , methodEv :: Name , methodResult :: Type } -- get and validate method information. getMethodInfos :: [Name] -> [Name] -> Q [MethodInfo] getMethodInfos sessionKeys names = do ms <- mapM getMethodInfo names mapM worker ms where worker m | length (methodKeys m) /= length sessionKeys = error $ "Inconsistent keys: " ++ pprint (methodName m) ++ ": " ++ show (sessionKeys, methodKeys m) | otherwise = case compare (sort (methodTyVars m)) (sort (methodKeys m)) of EQ -> return m GT -> error $ "Method too general: " ++ pprint (methodName m) LT -> error $ "Method not general enough: " ++ pprint (methodName m) getArgKeys (AppT t1 t2) = getArgKeys t1 ++ getArgKeys t2 getArgKeys (VarT key) = [key] getArgKeys _ = [] methodTyVars m = nub $ concatMap getArgKeys (methodResult m:methodArgs m) getMethodInfo :: Name -> Q MethodInfo getMethodInfo method = do methodInfo <- reify method case methodInfo of VarI _name funcType _decl _fixity -> return (getTypes funcType){methodName = method} _ -> error $ "Method is not a function: " ++ nameBase method ++ " is a " ++ showInfo methodInfo showInfo :: Info -> String showInfo (ClassI _) = "class" showInfo (TyConI _) = "type constructor" showInfo (PrimTyConI _ _ _) = "primitive type constructor" showInfo (DataConI _ _ _ _) = "data constructor" showInfo (VarI _ _ _ _) = "variable" showInfo (TyVarI _ _) = "type variable" showInfo x = pprint x -- Cases: -- forall m. MonadState state m => X -> m Y -- forall key. key -> Update () -- forall key m. MonadState state m => key -> m () -- X -> Ev (ReaderT state STM) Y -- X -> Ev (StateT state STM) Y getTypes :: Type -> MethodInfo getTypes (ForallT _ contxt t) = getTypes' contxt t getTypes t = getTypes' [] t -- FIXME: only allow type variables used by the component. getTypes' :: Cxt -> Type -> MethodInfo getTypes' contxt t = case runWriter (worker t) of ((keys,className, typeName, res), args) -> Method { methodName = error "Method name not set", methodKeys = keys , methodContext = filter (isRelevant keys) contxt , methodArgs = args [], methodClass = className , methodEv = typeName, methodResult = res} where -- recursive case: A -> B worker (AppT (AppT ArrowT t1) t2) = do tell (t1:) worker t2 -- end case: Update state res || Query state res worker (AppT (AppT (ConT c) state) r) | c == ''Update = return (getStateKeys state,''UpdateEvent, 'Update, r) | c == ''Query = return (getStateKeys state,''QueryEvent, 'Query, r) -- end case: Ev (ReaderT state STM) res || Ev (StateT state STM) res worker (AppT (AppT (ConT _ev) (AppT (AppT (ConT m) state) (ConT _stm))) r) | m == ''StateT = return (getStateKeys state,''UpdateEvent, 'Update, r) | m == ''ReaderT = return (getStateKeys state,''QueryEvent, 'Query, r) -- end case: m res (check if m is an instance of MonadState) worker (AppT name r) | Just state <- isMonadState contxt name = return (getStateKeys state, ''UpdateEvent, 'Update, r) | Just state <- isMonadReader contxt name = return (getStateKeys state, ''QueryEvent, 'Query, r) -- error case worker c = error ("Unexpected method type: " ++ pprint c) getStateKeys :: Type -> [Name] getStateKeys (AppT r r') = getStateKeys r ++ getStateKeys r' getStateKeys (VarT key) = [key] getStateKeys (ConT _st) = [] getStateKeys v = error $ "Bad state type: " ++ pprint v ++ " (expected a constant, an application or a type variable)" isMonadState :: [Type] -> Type -> Maybe Type isMonadState contxt name = listToMaybe [ state | AppT (AppT (ConT m) state) mName <- contxt, mName == name, m == ''MonadState ] isMonadReader :: [Type] -> Type -> Maybe Type isMonadReader contxt name = listToMaybe [ state | AppT (AppT (ConT m) state) mName <- contxt, mName == name, m == ''MonadReader ] isRelevant :: [Name] -> Type -> Bool isRelevant keys t = isAcceptableContext t && any (`elem` keys) (getStateKeys t) isAcceptableContext :: Type -> Bool isAcceptableContext (AppT r r') = isAcceptableContext r && isAcceptableContext r' isAcceptableContext (ConT con) = con `notElem` [''MonadState, ''MonadReader] isAcceptableContext _ = True requireSimpleCon :: Name -> Info -> [Name] requireSimpleCon _ (TyConI (DataD _ _ names _ _derv)) = names requireSimpleCon _ (TyConI (NewtypeD _ _ names _ _derv)) = names requireSimpleCon _ (TyConI (TySynD _ names _)) = names requireSimpleCon name _ = error $ "Cannot create component from '"++pprint name++"'. Expected a data structure."