{-# LANGUAGE TemplateHaskell, PatternGuards, CPP #-} module NoSlow.Backend.TH ( Spec(..), specialise, calls ) where import NoSlow.Util.Base ( named, Unsupported(..), noinline ) import NoSlow.Util.Computation ( deepSeq ) import qualified NoSlow.Backend.Interface as I import Language.Haskell.TH import Control.Monad ( liftM, liftM2, liftM3 ) import Data.Maybe ( isJust ) import qualified Data.Map as M interface_module :: String interface_module = case nameModule ''I.Vector of Just s -> s type TySubst = [(Name, Type)] class TypeLike a where substTy :: TySubst -> a -> a freeTyVars :: a -> [Name] instance TypeLike Type where substTy s (ForallT vars cxt ty) = ForallT vars (substTy s cxt) (substTy s ty) substTy s (VarT var) | Just ty <- lookup var s = ty | otherwise = VarT var substTy s (ty1 `AppT` ty2) = substTy s ty1 `AppT` substTy s ty2 #if __GLASGOW_HASKELL__ > 610 substTy s (ty `SigT` k) = substTy s ty `SigT` k #endif substTy s ty = ty freeTyVars (ForallT vars cxt ty) = [v | v <- freeTyVars cxt ++ freeTyVars ty , v `notElem` map tyVarBndrName vars] freeTyVars (VarT v) = [v] freeTyVars (ty1 `AppT` ty2) = freeTyVars ty1 ++ freeTyVars ty2 #if __GLASGOW_HASKELL__ > 610 freeTyVars (ty `SigT` k) = freeTyVars ty #endif freeTyVars ty = [] instance TypeLike a => TypeLike [a] where substTy s as = map (substTy s) as freeTyVars as = concatMap freeTyVars as #if __GLASGOW_HASKELL__ > 610 instance TypeLike Pred where substTy s (ClassP cls tys) = ClassP cls (substTy s tys) substTy s (ty1 `EqualP` ty2) = substTy s ty1 `EqualP` substTy s ty2 freeTyVars (ClassP _ tys) = freeTyVars tys freeTyVars (ty1 `EqualP` ty2) = freeTyVars ty1 ++ freeTyVars ty2 #else type Pred = Type #endif splitClsPred :: Pred -> Maybe (Name, [Type]) #if __GLASGOW_HASKELL__ > 610 splitClsPred (ClassP cls tys) = Just (cls, tys) splitClsPred (EqualP _ _) = Nothing #else splitClsPred ty = go ty [] where go (ty1 `AppT` ty2) tys = go ty1 (ty2 : tys) go (ConT cls) tys = Just (cls, tys) go _ _ = Nothing #endif funTyArity :: Type -> Int funTyArity (ForallT _ _ ty) = funTyArity ty #if __GLASGOW_HASKELL__ > 610 funTyArity (SigT ty _) = funTyArity ty #endif funTyArity (AppT (AppT ArrowT _) ty) = funTyArity ty + 1 funTyArity _ = 0 #if __GLASGOW_HASKELL__ <= 610 type TyVarBndr = Name #endif tyVarBndrName :: TyVarBndr -> Name #if __GLASGOW_HASKELL__ > 610 tyVarBndrName (PlainTV name) = name tyVarBndrName (KindedTV name _) = name #else tyVarBndrName = id #endif tyVarBndr :: Name -> TyVarBndr #if __GLASGOW_HASKELL__ > 610 tyVarBndr = PlainTV #else tyVarBndr = id #endif type Context = (Name, [Type]) data Spec = Spec { specVector :: ([Name], Type) , specElem :: ([Name], Type) , specContext :: Type -> Type -> Cxt } newtype SM a = SM { runSM :: Spec -> [Name] -> (a, [Name]) } instance Monad SM where return x = SM $ \_ ns -> (x, ns) SM p >>= q = SM $ \s ns -> case p s ns of (x, ns') -> runSM (q x) s ns' instance Functor SM where fmap = liftM getSpec :: SM Spec getSpec = SM $ \s ns -> (s,ns) withSpec :: Spec -> SM a -> SM a withSpec spec (SM p) = SM $ \_ -> p spec reference :: Name -> SM () reference name = SM $ \s ns -> ((), name : ns) specialise :: Q [Dec] -> Q [Dec] specialise decsq = do TyConI (TySynD _ spec_vector_vars spec_vector_ty) <- reify (mkName "Spec_Vector") TyConI (TySynD _ spec_elem_vars spec_elem_ty) <- reify (mkName "Spec_Elem") ClassI (ClassD spec_cxt _ [cxt_v, cxt_a] _ _) <- reify (mkName "Spec_Context") let spec = Spec { specVector = (map tyVarBndrName spec_vector_vars, spec_vector_ty) , specElem = (map tyVarBndrName spec_elem_vars, spec_elem_ty) , specContext = mk_context (tyVarBndrName cxt_v) (tyVarBndrName cxt_a) spec_cxt } decs <- decsq scs <- mapM (specialiseTopDec spec) decs let bad_names = [name | Right name <- scs] good_decs = [dec | Left dec <- scs , decName dec `notElem` bad_names] return $ {- vector_type : -} good_decs ++ noinline good_decs ++ [SigD name (ConT ''Unsupported) | name <- bad_names] ++ [ValD (VarP name) (NormalB (ConE 'Unsupported)) [] | name <- bad_names] where mk_context v a cxt ty_v ty_a = substTy [(v,ty_v),(a,ty_a)] cxt #if __GLASGOW_HASKELL__ > 610 noinline decs = [PragmaD $ InlineP name $ InlineSpec False False Nothing | SigD name _ <- decs] #else noinline decs = [] #endif specialiseTopDec :: Spec -> Dec -> Q (Either Dec Name) specialiseTopDec spec dec = case runSM (specialiseDec dec) spec [] of (dec', names) -> do ss <- mapM isSupported names return $ if and ss then Left dec' else Right (decName dec) where decName (SigD name _) = name decName (FunD name _) = name decName (ValD (VarP name) _ _) = name isSupported :: Name -> Q Bool isSupported name = liftM supported (reify name) where supported (VarI _ (ConT c) _ _) | c == ''Unsupported = False supported _ = True specialiseDec :: Dec -> SM Dec specialiseDec (SigD name ty) = SigD name `liftM` specialiseTy ty specialiseDec (FunD name clauses) = FunD name `liftM` mapM specialiseClause clauses specialiseDec (ValD pat body decs) = liftM2 (ValD pat) (specialiseBody body) (mapM specialiseDec decs) specialiseTy :: Type -> SM Type specialiseTy ty = do spec <- getSpec return $ specialiseTy' spec ty data SplitPred = VectorPred Type Type | OtherPred Pred specialiseTy' :: Spec -> Type -> Type specialiseTy' spec (ForallT vars cxt ty) = specialiseForall spec (map tyVarBndrName vars) cxt ty specialiseTy' spec ty = ty specialiseForall :: Spec -> [Name] -> Cxt -> Type -> Type specialiseForall spec vars cxt ty = mk_forall (map tyVarBndr $ vs1 ++ vs2 ++ other_vars') (filter keep_pred $ concatMap (subst_split_pred subst) split_preds) (substTy subst ty) where mk_forall [] _ ty = ty mk_forall vars cxt ty = ForallT vars cxt ty split_pred pred | Just (cls, [vect_ty, elem_ty]) <- splitClsPred pred , cls == ''I.Vector = VectorPred vect_ty elem_ty | otherwise = OtherPred pred split_preds = map split_pred cxt vector_vars = [v | VectorPred (VarT v) _ <- split_preds] elem_vars = [a | VectorPred _ (VarT a) <- split_preds] other_vars = filter (\v -> v `notElem` vector_vars && v `notElem` elem_vars) vars other_vars' = map rename other_vars (vs1, s1) = subst_for_vars vector_vars (specVector spec) (vs2, s2) = subst_for_vars elem_vars (specElem spec) vars' = vs1 ++ vs2 ++ other_vars' subst = s1 ++ s2 ++ zipWith (\v v' -> (v, VarT v')) other_vars other_vars' subst_for_vars vs poly_ty | (vs', s) <- unzip [subst_for_var v poly_ty | v <- vs] = (concat vs', s) subst_for_var v ([v'], VarT v'') | v' == v'' = ([rename v], (v, VarT (rename v))) subst_for_var v (vars, ty) = (new_vars, (v, substTy (zip vars (map VarT new_vars)) ty)) where new_vars = [mkName $ nameBase v ++ '_' : nameBase var | var <- vars] rename = mkName . nameBase subst_split_pred s (VectorPred vect_ty elem_ty) = specContext spec (substTy s vect_ty) (substTy s elem_ty) subst_split_pred s (OtherPred pred) = [substTy s pred] keep_pred = not . null . freeTyVars specialiseClause :: Clause -> SM Clause specialiseClause (Clause pats body decs) = liftM2 (Clause pats) (specialiseBody body) (mapM specialiseDec decs) specialiseBody :: Body -> SM Body specialiseBody (NormalB exp) = liftM NormalB $ specialiseExp (snd $ removeNamed exp) specialiseBody (GuardedB ps) = liftM GuardedB $ mapM spec ps where spec (guard, exp) = liftM2 (,) (specialiseGuard guard) (specialiseExp exp) specialiseGuard :: Guard -> SM Guard specialiseGuard (NormalG exp) = liftM NormalG $ specialiseExp exp specialiseExp :: Exp -> SM Exp specialiseExp (VarE v) | Just mod <- nameModule v , mod == interface_module = do let name = qualify "Impl" v reference name return $ VarE name specialiseExp (AppE e1 e2) = liftM2 AppE (specialiseExp e1) (specialiseExp e2) specialiseExp (InfixE me1 e2 me3) = liftM3 InfixE (mspec me1) (specialiseExp e2) (mspec me3) where mspec Nothing = return Nothing mspec (Just e) = liftM Just (specialiseExp e) specialiseExp (LamE pats e) = LamE pats `liftM` specialiseExp e specialiseExp (TupE es) = TupE `liftM` mapM specialiseExp es specialiseExp (CondE e1 e2 e3) = liftM3 CondE (specialiseExp e1) (specialiseExp e2) (specialiseExp e3) specialiseExp (LetE decs e) = liftM2 LetE (mapM specialiseDec decs) (specialiseExp e) specialiseExp (CaseE exp ms) = liftM2 CaseE (specialiseExp exp) (mapM specialiseMatch ms) specialiseExp (DoE _) = error "specialiseExp: do" specialiseExp (CompE _) = error "specialiseExp: comp" specialiseExp (ArithSeqE _) = error "specialiseExp: seq" specialiseExp (ListE es) = ListE `liftM` mapM specialiseExp es specialiseExp (SigE e ty) = liftM2 SigE (specialiseExp e) (specialiseTy ty) specialiseExp (RecConE _ _) = error "specialiseExp: rec_con" specialiseExp (RecUpdE _ _) = error "specialiseExp: rec_upd" specialiseExp e = return e specialiseMatch :: Match -> SM Match specialiseMatch (Match pat body decs) = liftM2 (Match pat) (specialiseBody body) (mapM specialiseDec decs) calls :: Name -> String -> Q [Dec] -> ExpQ calls fn mod qdecs = do decs <- qdecs let env = M.fromList (collectNames decs) cs <- mapM (call env) [(name, ty) | SigD name ty <- decs] return $ ListE [c | Just c <- cs] where call env (name, ty) = do ok <- isSupported name' return $ if ok then Just (VarE fn `AppE` LitE (StringL tag) `AppE` invoke (VarE name')) else Nothing where name' = qualify mod name tag = M.findWithDefault (nameBase name) (nameBase name) env arity = funTyArity ty - 1 invoke e = LamE [VarP t, pat] $ VarE 'deepSeq `AppE` foldl AppE (e' `AppE` VarE t) (map VarE vs) `AppE` ConE '() where e' = VarE 'noinline `AppE` e t = mkName "tag" vs = take arity [mkName [c] | c <- ['a' .. ]] pat | arity == 0 = ConP '() [] | arity == 1 = VarP $ head vs | otherwise = TupP $ map VarP vs qualify :: String -> Name -> Name qualify mod name = mkName $ mod ++ '.' : nameBase name collectNames :: [Dec] -> [(String, String)] collectNames = foldr collect1 [] where collect1 (FunD name [Clause _ (NormalB exp) _]) xs | (Just s, _) <- removeNamed exp = (nameBase name, s) : xs collect1 _ xs = xs removeNamed :: Exp -> (Maybe String, Exp) removeNamed (VarE f `AppE` LitE (StringL s) `AppE` e) | f == 'named = (Just s, e) removeNamed (InfixE (Just (VarE f `AppE` LitE (StringL s))) (VarE apply) (Just e)) | f == 'named , apply == '($) = (Just s, e) removeNamed e = (Nothing, e)