{-# LANGUAGE TemplateHaskell, PatternGuards, CPP #-} module NoSlow.Backend.TH ( Spec(..), specialise, calls ) where import NoSlow.Util.Base ( named, Unsupported(..) ) import qualified NoSlow.Backend.Interface as I import Language.Haskell.TH import Control.Monad ( liftM, liftM2, liftM3 ) import Data.Maybe ( isJust ) import qualified Data.Map as M interface_module :: String interface_module = case nameModule ''I.Vector of Just s -> s type Context = (Name, [Type]) data Spec = Spec { specModule :: String , specContext :: Name -> Name -> [Context] , specVars :: [(String, Type)] } newtype SM a = SM { runSM :: Spec -> [Name] -> (a, [Name]) } instance Monad SM where return x = SM $ \_ ns -> (x, ns) SM p >>= q = SM $ \s ns -> case p s ns of (x, ns') -> runSM (q x) s ns' instance Functor SM where fmap = liftM getSpec :: SM Spec getSpec = SM $ \s ns -> (s,ns) reference :: Name -> SM () reference name = SM $ \s ns -> ((), name : ns) specialise :: Spec -> Q [Dec] -> Q [Dec] specialise spec decsq = do decs <- decsq scs <- mapM (specialiseTopDec spec) decs let bad_names = [name | Right name <- scs] good_decs = [dec | Left dec <- scs , decName dec `notElem` bad_names] return $ vector_type : good_decs ++ noinline good_decs ++ [SigD name (ConT ''Unsupported) | name <- bad_names] ++ [ValD (VarP name) (NormalB (ConE 'Unsupported)) [] | name <- bad_names] where vector_type = TySynD (mkName "Vector_Type") ty_vars (ty_con `AppT` elem_ty) ty_con | Just ty <- lookup "v" (specVars spec) = ty (ty_vars, elem_ty) | Just ty <- lookup "a" (specVars spec) = ([], ty) | otherwise = ([binder a], VarT a) where a = mkName "a" #if __GLASGOW_HASKELL__ > 610 binder = PlainTV #else binder = id #endif #if __GLASGOW_HASKELL__ > 610 noinline decs = [PragmaD $ InlineP name $ InlineSpec False False Nothing | SigD name _ <- decs] #else noinline decs = [] #endif specialiseTopDec :: Spec -> Dec -> Q (Either Dec Name) specialiseTopDec spec dec = case runSM (specialiseDec dec) spec [] of (dec', names) -> do ss <- mapM isSupported names return $ if and ss then Left dec' else Right (decName dec) where decName (SigD name _) = name decName (FunD name _) = name decName (ValD (VarP name) _ _) = name isSupported :: Name -> Q Bool isSupported name = liftM supported (reify name) where supported (VarI _ (ConT c) _ _) | c == ''Unsupported = False supported _ = True specialiseDec :: Dec -> SM Dec specialiseDec (SigD name ty) = SigD name `liftM` specialiseTy ty specialiseDec (FunD name clauses) = FunD name `liftM` mapM specialiseClause clauses specialiseDec (ValD pat body decs) = liftM2 (ValD pat) (specialiseBody body) (mapM specialiseDec decs) specialiseTy :: Type -> SM Type specialiseTy ty = do spec <- getSpec return $ specialiseTy' spec ty specialiseTy' :: Spec -> Type -> Type specialiseTy' spec pty | ForallT vars cxt ty <- pty = mk_forall (map rename_bndr $ filter keep_bndr vars) (spec_cxt cxt) (spec_ty ty) | otherwise = spec_ty pty where #if __GLASGOW_HASKELL__ > 610 rename_bndr (PlainTV v) = PlainTV (rename v) rename_bndr (KindedTV v k) = KindedTV (rename v) k keep_bndr (PlainTV v) = keep_var v keep_bndr (KindedTV v _) = keep_var v #else rename_bndr = rename keep_bndr = keep_var #endif keep_var v = not $ isJust $ lookup (nameBase v) (specVars spec) rename = mkName . nameBase mk_forall [] _ ty = ty mk_forall vars cxt ty = ForallT vars cxt ty spec_cxt = filter var_pred . concatMap spec_pred #if __GLASGOW_HASKELL__ > 610 spec_pred (ClassP cls [VarT v, VarT a]) | cls == ''I.Vector = map (uncurry ClassP) $ specContext spec (rename v) (rename a) spec_pred (ClassP cls tys) = [ClassP cls $ map spec_ty tys] var_pred (ClassP _ tys) = any var_ty tys #else spec_pred (AppT (AppT (ConT cls) (VarT v)) (VarT a)) | cls == ''I.Vector = map mk_pred $ specContext spec (rename v) (rename a) spec_pred pred = [spec_ty pred] var_pred = var_ty mk_pred (cls, tys) = foldl AppT (ConT cls) tys #endif spec_ty (VarT v) | Just ty <- lookup (nameBase v) (specVars spec) = ty | otherwise = VarT (rename v) spec_ty (AppT t u) = spec_ty t `AppT` spec_ty u spec_ty ty = ty var_ty (VarT v) = True var_ty (AppT t u) = var_ty t || var_ty u var_ty _ = False specialiseClause :: Clause -> SM Clause specialiseClause (Clause pats body decs) = liftM2 (Clause pats) (specialiseBody body) (mapM specialiseDec decs) specialiseBody :: Body -> SM Body specialiseBody (NormalB exp) = liftM NormalB $ specialiseExp (snd $ removeNamed exp) specialiseExp :: Exp -> SM Exp specialiseExp (VarE v) | Just mod <- nameModule v , mod == interface_module = do mod <- specModule `fmap` getSpec let name = qualify mod v reference name return $ VarE name specialiseExp (AppE e1 e2) = liftM2 AppE (specialiseExp e1) (specialiseExp e2) specialiseExp (InfixE me1 e2 me3) = liftM3 InfixE (mspec me1) (specialiseExp e2) (mspec me3) where mspec Nothing = return Nothing mspec (Just e) = liftM Just (specialiseExp e) specialiseExp (LamE pats e) = LamE pats `liftM` specialiseExp e specialiseExp (TupE es) = TupE `liftM` mapM specialiseExp es specialiseExp (CondE e1 e2 e3) = liftM3 CondE (specialiseExp e1) (specialiseExp e2) (specialiseExp e3) specialiseExp (LetE decs e) = liftM2 LetE (mapM specialiseDec decs) (specialiseExp e) specialiseExp (CaseE _ _) = error "specialiseExp: case" specialiseExp (DoE _) = error "specialiseExp: do" specialiseExp (CompE _) = error "specialiseExp: comp" specialiseExp (ArithSeqE _) = error "specialiseExp: seq" specialiseExp (ListE es) = ListE `liftM` mapM specialiseExp es specialiseExp (SigE e ty) = liftM2 SigE (specialiseExp e) (specialiseTy ty) specialiseExp (RecConE _ _) = error "specialiseExp: rec_con" specialiseExp (RecUpdE _ _) = error "specialiseExp: rec_upd" specialiseExp e = return e calls :: Name -> String -> Q [Dec] -> ExpQ calls fn mod qdecs = do decs <- qdecs let env = M.fromList (collectNames decs) cs <- mapM (call env) [name | SigD name _ <- decs] return $ ListE [c | Just c <- cs] where call env name = do ok <- isSupported name' return $ if ok then Just (VarE fn `AppE` LitE (StringL tag) `AppE` VarE name') else Nothing where name' = qualify mod name tag = M.findWithDefault (nameBase name) (nameBase name) env qualify :: String -> Name -> Name qualify mod name = mkName $ mod ++ '.' : nameBase name collectNames :: [Dec] -> [(String, String)] collectNames = foldr collect1 [] where collect1 (FunD name [Clause _ (NormalB exp) _]) xs | (Just s, _) <- removeNamed exp = (nameBase name, s) : xs collect1 _ xs = xs removeNamed :: Exp -> (Maybe String, Exp) removeNamed (VarE f `AppE` LitE (StringL s) `AppE` e) | f == 'named = (Just s, e) removeNamed (InfixE (Just (VarE f `AppE` LitE (StringL s))) (VarE apply) (Just e)) | f == 'named , apply == '($) = (Just s, e) removeNamed e = (Nothing, e)