{-# LANGUAGE CPP,
BangPatterns,
DataKinds,
FlexibleContexts,
FlexibleInstances,
GADTs,
KindSignatures,
PolyKinds,
StandaloneDeriving,
TypeOperators,
RankNTypes #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Language.Hakaru.CodeGen.CodeGenMonad
( CodeGen
, CG(..)
, runCodeGen
, runCodeGenBlock
, runCodeGenWith
, emptyCG
, declare
, declare'
, assign
, putStat
, putExprStat
, extDeclare
, extDeclareTypes
, funCG
, whenPar
, parDo
, seqDo
, reserveIdent
, genIdent
, genIdent'
, createIdent
, createIdent'
, lookupIdent
, ifCG
, whileCG
, doWhileCG
, forCG
, reductionCG
, codeBlockCG
, putMallocStat
) where
import Control.Monad.State.Strict
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid (Monoid(..))
import Control.Applicative ((<$>))
#endif
import Language.Hakaru.Syntax.ABT hiding (var)
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Language.Hakaru.CodeGen.Types
import Language.Hakaru.CodeGen.AST
import Language.Hakaru.CodeGen.Libs
import Data.Number.Nat (fromNat)
import qualified Data.IntMap.Strict as IM
import qualified Data.Text as T
import qualified Data.Set as S
data CG = CG
{ freshNames :: [String]
, reservedNames :: S.Set String
, extDecls :: [CExtDecl]
, declarations :: [CDecl]
, statements :: [CStat]
, varEnv :: Env
, managedMem :: Bool
, sharedMem :: Bool
, simd :: Bool
, distributed :: Bool
, logProbs :: Bool
}
emptyCG :: CG
emptyCG = CG cNameStream mempty mempty [] [] emptyEnv False False False False True
type CodeGen = State CG
runCodeGen :: CodeGen a -> ([CExtDecl],[CDecl], [CStat])
runCodeGen m =
let (_, cg) = runState m emptyCG
in ( reverse $ extDecls cg
, reverse $ declarations cg
, reverse $ statements cg )
runCodeGenBlock :: CodeGen a -> CodeGen CStat
runCodeGenBlock m =
do cg <- get
let (_,cg') = runState m $ cg { statements = []
, declarations = [] }
put $ cg' { statements = statements cg
, declarations = declarations cg' ++ declarations cg
}
return . CCompound . fmap CBlockStat . reverse . statements $ cg'
runCodeGenWith :: CodeGen a -> CG -> [CExtDecl]
runCodeGenWith cg start = let (_,cg') = runState cg start in reverse $ extDecls cg'
whenPar :: CodeGen () -> CodeGen ()
whenPar m = (sharedMem <$> get) >>= (\b -> when b m)
parDo :: CodeGen a -> CodeGen a
parDo m = do
cg <- get
put (cg { sharedMem = True } )
a <- m
cg' <- get
put (cg' { sharedMem = sharedMem cg } )
return a
seqDo :: CodeGen a -> CodeGen a
seqDo m = do
cg <- get
put (cg { sharedMem = False } )
a <- m
cg' <- get
put (cg' { sharedMem = sharedMem cg } )
return a
reserveIdent :: String -> CodeGen Ident
reserveIdent s = do
get >>= \cg -> put $ cg { reservedNames = s `S.insert` reservedNames cg }
return (Ident s)
genIdent :: CodeGen Ident
genIdent = genIdent' ""
genIdent' :: String -> CodeGen Ident
genIdent' s =
do cg <- get
let (freshNs,name) = pullName (freshNames cg) (reservedNames cg)
put $ cg { freshNames = freshNs }
return $ Ident name
where pullName :: [String] -> S.Set String -> ([String],String)
pullName (n:names) reserved =
let name = s ++ "_" ++ n in
if S.member name reserved
then let (names',out) = pullName names reserved
in (n:names',out)
else (names,name)
pullName _ _ = error "should not happen, names is infinite"
createIdent :: Variable (a :: Hakaru) -> CodeGen Ident
createIdent = createIdent' ""
createIdent' :: String -> Variable (a :: Hakaru) -> CodeGen Ident
createIdent' s var@(Variable name _ _) =
do !cg <- get
let ident = Ident $ concat [concatMap toAscii . T.unpack $ name
,"_",s,"_",head $ freshNames cg ]
env' = updateEnv var ident (varEnv cg)
put $! cg { freshNames = tail $ freshNames cg
, varEnv = env' }
return ident
where toAscii c = let num = fromEnum c in
if num < 48 || num > 122
then "u" ++ (show num)
else [c]
lookupIdent :: Variable (a :: Hakaru) -> CodeGen Ident
lookupIdent var =
do !cg <- get
let !env = varEnv cg
case lookupVar var env of
Nothing -> error $ "lookupIdent: var not found --" ++ show var
Just i -> return i
declare :: Sing (a :: Hakaru) -> Ident -> CodeGen ()
declare SInt = declare' . typeDeclaration SInt
declare SNat = declare' . typeDeclaration SNat
declare SProb = declare' . typeDeclaration SProb
declare SReal = declare' . typeDeclaration SReal
declare m@(SMeasure t) = \i ->
extDeclareTypes m >> (declare' $ mdataDeclaration t i)
declare a@(SArray t) = \i ->
extDeclareTypes a >> (declare' $ arrayDeclaration t i)
declare d@(SData _ _) = \i ->
extDeclareTypes d >> (declare' $ datumDeclaration d i)
declare f@(SFun _ _) = \_ ->
extDeclareTypes f >> return ()
extDeclareTypes :: Sing (a :: Hakaru) -> CodeGen ()
extDeclareTypes SInt = return ()
extDeclareTypes SNat = return ()
extDeclareTypes SReal = return ()
extDeclareTypes SProb = return ()
extDeclareTypes (SMeasure i) = extDeclareTypes i >> extDeclare (mdataStruct i)
extDeclareTypes (SArray i) = extDeclareTypes i >> extDeclare (arrayStruct i)
extDeclareTypes (SFun x y) = extDeclareTypes x >> extDeclareTypes y
extDeclareTypes d@(SData _ i) = extDeclDatum i >> extDeclare (datumStruct d)
where extDeclDatum :: Sing (a :: [[HakaruFun]]) -> CodeGen ()
extDeclDatum SVoid = return ()
extDeclDatum (SPlus p s) = extDeclDatum s >> datumProdTypes p
datumProdTypes :: Sing (a :: [HakaruFun]) -> CodeGen ()
datumProdTypes SDone = return ()
datumProdTypes (SEt x p) = datumProdTypes p >> datumPrimTypes x
datumPrimTypes :: Sing (a :: HakaruFun) -> CodeGen ()
datumPrimTypes SIdent = return ()
datumPrimTypes (SKonst s) = extDeclareTypes s
declare' :: CDecl -> CodeGen ()
declare' d = do cg <- get
put $ cg { declarations = d:(declarations cg) }
putStat :: CStat -> CodeGen ()
putStat s = do cg <- get
put $ cg { statements = s:(statements cg) }
putExprStat :: CExpr -> CodeGen ()
putExprStat = putStat . CExpr . Just
assign :: Ident -> CExpr -> CodeGen ()
assign ident e = putStat . CExpr . Just $ (CVar ident .=. e)
extDeclare :: CExtDecl -> CodeGen ()
extDeclare d = do cg <- get
let extds = extDecls cg
extds' = if elem d extds
then extds
else d:extds
put $ cg { extDecls = extds' }
newtype Env = Env (IM.IntMap Ident)
deriving Show
emptyEnv :: Env
emptyEnv = Env IM.empty
updateEnv :: Variable (a :: Hakaru) -> Ident -> Env -> Env
updateEnv (Variable _ nat _) ident (Env env) =
Env $! IM.insert (fromNat nat) ident env
lookupVar :: Variable (a :: Hakaru) -> Env -> Maybe Ident
lookupVar (Variable _ nat _) (Env env) =
IM.lookup (fromNat nat) env
funCG :: [CTypeSpec] -> Ident -> [CDecl] -> CodeGen () -> CodeGen ()
funCG ts ident args m =
do cg <- get
let (_,cg') = runState m $ cg { statements = []
, declarations = []
, freshNames = cNameStream }
let decls = reverse . declarations $ cg'
stmts = reverse . statements $ cg'
put $ cg' { statements = statements cg
, declarations = declarations cg
, freshNames = freshNames cg }
extDeclare . CFunDefExt $
CFunDef (fmap CTypeSpec ts)
(CDeclr Nothing (CDDeclrIdent ident))
args
(CCompound ((fmap CBlockDecl decls) ++ (fmap CBlockStat stmts)))
ifCG :: CExpr -> CodeGen () -> CodeGen () -> CodeGen ()
ifCG bE m1 m2 =
do cg <- get
let (_,cg') = runState m1 $ cg { statements = []
, declarations = [] }
(_,cg'') = runState m2 $ cg' { statements = []
, declarations = [] }
thnBlock = (fmap CBlockDecl (reverse $ declarations cg'))
++ (fmap CBlockStat (reverse $ statements cg'))
elsBlock = (fmap CBlockDecl (reverse $ declarations cg'')
++ (fmap CBlockStat (reverse $ statements cg'')))
put $ cg'' { statements = statements cg
, declarations = declarations cg }
putStat $ CIf bE
(CCompound thnBlock)
(case elsBlock of
[] -> Nothing
_ -> Just . CCompound $ elsBlock)
whileCG' :: Bool -> CExpr -> CodeGen () -> CodeGen ()
whileCG' isDoWhile bE m =
do cg <- get
let (_,cg') = runState m $ cg { statements = []
, declarations = [] }
put $ cg' { statements = statements cg
, declarations = declarations cg }
putStat $ CWhile bE
(CCompound $ (fmap CBlockDecl (reverse $ declarations cg')
++ (fmap CBlockStat (reverse $ statements cg'))))
isDoWhile
whileCG :: CExpr -> CodeGen () -> CodeGen ()
whileCG = whileCG' False
doWhileCG :: CExpr -> CodeGen () -> CodeGen ()
doWhileCG = whileCG' True
forCG
:: CExpr
-> CExpr
-> CExpr
-> CodeGen ()
-> CodeGen ()
forCG iter cond inc body =
do cg <- get
let (_,cg') = runState body $ cg { statements = []
, declarations = []
, sharedMem = False }
put $ cg' { statements = statements cg
, declarations = declarations cg
, sharedMem = sharedMem cg }
whenPar . putStat . CPPStat . ompToPP $ OMP (Parallel [For])
putStat $ CFor (Just iter)
(Just cond)
(Just inc)
(CCompound $ (fmap CBlockDecl (reverse $ declarations cg')
++ (fmap CBlockStat (reverse $ statements cg'))))
reductionCG
:: Either CBinaryOp
( Sing (a :: Hakaru)
, CExpr -> CodeGen ()
, CExpr -> CExpr -> CodeGen () )
-> CExpr
-> CExpr
-> CExpr
-> CExpr
-> CodeGen ()
-> CodeGen ()
reductionCG op acc iter cond inc body =
do cg <- get
let (_,cg') = runState body $ cg { statements = []
, declarations = []
, sharedMem = False }
put $ cg' { statements = statements cg
, declarations = declarations cg
, sharedMem = sharedMem cg }
whenPar $
case op of
Left binop ->
putStat . CPPStat . ompToPP $
OMP (Parallel [For,Reduction (Left binop) [acc]])
Right (typ,unit,mul) ->
do { redId <- declareReductionCG typ unit mul
; putStat . CPPStat . ompToPP $
OMP (Parallel [For,Reduction (Right redId) [acc]]) }
putStat $ CFor (Just iter)
(Just cond)
(Just inc)
(CCompound $ (fmap CBlockDecl (reverse $ declarations cg')
++ (fmap CBlockStat (reverse $ statements cg'))))
declareReductionCG
:: Sing (a :: Hakaru)
-> (CExpr -> CodeGen ())
-> (CExpr -> CExpr -> CodeGen ())
-> CodeGen Ident
declareReductionCG typ unit mul =
do (redId:unitId:mulId:[]) <- mapM genIdent' ["red","unit","mul"]
let declType = typePtrDeclaration typ
inId <- genIdent' "in"
funCG [CVoid] unitId [declType inId] $
unit . CVar $ inId
(outId:in2Id:[]) <- mapM genIdent' ["out","in"]
funCG [CVoid] mulId [declType outId,declType in2Id] $
mul (CVar outId) (CVar in2Id)
let typ' = case buildType typ of
(x:_) -> x
_ -> error $ "buildType{" ++ (show typ) ++ "}"
putStat . CPPStat . ompToPP $
OMP (DeclareRed redId
typ'
(CCall (CVar mulId)
(fmap (address . CVar . Ident)
["omp_in","omp_out"]))
(CCall (CVar unitId)
[address . CVar . Ident $ "omp_priv"]))
return redId
codeBlockCG :: CodeGen () -> CodeGen ()
codeBlockCG body =
do cg <- get
let (_,cg') = runState body $ cg { statements = []
, declarations = [] }
put $ cg' { statements = statements cg
, declarations = declarations cg }
putStat $ (CCompound $ (fmap CBlockDecl (reverse $ declarations cg')
++ (fmap CBlockStat (reverse $ statements cg'))))
putMallocStat :: CExpr -> CExpr -> Sing (a :: Hakaru) -> CodeGen ()
putMallocStat loc size typ = do
isManagedMem <- managedMem <$> get
let malloc' = if isManagedMem then gcMalloc else mallocE
typ' = buildType typ
putExprStat $ loc
.=. ( CCast (CTypeName typ' True)
$ malloc' (size .*. (CSizeOfType (CTypeName typ' False))))