module Feldspar.Compiler.Plugins.Unroll where
import Feldspar.Compiler.Imperative.Representation
import Feldspar.Compiler.Options
import Prelude
import Feldspar.Compiler.Imperative.Semantics
import Feldspar.Compiler.PluginArchitecture
instance Plugin UnrollPlugin where
type ExternalInfo UnrollPlugin = UnrollStrategy
executePlugin UnrollPlugin ei p = case ei of
NoUnroll -> p
Unroll unrollCount -> fst $ executeTransformationPhase Unroll_2 Nothing $ fst $ executeTransformationPhase Unroll_1 unrollCount p
data UnrollPlugin = UnrollPlugin
instance TransformationPhase UnrollPlugin where
type From UnrollPlugin = ()
type To UnrollPlugin = ()
type Downwards UnrollPlugin = ()
type Upwards UnrollPlugin = ()
data Unroll_1 = Unroll_1
instance TransformationPhase Unroll_1 where
type From Unroll_1 = ()
type To Unroll_1 = UnrollSemInf
type Downwards Unroll_1 = Int
type Upwards Unroll_1 = Bool
upwardsParallelLoopProgramInProgram _ _ _ _ _ = True
transformParallelLoopProgramInProgram Unroll_1 d pl u = trParLoop1 d pl u
data Unroll_2 = Unroll_2
instance TransformationPhase Unroll_2 where
type From Unroll_2 = UnrollSemInf
type To Unroll_2 = ()
type Downwards Unroll_2 = Maybe SemInfPrg
type Upwards Unroll_2 = ()
downwardsProgram Unroll_2 d p
| programSemInf p == Nothing = d
| otherwise = programSemInf p
transformVariable Unroll_2 d v = trVariable d v
transformVariableLeftValueInLeftValue Unroll_2 d v = VariableLeftValue $ trVariable d $ v
transformLeftValueExpressionInExpression Unroll_2 d lvie u = trLVIE d lvie u
data UnrollSemInf = UnrollSemInf
instance SemanticInfo UnrollSemInf where
type ProcedureInfo UnrollSemInf = ()
type BlockInfo UnrollSemInf = ()
type ProgramInfo UnrollSemInf = Maybe SemInfPrg
type EmptyInfo UnrollSemInf = Maybe SemInfPrg
type PrimitiveInfo UnrollSemInf = Maybe SemInfPrg
type SequenceInfo UnrollSemInf = Maybe SemInfPrg
type BranchInfo UnrollSemInf = ()
type SequentialLoopInfo UnrollSemInf = ()
type ParallelLoopInfo UnrollSemInf = ()
type FormalParameterInfo UnrollSemInf = ()
type LocalDeclarationInfo UnrollSemInf = ()
type ExpressionInfo UnrollSemInf = ()
type ConstantInfo UnrollSemInf = ()
type FunctionCallInfo UnrollSemInf = ()
type LeftValueInfo UnrollSemInf = ()
type ArrayElemReferenceInfo UnrollSemInf = ()
type InstructionInfo UnrollSemInf = ()
type AssignmentInfo UnrollSemInf = ()
type ProcedureCallInfo UnrollSemInf = ()
type ActualParameterInfo UnrollSemInf = ()
type IntConstantInfo UnrollSemInf = ()
type FloatConstantInfo UnrollSemInf = ()
type BoolConstantInfo UnrollSemInf = ()
type ArrayConstantInfo UnrollSemInf = ()
type VariableInfo UnrollSemInf = ()
instance Combine Bool where
combine = (||)
data SemInfPrg = SemInfPrg
{ position :: Int
, varNames :: [String]
, loopVar :: String
} deriving (Eq, Show)
instance Default (Maybe SemInfPrg) where defaultValue = Nothing
trLVIE :: Downwards Unroll_2 -> LeftValue UnrollSemInf -> InfoFromLeftValueParts Unroll_2 -> ExpressionData ()
trLVIE d (LeftValue leftValue _) u = case d of
Just x -> result x
otherwise -> orig
where
name = case leftValue of
VariableLeftValue d -> Just $ getVarName d
otherwise -> Nothing
result x = case name of
Just n
| n == loopVar x -> FunctionCallExpression $ FunctionCall InfixOp (Numeric ImpSigned S32) ("+") ([loopVarPar, plusPar]) ()
| otherwise -> orig
otherwise -> orig
where
loopVarPar = Expression orig ()
num = position x
plusPar = Expression (ConstantExpression $ Constant (IntConstant $ IntConstantType num ()) ()) ()
orig = LeftValueExpression $ LeftValue (recursivelyTransformedLeftValueData u ) ()
trVariable d v
| d /= Nothing && elementOf (varNames (valueFromJust d)) (getVarName v) = v { variableName = (variableName v) ++ "_u" ++ (show $ position $ valueFromJust d), variableSemInf = ()}
| otherwise = v {variableSemInf = ()}
trParLoop1 :: Downwards Unroll_1 -> ParallelLoop () -> InfoFromParallelLoopParts Unroll_1 -> ProgramConstruction UnrollSemInf
trParLoop1 d pl u
| ( upwardsInfoFromParallelLoopCore u ) == False && (unrollPossible || varInExpr ) = ParallelLoopProgram newParLoop
| otherwise = ParallelLoopProgram trPl
where
newParLoop = trPl { parallelLoopStep = unrollNum
, parallelLoopCore = newLoopCore
, parallelLoopSemInf = ()}
newLoopCore = origLoopCore
{ blockDeclarations = unrollDecls
, blockInstructions = unrollPrg
, blockSemInf = ()}
unrollPrg = Program (SequenceProgram $ Sequence prgs (Nothing)) (Nothing)
prgs = map (\(i,p) -> writeSemInfToPrg p (Just $ SemInfPrg i varNames loopCounter)) $ zip [0,1..] replPrg
writeSemInfToPrg prg semInf = prg { programSemInf = semInf }
replPrg = replicate unrollNum origPrg
origPrg = blockInstructions $ origLoopCore
unrollDecls = concat $ map (\(i,ds) -> renameDecls ds i) $ zip [0,1..] replDecls
renameDecls ds i = map (\d -> renameDeclaration d ((getVarNameDecl d) ++ "_u" ++ (show i))) ds
replDecls = replicate unrollNum origDecls
origDecls = blockDeclarations $ origLoopCore
origLoopCore = recursivelyTransformedParallelLoopCore u
iterExpr = recursivelyTransformedNumberOfIterations u
loopCounter' = recursivelyTransformedParallelLoopConditionVariable u
trPl = ParallelLoop loopCounter' iterExpr (parallelLoopStep pl) origLoopCore ()
unrollNum = d
loopCounter = getVarName $ recursivelyTransformedParallelLoopConditionVariable u
varNames = map (\d -> getVarNameDecl d) origDecls
iterTemp = iterNumFromExpr iterExpr
origIterNum = valueFromJust iterTemp
iterNumIsConstant = isJust iterTemp
unrollPossible = iterNumIsConstant && ( mod origIterNum d == 0 )
varInExpr = not $ isJust iterTemp
iterNumFromExpr (Expression (ConstantExpression (Constant (IntConstant (IntConstantType i _)) _)) _) = Just i
iterNumFromExpr _ = Nothing
isJust (Just x) = True
isJust _ = False
getVarNameDecl d = getVarName $ localVariable d
getVarName v = variableName v
valueFromJust (Just v) = v
valueFromJust Nothing = error "This was Nothing"
renameDeclaration d n = d { localVariable = renameVariable (localVariable d) n }
renameVariable v n = v { variableName = n }
elementOf ss s = (length $ filter (\s' -> s' == s) ss) > 0