module Feldspar.Compiler.Plugins.HandlePrimitives
( HandlePrimitives(..)
, makeAssignment
, makePrimitive
) where
import Data.List (find)
import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Imperative.Semantics (SemanticInfo)
import Feldspar.Compiler.Imperative.CodeGeneration (simpleType, typeof, toLeftValue)
import Feldspar.Compiler.PluginArchitecture (TransformationPhase(..), Plugin(..), InfoFromPrimitiveParts(..), InfoFromProcedureParts(..))
import Feldspar.Compiler.PluginArchitecture.DefaultConvert (Combine(..))
import Feldspar.Compiler.Options
import Feldspar.Compiler.Error
handlePrimitivesError = handleError "PluginArch/HandlePrimitives" InternalError
data HandleTraceFunctions = HandleTraceFunctions
instance TransformationPhase HandleTraceFunctions where
type From HandleTraceFunctions = ()
type To HandleTraceFunctions = ()
type Downwards HandleTraceFunctions = ()
type Upwards HandleTraceFunctions = Bool
upwardsPrimitiveProgramInProgram = upwardsPrimitiveProgramInProgram'
transformProcedure = transformProcedure'
data HandlePrimitives = HandlePrimitives
instance TransformationPhase HandlePrimitives where
type From HandlePrimitives = ()
type To HandlePrimitives = ()
type Downwards HandlePrimitives = (Int,Platform)
type Upwards HandlePrimitives = ()
transformPrimitiveProgramInProgram = transformPrimitive'
instance Plugin HandlePrimitives where
type ExternalInfo HandlePrimitives = (Int, DebugOption, Platform)
executePlugin _ (_,NoPrimitiveInstructionHandling,_) procedure = procedure
executePlugin _ (defArrSize,_,platform) procedure
= fst $ executeTransformationPhase HandlePrimitives (defArrSize,platform)
$ fst $ executeTransformationPhase HandleTraceFunctions () procedure
instance Combine Bool where
combine x y = or [x,y]
upwardsPrimitiveProgramInProgram' :: HandleTraceFunctions -> () -> Primitive () -> InfoFromPrimitiveParts HandleTraceFunctions -> ProgramConstruction () -> Bool
upwardsPrimitiveProgramInProgram' _ _ old _ _ |nameS=="trace" = True
| otherwise = False
where
nameS = nameOfProcedureToCall $ (\(ProcedureCallInstruction x) -> x) $ instructionData $ primitiveInstruction old
transformProcedure' :: HandleTraceFunctions -> () -> Procedure () -> InfoFromProcedureParts HandleTraceFunctions -> Procedure ()
transformProcedure' _ _ old upwardInfos
| containsTrace = old{
procedureBody = (procedureBody old){
blockInstructions = (blockInstructions $ procedureBody old){
programConstruction = addTraceSE $programConstruction $ blockInstructions $ procedureBody old } } }
| otherwise = old
where
containsTrace = upwardsInfoFromProcedureBody upwardInfos
addTraceSE (SequenceProgram sequ) = SequenceProgram sequ{ sequenceProgramList = [traceStart] ++ (sequenceProgramList sequ) ++ [traceEnd] }
addTraceSE _ = SequenceProgram (Sequence [traceStart, blockInstructions $ procedureBody old, traceEnd] ())
traceStart = Program (PrimitiveProgram $ Primitive (Instruction (ProcedureCallInstruction $ ProcedureCall "traceStart" [] ()) ()) ()) ()
traceEnd = Program (PrimitiveProgram $ Primitive (Instruction (ProcedureCallInstruction $ ProcedureCall "traceEnd" [] ()) ()) ()) ()
transformPrimitive' :: HandlePrimitives -> (Int,Platform) -> Primitive () -> InfoFromPrimitiveParts HandlePrimitives -> ProgramConstruction ()
transformPrimitive' _ (defArrSize,pfm) old modified'
= case (nameS, inps, outs) of
("(!)", [arr, idx], [out])
-> mkPrg $ makeAssignment pfm
(lToe $ LeftValue
(ArrayElemReferenceLeftValue $ ArrayElemReference
(toLeftValue arr) idx ()
) ()
) out defArrSize
("setIx", [original, idx, val], [result])
-> SequenceProgram $ Sequence
[ Program (mkPrg $ makeAssignment pfm original result defArrSize) ()
, Program (mkPrg $ makeAssignment pfm val
(LeftValue (ArrayElemReferenceLeftValue $ ArrayElemReference result idx ()) ())
defArrSize
) ()
] ()
("copy", [in1], [out]) -> mkPrg $ makeAssignment pfm in1 out defArrSize
("trace", [label, original], [result])
-> SequenceProgram $ Sequence
[ Program (mkPrg $ makeAssignment pfm original result defArrSize) ()
, Program (mkPrg $ makePrimitive pfm (Proc "trace" firstInFP) [lToe result, label] [] 0) ()
] ()
_ -> case (find matchPrimitive $ primitives pfm) of
Just (fd,Right tp) -> SequenceProgram (Sequence plist ())
where
plist = map (\(cd',inps',outs') -> Program (mkPrg $ makePrimitive pfm cd' inps' outs' 0) ()) $ tp fd inps outs
Just (fd,Left cd) -> mkPrg $ makePrimitive pfm cd inps outs defArrSize
Nothing -> mkPrg $ modified
where
nameS = nameOfProcedureToCall $ (\(ProcedureCallInstruction x) -> x) $ instructionData $ primitiveInstruction old
as = actualParametersOfProcedureToCall $ (\(ProcedureCallInstruction x) -> x) $ instructionData modified
modified = recursivelyTransformedPrimitiveInstruction modified'
mkPrg x = PrimitiveProgram (Primitive x ())
inps = map aToE $ filter isInparam as
outs = map aToL $ filter (not . isInparam) as
matchPrimitive (fd,_) = (fName fd == nameS) && (matchTypes' (inputs fd) inps)
matchTypes' :: [TypeDesc] -> [Expression ()] -> Bool
matchTypes' [] [] = True
matchTypes' [] (y:ys) = False
matchTypes' (x:xs) [] = False
matchTypes' (x:xs) (y:ys) = (machTypes x $ typeof y) && (matchTypes' xs ys)
makeAssignment :: Platform -> Expression () -> LeftValue () -> Int -> Instruction ()
makeAssignment pfm in1 out defArrSize = makePrimitive pfm Assig [in1] [out] defArrSize
makePrimitive :: Platform -> CPrimDesc -> [Expression ()] -> [LeftValue ()] -> Int -> Instruction ()
makePrimitive pfm Assig [in1] [out] defArrSize
| simpleType (typeof in1) = Instruction (AssignmentInstruction $ Assignment out in1 ()) ()
| otherwise = case (typeof in1) of
(ImpArrayType _ t) -> makePrimitive pfm (Proc "copy" firstInFP) [in1, intToCe$ arraySize (typeof in1) defArrSize] [out] 0
_ -> handlePrimitivesError $ "Unknown type in makePrimitive:\n" ++ show (typeof in1)
makePrimitive pfm Assig _ _ _ = handlePrimitivesError $ "Wrong number of parameters for an assignment. (Parallel assignment not allowed.)"
makePrimitive pfm desc inps outs _
| isNotProc desc && simpleType (typeof $ head outs)
= Instruction (AssignmentInstruction $ Assignment (head outs) (Expression (FunctionCallExpression funCall) ()) ()) ()
| otherwise = Instruction (ProcedureCallInstruction procCall) ()
where
funCall = case (desc, length inps, length outs) of
(Op1 op, 1, 1) -> FunctionCall PrefixOp (typeof $ head outs) op inps ()
(Op2 op, 2, 1) -> FunctionCall InfixOp (typeof $ head outs) op inps ()
(Fun _ _, _, 1) -> FunctionCall SimpleFun (typeof $ head outs) completeFunName inps ()
_ -> errorMessage
procCall = case (desc, length inps, length outs) of
(Fun _ _, _, 1) -> procCall'
(Proc _ _, _, _) -> procCall'
_ -> errorMessage
procCall' = ProcedureCall completeProcName (inps' ++ outs') ()
inps' = map eToA inps
outs' = map lToA outs
completeFunName | funPf desc == noneFP = cName desc
| otherwise = cName desc ++ "_fun" ++ apsToName
completeProcName | funPf desc == noneFP = cName desc
| otherwise = cName desc ++ apsToName
apsToName = concatMap (("_"++) . (toFunName pfm) . typeof) apsToNameList
apsToNameList = (take (useInputs $ funPf desc) inps') ++ (take (useOutputs $ funPf desc) outs')
isNotProc (Proc _ _) = False
isNotProc _ = True
errorMessage = handlePrimitivesError $ "Wrong C pirmitive description or different number of parameter:\n"
++ show desc ++ "\n" ++ concatMap ((", "++) . show . typeof) inps
toFunName :: Platform -> Type -> String
toFunName pfm (ImpArrayType _ t@(ImpArrayType _ _)) = toFunName pfm t
toFunName pfm (ImpArrayType _ t) = "arrayOf_" ++ toFunName pfm t
toFunName pfm t = case (find (\(t',_,_) -> t == t') $ types pfm) of
Just (_,_,s) -> map (\c -> if c == ' ' then '_' else c) $ s
Nothing -> handlePrimitivesError $ "Unhandled type in platform " ++ name pfm
arraySize :: Type -> Int -> Int
arraySize a@(ImpArrayType _ t) defaultArraySize = arraySize' a
where
arraySize' (ImpArrayType (Norm n) t) = n * arraySize' t
arraySize' (ImpArrayType (Defined n) t) = n * arraySize' t
arraySize' (ImpArrayType Undefined t) = defaultArraySize * arraySize' t
arraySize' _ = 1
isInparam (ActualParameter (InputActualParameter _) _) = True
isInparam (ActualParameter (OutputActualParameter _) _) = False
aToE (ActualParameter (InputActualParameter x) ()) = x
aToL (ActualParameter (OutputActualParameter x) ()) = x
eToA x = ActualParameter (InputActualParameter x) ()
lToA x = ActualParameter (OutputActualParameter x) ()
intToCe x = Expression (ConstantExpression $ Constant (IntConstant $ IntConstantType x ()) ()) ()
lToe x = Expression (LeftValueExpression x) ()