{-# LANGUAGE FlexibleInstances, TemplateHaskell, CPP, PatternGuards #-} module HAppS.State.ComponentTH ( mkMethods ) where import Data.Char import Language.Haskell.TH import HAppS.State.Types import HAppS.Data.Serialize import HAppS.State.ComponentSystem import Control.Monad.State import Control.Monad.Reader import Control.Monad.Writer import Data.Maybe import Data.List import Data.Generics.Basics nubCxt tsQ = do ts <- cxt tsQ return $ nub ts {- Error cases: not all state keys are tyvars. method not using the keys in either the args or the result. Checked: component not being a data declaration with a single constructor. Checked: method using a tyvar that isn't a key. Checked: method is not a function. -} mkMethods :: Name -> [Name] -> Q [Dec] #ifndef __HADDOCK__ mkMethods componentName componentMethods = do keys <- liftM (requireSimpleCon componentName) $ reify componentName methodInfos <- getMethodInfos keys componentMethods ds1 <- genEventInstances methodInfos let handlers = genComponentHandlers methodInfos stType = mkType componentName keys mkArrowType [x] = x mkArrowType (x:xs) = appT (appT arrowT x) (mkArrowType xs) context' = mkKeyConstraints keys ++ concatMap (mkMethodConstraints keys) methodInfos ds2 <- instanceD (nubCxt context') (appT (conT ''Methods) stType) [ funD 'methods [clause [wildP] (normalB handlers) []] ] ds4 <- genMethodStructs [''Typeable] methodInfos ds5 <- genSerializeInstances methodInfos return (ds1 ++ [ds2] ++ ds4 ++ ds5 ) #endif mkKeyConstraints :: [Name] -> [TypeQ] #ifndef __HADDOCK__ mkKeyConstraints keys = [ appT (conT ''Typeable) (varT key) | key <- keys ] ++ [ appT (conT ''Serialize) (varT key) | key <- keys ] #endif mkMethodConstraints :: [Name] -> MethodInfo -> [TypeQ] #ifndef __HADDOCK__ mkMethodConstraints keys method = map return (substMethodContext method keys) #endif substMethodContext method keys = let relation = zip (methodKeys method) keys worker (VarT old) | Just new <- lookup old relation = VarT new worker (AppT l r) = AppT (worker l) (worker r) worker (ForallT cxt names t) = ForallT cxt names (worker t) worker t = t in map worker (methodContext method) mkType name args = foldl appT (conT name) (map varT args) genSerializeInstances :: [MethodInfo] -> Q [Dec] #ifndef __HADDOCK__ genSerializeInstances methods = liftM concat $ forM methods $ \method -> let constraints = nubCxt $ mkKeyConstraints (methodKeys method) ++ map return (methodContext method) upperMethod = upperName (methodName method) encode = do args <- replicateM (length (methodArgs method)) (newName "arg") lamE [conP upperMethod $ map varP args ] $ doE $ [ noBindS $ appE (varE 'safePut) (varE arg) | arg <- args] ++ [ noBindS [| return () |] ] decode = do args <- replicateM (length (methodArgs method)) (newName "arg") doE $ [ bindS (varP arg) (varE 'safeGet) | arg <- args] ++ [ noBindS $ appE (varE 'return) $ foldl appE (conE upperMethod) $ map varE args ] in do s <- instanceD constraints (appT (conT ''Serialize) (mkType (upperName (methodName method)) (methodKeys method))) [funD 'putCopy [clause [] (normalB [| contain . $(encode) |]) []] ,funD 'getCopy [clause [] (normalB [| contain $(decode) |]) []]] v <- instanceD constraints (appT (conT ''Version) (mkType (upperName (methodName method)) (methodKeys method))) [] return [s,v] #endif {- [ Update $ \(SetComponent c) -> setComponent c , Query $ \GetComponent -> getComponent ] -} genComponentHandlers :: [MethodInfo] -> ExpQ #ifndef __HADDOCK__ genComponentHandlers methods = do let localHandlers = flip map methods $ \method -> let upName = upperName (methodName method) in do args <- replicateM (length (methodArgs method)) (newName "arg") appE (conE (methodEv method)) $ lamE [conP upName (map varP args)] $ foldl appE (varE (methodName method)) $ map varE args handlers = listE localHandlers handlers #endif genEventInstances :: [MethodInfo] -> Q [Dec] genEventInstances methodsInfo = mapM genEventInstance methodsInfo -- instance (cxt keys, Serialize keys) => QueryEvent (GetPageCurrent key) WikiRevision genEventInstance :: MethodInfo -> Q Dec #ifndef __HADDOCK__ genEventInstance method = do let keys = methodKeys method eventType = foldl appT (conT (upperName (methodName method))) (map varT keys) instanceD (nubCxt $ [appT (conT ''Serialize) eventType ,appT (conT ''Serialize) (return (methodResult method))] ++ mkKeyConstraints keys ++ mkMethodConstraints keys method ) (appT (appT (conT (methodClass method)) eventType) (return (methodResult method))) [] #endif genMethodStructs :: [Name] -> [MethodInfo] -> Q [Dec] genMethodStructs derv methods = liftM concat (mapM (genMethodStruct derv) methods) -- FIXME: allow class constraints on keys. genMethodStruct :: [Name] -> MethodInfo -> Q [Dec] genMethodStruct derv method = do let c = NormalC (upperName (methodName method)) (zip (repeat NotStrict ) (methodArgs method)) return [ DataD [] (upperName (methodName method)) (methodKeys method) [c] (derv) ] upperName = mkName . upperFirst . nameBase upperFirst :: String -> String upperFirst (x:xs) = toUpper x : xs upperFirst "" = error "ComponentTH.UpperFirst []" data MethodInfo = Method { methodName :: Name , methodKeys :: [Name] , methodContext:: [Type] , methodArgs :: [Type] , methodClass :: Name , methodEv :: Name , methodResult :: Type } -- get and validate method information. getMethodInfos :: [Name] -> [Name] -> Q [MethodInfo] getMethodInfos sessionKeys names = do ms <- mapM getMethodInfo names mapM worker ms where worker m | length (methodKeys m) /= length sessionKeys = error $ "Inconsistent keys: " ++ pprint (methodName m) ++ ": " ++ show (sessionKeys, methodKeys m) | otherwise = case compare (sort (methodTyVars m)) (sort (methodKeys m)) of EQ -> return m GT -> error $ "Method too general: " ++ pprint (methodName m) LT -> error $ "Method not general enough: " ++ pprint (methodName m) getArgKeys (AppT t1 t2) = getArgKeys t1 ++ getArgKeys t2 getArgKeys (VarT key) = [key] getArgKeys _ = [] methodTyVars m = nub $ concatMap getArgKeys (methodResult m:methodArgs m) getMethodInfo :: Name -> Q MethodInfo getMethodInfo method = do methodInfo <- reify method case methodInfo of VarI _name funcType _decl _fixity -> return (getTypes funcType){methodName = method} _ -> error $ "Method is not a function: " ++ nameBase method ++ " is a " ++ showInfo methodInfo showInfo (ClassI _) = "class" showInfo (TyConI _) = "type constructor" showInfo (PrimTyConI _ _ _) = "primitive type constructor" showInfo (DataConI _ _ _ _) = "data constructor" showInfo (VarI _ _ _ _) = "variable" showInfo (TyVarI _ _) = "type variable" showInfo x = pprint x -- Cases: -- forall m. MonadState state m => X -> m Y -- forall key. key -> Update () -- forall key m. MonadState state m => key -> m () -- X -> Ev (ReaderT state STM) Y -- X -> Ev (StateT state STM) Y getTypes :: Type -> MethodInfo getTypes (ForallT _ cxt t) = getTypes' cxt t getTypes t = getTypes' [] t -- FIXME: only allow type variables used by the component. getTypes' :: Cxt -> Type -> MethodInfo #ifndef __HADDOCK__ getTypes' cxt t = case runWriter (worker t) of ((keys,className, typeName, res), args) -> Method { methodName = error "Method name not set", methodKeys = keys , methodContext = filter (isRelevant keys) cxt , methodArgs = args [], methodClass = className , methodEv = typeName, methodResult = res} where -- recursive case: A -> B worker (AppT (AppT ArrowT t1) t2) = do tell (t1:) worker t2 -- end case: Update state res || Query state res worker (AppT (AppT (ConT c) state) r) | c == ''Update = return (getStateKeys state,''UpdateEvent, 'Update, r) | c == ''Query = return (getStateKeys state,''QueryEvent, 'Query, r) -- end case: Ev (ReaderT state STM) res || Ev (StateT state STM) res worker (AppT (AppT (ConT _ev) (AppT (AppT (ConT m) state) (ConT _stm))) r) | m == ''StateT = return (getStateKeys state,''UpdateEvent, 'Update, r) | m == ''ReaderT = return (getStateKeys state,''QueryEvent, 'Query, r) -- end case: m res (check if m is an instance of MonadState) worker (AppT name r) | Just state <- isMonadState cxt name = return (getStateKeys state, ''UpdateEvent, 'Update, r) | Just state <- isMonadReader cxt name = return (getStateKeys state, ''QueryEvent, 'Query, r) -- error case worker t = error ("Unexpected method type: " ++ pprint t) getStateKeys (AppT r r') = getStateKeys r ++ getStateKeys r' getStateKeys (VarT key) = [key] getStateKeys (ConT _st) = [] getStateKeys v = error $ "Bad state type: " ++ pprint v ++ " (expected a constant, an application or a type variable)" isMonadState cxt name = listToMaybe [ state | AppT (AppT (ConT m) state) mName <- cxt, mName == name, m == ''MonadState ] isMonadReader cxt name = listToMaybe [ state | AppT (AppT (ConT m) state) mName <- cxt, mName == name, m == ''MonadReader ] isRelevant keys t = isAcceptableContext t && any (`elem` keys) (getStateKeys t) isAcceptableContext (AppT r r') = isAcceptableContext r && isAcceptableContext r' isAcceptableContext (ConT con) = con `notElem` [''MonadState, ''MonadReader] isAcceptableContext _ = True #endif requireSimpleCon :: Name -> Info -> [Name] requireSimpleCon _ (TyConI (DataD _ _ names _ _derv)) = names requireSimpleCon _ (TyConI (NewtypeD _ _ names _ _derv)) = names requireSimpleCon _ (TyConI (TySynD _ names _)) = names requireSimpleCon name _ = error $ "Cannot create component from '"++pprint name++"'. Expected a data structure."