module Feldspar.Compiler.Backend.C.Plugin.HandlePrimitives
( HandlePrimitives(..)
, completeFunProcName
) where
import Data.List (find)
import Data.Maybe (fromJust)
import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Backend.C.CodeGeneration (typeof, defaultMemberName)
import Feldspar.Transformation
import Feldspar.Compiler.Backend.C.Options
import Feldspar.Compiler.Error
handlePrimitivesError = handleError "PluginArch/HandlePrimitives" InternalError
data HandleTraceFunctions = HandleTraceFunctions
instance Default Bool where
def = False
instance Combine Bool where
combine x y = or [x,y]
instance Transformation HandleTraceFunctions where
type From HandleTraceFunctions = ()
type To HandleTraceFunctions = ()
type Down HandleTraceFunctions = ()
type Up HandleTraceFunctions = Bool
type State HandleTraceFunctions = ()
instance Transformable HandleTraceFunctions Expression where
transform t s d p = tr { up = combine u' (up tr) } where
tr = defaultTransform t s d p
u' = case p of
FunctionCall "trace" _ _ _ _ _ -> True
_ -> False
instance Transformable HandleTraceFunctions Definition where
transform t s d p@(Procedure n i o b _ _) = case up tr of
False -> tr
True -> tr
{ result = (result tr)
{ procBody = (procBody $ result tr)
{ blockBody = addTraceSE $ blockBody $ procBody $ result tr
}
}
}
where
tr = defaultTransform t s d p
addTraceSE sequ@(Sequence _ _ _) = sequ { sequenceProgs = [traceStart] ++ (sequenceProgs sequ) ++ [traceEnd] }
addTraceSE p = Sequence [traceStart, p, traceEnd] () ()
traceStart = ProcedureCall "traceStart" [] () ()
traceEnd = ProcedureCall "traceEnd" [] () ()
transform t s d p = defaultTransform t s d p
data HandlePrimitives = HandlePrimitives
instance Transformation HandlePrimitives where
type From HandlePrimitives = ()
type To HandlePrimitives = ()
type Down HandlePrimitives = (Int, Platform, Maybe (Expression ()))
type Up HandlePrimitives = ([Declaration ()], [Program ()])
type State HandlePrimitives = Int
instance Plugin HandlePrimitives where
type ExternalInfo HandlePrimitives = (Int, DebugOption, Platform)
executePlugin _ (_,NoPrimitiveInstructionHandling,_) procedure = procedure
executePlugin _ (defArrSize,_,platform) procedure
= result $ transform HandlePrimitives 0 (defArrSize, platform, Nothing) $
result $ transform HandleTraceFunctions () () procedure
instance Combine ([Declaration ()], [Program ()]) where
combine (xl, xi) (yl, yi) = (xl ++ yl, xi ++ yi)
instance Default [Declaration ()] where
def = []
instance Default [Program ()] where
def = []
instance Transformable HandlePrimitives Block where
transform t s d b = tr
{ result = addToBlock (result tr) (up tr)
, up = ([],[])
} where
tr = case (up tr') of
(_,[]) -> tr'
_ -> handlePrimitivesError $ "transform Block: upwards program list is not empty."
tr' = defaultTransform t s d b
instance Transformable HandlePrimitives Program where
transform t s d p@(ProcedureCall "copy" [o@(Out out _), i@(In inp _)] _ _) = case typeof out of
(ArrayType _ _) -> Result (ProcedureCall "copyArray" [out', inp'] () ()) arrS' arrU'
_ -> Result (Assign lhs rhs () ()) assS' assU'
where
(Result out' arrS arrU1) = transform t s d o
(Result inp' arrS' arrU2) = transform t arrS d i
arrU' = arrU1 `combine` arrU2
(Result lhs assS assU1) = transform t s d out
(Result rhs assS' assU2) = transform t assS d inp
assU' = assU1 `combine` assU2
transform t s d (SeqLoop c cc p inf1 inf2) = Result (SeqLoop (result tr1) cc' (result tr3) (convert inf1) $ convert inf2) (state tr3) ([],[]) where
tr1 = transform t s d c
tr2 = transform t (state tr1) d cc
tr3 = transform t (state tr2) d p
cc' = addToBlock (result tr2) (up tr1)
transform t s d p = defaultTransform t s d p
instance Transformable1 HandlePrimitives [] Program where
transform1 t s d [] = Result1 [] s def
transform1 t s d (x:xs) = Result1 (snd (up tr1) ++ [result tr1] ++ (result1 tr2)) (state1 tr2) (concatMap fst [up tr1,up1 tr2],[]) where
tr1 = transform t s d x
tr2 = transform1 t (state tr1) d xs
instance Transformable HandlePrimitives Declaration where
transform t s d@(das, pfm, _) (Declaration v i inf) = Result (Declaration (result tr1) i' $ convert inf) (state1 tr2) u' where
tr1 = transform t s d v
tr2 = transform1 t (state tr1) d i
(i',u') = case (up1 tr2) of
u@(ls,[]) -> (result1 tr2, combine (up tr1) u)
(ls, is) -> (Nothing, (ls, is ++ [makeAssignment (das, pfm) (fromJust $ result1 tr2) (vToE $ result tr1)]))
instance Transformable HandlePrimitives Expression where
transform t s d@(das, pfm, me) f@(FunctionCall nameS ot origRole origInps _ _) = res
where
res = case (nameS, origInps) of
("getFst", [FunctionCall "pair" _ _ [fs,sn] _ _]) -> transform t s (das, pfm, Nothing) fs
("getSnd", [FunctionCall "pair" _ _ [fs,sn] _ _]) -> transform t s (das, pfm, Nothing) sn
_ -> Result e' s' $ combine (up tr) (l',p')
tr = defaultTransform t s (das, pfm, Nothing) f
s2 = state tr
(s',l',p',e') = case (nameS, inps, me) of
("(!)", [arr, idx], _) -> (s2, [], [], ArrayElem arr idx () ())
("setIx", [arr, idx, val], _) -> (s2
, []
, [ makeAssignment d' val (ArrayElem arr idx () ()) ]
, arr
)
("getFst", [l], _) -> (s2, [], [], StructField l (defaultMemberName ++ "1") () ())
("getSnd", [l], _) -> (s2, [], [], StructField l (defaultMemberName ++ "2") () ())
("pair", [a,b], Just e) -> (s2
, []
, [ makeAssignment d' a (StructField e (defaultMemberName ++ "1") () ())
, makeAssignment d' b (StructField e (defaultMemberName ++ "2") () ())
]
, e
)
("pair", [a,b], Nothing) -> (s3
, [ makeDeclaration stc Nothing ]
, [ makeAssignment d' a (StructField (VarExpr stc ()) (defaultMemberName ++ "1") () ())
, makeAssignment d' b (StructField (VarExpr stc ()) (defaultMemberName ++ "2") () ())
]
, VarExpr stc ()
) where (s3, stc) = makeVariable ot "stc" s2
("trace", [lab, orig], Just e) -> (s2
, []
, [ makeAssignment d' orig e
, makeProcedureCall pfm (Proc "trace" firstInFP) [e, lab] []
]
, e
)
("trace", [lab, orig], Nothing) -> (s3
, [ makeDeclaration trc Nothing ]
, [ makeAssignment d' orig (VarExpr trc ())
, makeProcedureCall pfm (Proc "trace" firstInFP) [VarExpr trc (), lab] []
]
, VarExpr trc ()
) where (s3, trc) = makeVariable ot "trc" s2
_ -> case (find matchPrimitive $ primitives pfm) of
Just (fd,Right tp) -> transformPrgDesc d' s2 (tp fd inps ot)
Just (fd,Left cd) -> transformCPrimDesc d' s2 cd inps ot
Nothing -> (s2, [], [], result tr)
matchPrimitive (fd,_) = (fName fd == nameS) && (matchTypes' (inputs fd) inps)
inps = funCallParams $ result tr
d' = (das, pfm)
transform t s d@(das, pfm, _) p = defaultTransform t s (das, pfm, Nothing) p
addToBlock :: Block () -> ([Declaration ()], [Program ()]) -> Block ()
addToBlock b (ls,is)
= b {
locals = locals b ++ ls,
blockBody = case (blockBody b) of
(Sequence s () ()) -> Sequence (s ++ is) () ()
p -> Sequence ([p] ++ is) () ()
}
transformCPrimDesc :: (Int,Platform) -> Int -> CPrimDesc -> [Expression ()] -> Type -> (Int, [Declaration ()], [Program ()], Expression ())
transformCPrimDesc (_,pfm) serial cd inps ot
= case (cd, length inps) of
(Op1 op, 1) -> (serial, [], [], FunctionCall op ot PrefixOp inps () ())
(Op2 op, 2) -> (serial, [], [], FunctionCall op ot InfixOp inps () ())
(Fun _ _, _) -> (serial, [], [], FunctionCall (completeFunProcName pfm cd (map typeof inps) [ot]) ot SimpleFun inps () ())
(Cas, 1) -> (serial, [], [], Cast ot (head inps) () ())
(Assig, 1) -> (serial, [], [], head inps)
_ -> (serial', [makeDeclaration ov Nothing], [makeProcedureCall pfm cd inps [vToE ov]], vToE ov)
where
(serial', ov) = makeVariable ot "vhp" serial
transformPrgDesc :: (Int,Platform) -> Int -> PrgDesc -> (Int, [Declaration ()], [Program ()], Expression ())
transformPrgDesc down@(_,pfm) serial (PrgDesc crts lns rgt)
= (serial', map (\(_,_,v,me) -> makeDeclaration v me) vars, ins, transformRgt vars rgt)
where
(serial', vars') = foldl transformCrtFold (serial, []) (map searchDuplicateLabels crts)
(vars, ins) = foldl transformLineFold (vars', []) lns
searchDuplicateLabels c = if (length $ filter (==c) crts) > 1 then handlePrimitivesError $ "multiple declaration" ++ show c else c
transformCrtFold (n ,vs) (Crt t v@(Var s) (Just r)) = (n', vs ++ [(v, True, mv, Just $ transformRgt vs r)])
where
(n', mv) = makeVariable t s n
transformCrtFold (n ,vs) (Crt t v@(Var s) Nothing) = (n', vs ++ [(v, False, mv, Nothing)])
where
(n', mv) = makeVariable t s n
transformLineFold (vs, is) ln = case (ln) of
(Asg v r) -> (updateVars [v], is ++ [makeAssignment down (transformRgt' r) (transformVarL' v)])
(Prc cd inps outs) -> (updateVars outs, is ++ [makeProcedureCall pfm cd (map transformRgt' inps) (map transformVarL' outs)])
where
updateVars xs = map (\y@(v',_,vv,mr) -> if elem v' xs then (v',True,vv,mr) else y) vs
transformRgt' = transformRgt vs
transformVarL' = vToE . transformVarL vs
transformRgt vs (Exp e) = e
transformRgt vs (Fnc cd rgts ot) = makeFunctionCallOrCast down cd (map (transformRgt vs) rgts) ot
transformRgt vs (VarR v) = vToE $ transformVarR vs v
transformVarL vs v@(Var s) = case (find (\(v',_,_,_) -> v' == v) vs) of
Just (_,_,vv,_) -> vv
Nothing -> handlePrimitivesError $ "Not declared: " ++ show v
transformVarR vs v@(Var s) = case (find (\(v',_,_,_) -> v' == v) vs) of
Just (_,True,vv,_) -> vv
Just (_,False,vv,_) -> vv
Nothing -> handlePrimitivesError $ "Not declared: " ++ show v
makeFunctionCallOrCast :: (Int,Platform) -> CPrimDesc -> [Expression ()] -> Type -> Expression ()
makeFunctionCallOrCast down cd inps ot
= case (transformCPrimDesc down (1) cd inps ot) of
(_, [], [], ed) -> ed
_ -> handlePrimitivesError $ "it's not a FunctionCall: " ++ show cd ++ "number of inputs: " ++ (show $ length inps)
makeVariable :: Type -> String -> Int -> (Int, Variable ())
makeVariable t s n = (n+1, Variable (s ++ show n) t Value ())
makeDeclaration :: Variable () -> Maybe (Expression ()) -> Declaration ()
makeDeclaration v me = Declaration v me ()
makeAssignment :: (Int,Platform) -> Expression () -> Expression () -> Program ()
makeAssignment (defArrSize,pfm) inp out
= case (sameVariable inp out, typeof inp) of
(True, _) -> Empty () ()
(_, ArrayType _ t) -> ProcedureCall "copyArray" [eToOut out, eToIn inp] () ()
where
_ -> Assign out inp () ()
where
sameVariable (VarExpr v1 _) (VarExpr v2 _) | v1 == v2 = True
| otherwise = False
sameVariable (ArrayElem a1 i1 _ _) (ArrayElem a2 i2 _ _) | a1 == a2 && i1 == i2 = True
| otherwise = False
sameVariable _ _ = False
makeProcedureCall :: Platform -> CPrimDesc -> [Expression ()] -> [Expression ()] -> Program ()
makeProcedureCall pfm cd@(Proc _ _) inps outs = ProcedureCall (completeFunProcName pfm cd its ots) (inps' ++ outs') () ()
where
inps' = map eToIn inps
outs' = map eToOut outs
its = map typeof inps
ots = map typeof outs
makeProcedureCall _ cd _ _ = handlePrimitivesError $ "Wrong C pirmitive description in makeProcedureCall:\n" ++ show cd
matchTypes' :: [TypeDesc] -> [Expression ()] -> Bool
matchTypes' [] [] = True
matchTypes' [] (y:ys) = False
matchTypes' (x:xs) [] = False
matchTypes' (x:xs) (y:ys) = (machTypes x $ typeof y) && (matchTypes' xs ys)
completeFunProcName :: Platform -> CPrimDesc -> [Type] -> [Type] -> String
completeFunProcName pfm desc its ots
| funPf desc == noneFP = cName desc
| otherwise = cName desc ++ ifFun ++ apsToName
where
ifFun = case desc of
Fun _ _ -> "_fun"
Proc _ _ -> ""
apsToName = concatMap (("_"++) . (toFunName pfm)) apsToNameList
apsToNameList = (take (useInputs $ funPf desc) its) ++ (take (useOutputs $ funPf desc) ots)
toFunName :: Platform -> Type -> String
toFunName pfm (ArrayType _ t@(ArrayType _ _)) = toFunName pfm t
toFunName pfm (ArrayType _ t) = "arrayOf_" ++ toFunName pfm t
toFunName pfm t = case (find (\(t',_,_) -> t == t') $ types pfm) of
Just (_,_,s) -> map (\c -> if c == ' ' then '_' else c) $ s
Nothing -> handlePrimitivesError $ "Unhandled type in platform " ++ name pfm
prod_const a b = FunctionCall "*" (NumType Unsigned S32) InfixOp [a,b] () ()
isInparam (In _ _) = True
isInparam (Out _ _) = False
aToE (In x ()) = x
aToE (Out x ()) = x
eToIn x = In x ()
eToOut x = Out x ()
intToCe x = ConstExpr (IntConst x () ()) ()
vToE v = VarExpr v ()