module Feldspar.Compiler.Transformation.GraphToImperative where
import Feldspar.Core.Graph
import Feldspar.Range
import qualified Feldspar.Core.Graph as Graph
import Feldspar.Core.Types hiding (typeOf)
import qualified Feldspar.Core.Types as CoreTypes
import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Imperative.CodeGeneration
import qualified Feldspar.Compiler.Imperative.Representation as Representation
import Feldspar.Compiler.Transformation.GraphUtils
import Data.List
import qualified Data.Map as Map
import qualified Data.Maybe as Maybe
import Feldspar.Compiler.Error
import Feldspar.Compiler.Imperative.Semantics
graphToImperative :: HierarchicalGraph -> [Procedure InitSemInf]
graphToImperative g = map transformSourceToProcedure sources where
sources = this : collectSources g
this = ProcedureSource
{ interface = hierGraphInterface g
, hierarchy = graphHierarchy g
}
data ProcedureSource
= ProcedureSource
{ interface :: Interface
, hierarchy :: Hierarchy
}
class Collect t where
collectSources :: t -> [ProcedureSource]
instance Collect HierarchicalGraph where
collectSources g = collectSources $ graphHierarchy g
instance Collect Hierarchy where
collectSources (Hierarchy xs) = collectSources xs
instance (Collect t) => Collect [t] where
collectSources xs = concatMap collectSources xs
instance Collect (Node,[Hierarchy]) where
collectSources (n,hs) = this ++ collectSources hs where
this = case function n of
NoInline name interface -> case hs of
[hierarchy] -> [ProcedureSource interface hierarchy]
_ -> error $ "Graph error: malformed hierarchy list in the 'NoInline' node with id " ++ show (nodeId n)
_ -> []
transformSourceToProcedure :: ProcedureSource -> Procedure InitSemInf
transformSourceToProcedure (ProcedureSource ifc (Hierarchy pairs))
= Procedure {
procedureName = "PLACEHOLDER",
inParameters = inputDecls,
outParameters = outputDecls,
procedureBody = Block {
blockDeclarations = localDecls,
blockInstructions = Program {
programConstruction = SequenceProgram $ Sequence {
sequenceProgramList = ( map transformNodeToProgram pairs
++ copyToOutput (interfaceOutput ifc) (interfaceOutputType ifc) True ),
sequenceSemInf = ()
},
programSemInf = ()
},
blockSemInf = ()
},
procedureSemInf = ()
} where
inputDecls = case inputNodes of
[inputNode] -> transformNodeToFormalParameter inputNode
[] -> handleError "GraphToImperative" InvariantViolation $ "no input node found" ++ (show (map fst pairs))
_ -> handleError "GraphToImperative" InvariantViolation $ "exactly one input node expected; nodeId==" ++ (show $ nodeId $ head inputNodes)
localDecls = concatMap transformNodeToLocalDeclaration localNodes
outputDecls = tupleWalk transformSourceToFormalParameter $ interfaceOutputType ifc
transformSourceToFormalParameter :: [Int] -> StorableType -> FormalParameter InitSemInf
transformSourceToFormalParameter path typ = FormalParameter {
formalParameterVariable = Representation.Variable FunOut ctyp (outName path) (),
formalParameterSemInf = ()
} where
ctyp = compileStorableType typ
(inputNodes,localNodes) = partition (\n -> nodeId n == interfaceInput ifc) $ map fst pairs
transformNodeToFormalParameter :: Node -> [FormalParameter InitSemInf]
transformNodeToFormalParameter n = tupleWalk genDecl $ tupleZip (outTyps,initVals) where
genDecl path (typ,ini)
= FormalParameter {
formalParameterVariable = Representation.Variable Value ctyp (varPrefix (nodeId n) ++ varPath path) (),
formalParameterSemInf = ()
} where
ctyp = compileStorableType typ
outTyps = outputType n
initVals = case function n of
Array d -> case outTyps of
One t -> One $ Just $ compileStorableData d t
_ -> error "Error: malformed output type of array node."
otherwise -> genNothingTuple outTyps
genNothingTuple (One _) = One Nothing
genNothingTuple (Tup xs) = Tup $ map genNothingTuple xs
transformNodeToLocalDeclaration :: Node -> [LocalDeclaration InitSemInf]
transformNodeToLocalDeclaration n = tupleWalk genDecl $ tupleZip (outTyps,initVals) where
genDecl path (typ,ini) = LocalDeclaration {
localVariable = Representation.Variable {
variableRole = Value,
variableType = ctyp,
variableName = (varPrefix (nodeId n) ++ varPath path),
variableSemInf = ()
},
localInitValue = ini,
localDeclarationSemInf = ()
} where
ctyp = compileStorableType typ
outTyps = outputType n
initVals = case function n of
Array d -> case outTyps of
One t -> One $ Just $ compileStorableData d t
_ -> error "Error: malformed output type of array node."
otherwise -> genNothingTuple outTyps
genNothingTuple (One _) = One Nothing
genNothingTuple (Tup xs) = Tup $ map genNothingTuple xs
transformNodeListToFormalParameters :: [Node] -> [FormalParameter InitSemInf]
transformNodeListToFormalParameters ns = concatMap transformNodeToFormalParameter ns
transformNodeListToLocalDeclarations :: [Node] -> [LocalDeclaration InitSemInf]
transformNodeListToLocalDeclarations ns = concatMap transformNodeToLocalDeclaration ns
transformNodeToProgram :: (Node, [Hierarchy]) -> Program InitSemInf
transformNodeToProgram (n,hs) = case function n of
Graph.Input -> Program (EmptyProgram $ Empty ()) ()
Array _ -> Program (EmptyProgram $ Empty ()) ()
Function s -> Program {
programConstruction = PrimitiveProgram $ Primitive {
primitiveInstruction = Instruction {
instructionData = (ProcedureCallInstruction $ ProcedureCall {
nameOfProcedureToCall = s,
actualParametersOfProcedureToCall = passInArgs (input n) (inputType n) ++
passOutArgs (nodeId n) (outputType n),
procedureCallSemInf = ()
}),
instructionSemInf = ()
},
primitiveSemInf = False
},
programSemInf = ()
}
NoInline s ifc -> Program {
programConstruction = PrimitiveProgram $ Primitive {
primitiveInstruction = Instruction {
instructionData = (ProcedureCallInstruction $ ProcedureCall {
nameOfProcedureToCall = s,
actualParametersOfProcedureToCall = passInArgs (input n) (inputType n) ++
passOutArgs (nodeId n) (outputType n),
procedureCallSemInf = ()
}),
instructionSemInf = ()
},
primitiveSemInf = False
},
programSemInf = ()
}
Graph.IfThenElse thenIfc elseIfc -> case hs of
[thenH, elseH] -> case (input n, inputType n) of
(Tup [cond, inp], Tup [One condTyp, inTyp])
| interfaceInputType thenIfc /= inTyp || interfaceInputType elseIfc /= inTyp
-> error "Error in 'ifThenElse' node: incorrect interface input type."
| compileStorableType condTyp /= Feldspar.Compiler.Imperative.Representation.BoolType
-> error "Error in 'ifThenElse' node: node output is expected to be 'Bool'."
| otherwise -> Program {
programConstruction = BranchProgram $ Branch {
branchConditionVariable = condVar,
thenBlock = mkBranch n thenIfc thenH,
elseBlock = mkBranch n elseIfc elseH,
branchSemInf = ()
},
programSemInf = ()
}
where
mkBranch :: Node -> Interface -> Hierarchy -> Block InitSemInf
mkBranch n ifc h@(Hierarchy pairs) = Block {
blockDeclarations = (transformNodeListToLocalDeclarations $ map fst pairs),
blockInstructions = Program {
programConstruction = SequenceProgram $ Sequence {
sequenceProgramList = (copyResult inp (interfaceInput ifc) inTyp False
++ transformNodeListToPrograms pairs
++ copyResult (interfaceOutput ifc) (nodeId n) (outputType n) True),
sequenceSemInf = ()
},
programSemInf = ()
},
blockSemInf = ()
}
condVar = case cond of
One (Graph.Variable (id,path)) ->
Representation.Variable Value Representation.BoolType (varName id path) ()
_ -> error "Error in 'ifThenElse' node: condition is not a variable."
otherwise -> error $ "Error in 'ifThenElse' node: incorrect node input or node input type"
otherwise -> error $ "Error in 'ifThenElse' node: two hierarchies expected, found " ++ show (length hs)
While condIfc bodyIfc -> Program {
programConstruction = SequenceProgram $ Sequence {
sequenceProgramList =
(copyResult (input n) (nodeId n) (outputType n) True ++
[Program {
programConstruction = SequentialLoopProgram $ SequentialLoop {
sequentialLoopCondition = (case interfaceOutput condIfc of
One (Graph.Variable (id,path)) -> varToExpr $ Representation.Variable Value Representation.BoolType (varName id path) ()
One (Graph.Constant (BoolData x)) -> Expression {
expressionData = ConstantExpression $ Representation.Constant {
constantData = BoolConstant $ BoolConstantType x (),
constantSemInf = ()
},
expressionSemInf = ()
}
unknown -> error $ "Error in a while loop: Malformed interface output of condition calculation: " ++ (show unknown)
),
conditionCalculation = Block {
blockDeclarations = (transformNodeListToLocalDeclarations condNodes),
blockInstructions = Program {
programConstruction = (SequenceProgram (Sequence (copyStateToCond ++ calculationCond) ())),
programSemInf = ()
},
blockSemInf = ()
},
sequentialLoopCore = Block {
blockDeclarations = (transformNodeListToLocalDeclarations bodyNodes),
blockInstructions = Program {
programConstruction = (SequenceProgram (Sequence (copyStateToBody ++ calculationBody ++ copyResultToState) ())),
programSemInf = ()
},
blockSemInf = ()
},
sequentialLoopSemInf = ()
},
programSemInf = ()
}
]),
sequenceSemInf = ()
},
programSemInf = ()
}
where
(Hierarchy condHier, Hierarchy bodyHier) = case hs of
[c,b] -> (c,b)
_ -> error $ "Error in a while node: expected 2 hierarchies, but found " ++ show (length hs)
condNodes = map fst condHier
bodyNodes = map fst bodyHier
copyStateToCond = copyNode (nodeId n) (interfaceInput condIfc) (outputType n) False
calculationCond = transformNodeListToPrograms condHier
copyStateToBody = copyNode (nodeId n) (interfaceInput bodyIfc) (outputType n) False
calculationBody = transformNodeListToPrograms bodyHier
copyResultToState = copyResult (interfaceOutput bodyIfc) (nodeId n) (outputType n) True
Parallel ifc ->
Program {
programConstruction = ParallelLoopProgram $ ParallelLoop
(Representation.Variable Value (Numeric ImpSigned S32) (varName inpId []) ()) num 1 prg
(),
programSemInf = ()
} where
num = case (input n, inputType n) of
(One inp, One intyp) -> transformSourceToExpr inp intyp
otherwise -> error "Invalid input of a Parallel node."
hist = case hs of
[(Hierarchy hist)] -> hist
_ -> error "More than one Hierarchy in a Parallel construct"
isInp (node,hs) = case (function node) of
Graph.Input -> True
_ -> False
(inps,notInps) = partition isInp hist
inpId = case inps of
[(node,hs)] -> nodeId node
_ -> error "More than one input node inside the Hierarchy of a Parallel construct"
topLevelNodes = map fst notInps
declarations = concatMap transformNodeToLocalDeclaration topLevelNodes
outSrc = case interfaceOutput ifc of
One src -> src
_ -> error "The interfaceOutput of a Parallel is not (One ...) "
outTypElem = case interfaceOutputType ifc of
One typ -> typ
_ -> error "The interfaceOutputType of a Parallel is not (One ...) "
outTypArray = case outputType n of
One typ -> typ
_ -> error "The outputType of a Parallel is not (One ...) "
outTypArrayImp = compileStorableType outTypArray
outTypElemImp = compileStorableType outTypElem
prg = Block {
blockDeclarations = declarations,
blockInstructions = Program {
programConstruction = SequenceProgram $ Sequence {
sequenceProgramList = map transformNodeToProgram notInps ++
[ Program {
programConstruction = PrimitiveProgram $ Primitive {
primitiveInstruction = makeCopyFromExprs
(transformSourceToExpr outSrc outTypElem)
(Expression {
expressionData = LeftValueExpression $ LeftValue {
leftValueData = (ArrayElemReferenceLeftValue $ ArrayElemReference {
arrayName = LeftValue {
leftValueData = VariableLeftValue $ Representation.Variable {
variableRole = Value,
variableType = outTypArrayImp,
variableName = (varName (nodeId n) []),
variableSemInf = ()
},
leftValueSemInf = ()
},
arrayIndex = (genVar inpId [] intType),
arrayElemReferenceSemInf = ()
}),
leftValueSemInf = ()
},
expressionSemInf = ()
}),
primitiveSemInf = True
},
programSemInf = ()
} ],
sequenceSemInf = ()
},
programSemInf = ()
},
blockSemInf = ()
}
transformNodeListToPrograms :: [(Node, [Hierarchy])] -> [Program InitSemInf]
transformNodeListToPrograms pairs = map transformNodeToProgram pairs
varPrefix :: NodeId -> String
varPrefix id = "var" ++ show id
varPath :: [Int] -> String
varPath path = concatMap (\id -> '_' : show id) path
varName :: NodeId -> [Int] -> String
varName id path = varPrefix id ++ varPath path
genVar :: NodeId -> [Int] -> Type -> Expression InitSemInf
genVar id path typ = Expression {
expressionData = LeftValueExpression $ LeftValue {
leftValueData = VariableLeftValue $ Representation.Variable {
variableRole = Value,
variableType = typ,
variableName = (varName id path),
variableSemInf = ()
},
leftValueSemInf = ()
},
expressionSemInf = ()
}
outPrefix :: String
outPrefix = "out"
outName :: [Int] -> String
outName path = outPrefix ++ varPath path
genOut :: [Int] -> Type -> Expression InitSemInf
genOut path typ = Expression {
expressionData = LeftValueExpression $ LeftValue {
leftValueData = VariableLeftValue $ Representation.Variable {
variableRole = FunOut,
variableType = typ,
variableName = (outName path),
variableSemInf = ()
},
leftValueSemInf = ()
},
expressionSemInf = ()
}
passInArgs :: Tuple Source -> Tuple StorableType -> [ActualParameter InitSemInf]
passInArgs tup typs = tupleWalk genArg $ tupleZip (tup,typs) where
genArg _ (Graph.Constant primData, StorableType _ typ) = ActualParameter {
actualParameterData = InputActualParameter $ compilePrimData primData typ,
actualParameterSemInf = ()
}
genArg _ (Graph.Variable (id, path), typ) = ActualParameter {
actualParameterData = InputActualParameter $ genVar id path (compileStorableType typ),
actualParameterSemInf = ()
}
passOutArgs :: NodeId -> Tuple StorableType -> [ActualParameter InitSemInf]
passOutArgs id typs = tupleWalk genArg typs where
genArg path t = ActualParameter {
actualParameterData = OutputActualParameter $ toLeftValue $ genVar id path (compileStorableType t),
actualParameterSemInf = ()
}
compileStorableType :: StorableType -> Type
compileStorableType (StorableType dims elemTyp) = case dims of
[] -> compilePrimitiveType elemTyp
(d:ds) -> ImpArrayType (getLength $ upperBound d) $ compileStorableType $ StorableType ds elemTyp
getLength (Just i) = Norm i
getLength _ = Undefined
compilePrimitiveType :: PrimitiveType -> Type
compilePrimitiveType typ = case typ of
UnitType -> Representation.BoolType
CoreTypes.BoolType -> Representation.BoolType
IntType True 8 _ -> Numeric ImpSigned S8
IntType True 16 _ -> Numeric ImpSigned S16
IntType True 32 _ -> Numeric ImpSigned S32
IntType True 64 _ -> Numeric ImpSigned S64
IntType False 8 _ -> Numeric ImpUnsigned S8
IntType False 16 _ -> Numeric ImpUnsigned S16
IntType False 32 _ -> Numeric ImpUnsigned S32
IntType False 64 _ -> Numeric ImpUnsigned S64
IntType sig size _ -> handleError "GraphToImperative" InvariantViolation $ "unknown integer type: IntType" ++ (show sig) ++ " " ++ (show size)
CoreTypes.FloatType x -> Representation.FloatType
CoreTypes.UserType userTypeName -> Representation.UserType userTypeName
compileStorableDataToConst :: StorableData -> Constant InitSemInf
compileStorableDataToConst (CoreTypes.PrimitiveData pd) = compilePrimDataToConst pd
compileStorableDataToConst (StorableData ds) = Representation.Constant {
constantData = ArrayConstant $ ArrayConstantType (map compileStorableDataToConst ds) (),
constantSemInf = ()
}
compilePrimDataToConst :: CoreTypes.PrimitiveData -> Constant InitSemInf
compilePrimDataToConst (UnitData ()) = Representation.Constant {
constantData = BoolConstant $ BoolConstantType False (),
constantSemInf = ()
}
compilePrimDataToConst (BoolData x) = Representation.Constant {
constantData = BoolConstant $ BoolConstantType x (),
constantSemInf = ()
}
compilePrimDataToConst (IntData x) = Representation.Constant {
constantData = IntConstant $ IntConstantType (fromInteger x) (),
constantSemInf = ()
}
compilePrimDataToConst (FloatData x) = Representation.Constant {
constantData = FloatConstant $ FloatConstantType x (),
constantSemInf = ()
}
compileStorableData :: StorableData -> StorableType -> Expression InitSemInf
compileStorableData (CoreTypes.PrimitiveData pd) (StorableType _ elemTyp) = compilePrimData pd elemTyp
compileStorableData a@(StorableData ds) typ = Expression (ConstantExpression $ compileStorableDataToConst a) ()
compilePrimData :: CoreTypes.PrimitiveData -> PrimitiveType -> Expression InitSemInf
compilePrimData d t = Expression (ConstantExpression $ compilePrimDataToConst d) ()
charType = Numeric ImpSigned S8
intType = Numeric ImpSigned S32
transformSourceToExpr :: Source -> StorableType -> Expression InitSemInf
transformSourceToExpr (Graph.Constant primData) (StorableType _ typ) = compilePrimData primData typ
transformSourceToExpr (Graph.Variable (id,path)) typ = genVar id path ctyp
where
ctyp = compileStorableType typ
makeCopyFromIds :: (NodeId,[Int],StorableType) -> (NodeId,[Int],StorableType) -> Instruction InitSemInf
makeCopyFromIds (idFrom,pathFrom,typeFrom) (idTo,pathTo,typeTo) =
makeCopyFromExprs
(genVar idFrom pathFrom ctypFrom)
(genVar idTo pathTo ctypTo)
where
ctypTo = compileStorableType typeTo
ctypFrom = compileStorableType typeFrom
makeCopyFromExprs :: Expression InitSemInf -> Expression InitSemInf -> Instruction InitSemInf
makeCopyFromExprs from to = Instruction {
instructionData = ProcedureCallInstruction $ ProcedureCall {
nameOfProcedureToCall = "copy",
actualParametersOfProcedureToCall = [ActualParameter {
actualParameterData = InputActualParameter from,
actualParameterSemInf = ()
},
ActualParameter {
actualParameterData = OutputActualParameter $ toLeftValue to,
actualParameterSemInf = ()
}],
procedureCallSemInf = ()
},
instructionSemInf = ()
}
copyNode :: NodeId -> NodeId -> Tuple StorableType -> Bool -> [Program InitSemInf]
copyNode fromId toId typeStructure isOutputCopying =
tupleWalk
(\path typ ->
Program {
programConstruction = PrimitiveProgram (Primitive {
primitiveInstruction = (makeCopyFromIds (fromId,path,typ) (toId,path,typ)),
primitiveSemInf = isOutputCopying
}),
programSemInf = ()
}
)
typeStructure
copyResult :: Tuple Source -> NodeId -> Tuple StorableType -> Bool -> [Program InitSemInf]
copyResult ifcOut nid outTyp isOutputCopying =
tupleWalk
(\path (out,typ) ->
Program {
programConstruction = PrimitiveProgram (Primitive {
primitiveInstruction = (makeCopyFromExprs (transformSourceToExpr out typ) (genVar nid path $ compileStorableType typ)),
primitiveSemInf = isOutputCopying
}),
programSemInf = ()
}
)
(tupleZip (ifcOut, outTyp))
copyToOutput :: Tuple Source -> Tuple StorableType -> Bool -> [Program InitSemInf]
copyToOutput ifcOut outTyp isOutputCopying =
tupleWalk
(\path (out,typ) ->
Program {
programConstruction = PrimitiveProgram $ Primitive {
primitiveInstruction = (makeCopyFromExprs (transformSourceToExpr out typ) (genOut path $ compileStorableType typ)),
primitiveSemInf = isOutputCopying
},
programSemInf = ()
}
)
(tupleZip (ifcOut, outTyp))
varToExpr :: Representation.Variable InitSemInf -> Expression InitSemInf
varToExpr v = Expression {
expressionData = LeftValueExpression $ LeftValue {
leftValueData = VariableLeftValue v,
leftValueSemInf = ()
},
expressionSemInf = ()
}