module Data.Array.Parallel.Lifted.TH.Repr (
scalarInstances, tupleInstances,
voidPRInstance, unitPRInstance, wrapPRInstance
) where
import qualified Data.Array.Parallel.Unlifted as U
import Data.Array.Parallel.Lifted.PArray
import Data.Array.Parallel.Base.DTrace (traceFn)
import Language.Haskell.TH
import Data.List (intercalate)
tyBndrVar :: TyVarBndr -> Name
tyBndrVar (PlainTV n) = n
tyBndrVar (KindedTV n _) = n
mkAppTs :: Type -> [Type] -> Type
mkAppTs = foldl AppT
varTs :: [Name] -> [TypeQ]
varTs = map varT
appTs :: TypeQ -> [TypeQ] -> TypeQ
appTs = foldl appT
varEs :: [Name] -> [ExpQ]
varEs = map varE
appEs :: ExpQ -> [ExpQ] -> ExpQ
appEs = foldl appE
normalMatch :: PatQ -> ExpQ -> MatchQ
normalMatch pat xx = match pat (normalB xx) []
varPs :: [Name] -> [PatQ]
varPs = map varP
vanillaC :: Name -> [TypeQ] -> ConQ
vanillaC con tys = normalC con (map (strictType notStrict) tys)
simpleFunD :: Name -> [PatQ] -> ExpQ -> DecQ
simpleFunD name pats xx
= funD name [clause pats (normalB xx) []]
inlineD :: Name -> DecQ
inlineD name = pragInlD name (inlineSpecNoPhase True False)
instance_PData :: TypeQ -> [Name] -> Name -> [TypeQ] -> DecQ
instance_PData tycon tyargs con tys
= dataInstD (cxt []) ''PData [tycon `appTs` varTs tyargs]
[vanillaC con tys]
[]
newtype_instance_PData :: Name -> [Name] -> Name -> TypeQ -> DecQ
newtype_instance_PData tycon tyargs con ty
= newtypeInstD (cxt []) ''PData [conT tycon `appTs` varTs tyargs]
(vanillaC con [ty])
[]
splitConAppTy :: Type -> Maybe (Type, [Type])
splitConAppTy ty = collect ty []
where
collect (ConT tycon) args = Just (ConT tycon, args)
collect (TupleT n) args = Just (TupleT n, args)
collect ListT args = Just (ListT, args)
collect ArrowT args = Just (ArrowT, args)
collect (AppT t arg) args = collect t (arg:args)
collect _ _ = Nothing
normaliseTy :: Type -> Q Type
normaliseTy ty
= case splitConAppTy ty of
Just (ConT tycon, args)
-> do
info <- reify tycon
case info of
TyConI (TySynD _ bndrs t)
-> return $ substTy (zip (map tyBndrVar bndrs) args) t
_ -> return ty
_ -> return ty
substTy :: [(Name, Type)] -> Type -> Type
substTy _ (ForallT _ _ _)
= error "DPH gen: can't substitute in forall ty"
substTy env (VarT v) = case lookup v env of
Just ty -> ty
Nothing -> VarT v
substTy env (AppT t u) = AppT (substTy env t) (substTy env u)
substTy env (SigT t k) = SigT (substTy env t) k
substTy _ t = t
splitFunTy :: Type -> ([Type], Type)
splitFunTy ty = case splitConAppTy ty of
Just (ArrowT, [arg, r]) -> let (args, res) = splitFunTy r
in (arg:args, res)
_ -> ([], ty)
data Val = ScalarVal
| PDataVal
| ListVal
| UnitVal
| OtherVal
type NameGen = String -> String
type ArgVal = (Val, NameGen)
genPR_methods :: (Name -> [ArgVal] -> Val -> DecQ) -> Q [Dec]
genPR_methods mk_method
= do
ClassI (ClassD _ _ _ _ decs) _ <- reify ''PR
inls <- sequence [inlineD $ mkName $ nameBase name | SigD name _ <- decs]
defs <- mapM gen [(name, ty) | SigD name ty <- decs]
return $ inls ++ defs
where
gen (name, ty)
= case lookup name nameGens of
Just gs -> do
(args, res) <- methodVals ty
mk_method name (zip args gs) res
Nothing -> error $ "DPH gen: no name generator for " ++ show name
methodVals :: Type -> Q ([Val], Val)
methodVals (ForallT (PlainTV vv : _) _ ty)
= do
ty' <- normaliseTy ty
let (args, res) = splitFunTy ty'
return (map (val vv) args, val vv res)
where
val v (VarT n) | v == n = ScalarVal
val v (AppT (ConT c) (VarT n))
| c == ''PData && v == n = PDataVal
| c == ''[] && v == n = ListVal
val v (AppT ListT (VarT n)) | v==n = ListVal
val _ (ConT c) | c == ''() = UnitVal
val _ (TupleT 0) = UnitVal
val _ _ = OtherVal
methodVals tt
= error $ "DPH gen: methodVals: no match for " ++ show tt
data Split = PatSplit PatQ
| CaseSplit PatQ ExpQ PatQ
data Arg = RecArg [ExpQ] [ExpQ]
| OtherArg ExpQ
data Gen = Gen {
recursiveCalls :: Int
, recursiveName :: Name -> Name
, split :: ArgVal -> (Split, Arg)
, join :: Val -> [ExpQ] -> ExpQ
, typeName :: String
}
recursiveMethod :: Gen -> Name -> [ArgVal] -> Val -> DecQ
recursiveMethod gen name avs res
= simpleFunD (mkName $ nameBase name) (map pat splits)
$ appE (varE 'traceFn `appEs` [stringE (nameBase name), stringE (typeName gen)])
$ foldr mk_case
(join gen res
. recurse (recursiveCalls gen)
. trans
$ map expand args)
splits
where
(splits, args) = unzip (map split_arg avs)
pat (PatSplit p) = p
pat (CaseSplit p _ _) = p
split_arg (OtherVal, g)
= let v = mkName (g "")
in (PatSplit (varP v), OtherArg (varE v))
split_arg arg = split gen arg
mk_case (PatSplit _) xx = xx
mk_case (CaseSplit _ scrut pat') xx = caseE scrut [normalMatch pat' xx]
expand (RecArg _ es) = es
expand (OtherArg e) = repeat e
trans [] = []
trans [xs] = [[x] | x <- xs]
trans (xs : yss) = zipWith (:) xs (trans yss)
recurse 0 _ = []
recurse n [] = replicate n (varE rec_name)
recurse n args' = [varE rec_name `appEs` es | es <- take n args']
rec_name = recursiveName gen name
nameGens :: [(Name, [[Char] -> [Char]])]
nameGens =
[
('emptyPR, [])
, ('replicatePR, [const "n#", id])
, ('replicatelPR, [const "segd", id])
, ('repeatPR, [const "n#", const "len#", id])
, ('indexPR, [id, const "i#"])
, ('extractPR, [id, const "i#", const "n#"])
, ('bpermutePR, [id, const "n#", const "is"])
, ('appPR, [(++"1"), (++"2")])
, ('applPR, [const "segd", const "is", (++"1"), const "js", (++"2")])
, ('packByTagPR, [id, const "n#", const "tags", const "t#"])
, ('combine2PR, [const "n#", const "sel", (++"1"), (++"2")])
, ('updatePR, [(++"1"), const "is", (++"2")])
, ('fromListPR, [const "n#", id])
, ('nfPR, [id])
]
scalarInstances :: [Name] -> Q [Dec]
scalarInstances tys
= do
pdatas <- mapM instance_PData_scalar tys
scalars <- mapM instance_Scalar_scalar tys
prs <- mapM instance_PR_scalar tys
return $ pdatas ++ scalars ++ prs
pdataScalarCon :: Name -> Name
pdataScalarCon n = mkName ("P" ++ nameBase n)
instance_PData_scalar :: Name -> DecQ
instance_PData_scalar tycon
= newtype_instance_PData tycon [] (pdataScalarCon tycon)
(conT ''U.Array `appT` conT tycon)
instance_Scalar_scalar :: Name -> DecQ
instance_Scalar_scalar ty
= instanceD (cxt [])
(conT ''Scalar `appT` conT ty)
(map (inlineD . mkName . fst) methods ++ map snd methods)
where
pcon = pdataScalarCon ty
xs = mkName "xs"
methods = [("fromScalarPData", mk_fromScalarPData),
("toScalarPData", mk_toScalarPData)]
mk_fromScalarPData = simpleFunD (mkName "fromScalarPData")
[conP pcon [varP xs]]
(varE xs)
mk_toScalarPData = simpleFunD (mkName "toScalarPData") [] (conE pcon)
instance_PR_scalar :: Name -> DecQ
instance_PR_scalar ty
= do
methods <- genPR_methods (scalarMethod ty)
return $ InstanceD []
(ConT ''PR `AppT` ConT ty)
methods
scalarMethod :: Name -> Name -> [ArgVal] -> Val -> DecQ
scalarMethod _ meth _ _
= simpleFunD (mkName $ nameBase meth) []
$ varE
$ mkName (nameBase meth ++ "Scalar")
voidPRInstance :: Name -> Name -> Name -> Q [Dec]
voidPRInstance ty void pvoid
= do
methods <- genPR_methods (voidMethod void pvoid)
return [InstanceD []
(ConT ''PR `AppT` ConT ty)
methods]
voidMethod :: Name -> Name -> Name -> [ArgVal] -> Val -> DecQ
voidMethod void pvoid meth avs res
= simpleFunD (mkName $ nameBase meth) (map (const wildP) avs)
$ result res
where
result ScalarVal = varE void
result PDataVal = varE pvoid
result UnitVal = conE '()
result _ = error "DPH gen: voidMethod: no match"
unitPRInstance :: Name -> Q [Dec]
unitPRInstance punit
= do
methods <- genPR_methods (unitMethod punit)
return [InstanceD []
(ConT ''PR `AppT` ConT ''())
methods]
unitMethod :: Name -> Name -> [ArgVal] -> Val -> DecQ
unitMethod punit meth avs res
= simpleFunD (mkName $ nameBase meth) pats
$ foldr seq_val (result res) es
where
(pats, es) = unzip [mkpat v g | (v,g) <- avs]
mkpat ScalarVal _ = (conP '() [], Nothing)
mkpat PDataVal _ = (conP punit [], Nothing)
mkpat ListVal g
= let xs = mkName (g "xs")
in (varP xs, Just $ \e -> varE 'foldr `appEs` [varE 'seq, e, varE xs])
mkpat OtherVal _ = (wildP, Nothing)
mkpat _ _ = error "DPH gen: unitMethod/mkpat: no match"
result ScalarVal = conE '()
result PDataVal = conE punit
result UnitVal = conE '()
result _ = error "DPH gen: unitMethod/result: no match"
seq_val Nothing e = e
seq_val (Just f) e = f e
wrapPRInstance :: Name -> Name -> Name -> Name -> Q [Dec]
wrapPRInstance ty wrap unwrap pwrap
= do
methods <- genPR_methods (recursiveMethod (wrapGen wrap unwrap pwrap))
return [InstanceD [ClassP ''PA [a]]
(ConT ''PR `AppT` (ConT ty `AppT` a))
methods]
where
a = VarT (mkName "a")
wrapGen :: Name -> Name -> Name -> Gen
wrapGen wrap unwrap pwrap
= Gen { recursiveCalls = 1
, recursiveName = recursiveName'
, split = split'
, join = join'
, typeName = "Wrap a"
}
where
recursiveName' = mkName . replace . nameBase
where
replace s = init s ++ "D"
split' (ScalarVal, gen)
= (PatSplit (conP wrap [varP x]), RecArg [] [varE x])
where
x = mkName (gen "x")
split' (PDataVal, gen)
= (PatSplit (conP pwrap [varP xs]), RecArg [] [varE xs])
where
xs = mkName (gen "xs")
split' (ListVal, gen)
= (PatSplit (varP xs),
RecArg [] [varE 'map `appEs` [varE unwrap, varE xs]])
where
xs = mkName (gen "xs")
split' _ = error "DPH gen: split': no match"
join' ScalarVal [x] = conE wrap `appE` x
join' PDataVal [xs] = conE pwrap `appE` xs
join' UnitVal [x] = x
join' _ _ = error "DPH gen: wrapGen: no match"
tupleInstances :: [Int] -> Q [Dec]
tupleInstances ns
= do
pdatas <- mapM instance_PData_tup ns
prs <- mapM instance_PR_tup ns
return $ pdatas ++ prs
pdataTupCon :: Int -> Name
pdataTupCon n = mkName ("P_" ++ show n)
instance_PData_tup :: Int -> DecQ
instance_PData_tup arity
= instance_PData (tupleT arity) vars (pdataTupCon arity)
[conT ''PData `appT` varT v | v <- vars]
where
vars = take arity $ [mkName [c] | c <- ['a' .. ]]
instance_PR_tup :: Int -> DecQ
instance_PR_tup arity
= do
methods <- genPR_methods (recursiveMethod (tupGen arity))
return $ InstanceD [ClassP ''PR [ty] | ty <- tys]
(ConT ''PR `AppT` (TupleT arity `mkAppTs` tys))
methods
where
tys = take arity $ [VarT $ mkName [c] | c <- ['a' .. ]]
tupGen :: Int -> Gen
tupGen arity = Gen { recursiveCalls = arity
, recursiveName = id
, split = split'
, join = join'
, typeName = tyname
}
where
split' (ScalarVal, gen)
= (PatSplit (tupP $ varPs names), RecArg [] (varEs names))
where
names = map (mkName . gen) vs
split' (PDataVal, gen)
= (PatSplit (conP (pdataTupCon arity) $ varPs names),
RecArg [] (varEs names))
where
names = map (mkName . gen) pvs
split' (ListVal, gen)
= (CaseSplit (varP xs) (varE mkunzip `appE` varE xs)
(tupP $ varPs names),
RecArg [] (varEs names))
where
xs = mkName (gen "xs")
names = map (mkName . gen) pvs
mkunzip | arity == 2 = mkName "unzip"
| otherwise = mkName ("unzip" ++ show arity)
split' _ = error "DPH Gen: tupGen/split: no match"
join' ScalarVal xs = tupE xs
join' PDataVal xs = conE (pdataTupCon arity) `appEs` xs
join' UnitVal xs = foldl1 (\x y -> varE 'seq `appEs` [x,y]) xs
join' _ _ = error "DPH Gen: tupGen/join: no match"
vs = take arity [[c] | c <- ['a' ..]]
pvs = take arity [c : "s" | c <- ['a' ..]]
tyname = "(" ++ intercalate "," vs ++ ")"