{-# LANGUAGE OverlappingInstances, PatternGuards, UndecidableInstances #-} module Feldspar.Compiler.Imperative.FromCore ( Compilable(..) , numArgs , fromCore ) where import Control.Monad.State import Control.Monad.Writer import Data.List import Data.String import Data.Bits import qualified Feldspar.DSL.Expression as Lang import qualified Feldspar.DSL.Lambda as Lang import Feldspar.DSL.Lambda hiding (Value, Variable) import Feldspar.DSL.Network import Feldspar.Set (universal) import qualified Feldspar.Core.Types as Lang import Feldspar.Core.Representation hiding (variable) import qualified Feldspar.Core.Representation as Lang import qualified Feldspar.Core.Functions.Array as Lang import qualified Feldspar.Core.Functions.Tuple as Lang import Feldspar.Core.Functions.Num () import Feldspar.Range (Range(..),BoundedInt) import Data.Typeable (Typeable, typeOf) import Feldspar.Compiler.Imperative.Representation hiding (blockProgram) import Feldspar.Compiler.Imperative.Frontend import Feldspar.Compiler.Backend.C.CodeGeneration import Feldspar.Compiler.Error import Feldspar.Compiler.Backend.C.Library type Transformer a = StateT Integer (Writer [Definition ()]) a -- | Path of a matching variable type Path = [Int] type MultiVar = [(Path, Lang.TypeRep)] indexType :: Type indexType = NumType Unsigned S32 -- | Where to place the result of a single-edge data SingleLoc a = Shifted (Expression (), FeldNetwork (Out ()) Lang.Length) (SingleLoc a) -- ^ A shifted location | SingleLoc { singleLV :: Expression () , feedback :: FeldNetwork (Out ()) a } -- ^ A location represented by a 'LeftValue ()' and a corresponding -- expression. The expression is used for feedback. -- | Container of a SingleLoc getSingleLV :: SingleLoc a -> Expression () getSingleLV (SingleLoc lv _) = lv getSingleLV (Shifted _ sl) = getSingleLV sl -- | Where to place the result of a general program data Location ra a where S :: { single :: SingleLoc a } -> Location (Out ()) a M :: { multi :: Ident } -> Location ra a -- | Append a path to an identifier pathIdent :: Ident -> Path -> Ident pathIdent ident path = concat $ intersperse "_" $ (ident :) $ map show path toSingle :: Type -> Location (Out ()) a -> SingleLoc a toSingle typ (S loc) = loc toSingle typ (M ident) = SingleLoc expr fb where expr = varExpr $ variable ident typ fb = Lang.Variable ident getIx :: Lang.Type a => FeldNetwork (Out ()) [a] -> FeldNetwork (Out ()) Lang.Index -> FeldNetwork (Out ()) a getIx a ix = undoEdge $ unData $ Lang.getIx (fromOutEdge universal a) (fromOutEdge universal ix) -- TODO Think about size addExpr :: FeldNetwork (Out ()) Lang.DefaultWord -> FeldNetwork (Out ()) Lang.DefaultWord -> FeldNetwork (Out ()) Lang.DefaultWord addExpr a b = undoEdge $ unData (fromOutEdge universal a + fromOutEdge universal b) -- TODO Think about size add :: Expression () -> Expression () -> Expression () add a b = FunctionCall "(+)" indexType InfixOp [a,b] () () indexLocSingle :: Lang.Type a => [(Expression (), FeldNetwork (Out ()) Lang.Length)] -- ^ Accumulated shifts -> SingleLoc [a] -- ^ Original location -> Expression () -- ^ Index variable -> FeldNetwork (Out ()) Lang.Index -- ^ Index expression -> SingleLoc a indexLocSingle ls (Shifted l loc) ixVar ixExpr = indexLocSingle (l:ls) loc ixVar ixExpr indexLocSingle ls (SingleLoc lv fb) ixVar ixExpr = SingleLoc lv' fb' where lv' = arrayElem lv $ foldl add ixVar (map fst ls) ixExpr' = foldl addExpr ixExpr (map snd ls) fb' = getIx fb ixExpr' -- | Sum the shifts of a shifted location sumShifts :: SingleLoc a -> Expression () sumShifts (SingleLoc _ _) = createConstantExpression $ intConst 0 sumShifts (Shifted (e,_) l) = add e $ sumShifts l -- | Indexing into a location indexLoc :: Lang.Type a => SingleLoc [a] -- ^ Original location -> Expression () -- ^ Index variable -> FeldNetwork (Out ()) Lang.Index -- ^ Index expression -> SingleLoc a indexLoc loc ixVar ixExpr = indexLocSingle [] loc ixVar ixExpr selectFst :: (Lang.Type a, Lang.Type b) => SingleLoc (a,b) -> SingleLoc a selectFst (SingleLoc e fb) = SingleLoc (StructField e member () ()) fb' where member = fst $ head $ s (StructType s) = typeof e fb' = undoEdge $ unData $ Lang.getFst $ fromOutEdge universal fb selectSnd :: (Lang.Type a, Lang.Type b) => SingleLoc (a,b) -> SingleLoc b selectSnd (SingleLoc e fb) = SingleLoc (StructField e member () ()) fb' where member = fst $ head $ tail $ s (StructType s) = typeof e fb' = undoEdge $ unData $ Lang.getSnd $ fromOutEdge universal fb locToNode :: Location (Out ra) a -> FeldNetwork (Out ra) a locToNode (S (SingleLoc _ fb)) = fb locToNode (M ident) = Lang.Variable ident -- TODO Ingoring shift multiVarIn :: FeldNetwork (In ra) a -> MultiVar multiVarIn = listEdge $ \path a -> (path, edgeType (edgeInfo a)) simpleExpr :: Expression () -> Transformer ([Program ()], Expression ()) simpleExpr expr = return ([],expr) isComplex :: FeldNetwork (Out ra) a -> Bool isComplex (Inject (Node Condition) :$: _ :$: _ :$: _) = True isComplex (Inject (Node Parallel) :$: _ :$: _ :$: _) = True isComplex (Inject (Node Sequential) :$: _ :$: _ :$: _ :$: _) = True isComplex (Inject (Node ForLoop) :$: _ :$: _ :$: _) = True isComplex (Inject (Node (NoInline _)) :$: _ :$: _) = True isComplex (Inject (Node SetLength) :$: _ :$: _) = True isComplex (Inject (Node Pair) :$: _ :$: _) = True isComplex (Inject (Node SetIx) :$: _ :$: _ :$: _) = True isComplex a = isArrayLit a genLiteralExpression :: Lang.Type a => a -> Transformer ([Program ()], Expression ()) genLiteralExpression = simpleExpr . flip ConstExpr () . compileDataRep . Lang.dataRep genVarExpression :: Type -> FeldNetwork (Out ()) a -> Transformer ([Program ()], Expression ()) genVarExpression typ a = simpleExpr (varExpr var) where Just ident = traceVar a var = variable (pathIdent ident $ matchPath a) typ genApplyExpression :: Type -> String -> FeldNetwork (In ra) a -> Transformer ([Program ()], Expression ()) genApplyExpression typ fun a = do (progs,exprs) <- liftM unzip $ sequence $ listEdge (const genExpressionIn) a return (concat progs, FunctionCall fun typ SimpleFun exprs () ()) -- | Generate an expression that is not a literal, function call or a variable genComplexExpression :: Type -> FeldNetwork (Out ()) a -> Transformer ([Program ()], Expression ()) genComplexExpression typ a = do ident <- newName "w" let var = variable ident typ decl = Declaration var Nothing () prog <- genNode (M ident) a return ([BlockProgram (block [decl] prog) ()], VarExpr var ()) genExpressionIn :: FeldNetwork (In ()) a -> Transformer ([Program ()], Expression ()) genExpressionIn a = genExpression typ (undoEdge a) where typ = compileTypeRep $ edgeType $ edgeInfo a -- | Generate an expression plus support code genExpression :: Type -> FeldNetwork (Out ()) a -> Transformer ([Program ()], Expression ()) genExpression _ a@(Inject (Node (Literal lit))) | not (isArrayLit a) = genLiteralExpression lit genExpression typ (Inject (Node (Function fun _)) :$: a) = genApplyExpression typ fun a genExpression typ a | Just _ <- traceVar a = genVarExpression typ a genExpression typ (Inject m :$: _) | isMatch m = localError InvariantViolation "matching on non-variable" genExpression typ a = genComplexExpression typ a genDeclarations :: Ident -> MultiVar -> [Declaration ()] -- TODO Would be nice to have (Out ra) genDeclarations ident vars = [ Declaration var Nothing () | (path,typ) <- vars , let ident' = pathIdent ident path , let var = variable ident' $ compileTypeRep typ ] genMultiCopy :: MultiVar -> Location ra a -> Location rb a -> [Program ()] genMultiCopy [(_,typ)] (S (SingleLoc lv _)) (M rIdent) = [copyProg lv rhs] where rhs = varExpr $ variable rIdent $ compileTypeRep typ genMultiCopy _ (S (SingleLoc lhs _)) (S (SingleLoc rhs _)) = [copyProg lhs rhs] genMultiCopy vars (M ident) (M rIdent) = [ copyProg (varExpr lVar) (varExpr rVar) | (path,typ) <- vars , let ident' = pathIdent ident path , let rIdent' = pathIdent rIdent path , let typ' = compileTypeRep typ , let lVar = variable ident' typ' , let rVar = variable rIdent' typ' ] -- TODO Ingoring shift genMultiCopy [(_,typ)] (M ident) (S (SingleLoc rhs _)) = [copyProg lhs rhs] where lhs = varExpr $ variable ident $ compileTypeRep typ -- | Generate code for a 'Let' expression genLet :: ([Program ()] -> FeldNetwork ra a -> Transformer b) -- ^ Generator for the body -> FeldNetwork ra a -> Transformer b genLet gen (Let base :$: a :$: Lambda f) = do aIdent <- newName base aProg <- genNode (M aIdent) a let aBlock = block (genDeclarations aIdent (resTypes a)) aProg letProg = f (Lang.Variable aIdent) gen [BlockProgram aBlock ()] letProg genNodeExpression :: SingleLoc a -> Type -> FeldNetwork (Out ()) a -> Transformer [Program ()] genNodeExpression loc typ a = do (prog,rhs) <- genExpression typ a return $ prog ++ case loc of SingleLoc l _ -> [copyProg l rhs] l@(Shifted _ _) -> [copyProgPos (getSingleLV l) (sumShifts l) rhs] genNodeSingle :: SingleLoc a -> Type -> FeldNetwork (Out ()) a -> Transformer [Program ()] genNodeSingle loc _ a | isComplex a = genNode (S loc) a genNodeSingle loc typ a = genNodeExpression loc typ a -- | TODO network must be a 'Node' or a (nested) 'Let' resulting in a 'Node' genNode :: forall ra a . Location ra a -> FeldNetwork ra a -- TODO Would be nice to have (Out ra) -> Transformer [Program ()] genNode loc a | isLet a = genLet genBody a where genBody letProg body = liftM (letProg++) $ genNode loc body genNode loc (Inject (Node Condition) :$: cond :$: t :$: e) = do (condProg,condExpr) <- genExpressionIn cond thenProg <- genEdge loc t elseProg <- genEdge loc e let branchProg = Branch condExpr (block [] thenProg) (block [] elseProg) () () return (condProg ++ [branchProg]) genNode loc a@(Inject (Node Parallel) :$: len :$: Lambda ixf :$: cont) = do (lenProg,lenExpr) <- genExpressionIn len ixIdent <- newName "i" let bodyExpr = ixf (Lang.Variable ixIdent) ixVar = variable ixIdent indexType ixExpr = Lang.Variable ixIdent loc' = toSingle typ loc locBody = indexLoc loc' (varExpr ixVar) ixExpr locCont = Shifted (lenExpr, undoEdge len) loc' setLen = case loc' of SingleLoc _ _ -> setLength (singleLV loc') lenExpr Shifted _ l -> increaseLength (getSingleLV l) lenExpr bodyProg <- genEdge (S locBody) bodyExpr contProg <- genEdge (S locCont) cont return (lenProg ++ [setLen,ParLoop ixVar lenExpr 1 (block [] bodyProg) () ()] ++ contProg) -- TODO Should use \"default size\" for index type where [(_, Lang.ArrayType _ t)] = resTypes a typ = compileTypeRep t genNode loc a@(Inject (Node Sequential) :$: len :$: init :$: Lambda step :$: Lambda cont) = do (lenProg,lenExpr) <- genExpressionIn len stepIdent <- newName "x" let stIdent = stepIdent ++ "_2" tempIdent <- newName "temp" ixIdent <- newName "i" initProg <- genEdge (M stIdent) init let step' = step (Lang.Variable ixIdent) stepExpr = shallowApply step' (Lang.Variable stIdent) tempVarElem = variable (tempIdent ++ "_1") elemTyp ixVar = variable ixIdent indexType ixExpr = Lang.Variable ixIdent loc' = toSingle arrTyp loc locElem = indexLoc loc' (varExpr ixVar) ixExpr locCont = Shifted (lenExpr, undoEdge len) loc' tempDeclElem = genDeclarations tempIdent [([1],tElem)] tempDeclsSt = genDeclarations (tempIdent ++ "_2") (multiVarIn init) stDecls = genDeclarations stIdent (multiVarIn init) elemCopy = copyProg (singleLV locElem) (varExpr tempVarElem) stCopy = genMultiCopy (multiVarIn init) (M stIdent) (M $ tempIdent ++ "_2") apa = cont (Lang.Variable stIdent) setLen = setLength (singleLV loc') lenExpr stepProg <- genEdge (M tempIdent) stepExpr contProg <- genEdge (S locCont) apa return $ [BlockProgram (block stDecls (initProg ++ lenProg ++ [setLen, ParLoop ixVar lenExpr 1 (block (tempDeclElem++tempDeclsSt) (stepProg ++ stCopy ++ [elemCopy])) () ()] ++ contProg)) ()] where [(_, tArr@(Lang.ArrayType _ tElem))] = resTypes a arrTyp = compileTypeRep tArr elemTyp = compileTypeRep tElem genNode loc a@(Inject (Node ForLoop) :$: len :$: init :$: Lambda body) = do (lenProg,lenExpr) <- genExpressionIn len tempIdent <- newName "temp" ixIdent <- newName "i" initProg <- genEdge loc init let body' = body (Lang.Variable ixIdent) bodyExpr = shallowApply body' (locToNode loc) ixVar = variable ixIdent indexType tempDecls = genDeclarations tempIdent (resTypes a) bodyProg <- genEdge (M tempIdent) bodyExpr let copyProg = genMultiCopy (resTypes a) loc (M tempIdent) return [BlockProgram (block tempDecls (lenProg ++ initProg ++ [ParLoop ixVar lenExpr 1 (block [] (bodyProg ++ copyProg)) () ()])) ()] genNode (S (Shifted _ _)) a@(Inject (Node (Literal _))) | isArrayLit a && isEmpty a = return [] genNode loc a@(Inject (Node (Literal b))) | isArrayLit a = return [initialize (getSingleLV l) (sumShifts l) v] where l = toSingle typ loc v = compileDataRep $ Lang.dataRep b typ = compileTypeRep $ Lang.typeRep' b genNode loc a@(Inject (Node (NoInline name)) :$: Lambda body :$: x) = do param <- newName "in" result <- newName "out" prolog <- sequence $ listEdge (const genExpressionIn) x let prologProgs = concatMap fst prolog prologArgs = map (flip In () . snd) prolog argVars = genVars param x resVars = genVars result (body $ Lang.Variable param) outArgs = map (flip Out ()) $ genLocExprs loc a prog <- genEdge (M result) (body $ Lang.Variable param) tell [procedure name argVars resVars (block [] prog)] return $ prologProgs ++ [procedureCall name prologArgs outArgs] where typ = undefined nodeType = compileTypeRep . snd . head . resTypes genNode (S loc) (Inject (Node SetLength) :$: len :$: (Inject (Edge edge) :$: a)) | not (isComplex a), Just name <- traceVar a = do (lenProg,lenExpr) <- genExpressionIn len let typ = compileTypeRep $ edgeType edge let arrProg = [copyProgLen (singleLV loc) (varExpr $ createVariable name typ) lenExpr] return $ lenProg ++ arrProg genNode (S loc) a@(Inject (Node SetLength) :$: len :$: arr) = do (lenProg,lenExpr) <- genExpressionIn len arrProg <- genEdge (S loc) arr return $ lenProg ++ arrProg ++ [setLength (singleLV loc) lenExpr] genNode loc a@(Inject (Node SetIx) :$: ix :$: val :$: e) = do (ixProg,ixExpr) <- genExpressionIn ix copy <- genEdge loc e let updLoc = indexLoc (toSingle elemType loc) ixExpr $ undoEdge ix update <- genEdge (S updLoc) val return $ copy ++ ixProg ++ update where [(_, (Lang.ArrayType _ tElem))] = resTypes a elemType = compileTypeRep tElem genNode loc a@(Inject (Node Pair) :$: x :$: y) = do prog1 <- genEdge (S $ selectFst $ toSingle (nodeType a) loc) x prog2 <- genEdge (S $ selectSnd $ toSingle (nodeType a) loc) y return $ prog1 ++ prog2 where nodeType = compileTypeRep . snd . head . resTypes genNode loc a@(Inject (Node (Literal _))) = genNodeSingle (toSingle (nodeType a) loc) (nodeType a) a where nodeType = compileTypeRep . snd . head . resTypes genNode loc a@(Inject (Node (Function _ _)) :$: _) = genNodeSingle (toSingle (nodeType a) loc) (nodeType a) a where nodeType = compileTypeRep . snd . head . resTypes genLocExprs :: Location ra a -> FeldNetwork ra a -> [Expression ()] genLocExprs (S loc) _ = [singleLV loc] genLocExprs (M ident) a = [ varExpr $ variable (pathIdent ident path) (compileTypeRep typ) | (path,typ) <- resTypes a ] viewGroup2 :: FeldNetwork (In (ra,rb)) (a,b) -> (FeldNetwork (In ra) a, FeldNetwork (In rb) b) viewGroup2 (Inject Group2 :$: a :$: b) = (a,b) -- TODO: Move somewhere else genEdgeSingle :: Ident -> Path -> FeldNetwork (In ()) a -> Transformer [Program ()] genEdgeSingle ident path a = genNodeSingle (toSingle typ $ M ident') typ (undoEdge a) where ident' = pathIdent ident path typ = compileTypeRep $ edgeType $ edgeInfo a -- | Generate code for a multi-edge genEdge :: Location (Out ra) a -> FeldNetwork (In ra) a -> Transformer [Program ()] genEdge loc a | isLet a = genLet genBody a where genBody letProg body = liftM (letProg++) $ genEdge loc body genEdge loc (Inject (Edge edge) :$: a) = genNodeSingle (toSingle typ loc) typ a where typ = compileTypeRep $ edgeType edge genEdge (M ident) a = liftM concat $ sequence $ listEdge (genEdgeSingle ident) a -- | Generate a variable of the same type as the given single-edge genVar :: Ident -> Path -> FeldNetwork (In ()) a -> Variable () genVar ident path a = variable (pathIdent ident path) typ where typ = compileTypeRep $ edgeType $ edgeInfo a genVars :: Ident -> FeldNetwork (In ra) a -> [Variable ()] genVars ident (Let _ :$: a :$: Lambda f) = genVars ident (f (Lang.Variable "TODO")) genVars ident a = listEdge (genVar ident) a class Compilable t where toImperativeM :: String -- ^ Name of procedure -> [Variable ()] -- ^ Free variables -> t -- ^ Program to compile -> Transformer () -- | Returns a list containing the number of edges in each curried argument buildInParamDescriptor :: t -> [Int] instance Syntactic a => Compilable a where toImperativeM procName freeVars prog = do ident <- newName "out" body <- genEdge (M ident) prog' let resVars = genVars ident prog' tell [Procedure procName freeVars resVars (block [] body) () ()] where prog' = feldSharing (toEdge prog) buildInParamDescriptor _ = [] instance (Syntactic a, Compilable t) => Compilable (a -> t) where toImperativeM procName freeVars prog = do ident <- newName "in" let arg = Lang.variable universal ident let vars = genVars ident (toEdge arg) toImperativeM procName (freeVars ++ vars) (prog arg) buildInParamDescriptor prog = countEdges (toEdge arg) : buildInParamDescriptor (prog arg) where arg = Lang.variable universal "argument" numArgs :: Compilable a => a -> Int numArgs = length . buildInParamDescriptor fromCore :: Compilable t => String -> t -> Module () fromCore procName prog = Module (execWriter $ evalStateT (toImperativeM procName [] prog) 0) () initialize :: Expression () -> Expression () -> Constant () -> Program () initialize loc shift (ArrayConst vs _ _) = createProgramSequence $ (setLength loc $ shift `add` intConstExpr (toInteger $ length vs)) : (map (\(v,i) -> initialize (arrayElem loc $ shift `add` i) (intConstExpr 0) v) $ zip vs $ map intConstExpr [0..]) initialize loc _ v = copyProg loc $ createConstantExpression v -- | Compilation of a data representation to an imperative constant compileDataRep :: Lang.DataRep -> Constant () compileDataRep (Lang.BoolData x) = BoolConst x () () compileDataRep (Lang.IntData x) = IntConst x () () compileDataRep (Lang.FloatData x) = FloatConst (fromRational $ toRational x) () () compileDataRep (Lang.ComplexData r i) = ComplexConst (compileDataRep r) (compileDataRep i) () () compileDataRep (Lang.ArrayData xs) = ArrayConst (map compileDataRep xs) () () compileDataRep (Lang.StructData sd) = localError InternalError "Struct constants not supported yet." -- | Compilation of a type representation to an imperative type compileTypeRep :: Lang.TypeRep -> Type compileTypeRep typ = case typ of Lang.BoolType -> BoolType Lang.IntType r -> compileNumericType r Lang.FloatType -> FloatType Lang.ComplexType typ -> ComplexType (compileTypeRep typ) Lang.UserType userTypeName -> UserType userTypeName Lang.ArrayType dim elemTyp -> ArrayType (getLength dim) $ compileTypeRep elemTyp Lang.StructType memberTypes -> StructType $ zip (map ((defaultMemberName++).show) [1..]) $ map compileTypeRep memberTypes -- | Numeric type based on a range compileNumericType :: (BoundedInt a, Typeable a) => Range a -> Type compileNumericType r = NumType (intSign r) (intSize r) -- | Sign based on a range intSign :: BoundedInt a => Range a -> Signedness intSign r | isSigned (upperBound r) = Signed | otherwise = Unsigned -- | Size based on a range intSize :: (BoundedInt a, Typeable a) => Range a -> Size intSize r = case bitSize i of 8 -> S8 16 -> S16 32 -> S32 64 -> S64 _ -> localError InvariantViolation $ "unknown integer type: " ++ show (typeOf i) where i = upperBound r -- | Compilation of a length getLength :: Range Lang.Length -> Length getLength l | u == maxBound = UndefinedLen | otherwise = LiteralLen (fromIntegral u) where u = upperBound l -- | Customized error function localError = handleError "Backends :: C :: ConstTransformation"