module Feldspar.Compiler.Imperative.FromCore
(
Compilable(..)
, numArgs
, fromCore
) where
import Control.Monad.State
import Control.Monad.Writer
import Data.List
import Data.String
import Data.Bits
import qualified Feldspar.DSL.Expression as Lang
import qualified Feldspar.DSL.Lambda as Lang
import Feldspar.DSL.Lambda hiding (Value, Variable)
import Feldspar.DSL.Network
import Feldspar.Set (universal)
import qualified Feldspar.Core.Types as Lang
import Feldspar.Core.Representation hiding (variable)
import qualified Feldspar.Core.Representation as Lang
import qualified Feldspar.Core.Functions.Array as Lang
import qualified Feldspar.Core.Functions.Tuple as Lang
import Feldspar.Core.Functions.Num ()
import Feldspar.Range (Range(..),BoundedInt)
import Data.Typeable (Typeable, typeOf)
import Feldspar.Compiler.Imperative.Representation hiding (blockProgram)
import Feldspar.Compiler.Imperative.Frontend
import Feldspar.Compiler.Backend.C.CodeGeneration
import Feldspar.Compiler.Error
import Feldspar.Compiler.Backend.C.Library
type Transformer a = StateT Integer (Writer [Definition ()]) a
type Path = [Int]
type MultiVar = [(Path, Lang.TypeRep)]
indexType :: Type
indexType = NumType Unsigned S32
data SingleLoc a
= Shifted (Expression (), FeldNetwork (Out ()) Lang.Length) (SingleLoc a)
| SingleLoc
{ singleLV :: Expression ()
, feedback :: FeldNetwork (Out ()) a
}
getSingleLV :: SingleLoc a -> Expression ()
getSingleLV (SingleLoc lv _) = lv
getSingleLV (Shifted _ sl) = getSingleLV sl
data Location ra a
where
S :: { single :: SingleLoc a } -> Location (Out ()) a
M :: { multi :: Ident } -> Location ra a
pathIdent :: Ident -> Path -> Ident
pathIdent ident path = concat $ intersperse "_" $ (ident :) $ map show path
toSingle :: Type -> Location (Out ()) a -> SingleLoc a
toSingle typ (S loc) = loc
toSingle typ (M ident) = SingleLoc expr fb
where
expr = varExpr $ variable ident typ
fb = Lang.Variable ident
getIx :: Lang.Type a
=> FeldNetwork (Out ()) [a]
-> FeldNetwork (Out ()) Lang.Index
-> FeldNetwork (Out ()) a
getIx a ix
= undoEdge
$ unData
$ Lang.getIx (fromOutEdge universal a) (fromOutEdge universal ix)
addExpr
:: FeldNetwork (Out ()) Lang.DefaultWord
-> FeldNetwork (Out ()) Lang.DefaultWord
-> FeldNetwork (Out ()) Lang.DefaultWord
addExpr a b = undoEdge $ unData (fromOutEdge universal a + fromOutEdge universal b)
add :: Expression () -> Expression () -> Expression ()
add a b = FunctionCall "(+)" indexType InfixOp [a,b] () ()
indexLocSingle
:: Lang.Type a
=> [(Expression (), FeldNetwork (Out ()) Lang.Length)]
-> SingleLoc [a]
-> Expression ()
-> FeldNetwork (Out ()) Lang.Index
-> SingleLoc a
indexLocSingle ls (Shifted l loc) ixVar ixExpr =
indexLocSingle (l:ls) loc ixVar ixExpr
indexLocSingle ls (SingleLoc lv fb) ixVar ixExpr = SingleLoc lv' fb'
where
lv' = arrayElem lv $ foldl add ixVar (map fst ls)
ixExpr' = foldl addExpr ixExpr (map snd ls)
fb' = getIx fb ixExpr'
sumShifts :: SingleLoc a -> Expression ()
sumShifts (SingleLoc _ _) = createConstantExpression $ intConst 0
sumShifts (Shifted (e,_) l) = add e $ sumShifts l
indexLoc
:: Lang.Type a
=> SingleLoc [a]
-> Expression ()
-> FeldNetwork (Out ()) Lang.Index
-> SingleLoc a
indexLoc loc ixVar ixExpr = indexLocSingle [] loc ixVar ixExpr
selectFst :: (Lang.Type a, Lang.Type b)
=> SingleLoc (a,b)
-> SingleLoc a
selectFst (SingleLoc e fb) = SingleLoc (StructField e member () ()) fb'
where
member = fst $ head $ s
(StructType s) = typeof e
fb' = undoEdge $ unData $ Lang.getFst $ fromOutEdge universal fb
selectSnd :: (Lang.Type a, Lang.Type b)
=> SingleLoc (a,b)
-> SingleLoc b
selectSnd (SingleLoc e fb) = SingleLoc (StructField e member () ()) fb'
where
member = fst $ head $ tail $ s
(StructType s) = typeof e
fb' = undoEdge $ unData $ Lang.getSnd $ fromOutEdge universal fb
locToNode :: Location (Out ra) a -> FeldNetwork (Out ra) a
locToNode (S (SingleLoc _ fb)) = fb
locToNode (M ident) = Lang.Variable ident
multiVarIn :: FeldNetwork (In ra) a -> MultiVar
multiVarIn = listEdge $ \path a -> (path, edgeType (edgeInfo a))
simpleExpr :: Expression () -> Transformer ([Program ()], Expression ())
simpleExpr expr = return ([],expr)
isComplex :: FeldNetwork (Out ra) a -> Bool
isComplex (Inject (Node Condition) :$: _ :$: _ :$: _) = True
isComplex (Inject (Node Parallel) :$: _ :$: _ :$: _) = True
isComplex (Inject (Node Sequential) :$: _ :$: _ :$: _ :$: _) = True
isComplex (Inject (Node ForLoop) :$: _ :$: _ :$: _) = True
isComplex (Inject (Node (NoInline _)) :$: _ :$: _) = True
isComplex (Inject (Node SetLength) :$: _ :$: _) = True
isComplex (Inject (Node Pair) :$: _ :$: _) = True
isComplex (Inject (Node SetIx) :$: _ :$: _ :$: _) = True
isComplex a = isArrayLit a
genLiteralExpression ::
Lang.Type a => a -> Transformer ([Program ()], Expression ())
genLiteralExpression = simpleExpr . flip ConstExpr () . compileDataRep . Lang.dataRep
genVarExpression ::
Type -> FeldNetwork (Out ()) a -> Transformer ([Program ()], Expression ())
genVarExpression typ a = simpleExpr (varExpr var)
where
Just ident = traceVar a
var = variable (pathIdent ident $ matchPath a) typ
genApplyExpression
:: Type -> String -> FeldNetwork (In ra) a
-> Transformer ([Program ()], Expression ())
genApplyExpression typ fun a = do
(progs,exprs) <- liftM unzip $ sequence $ listEdge (const genExpressionIn) a
return (concat progs, FunctionCall fun typ SimpleFun exprs () ())
genComplexExpression ::
Type -> FeldNetwork (Out ()) a -> Transformer ([Program ()], Expression ())
genComplexExpression typ a = do
ident <- newName "w"
let var = variable ident typ
decl = Declaration var Nothing ()
prog <- genNode (M ident) a
return ([BlockProgram (block [decl] prog) ()], VarExpr var ())
genExpressionIn
:: FeldNetwork (In ()) a
-> Transformer ([Program ()], Expression ())
genExpressionIn a = genExpression typ (undoEdge a)
where
typ = compileTypeRep $ edgeType $ edgeInfo a
genExpression
:: Type
-> FeldNetwork (Out ()) a
-> Transformer ([Program ()], Expression ())
genExpression _ a@(Inject (Node (Literal lit)))
| not (isArrayLit a) = genLiteralExpression lit
genExpression typ (Inject (Node (Function fun _)) :$: a) =
genApplyExpression typ fun a
genExpression typ a
| Just _ <- traceVar a = genVarExpression typ a
genExpression typ (Inject m :$: _)
| isMatch m = localError InvariantViolation "matching on non-variable"
genExpression typ a = genComplexExpression typ a
genDeclarations :: Ident -> MultiVar -> [Declaration ()]
genDeclarations ident vars =
[ Declaration var Nothing ()
| (path,typ) <- vars
, let ident' = pathIdent ident path
, let var = variable ident' $ compileTypeRep typ
]
genMultiCopy :: MultiVar -> Location ra a -> Location rb a -> [Program ()]
genMultiCopy [(_,typ)] (S (SingleLoc lv _)) (M rIdent) = [copyProg lv rhs]
where
rhs = varExpr $ variable rIdent $ compileTypeRep typ
genMultiCopy _ (S (SingleLoc lhs _)) (S (SingleLoc rhs _)) = [copyProg lhs rhs]
genMultiCopy vars (M ident) (M rIdent) =
[ copyProg (varExpr lVar) (varExpr rVar)
| (path,typ) <- vars
, let ident' = pathIdent ident path
, let rIdent' = pathIdent rIdent path
, let typ' = compileTypeRep typ
, let lVar = variable ident' typ'
, let rVar = variable rIdent' typ'
]
genMultiCopy [(_,typ)] (M ident) (S (SingleLoc rhs _)) = [copyProg lhs rhs]
where
lhs = varExpr $ variable ident $ compileTypeRep typ
genLet
:: ([Program ()] -> FeldNetwork ra a -> Transformer b)
-> FeldNetwork ra a
-> Transformer b
genLet gen (Let base :$: a :$: Lambda f) = do
aIdent <- newName base
aProg <- genNode (M aIdent) a
let aBlock = block (genDeclarations aIdent (resTypes a)) aProg
letProg = f (Lang.Variable aIdent)
gen [BlockProgram aBlock ()] letProg
genNodeExpression ::
SingleLoc a -> Type -> FeldNetwork (Out ()) a -> Transformer [Program ()]
genNodeExpression loc typ a = do
(prog,rhs) <- genExpression typ a
return $ prog ++ case loc of
SingleLoc l _ -> [copyProg l rhs]
l@(Shifted _ _) -> [copyProgPos (getSingleLV l) (sumShifts l) rhs]
genNodeSingle
:: SingleLoc a
-> Type
-> FeldNetwork (Out ()) a
-> Transformer [Program ()]
genNodeSingle loc _ a | isComplex a = genNode (S loc) a
genNodeSingle loc typ a = genNodeExpression loc typ a
genNode :: forall ra a
. Location ra a
-> FeldNetwork ra a
-> Transformer [Program ()]
genNode loc a | isLet a = genLet genBody a
where
genBody letProg body = liftM (letProg++) $ genNode loc body
genNode loc (Inject (Node Condition) :$: cond :$: t :$: e) = do
(condProg,condExpr) <- genExpressionIn cond
thenProg <- genEdge loc t
elseProg <- genEdge loc e
let branchProg = Branch condExpr (block [] thenProg) (block [] elseProg) () ()
return (condProg ++ [branchProg])
genNode loc a@(Inject (Node Parallel) :$: len :$: Lambda ixf :$: cont) = do
(lenProg,lenExpr) <- genExpressionIn len
ixIdent <- newName "i"
let bodyExpr = ixf (Lang.Variable ixIdent)
ixVar = variable ixIdent indexType
ixExpr = Lang.Variable ixIdent
loc' = toSingle typ loc
locBody = indexLoc loc' (varExpr ixVar) ixExpr
locCont = Shifted (lenExpr, undoEdge len) loc'
setLen = case loc' of
SingleLoc _ _ -> setLength (singleLV loc') lenExpr
Shifted _ l -> increaseLength (getSingleLV l) lenExpr
bodyProg <- genEdge (S locBody) bodyExpr
contProg <- genEdge (S locCont) cont
return (lenProg ++ [setLen,ParLoop ixVar lenExpr 1 (block [] bodyProg) () ()] ++ contProg)
where
[(_, Lang.ArrayType _ t)] = resTypes a
typ = compileTypeRep t
genNode loc a@(Inject (Node Sequential) :$: len :$: init :$: Lambda step :$: Lambda cont) = do
(lenProg,lenExpr) <- genExpressionIn len
stepIdent <- newName "x"
let stIdent = stepIdent ++ "_2"
tempIdent <- newName "temp"
ixIdent <- newName "i"
initProg <- genEdge (M stIdent) init
let step' = step (Lang.Variable ixIdent)
stepExpr = shallowApply step' (Lang.Variable stIdent)
tempVarElem = variable (tempIdent ++ "_1") elemTyp
ixVar = variable ixIdent indexType
ixExpr = Lang.Variable ixIdent
loc' = toSingle arrTyp loc
locElem = indexLoc loc' (varExpr ixVar) ixExpr
locCont = Shifted (lenExpr, undoEdge len) loc'
tempDeclElem = genDeclarations tempIdent [([1],tElem)]
tempDeclsSt = genDeclarations (tempIdent ++ "_2") (multiVarIn init)
stDecls = genDeclarations stIdent (multiVarIn init)
elemCopy = copyProg (singleLV locElem) (varExpr tempVarElem)
stCopy = genMultiCopy (multiVarIn init) (M stIdent) (M $ tempIdent ++ "_2")
apa = cont (Lang.Variable stIdent)
setLen = setLength (singleLV loc') lenExpr
stepProg <- genEdge (M tempIdent) stepExpr
contProg <- genEdge (S locCont) apa
return $
[BlockProgram (block stDecls (initProg ++ lenProg ++ [setLen, ParLoop ixVar lenExpr 1 (block (tempDeclElem++tempDeclsSt) (stepProg ++ stCopy ++ [elemCopy])) () ()] ++ contProg)) ()]
where
[(_, tArr@(Lang.ArrayType _ tElem))] = resTypes a
arrTyp = compileTypeRep tArr
elemTyp = compileTypeRep tElem
genNode loc a@(Inject (Node ForLoop) :$: len :$: init :$: Lambda body) = do
(lenProg,lenExpr) <- genExpressionIn len
tempIdent <- newName "temp"
ixIdent <- newName "i"
initProg <- genEdge loc init
let body' = body (Lang.Variable ixIdent)
bodyExpr = shallowApply body' (locToNode loc)
ixVar = variable ixIdent indexType
tempDecls = genDeclarations tempIdent (resTypes a)
bodyProg <- genEdge (M tempIdent) bodyExpr
let copyProg = genMultiCopy (resTypes a) loc (M tempIdent)
return [BlockProgram (block tempDecls (lenProg ++ initProg ++ [ParLoop ixVar lenExpr 1 (block [] (bodyProg ++ copyProg)) () ()])) ()]
genNode (S (Shifted _ _)) a@(Inject (Node (Literal _))) | isArrayLit a && isEmpty a = return []
genNode loc a@(Inject (Node (Literal b))) | isArrayLit a = return [initialize (getSingleLV l) (sumShifts l) v]
where
l = toSingle typ loc
v = compileDataRep $ Lang.dataRep b
typ = compileTypeRep $ Lang.typeRep' b
genNode loc a@(Inject (Node (NoInline name)) :$: Lambda body :$: x) = do
param <- newName "in"
result <- newName "out"
prolog <- sequence $ listEdge (const genExpressionIn) x
let prologProgs = concatMap fst prolog
prologArgs = map (flip In () . snd) prolog
argVars = genVars param x
resVars = genVars result (body $ Lang.Variable param)
outArgs = map (flip Out ()) $ genLocExprs loc a
prog <- genEdge (M result) (body $ Lang.Variable param)
tell [procedure name argVars resVars (block [] prog)]
return $ prologProgs ++ [procedureCall name prologArgs outArgs]
where
typ = undefined
nodeType = compileTypeRep . snd . head . resTypes
genNode (S loc) (Inject (Node SetLength) :$: len :$: (Inject (Edge edge) :$: a))
| not (isComplex a), Just name <- traceVar a = do
(lenProg,lenExpr) <- genExpressionIn len
let typ = compileTypeRep $ edgeType edge
let arrProg = [copyProgLen (singleLV loc) (varExpr $ createVariable name typ) lenExpr]
return $ lenProg ++ arrProg
genNode (S loc) a@(Inject (Node SetLength) :$: len :$: arr) = do
(lenProg,lenExpr) <- genExpressionIn len
arrProg <- genEdge (S loc) arr
return $ lenProg ++ arrProg ++ [setLength (singleLV loc) lenExpr]
genNode loc a@(Inject (Node SetIx) :$: ix :$: val :$: e) = do
(ixProg,ixExpr) <- genExpressionIn ix
copy <- genEdge loc e
let
updLoc = indexLoc (toSingle elemType loc) ixExpr $ undoEdge ix
update <- genEdge (S updLoc) val
return $ copy ++ ixProg ++ update
where
[(_, (Lang.ArrayType _ tElem))] = resTypes a
elemType = compileTypeRep tElem
genNode loc a@(Inject (Node Pair) :$: x :$: y) = do
prog1 <- genEdge (S $ selectFst $ toSingle (nodeType a) loc) x
prog2 <- genEdge (S $ selectSnd $ toSingle (nodeType a) loc) y
return $ prog1 ++ prog2
where
nodeType = compileTypeRep . snd . head . resTypes
genNode loc a@(Inject (Node (Literal _))) = genNodeSingle (toSingle (nodeType a) loc) (nodeType a) a
where
nodeType = compileTypeRep . snd . head . resTypes
genNode loc a@(Inject (Node (Function _ _)) :$: _) = genNodeSingle (toSingle (nodeType a) loc) (nodeType a) a
where
nodeType = compileTypeRep . snd . head . resTypes
genLocExprs :: Location ra a -> FeldNetwork ra a -> [Expression ()]
genLocExprs (S loc) _ = [singleLV loc]
genLocExprs (M ident) a =
[ varExpr $ variable (pathIdent ident path) (compileTypeRep typ)
| (path,typ) <- resTypes a
]
viewGroup2 :: FeldNetwork (In (ra,rb)) (a,b) -> (FeldNetwork (In ra) a, FeldNetwork (In rb) b)
viewGroup2 (Inject Group2 :$: a :$: b) = (a,b)
genEdgeSingle
:: Ident
-> Path
-> FeldNetwork (In ()) a
-> Transformer [Program ()]
genEdgeSingle ident path a =
genNodeSingle (toSingle typ $ M ident') typ (undoEdge a)
where
ident' = pathIdent ident path
typ = compileTypeRep $ edgeType $ edgeInfo a
genEdge :: Location (Out ra) a -> FeldNetwork (In ra) a -> Transformer [Program ()]
genEdge loc a | isLet a = genLet genBody a
where
genBody letProg body = liftM (letProg++) $ genEdge loc body
genEdge loc (Inject (Edge edge) :$: a) = genNodeSingle (toSingle typ loc) typ a
where
typ = compileTypeRep $ edgeType edge
genEdge (M ident) a = liftM concat $ sequence $ listEdge (genEdgeSingle ident) a
genVar :: Ident -> Path -> FeldNetwork (In ()) a -> Variable ()
genVar ident path a = variable (pathIdent ident path) typ
where
typ = compileTypeRep $ edgeType $ edgeInfo a
genVars :: Ident -> FeldNetwork (In ra) a -> [Variable ()]
genVars ident (Let _ :$: a :$: Lambda f) =
genVars ident (f (Lang.Variable "TODO"))
genVars ident a = listEdge (genVar ident) a
class Compilable t where
toImperativeM
:: String
-> [Variable ()]
-> t
-> Transformer ()
buildInParamDescriptor :: t -> [Int]
instance Syntactic a => Compilable a where
toImperativeM procName freeVars prog = do
ident <- newName "out"
body <- genEdge (M ident) prog'
let resVars = genVars ident prog'
tell [Procedure procName freeVars resVars (block [] body) () ()]
where
prog' = feldSharing (toEdge prog)
buildInParamDescriptor _ = []
instance (Syntactic a, Compilable t) => Compilable (a -> t) where
toImperativeM procName freeVars prog = do
ident <- newName "in"
let arg = Lang.variable universal ident
let vars = genVars ident (toEdge arg)
toImperativeM procName (freeVars ++ vars) (prog arg)
buildInParamDescriptor prog =
countEdges (toEdge arg) : buildInParamDescriptor (prog arg)
where
arg = Lang.variable universal "argument"
numArgs :: Compilable a => a -> Int
numArgs = length . buildInParamDescriptor
fromCore :: Compilable t => String -> t -> Module ()
fromCore procName prog = Module (execWriter $ evalStateT (toImperativeM procName [] prog) 0) ()
initialize :: Expression () -> Expression () -> Constant () -> Program ()
initialize loc shift (ArrayConst vs _ _) = createProgramSequence $
(setLength loc $ shift `add` intConstExpr (toInteger $ length vs)) :
(map (\(v,i) -> initialize (arrayElem loc $ shift `add` i) (intConstExpr 0) v) $
zip vs $ map intConstExpr [0..])
initialize loc _ v = copyProg loc $ createConstantExpression v
compileDataRep :: Lang.DataRep -> Constant ()
compileDataRep (Lang.BoolData x) = BoolConst x () ()
compileDataRep (Lang.IntData x) = IntConst x () ()
compileDataRep (Lang.FloatData x) = FloatConst (fromRational $ toRational x) () ()
compileDataRep (Lang.ComplexData r i) = ComplexConst (compileDataRep r) (compileDataRep i) () ()
compileDataRep (Lang.ArrayData xs) = ArrayConst (map compileDataRep xs) () ()
compileDataRep (Lang.StructData sd) = localError InternalError "Struct constants not supported yet."
compileTypeRep :: Lang.TypeRep -> Type
compileTypeRep typ = case typ of
Lang.BoolType -> BoolType
Lang.IntType r -> compileNumericType r
Lang.FloatType -> FloatType
Lang.ComplexType typ -> ComplexType (compileTypeRep typ)
Lang.UserType userTypeName -> UserType userTypeName
Lang.ArrayType dim elemTyp -> ArrayType (getLength dim) $ compileTypeRep elemTyp
Lang.StructType memberTypes -> StructType $ zip (map ((defaultMemberName++).show) [1..]) $ map compileTypeRep memberTypes
compileNumericType :: (BoundedInt a, Typeable a) => Range a -> Type
compileNumericType r = NumType (intSign r) (intSize r)
intSign :: BoundedInt a => Range a -> Signedness
intSign r
| isSigned (upperBound r) = Signed
| otherwise = Unsigned
intSize :: (BoundedInt a, Typeable a) => Range a -> Size
intSize r = case bitSize i of
8 -> S8
16 -> S16
32 -> S32
64 -> S64
_ -> localError InvariantViolation $ "unknown integer type: " ++ show (typeOf i)
where
i = upperBound r
getLength :: Range Lang.Length -> Length
getLength l
| u == maxBound = UndefinedLen
| otherwise = LiteralLen (fromIntegral u)
where
u = upperBound l
localError = handleError "Backends :: C :: ConstTransformation"