{-# LANGUAGE CPP, TemplateHaskell, Rank2Types #-} module ForSyDe.Deep.Process.Desugar (desugarTransform) where import qualified Language.Haskell.TH as TH import Language.Haskell.TH import Control.Monad.State.Lazy import Data.Generics import qualified Data.Param.FSVec as V {- desugarTransform :: [Dec] -> Q [Dec] desugarTransform = return -} desugarTransform :: [Dec] -> Q [Dec] --desugarTransform decs = foldAccDecs (mkM extractTypedLambda) decs desugarTransform decs = (return decs) =+=> extractTypedLambda =+=> extractInfixOpSection =+=> specializeHof =+=> deleteSigE #ifdef DEVELOPER >>= dumpTree where dumpTree :: [Dec] -> Q [Dec] dumpTree ts = do _ <- runIO $ mapM (print.ppr) ts return ts #endif -- Apply declaration-accumulating transformation (=+=>) :: Typeable a => Q [Dec] -> (a -> DecAccM a) -> Q [Dec] decs =+=> f = decs >>= (foldAccDecs $ mkM f) -- Extract all occurences of lambda expressions within type signatures -- ... ((\... -> ...) :: ...) ... -- Create functions declarations for them and replace the lambda expression -- with a reference to the declaration: -- ... lambda0 ... -- where -- lambda0 :: ... -- lambda0 ... = ... extractTypedLambda :: Exp -> DecAccM Exp extractTypedLambda (SigE (LamE patterns lamBody) ty) = do name <- lift $ newName "lambda" let body = NormalB lamBody signature = (SigD name ty) function = (FunD name [(Clause patterns body [])]) modify $ \declState -> signature:function:declState return (VarE name) extractTypedLambda e = return e -- Make a function out of a suitably annotated operator section -- ... (+3)::Int32->Int32 ... -- is transformed to -- ... infix_section_0 ... -- where infix_section_0 a = a+3 extractInfixOpSection :: Exp -> DecAccM Exp extractInfixOpSection (SigE exp@(InfixE _ _ _) ty) = do name <- lift.newName $ "infix_section" argname <- lift.newName $ "a" let body = NormalB $ everywhere (mkT $ insertArgument argname) exp patterns = [VarP argname] signature = SigD name ty function = FunD name [(Clause patterns body [])] modify $ \declState -> signature:function:declState return (VarE name) where insertArgument :: Name -> Exp -> Exp insertArgument argn (InfixE Nothing op r) = InfixE (Just $ VarE argn) op r insertArgument argn (InfixE l op Nothing) = InfixE l op (Just $ VarE argn) insertArgument _ e = e extractInfixOpSection e = return e hofNames :: [TH.Name] hofNames = [ 'V.foldr, 'V.foldl, 'V.map, 'V.zipWith, 'V.zipWith3 ] isHigherOrderFunctionName :: TH.Name -> Bool isHigherOrderFunctionName n = elem n hofNames -- specializing extraction: -- find all hof applications: (AppE (VarE hofName) (VarE fname)) -- build specialized function definition in normal form -- collect specialized declarations in state of DecAccM -- replace application with call to specialized function specializeHof :: Exp -> DecAccM Exp -- arity=2 specializeHof (SigE (AppE hofapp@(AppE (VarE hofName) (VarE argFunName)) arg1@(SigE _ argtype)) rettype) | isHigherOrderFunctionName hofName = do name <- lift.newName $ (nameBase hofName)++"_"++(nameBase argFunName) argname <- lift.newName $ "v" let body = NormalB $ AppE hofapp (VarE argname) patterns = [VarP argname] signature = (SigD name (AppT (AppT ArrowT argtype) rettype)) function = (FunD name [(Clause patterns body [])]) modify $ \declState -> declState++[signature,function] return (SigE (AppE (VarE name) arg1) rettype) specializeHof e = return e -- Delete all the type signature expressions as they are not recognized during -- translation deleteSigE :: Exp -> DecAccM Exp deleteSigE (SigE e _) = return e deleteSigE e = return e type DecAccM t = StateT [Dec] Q t -- stateful fold over the tree. Applies f everywhere within the tree, -- collecting additional declarations during traversal, and adding them to the -- corresponding scope foldAccDecs :: GenericM DecAccM -> [Dec] -> Q [Dec] foldAccDecs transform decs = mapM apply decs where apply :: Dec -> Q Dec apply (FunD name [Clause pat body decls]) = do (newBody,newDecls) <- runStateT (everywhereM transform body) [] transfDecls <- foldAccDecs transform decls return $ FunD name [Clause pat newBody (transfDecls++newDecls)] apply (ValD pat body decls) = do (newBody,newDecls) <- runStateT (everywhereM transform body) [] transfDecls <- foldAccDecs transform decls return $ ValD pat newBody (transfDecls++newDecls) apply d = return d