module Feldspar.Compiler.Transformation.GraphToImperative where
import Feldspar.Core.Graph
import Feldspar.Core.Types hiding (typeOf)
import Feldspar.Compiler.Imperative.Representation hiding (Array)
import Feldspar.Compiler.Transformation.GraphUtils
import Data.List
import qualified Data.Map as Map
graphToImperative :: String -> HierarchicalGraph -> [ImpFunction]
graphToImperative s g = map transformSourceToImpFunction sources where
sources = this : collectSources g
this = ImpFunctionSource
{ functionName = s
, interface = hierGraphInterface g
, hierarchy = graphHierarchy g
}
data ImpFunctionSource
= ImpFunctionSource
{ functionName :: String
, interface :: Interface
, hierarchy :: Hierarchy
}
instance Show ImpFunctionSource where
show (ImpFunctionSource s _ _) = s
class Collect t where
collectSources :: t -> [ImpFunctionSource]
instance Collect HierarchicalGraph where
collectSources g = collectSources $ graphHierarchy g
instance Collect Hierarchy where
collectSources (Hierarchy xs) = collectSources xs
instance (Collect t) => Collect [t] where
collectSources xs = concatMap collectSources xs
instance Collect (Node,[Hierarchy]) where
collectSources (n,hs) = this ++ collectSources hs where
this = case function n of
NoInline name interface -> case hs of
[hierarchy] -> [ImpFunctionSource name interface hierarchy]
_ -> error $ "Graph error: malformed hierarchy list in the 'NoInline' node with id " ++ show (nodeId n)
_ -> []
transformSourceToImpFunction :: ImpFunctionSource -> ImpFunction
transformSourceToImpFunction (ImpFunctionSource n ifc (Hierarchy pairs))
= Fun
{ funName = n
, inParameters = inputDecls
, outParameters = outputDecls
, prg
= CompPrg
{ locals = localDecls
, body = Seq ( map transformNodeToProgram pairs
++ copyToOutput (interfaceOutput ifc) (interfaceOutputType ifc) True) []
}
} where
(inputDecls, localDecls) = partition isInputDecl declarations where
isInputDecl d = isPrefixOf (varPrefix $ interfaceInput ifc) (name $ var d)
outputDecls = tupleWalk transformSourceToDecl $ interfaceOutputType ifc
transformSourceToDecl path typ
= Decl
{ var = Var (outName path) OutKind ctyp
, declType = ctyp
, initVal = Nothing
, semInfVar = unknownSemInfVar
} where
ctyp = compileStorableType typ
declarations = concatMap transformNodeToDeclaration topLevelNodes
topLevelNodes = map fst pairs
transformNodeToDeclaration :: Node -> [Declaration]
transformNodeToDeclaration n = tupleWalk genDecl $ tupleZip (outTyps,initVals) where
genDecl path (typ,ini)
= Decl
{ var = Var (varPrefix (nodeId n) ++ varPath path) Normal ctyp
, declType = ctyp
, initVal = ini
, semInfVar = unknownSemInfVar
} where
ctyp = compileStorableType typ
outTyps = outputType n
initVals = case function n of
Array d -> case outTyps of
One t -> One $ Just $ compileStorableData d t
_ -> error "Error: malformed output type of array node."
otherwise -> genNothingTuple outTyps
genNothingTuple (One _) = One Nothing
genNothingTuple (Tup xs) = Tup $ map genNothingTuple xs
transformNodeListToDeclarations :: [Node] -> [Declaration]
transformNodeListToDeclarations ns = concatMap transformNodeToDeclaration ns
transformNodeToProgram :: (Node, [Hierarchy]) -> Program
transformNodeToProgram (n,hs) = case function n of
Input -> Empty
Array _ -> Empty
Function s -> Primitive
(CFun s $ passInArgs (input n) (inputType n) ++ passOutArgs (nodeId n) (outputType n))
(SemInfPrim Map.empty False)
NoInline s ifc -> Primitive
(CFun s $ passInArgs (input n) (inputType n) ++ passOutArgs (nodeId n) (outputType n))
(SemInfPrim Map.empty False)
Feldspar.Core.Graph.IfThenElse thenIfc elseIfc -> case hs of
[thenH, elseH] -> case (input n, inputType n) of
(Tup [cond, inp], Tup [One condTyp, inTyp])
| interfaceInputType thenIfc /= inTyp || interfaceInputType elseIfc /= inTyp
-> error "Error in 'ifThenElse' node: incorrect interface input type."
| compileStorableType condTyp /= Feldspar.Compiler.Imperative.Representation.BoolType
-> error "Error in 'ifThenElse' node: node output is expected to be 'Bool'."
| otherwise -> Feldspar.Compiler.Imperative.Representation.IfThenElse
condVar
(mkBranch n thenIfc thenH)
(mkBranch n elseIfc elseH)
[]
where
mkBranch :: Node -> Interface -> Hierarchy -> CompleteProgram
mkBranch n ifc h@(Hierarchy pairs) = CompPrg
(transformNodeListToDeclarations $ map fst pairs)
(Seq (copyResult inp (interfaceInput ifc) inTyp False
++ transformNodeListToPrograms pairs
++ copyResult (interfaceOutput ifc) (nodeId n) (outputType n) True)
[])
condVar = case cond of
One (Variable (id,path)) -> Var (varName id path) Normal Feldspar.Compiler.Imperative.Representation.BoolType
_ -> error "Error in 'ifThenElse' node: condition is not a variable."
otherwise -> error $ "Error in 'ifThenElse' node: incorrect node input or node input type"
otherwise -> error $ "Error in 'ifThenElse' node: two hierarchies expected, found " ++ show (length hs)
While condIfc bodyIfc -> Seq
(copyResult (input n) (nodeId n) (outputType n) False ++
[SeqLoop
(case interfaceOutput condIfc of
One (Variable (id,path)) -> Var (varName id path) Normal Feldspar.Compiler.Imperative.Representation.BoolType
_ -> error "Error in a while loop: Malformed interface output of condition calculation."
)
(CompPrg
(transformNodeListToDeclarations condNodes)
(Seq (copyStateToCond ++ calculationCond) [])
)
(CompPrg
(transformNodeListToDeclarations bodyNodes)
(Seq (copyStateToBody ++ calculationBody ++ copyResultToState) [])
)
[]
]) [] where
(Hierarchy condHier, Hierarchy bodyHier) = case hs of
[c,b] -> (c,b)
_ -> error $ "Error in a while node: expected 2 hierarchies, but found " ++ show (length hs)
condNodes = map fst condHier
bodyNodes = map fst bodyHier
copyStateToCond = copyNode (nodeId n) (interfaceInput condIfc) (outputType n) False
calculationCond = transformNodeListToPrograms condHier
copyStateToBody = copyNode (nodeId n) (interfaceInput bodyIfc) (outputType n) False
calculationBody = transformNodeListToPrograms bodyHier
copyResultToState = copyResult (interfaceOutput bodyIfc) (nodeId n) (outputType n) True
Parallel _ ifc ->
ParLoop (Var (varName inpId []) Normal $ Numeric ImpSigned S32) num 1 prg [] where
num = case (input n, inputType n) of
(One inp, One intyp) -> transformSourceToExpr inp intyp
otherwise -> error "Invalid input of a Parallel node."
hist = case hs of
[(Hierarchy hist)] -> hist
_ -> error "More than one Hierarchy in a Parallel construct"
isInp (node,hs) = case (function node) of
Input -> True
_ -> False
(inps,notInps) = partition isInp hist
inpId = case inps of
[(node,hs)] -> nodeId node
_ -> error "More than one input node inside the Hierarchy of a Parallel construct"
topLevelNodes = map fst notInps
declarations = concatMap transformNodeToDeclaration topLevelNodes
outSrc = case interfaceOutput ifc of
One src -> src
_ -> error "The interfaceOutput of a Parallel is not (One ...) "
outTyp = case interfaceOutputType ifc of
One typ -> typ
_ -> error "The interfaceOutputType of a Parallel is not (One ...) "
prg = CompPrg
{ locals = declarations
, body = Seq ( map transformNodeToProgram notInps ++
[ Primitive ( makeCopyFromExprs
(transformSourceToExpr outSrc outTyp)
(Expr (LeftExpr $ ArrayElem (LVar (Var (varName (nodeId n) []) Normal intType)) (Expr (genVar inpId [] intType) intType)) intType)
)
(SemInfPrim Map.empty True)
]
) []
}
transformNodeListToPrograms :: [(Node, [Hierarchy])] -> [Program]
transformNodeListToPrograms pairs = map transformNodeToProgram pairs
varPrefix :: NodeId -> String
varPrefix id = "var" ++ show id
varPath :: [Int] -> String
varPath path = concatMap (\id -> '_' : show id) path
varName :: NodeId -> [Int] -> String
varName id path = varPrefix id ++ varPath path
genVar :: NodeId -> [Int] -> Type -> UntypedExpression
genVar id path typ = LeftExpr $ LVar $ Var (varName id path) Normal typ
outPrefix :: String
outPrefix = "out"
outName :: [Int] -> String
outName path = outPrefix ++ varPath path
genOut :: [Int] -> Type -> UntypedExpression
genOut path typ = LeftExpr $ LVar $ Var (outName path) OutKind typ
passInArgs :: Tuple Source -> Tuple StorableType -> [Parameter]
passInArgs tup typs = tupleWalk genArg $ tupleZip (tup,typs) where
genArg _ (Constant primData, StorableType _ typ) = In $ compilePrimData primData typ
genArg _ (Variable (id, path), typ) = In $ Expr (genVar id path ctyp) $ ctyp
where
ctyp = compileStorableType typ
passOutArgs :: NodeId -> Tuple StorableType -> [Parameter]
passOutArgs id typs = tupleWalk genArg typs where
genArg path t = Out (Normal,Expr (genVar id path ctyp) $ ctyp)
where
ctyp = compileStorableType t
compileStorableType :: StorableType -> Type
compileStorableType (StorableType dims elemTyp) = case dims of
[] -> compilePrimitiveType elemTyp
(d:ds) -> ImpArrayType (Just d) $ compileStorableType $ StorableType ds elemTyp
compilePrimitiveType :: PrimitiveType -> Type
compilePrimitiveType typ = case typ of
UnitType -> Feldspar.Compiler.Imperative.Representation.BoolType
Feldspar.Core.Types.BoolType
-> Feldspar.Compiler.Imperative.Representation.BoolType
IntType -> Numeric ImpSigned S32
Feldspar.Core.Types.FloatType
-> Feldspar.Compiler.Imperative.Representation.FloatType
compileStorableDataToConst :: StorableData -> Constant
compileStorableDataToConst (PrimitiveData pd) = compilePrimDataToConst pd
compileStorableDataToConst (StorableData len ds) = ArrayConst len $ map compileStorableDataToConst ds
compilePrimDataToConst :: PrimitiveData -> Constant
compilePrimDataToConst UnitData = BoolConst False
compilePrimDataToConst (BoolData x) = BoolConst x
compilePrimDataToConst (IntData x) = IntConst x
compilePrimDataToConst (FloatData x) = FloatConst x
compileStorableData :: StorableData -> StorableType -> ImpLangExpr
compileStorableData (PrimitiveData pd) (StorableType _ elemTyp) = compilePrimData pd elemTyp
compileStorableData a@(StorableData len ds) typ = Expr (ConstExpr $ compileStorableDataToConst a) $ compileStorableType typ
compilePrimData :: PrimitiveData -> PrimitiveType -> ImpLangExpr
compilePrimData d t = Expr (ConstExpr $ compilePrimDataToConst d) $ compilePrimitiveType t
charType = Numeric ImpSigned S8
intType = Numeric ImpSigned S32
transformSourceToExpr :: Source -> StorableType -> ImpLangExpr
transformSourceToExpr (Constant primData) (StorableType _ typ) = compilePrimData primData typ
transformSourceToExpr (Variable (id,path)) typ = Expr (genVar id path ctyp) $ ctyp
where
ctyp = compileStorableType typ
makeCopyFromIds :: (NodeId,[Int],StorableType) -> (NodeId,[Int],StorableType) -> Instruction
makeCopyFromIds (idFrom,pathFrom,typeFrom) (idTo,pathTo,typeTo) =
makeCopyFromExprs
(Expr (genVar idFrom pathFrom ctypFrom) ctypFrom)
(Expr (genVar idTo pathTo ctypTo) ctypTo)
where
ctypTo = compileStorableType typeTo
ctypFrom = compileStorableType typeFrom
makeCopyFromExprs :: ImpLangExpr -> ImpLangExpr -> Instruction
makeCopyFromExprs from to = CFun "copy" [In from, Out (Normal,to)]
copyNode :: NodeId -> NodeId -> Tuple StorableType -> Bool -> [Program]
copyNode fromId toId typeStructure isOutputCopying =
tupleWalk
(\path typ ->
Primitive
(makeCopyFromIds (fromId,path,typ) (toId,path,typ))
(SemInfPrim Map.empty isOutputCopying)
)
typeStructure
copyResult :: Tuple Source -> NodeId -> Tuple StorableType -> Bool -> [Program]
copyResult ifcOut nid outTyp isOutputCopying =
tupleWalk
(\path (out,typ) ->
Primitive
(makeCopyFromExprs (transformSourceToExpr out typ) (Expr (genVar nid path $ compileStorableType typ) $ compileStorableType typ))
(SemInfPrim Map.empty isOutputCopying)
)
(tupleZip (ifcOut, outTyp))
copyToOutput :: Tuple Source -> Tuple StorableType -> Bool -> [Program]
copyToOutput ifcOut outTyp isOutputCopying =
tupleWalk
(\path (out,typ) ->
Primitive
(makeCopyFromExprs (transformSourceToExpr out typ) (Expr (genOut path $ compileStorableType typ) $ compileStorableType typ))
(SemInfPrim Map.empty isOutputCopying)
)
(tupleZip (ifcOut, outTyp))