-- | Template Haskell support -- -- (In a separate file for convenience) {-# LANGUAGE MagicHash #-} module Control.Distributed.Process.Internal.Closure.TH ( -- * User-level API remotable , mkStatic , functionSDict , functionTDict ) where import Prelude hiding (lookup) import Data.Accessor ((^=)) import Data.Typeable (typeOf) import Control.Applicative ((<$>)) import Language.Haskell.TH ( -- Q monad and operations Q , reify -- Names , Name , mkName , nameBase -- Algebraic data types , Dec , Exp , Type(AppT, ForallT, VarT, ArrowT) , Info(VarI) , TyVarBndr(PlainTV, KindedTV) , Pred(ClassP) -- Lifted constructors -- .. Literals , stringL -- .. Patterns , normalB , clause -- .. Expressions , varE , litE -- .. Top-level declarations , funD , sigD ) import Control.Distributed.Process.Internal.Types ( RemoteTable , Static(Static) , StaticLabel(StaticLabel) , remoteTableLabel , SerializableDict(SerializableDict) , Process ) import Control.Distributed.Process.Internal.Dynamic ( Dynamic(..) , unsafeCoerce# , toDyn ) -------------------------------------------------------------------------------- -- User-level API -- -------------------------------------------------------------------------------- -- | Create the closure, decoder, and metadata definitions for the given list -- of functions remotable :: [Name] -> Q [Dec] remotable ns = do (closures, inserts) <- unzip <$> mapM generateDefs ns rtable <- createMetaData (concat inserts) return $ concat closures ++ rtable -- | Construct a static value. -- -- If @f : forall a1 .. an. T@ -- then @$(mkStatic 'f) :: forall a1 .. an. Static T@. -- Be sure to pass 'f' to 'remotable'. mkStatic :: Name -> Q Exp mkStatic = varE . staticName -- | If @f : T1 -> T2@ is a monomorphic function -- then @$(functionSDict 'f) :: Static (SerializableDict T1)@. -- -- Be sure to pass 'f' to 'remotable'. functionSDict :: Name -> Q Exp functionSDict = varE . sdictName -- | If @f : T1 -> Process T2@ is a monomorphic function -- then @$(functionTDict 'f) :: Static (SerializableDict T2)@. -- -- Be sure to pass 'f' to 'remotable'. functionTDict :: Name -> Q Exp functionTDict = varE . tdictName -------------------------------------------------------------------------------- -- Internal (Template Haskell) -- -------------------------------------------------------------------------------- -- | Generate the code to add the metadata to the CH runtime createMetaData :: [Q Exp] -> Q [Dec] createMetaData is = [d| __remoteTable :: RemoteTable -> RemoteTable ; __remoteTable = $(compose is) |] generateDefs :: Name -> Q ([Dec], [Q Exp]) generateDefs n = do proc <- [t| Process |] mType <- getType n case mType of Just (origName, typ) -> do let (typVars, typ') = case typ of ForallT vars [] mono -> (vars, mono) _ -> ([], typ) -- The main "static" entry (static, register) <- makeStatic origName typVars typ' -- If n :: T1 -> T2, static serializable dictionary for T1 -- TODO: we should check if arg is an instance of Serializable, but we cannot -- http://hackage.haskell.org/trac/ghc/ticket/7066 (sdict, registerSDict) <- case (typVars, typ') of ([], ArrowT `AppT` arg `AppT` _res) -> makeDict (sdictName origName) arg _ -> return ([], []) -- If n :: T1 -> Process T2, static serializable dictionary for T2 -- TODO: check if T2 is serializable (same as above) (tdict, registerTDict) <- case (typVars, typ') of ([], ArrowT `AppT` _arg `AppT` (proc' `AppT` res)) | proc' == proc -> makeDict (tdictName origName) res _ -> return ([], []) return ( concat [static, sdict, tdict] , concat [register, registerSDict, registerTDict] ) _ -> fail $ "remotable: " ++ show n ++ " not found" where makeStatic :: Name -> [TyVarBndr] -> Type -> Q ([Dec], [Q Exp]) makeStatic origName typVars typ = do static <- generateStatic origName typVars typ let dyn = case typVars of [] -> [| toDyn $(varE origName) |] _ -> [| Dynamic (error "Polymorphic value") (unsafeCoerce# $(varE origName)) |] return ( static , [ [| registerStatic $(stringE (show origName)) $dyn |] ] ) makeDict :: Name -> Type -> Q ([Dec], [Q Exp]) makeDict dictName typ = do sdict <- generateDict dictName typ let dyn = [| toDyn (SerializableDict :: SerializableDict $(return typ)) |] return ( sdict , [ [| registerStatic $(stringE (show dictName)) $dyn |] ] ) registerStatic :: String -> Dynamic -> RemoteTable -> RemoteTable registerStatic label dyn = remoteTableLabel label ^= Just dyn -- | Generate a static value generateStatic :: Name -> [TyVarBndr] -> Type -> Q [Dec] generateStatic n xs typ = do staticTyp <- [t| Static |] sequence [ sigD (staticName n) $ return (ForallT xs (map typeable xs) (staticTyp `AppT` typ) ) , sfnD (staticName n) [| Static $ StaticLabel $(stringE (show n)) (typeOf (undefined :: $(return typ))) |] ] where typeable :: TyVarBndr -> Pred typeable (PlainTV v) = ClassP (mkName "Typeable") [VarT v] typeable (KindedTV v _) = ClassP (mkName "Typeable") [VarT v] -- | Generate a serialization dictionary with name 'n' for type 'typ' generateDict :: Name -> Type -> Q [Dec] generateDict n typ = do sequence [ sigD n $ [t| Static (SerializableDict $(return typ)) |] , sfnD n [| Static $ StaticLabel $(stringE (show n)) (typeOf (undefined :: SerializableDict $(return typ))) |] ] staticName :: Name -> Name staticName n = mkName $ nameBase n ++ "__static" sdictName :: Name -> Name sdictName n = mkName $ nameBase n ++ "__sdict" tdictName :: Name -> Name tdictName n = mkName $ nameBase n ++ "__tdict" -------------------------------------------------------------------------------- -- Generic Template Haskell auxiliary functions -- -------------------------------------------------------------------------------- -- | Compose a set of expressions compose :: [Q Exp] -> Q Exp compose [] = [| id |] compose [e] = e compose (e:es) = [| $e . $(compose es) |] -- | Literal string as an expression stringE :: String -> Q Exp stringE = litE . stringL -- | Look up the "original name" (module:name) and type of a top-level function getType :: Name -> Q (Maybe (Name, Type)) getType name = do info <- reify name case info of VarI origName typ _ _ -> return $ Just (origName, typ) _ -> return Nothing -- | Variation on 'funD' which takes a single expression to define the function sfnD :: Name -> Q Exp -> Q Dec sfnD n e = funD n [clause [] (normalB e) []]