{-# LANGUAGE CPP,
BangPatterns,
DataKinds,
FlexibleContexts,
GADTs,
KindSignatures,
ScopedTypeVariables,
RankNTypes,
TypeOperators #-}
module Language.Hakaru.CodeGen.Flatten
( flattenABT
, flattenVar
, flattenTerm
, flattenWithName
, flattenWithName'
, localVar
, localVar'
, opComment
) where
import Language.Hakaru.CodeGen.CodeGenMonad
import Language.Hakaru.CodeGen.AST
import Language.Hakaru.CodeGen.Libs
import Language.Hakaru.CodeGen.Types
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.Datum hiding (Ident)
import Language.Hakaru.Syntax.Reducer
import qualified Language.Hakaru.Syntax.Prelude as HKP
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.Sing
import Control.Monad.State.Strict
import Control.Monad (replicateM)
import Control.Applicative (pure)
import Data.Number.Natural
import Data.Monoid hiding (Product,Sum)
import Data.Ratio
import qualified Data.List.NonEmpty as NE
import qualified Data.Sequence as S
import qualified Data.Foldable as F
import qualified Data.Traversable as T
#if __GLASGOW_HASKELL__ < 710
import Data.Functor
#endif
opComment :: String -> CStat
opComment opStr = CComment $ concat [space," ",opStr," ",space]
where size = (50 - (length opStr)) `div` 2 - 8
space = replicate size '-'
localVar :: Sing (a :: Hakaru) -> CodeGen CExpr
localVar typ = localVar' typ ""
localVar' :: Sing (a :: Hakaru) -> String -> CodeGen CExpr
localVar' typ s =
do eId <- genIdent' s
declare typ eId
return (CVar eId)
flattenWithName'
:: ABT Term abt
=> abt '[] a
-> String
-> CodeGen CExpr
flattenWithName' abt hint = do
ident <- genIdent' hint
declare (typeOf abt) ident
let cvar = CVar ident
flattenABT abt cvar
return cvar
flattenWithName
:: ABT Term abt
=> abt '[] a
-> CodeGen CExpr
flattenWithName abt = flattenWithName' abt ""
flattenABT
:: ABT Term abt
=> abt '[] a
-> (CExpr -> CodeGen ())
flattenABT abt = caseVarSyn abt flattenVar flattenTerm
-- note that variables will find their values in the state of the CodeGen monad
flattenVar
:: Variable (a :: Hakaru)
-> (CExpr -> CodeGen ())
flattenVar v = \loc ->
do v' <- CVar <$> lookupIdent v
putExprStat $ loc .=. v'
flattenTerm
:: ABT Term abt
=> Term abt a
-> (CExpr -> CodeGen ())
-- SCon can contain mochastic terms
flattenTerm (x :$ ys) = flattenSCon x ys
flattenTerm (NaryOp_ t s) = flattenNAryOp t s
flattenTerm (Literal_ x) = flattenLit x
flattenTerm (Empty_ _) = error "TODO: flattenTerm{Empty}"
flattenTerm (Datum_ d) = flattenDatum d
flattenTerm (Case_ c bs) = flattenCase c bs
flattenTerm (Bucket b e rs) = flattenBucket b e rs
flattenTerm (Array_ s e) = flattenArray s e
flattenTerm (ArrayLiteral_ s) = flattenArrayLiteral s
---------------------
-- Mochastic Terms --
---------------------
flattenTerm (Superpose_ wes) = flattenSuperpose wes
flattenTerm (Reject_ _) = \loc -> putExprStat (mdataWeight loc .=. (intE 0))
--------------------------------------------------------------------------------
-- SCon --
--------------------------------------------------------------------------------
flattenSCon
:: ( ABT Term abt )
=> SCon args a
-> SArgs abt args
-> (CExpr -> CodeGen ())
flattenSCon Let_ =
\(expr :* body :* End) ->
\loc -> do
caseBind body $ \v@(Variable _ _ typ) body'->
do ident <- createIdent v
case typ of
(SFun _ _) -> return ()
_ -> declare typ ident
flattenABT expr (CVar ident)
flattenABT body' loc
-- Lambdas produce functions and then return a function label exprssion
flattenSCon Lam_ =
\(body :* End) ->
\loc -> do
-- externally declare closure and function
closureTypeSpec <- coalesceLambda body extDeclClosure
-- declare local closure var
closureId <- genIdent' "closure"
declare' (buildDeclaration closureTypeSpec closureId)
-- capture environment in closure
putExprStat $ loc .=. (CVar closureId)
where coalesceLambda
:: ( ABT Term abt )
=> abt '[x] a
-> (forall (ys :: [Hakaru]) b. List1 Variable ys -> abt '[] b -> r)
-> r
coalesceLambda abt k =
caseBind abt $ \v abt' ->
caseVarSyn abt' (const (k (Cons1 v Nil1) abt')) $ \term ->
case term of
(Lam_ :$ body :* End) ->
coalesceLambda body $ \vars abt'' -> k (Cons1 v vars) abt''
_ -> k (Cons1 v Nil1) abt'
-- given a parameter, create identifiers corresponding to Hakaru vars,
-- and return a CTypeSpec for the param
-- Will this fail if the parameter is a SFun?
mkVarIdandSpec :: Variable (a :: Hakaru) -> CodeGen (Ident,[CTypeSpec])
mkVarIdandSpec v@(Variable _ _ typ) = do
extDeclareTypes typ
vId <- createIdent v
return (vId,buildType typ)
extDeclClosure
:: ( ABT Term abt )
=> List1 Variable (ys :: [Hakaru])
-> abt '[] b
-> CodeGen CTypeSpec
extDeclClosure vars body'= do
funcId <- genIdent' "fn"
idAndSpecs <- sequence $ foldMap11 (\v -> [mkVarIdandSpec v]) vars
let fVars = freeVars body'
typ = typeOf body'
sId@(Ident sname) <- extDeclClosureStruct typ (fmap snd idAndSpecs) fVars
funCG (head . buildType $ typ)
funcId
([buildDeclaration (callStruct sname) (Ident "env")]
++ (fmap (\(vId,specs) -> buildDeclaration' specs vId) idAndSpecs))
((putStat . CReturn . Just) =<< flattenWithName body')
return (callStruct sname)
extDeclClosureStruct
:: forall (a :: Hakaru) (ys :: [Hakaru])
. Sing a
-> [[CTypeSpec]]
-> VarSet (KindOf a)
-> CodeGen Ident
extDeclClosureStruct retTyp paramTypeSpecs freeVars = do
sId@(Ident sname) <- genIdent' "clos"
freeVarDecls <- mapM (\(SomeVariable v@(Variable _ _ typ)) -> do
extDeclareTypes typ
vId <- createIdent v
return (typeDeclaration typ vId)
) (fromVarSet freeVars)
let funPtrDecl =
CDecl (fmap CTypeSpec $ buildType retTyp)
[( CDeclr Nothing
(CDDeclrFun
(CDDeclrRec (CDeclr (Just $ CPtrDeclr []) (CDDeclrIdent . Ident $ "fn")))
([callStruct sname]++(concat paramTypeSpecs)))
, Nothing)]
extDeclare $ CDeclExt
$ CDecl [ CTypeSpec $ buildStruct (Just sId)
([funPtrDecl]++freeVarDecls) ]
[]
return sId
flattenSCon App_ =
\(fun :* arg :* End) ->
\loc -> do
closE <- flattenWithName' fun "clos"
paramE <- flattenWithName' fun "param"
putExprStat $ loc .=. (CCall (CMember closE (Ident "fn") True) [paramE])
flattenSCon (PrimOp_ op) = flattenPrimOp op
flattenSCon (ArrayOp_ op) = flattenArrayOp op
flattenSCon (Summate _ sr) =
\(lo :* hi :* body :* End) ->
\loc ->
let semiTyp = sing_HSemiring sr in
do loE <- flattenWithName' lo "lo"
hiE <- flattenWithName' hi "hi"
putStat $ opComment "Begin Summate"
case semiTyp of
-- special prob branch
SProb -> do
summateArrId <- genIdent' "summate_arr"
declare (SArray SProb) summateArrId
let summateArrE = CVar summateArrId
putExprStat $ arraySize summateArrE .=. (hiE .-. loE)
putExprStat $ arrayData summateArrE
.=. (castToPtrOf [CDouble]
(mallocE ((arraySize summateArrE) .*.
(CSizeOfType (CTypeName [CDouble] False)))))
lseSummateArrayCG body summateArrE loc
putExprStat $ freeE (arrayData summateArrE)
_ ->
caseBind body $ \v body' -> do
iterI <- createIdent v
declare SNat iterI
accI <- genIdent' "acc"
declare semiTyp accI
assign accI (case semiTyp of
SReal -> floatE 0
_ -> intE 0)
let accVar = CVar accI
iterVar = CVar iterI
reductionCG (Left CAddOp)
accI
(iterVar .=. loE)
(iterVar .<. hiE)
(CUnary CPostIncOp iterVar) $
(putExprStat . (accVar .+=.) =<< flattenWithName body')
putExprStat $ loc .=. accVar
putStat $ opComment "End Summate"
flattenSCon (Product _ sr) =
\(lo :* hi :* body :* End) ->
\loc ->
let semiTyp = sing_HSemiring sr in
do loE <- flattenWithName' lo "lo"
hiE <- flattenWithName' hi "hi"
putStat $ opComment "Begin Product"
case semiTyp of
-- special prob branch
SProb -> kahanSummationCG body loE hiE loc
_ ->
caseBind body $ \v body' -> do
iterI <- createIdent v
declare SNat iterI
accI <- genIdent' "acc"
declare semiTyp accI
assign accI (case semiTyp of
SReal -> floatE 1
_ -> intE 1)
let accVar = CVar accI
iterVar = CVar iterI
reductionCG (Left CMulOp)
accI
(iterVar .=. loE)
(iterVar .<. hiE)
(CUnary CPostIncOp iterVar) $
(putExprStat . (accVar .*=.) =<< flattenWithName body')
putExprStat (loc .=. accVar)
putStat $ opComment "End Product"
--------------------
-- SCon Coercions --
--------------------
{- Helpers found by searching "Coercion Helpers" -}
flattenSCon (CoerceTo_ ctyp) =
\(e :* End) ->
\loc ->
do eE <- flattenWithName e
cE <- coerceToCG ctyp eE
putExprStat $ loc .=. cE
flattenSCon (UnsafeFrom_ ctyp) =
\(e :* End) ->
\loc ->
do eE <- flattenWithName e
cE <- coerceFromCG ctyp eE
putExprStat $ loc .=. cE
-----------------------------------
-- SCons in the Stochastic Monad --
-----------------------------------
flattenSCon (MeasureOp_ op) = flattenMeasureOp op
flattenSCon Dirac =
\(e :* End) ->
\loc ->
do sE <- flattenWithName' e "samp"
putExprStat $ mdataWeight loc .=. (floatE 0)
putExprStat $ mdataSample loc .=. sE
flattenSCon MBind =
\(ma :* b :* End) ->
\loc ->
caseBind b $ \v@(Variable _ _ typ) mb ->
do -- first
mE <- flattenWithName' ma "m"
-- assign that sample to var
vId <- createIdent v
declare typ vId
assign vId (mdataSample mE)
flattenABT mb loc
putExprStat $ mdataWeight loc .+=. (mdataWeight mE)
-- for now plats make use of a global sample
flattenSCon Plate =
\(size :* b :* End) ->
\loc ->
caseBind b $ \v@(Variable _ _ typ) body ->
do sizeE <- flattenWithName' size "s"
isMM <- managedMem <$> get
when (not isMM) (error "plate will leak memory without the '-g' flag and boehm-gc")
putExprStat $ (arraySize . mdataSample $ loc) .=. sizeE
putMallocStat (arrayData . mdataSample $ loc) sizeE (typeOf body)
weightId <- genIdent' "w"
declare SProb weightId
let weightE = CVar weightId
assign weightId (floatE 0)
itId <- createIdent v
declare SNat itId
let itE = CVar itId
currInd = index (arrayData . mdataSample $ loc) itE
sampId <- genIdent' "samp"
declare (typeOf $ body) sampId
let sampE = CVar sampId
reductionCG (Left CAddOp)
weightId
(itE .=. (intE 0))
(itE .<. sizeE)
(CUnary CPostIncOp itE)
(do flattenABT body sampE
putExprStat (currInd .=. (mdataSample sampE))
putExprStat (weightE .+=. (mdataWeight sampE)))
putExprStat $ mdataWeight loc .=. weightE
-----------------------------------
-- SCon's that arent implemented --
-----------------------------------
flattenSCon x = \_ -> \_ -> error $ "TODO: flattenSCon: " ++ show x
--------------------------------------------------------------------------------
-- NaryOps --
--------------------------------------------------------------------------------
flattenNAryOp :: ABT Term abt
=> NaryOp a
-> S.Seq (abt '[] a)
-> (CExpr -> CodeGen ())
flattenNAryOp op args =
\loc ->
do es <- T.mapM flattenWithName args
case op of
And -> boolNaryOp op es loc
Or -> boolNaryOp op es loc
Xor -> boolNaryOp op es loc
Iff -> boolNaryOp op es loc
(Sum HSemiring_Prob) -> logSumExpCG es loc
_ -> let opE = F.foldr (binaryOp op) (S.index es 0) (S.drop 1 es)
in putExprStat (loc .=. opE)
where boolNaryOp op' es' loc' =
let indexOf x = CMember x (Ident "index") True
es'' = fmap indexOf es'
expr = F.foldr (binaryOp op')
(S.index es'' 0)
(S.drop 1 es'')
in putExprStat ((indexOf loc') .=. expr)
--------------------------------------------------------------------------------
-- Literals --
--------------------------------------------------------------------------------
flattenLit
:: Literal a
-> (CExpr -> CodeGen ())
flattenLit lit =
\loc ->
case lit of
(LNat x) -> putExprStat $ loc .=. (intE $ fromIntegral x)
(LInt x) -> putExprStat $ loc .=. (intE x)
(LReal x) -> putExprStat $ loc .=. (floatE $ fromRational x)
(LProb x) -> let rat = fromNonNegativeRational x
x' = (fromIntegral $ numerator rat)
/ (fromIntegral $ denominator rat)
xE = log1pE (floatE x' .-. intE 1)
in putExprStat (loc .=. xE)
--------------------------------------------------------------------------------
-- Array and ArrayOps --
--------------------------------------------------------------------------------
flattenArray
:: (ABT Term abt)
=> (abt '[] 'HNat)
-> (abt '[ 'HNat ] a)
-> (CExpr -> CodeGen ())
flattenArray arity body =
\loc ->
caseBind body $ \v body' -> do
let arityE = arraySize loc
dataE = arrayData loc
typ = typeOf body'
flattenABT arity arityE
isManagedMem <- managedMem <$> get
let malloc' = if isManagedMem then gcMalloc else mallocE
putExprStat $ dataE
.=. (CCast (CTypeName (buildType typ) True)
(malloc' (arityE .*. (CSizeOfType (CTypeName (buildType typ) False)))))
itId <- createIdent v
declare SNat itId
let itE = CVar itId
currInd = index dataE itE
putStat $ opComment "Begin Array"
forCG (itE .=. (intE 0))
(itE .<. arityE)
(CUnary CPostIncOp itE)
(flattenABT body' currInd)
putStat $ opComment "End Array"
flattenArrayLiteral
:: ( ABT Term abt )
=> [abt '[] a]
-> (CExpr -> CodeGen ())
flattenArrayLiteral es =
\loc -> do
arrId <- genIdent
isManagedMem <- managedMem <$> get
let arity = fromIntegral . length $ es
typ = typeOf . head $ es
arrE = CVar arrId
malloc' = if isManagedMem then gcMalloc else mallocE
declare (SArray typ) arrId
putExprStat $ (arrayData arrE)
.=. (CCast (CTypeName (buildType typ) True)
(malloc' ((intE arity) .*. (CSizeOfType (CTypeName (buildType typ) False)))))
putExprStat $ arraySize arrE .=. (intE arity)
sequence_ . snd $ foldl (\(i,acc) e -> (succ i,(assignIndex e i arrE):acc))
(0,[])
es
putExprStat $ loc .=. arrE
where assignIndex
:: ( ABT Term abt )
=> abt '[] a
-> Integer
-> (CExpr -> CodeGen ())
assignIndex e index loc = do
eE <- flattenWithName e
putExprStat $ indirect ((arrayData loc) .+. (intE index)) .=. eE
--------------
-- ArrayOps --
--------------
flattenArrayOp
:: ( ABT Term abt
, typs ~ UnLCs args
, args ~ LCs typs
)
=> ArrayOp typs a
-> SArgs abt args
-> (CExpr -> CodeGen ())
flattenArrayOp (Index _) =
\(arr :* ind :* End) ->
\loc ->
do indE <- flattenWithName ind
arrE <- flattenWithName arr
let valE = index (CMember arrE (Ident "data") True) indE
putExprStat (loc .=. valE)
flattenArrayOp (Size _) =
\(arr :* End) ->
\loc ->
do arrE <- flattenWithName arr
putExprStat (loc .=. (CMember arrE (Ident "size") True))
flattenArrayOp (Reduce _) = error "TODO: flattenArrayOp"
-- \(fun :* base :* arr :* End) ->
-- do funE <- flattenABT fun
-- baseE <- flattenABT base
-- arrE <- flattenABT arr
-- accI <- genIdent' "acc"
-- iterI <- genIdent' "iter"
-- let sizeE = CMember arrE (Ident "size") True
-- iterE = CVar iterI
-- accE = CVar accI
-- cond = iterE .<. sizeE
-- inc = CUnary CPostIncOp iterE
-- declare (typeOf base) accI
-- declare SInt iterI
-- assign accI baseE
-- forCG (iterE .=. (intE 0)) cond inc $
-- assign accI $ CCall funE [accE]
-- return accE
--------------------------------------------------------------------------------
-- Bucket and Reducers --
--------------------------------------------------------------------------------
{- Declarations for buckets -
since we will have some product of monoids we need unique names for each one.
Ex:
bucket i from 0 to 100:
fanout(add(\_ -> 1),add(\_ -> 2))
we will need to keep track of two ints:
int x;
int y;
for (i = 0; i < 100; i++) {
x += 1;
y += 2;
}
Summary objects are nested pairs, e.g. pair(nat,pair(real,array(nat)))
-}
flattenBucket
:: (ABT Term abt)
=> abt '[] 'HNat
-> abt '[] 'HNat
-> Reducer abt '[] a
-> (CExpr -> CodeGen ())
flattenBucket lo hi red = \loc -> do
putStat $ opComment "Begin Bucket"
loE <- flattenWithName' lo "lo"
hiE <- flattenWithName' hi "hi"
itId <- genIdent' "it"
declare SNat itId
let itE = CVar itId
initRed red loc
forCG (itE .=. loE)
(itE .<. hiE)
(CUnary CPostIncOp itE)
(accumRed red itE loc)
putStat $ opComment "End Bucket"
where initRed
:: (ABT Term abt)
=> Reducer abt xs a
-> (CExpr -> CodeGen ())
initRed mr = \loc ->
case mr of
(Red_Fanout mr1 mr2) -> initRed mr1 (datumFst loc)
>> initRed mr2 (datumSnd loc)
(Red_Split _ mr1 mr2) -> initRed mr1 (datumFst loc)
>> initRed mr2 (datumSnd loc)
(Red_Index s _ body) ->
let (vs,s') = caseBinds s
btyp = typeOfReducer body in
do sequence_ . foldMap11
(\v' -> case v' of
(Variable _ _ typ') ->
[declare typ' =<< createIdent v'])
$ vs
sE <- flattenWithName s'
putExprStat $ arraySize loc .=. sE
putMallocStat (arrayData loc) sE btyp
itId <- genIdent
declare SNat itId
let itE = CVar itId
forCG (itE .=. (intE 0))
(itE .<. sE)
(CUnary CPostIncOp itE)
(initRed body (index (arrayData loc) itE))
Red_Nop -> return ()
(Red_Add sr _) ->
putExprStat $ loc .=. (addMonoidIdentity . sing_HSemiring $ sr)
accumRed
:: (ABT Term abt)
=> Reducer abt xs a
-> CExpr
-> (CExpr -> CodeGen ())
accumRed mr itE = \loc ->
case mr of
(Red_Index _ a body) ->
caseBind a $ \v@(Variable _ _ typ) a' ->
let (vs,a'') = caseBinds a' in
do vId <- createIdent v
declare typ vId
putExprStat $ (CVar vId) .=. itE
sequence_ . foldMap11
(\v' -> case v' of
(Variable _ _ typ') ->
[declare typ' =<< createIdent v'])
$ vs
aE <- flattenWithName' a'' "index"
accumRed body itE (index (arrayData loc) aE)
(Red_Fanout mr1 mr2) -> accumRed mr1 itE (datumFst loc)
>> accumRed mr2 itE (datumSnd loc)
(Red_Split b mr1 mr2) ->
caseBind b $ \v@(Variable _ _ typ) b' ->
let (vs,b'') = caseBinds b' in
do vId <- createIdent v
declare typ vId
putExprStat $ (CVar vId) .=. itE
sequence_ . foldMap11
(\v' -> case v' of
(Variable _ _ typ') ->
[declare typ' =<< createIdent v'])
$ vs
bE <- flattenWithName' b'' "cond"
ifCG (bE ... "index" .==. (intE 0))
(accumRed mr1 itE (datumFst loc))
(accumRed mr2 itE (datumSnd loc))
Red_Nop -> return ()
(Red_Add sr e) ->
caseBind e $ \v@(Variable _ _ typ) e' ->
let (vs,e'') = caseBinds e' in
do vId <- createIdent v
declare typ vId
putExprStat $ (CVar vId) .=. itE
sequence_ . foldMap11
(\v' -> case v' of
(Variable _ _ typ') ->
[declare typ' =<< createIdent v'])
$ vs
eE <- flattenWithName e''
case sing_HSemiring sr of
SProb -> logSumExpCG (S.fromList [loc,eE]) loc
_ -> putExprStat $ loc .+=. eE
addMonoidIdentity :: Sing (a :: Hakaru) -> CExpr
addMonoidIdentity s =
case s of
SNat -> intE 0
SInt -> intE 0
SReal -> floatE 0
SProb -> logE (floatE 0)
SArray x -> addMonoidIdentity x
x -> error $ "addMonoidIdentity{" ++ show x ++ "}"
--------------------------------------------------------------------------------
-- Datum and Case --
--------------------------------------------------------------------------------
{-
Datum are sums of products of types. This maps to a C structure. flattenDatum
will produce a literal of some datum type. This will also produce a global
struct representing that datum which will be needed for the C compiler.
-}
flattenDatum
:: (ABT Term abt)
=> Datum (abt '[]) (HData' a)
-> (CExpr -> CodeGen ())
flattenDatum (Datum _ typ code) =
\loc ->
do extDeclareTypes typ
assignDatum code loc
assignDatum
:: (ABT Term abt)
=> DatumCode xss (abt '[]) c
-> CExpr
-> CodeGen ()
assignDatum code ident =
let index = getIndex code
indexExpr = CMember ident (Ident "index") True
in do putExprStat (indexExpr .=. (intE index))
sequence_ $ assignSum code ident
where getIndex :: DatumCode xss b c -> Integer
getIndex (Inl _) = 0
getIndex (Inr rest) = succ (getIndex rest)
assignSum
:: (ABT Term abt)
=> DatumCode xs (abt '[]) c
-> CExpr
-> [CodeGen ()]
assignSum code ident = fst $ runState (assignSum' code ident) cNameStream
assignSum'
:: (ABT Term abt)
=> DatumCode xs (abt '[]) c
-> CExpr
-> State [String] [CodeGen ()]
assignSum' (Inr rest) topIdent =
do (_:names) <- get
put names
assignSum' rest topIdent
assignSum' (Inl prod) topIdent =
do (name:_) <- get
return $ assignProd prod topIdent (CVar . Ident $ name)
assignProd
:: (ABT Term abt)
=> DatumStruct xs (abt '[]) c
-> CExpr
-> CExpr
-> [CodeGen ()]
assignProd dstruct topIdent sumIdent =
fst $ runState (assignProd' dstruct topIdent sumIdent) cNameStream
assignProd'
:: (ABT Term abt)
=> DatumStruct xs (abt '[]) c
-> CExpr
-> CExpr
-> State [String] [CodeGen ()]
assignProd' Done _ _ = return []
assignProd' (Et (Konst d) rest) topIdent (CVar sumIdent) =
do (name:names) <- get
put names
let varName = CMember (CMember (CMember topIdent
(Ident "sum")
True)
sumIdent
True)
(Ident name)
True
rest' <- assignProd' rest topIdent (CVar sumIdent)
return $ [flattenABT d varName] ++ rest'
assignProd' _ _ _ = error $ "TODO: assignProd Ident"
----------
-- Case --
----------
-- currently we can only match on boolean values
flattenCase
:: forall abt a b
. (ABT Term abt)
=> abt '[] a
-> [Branch a abt b]
-> (CExpr -> CodeGen ())
flattenCase c [ Branch (PDatum _ (PInl PDone)) trueB
, Branch (PDatum _ (PInr (PInl PDone))) falseB ] =
\loc ->
do cE <- flattenWithName c
ifCG ((cE ... "index") .==. (intE 0))
(flattenABT trueB loc)
(flattenABT falseB loc)
flattenCase e [ Branch (PDatum _ (PInl (PEt (PKonst PVar)
(PEt (PKonst PVar)
PDone)))) b
]
= \loc -> do
eE <- flattenWithName e
caseBind b $ \vfst@(Variable _ _ fstTyp) b' ->
caseBind b' $ \vsnd@(Variable _ _ sndTyp) b'' -> do
fstId <- createIdent vfst
sndId <- createIdent vsnd
declare fstTyp fstId
declare sndTyp sndId
putExprStat $ (CVar fstId) .=. (datumFst eE)
putExprStat $ (CVar sndId) .=. (datumSnd eE)
flattenABT b'' loc
flattenCase e _ = error $ "TODO: flattenCase{" ++ show (typeOf e) ++ "}"
--------------------------------------------------------------------------------
-- PrimOp --
--------------------------------------------------------------------------------
flattenPrimOp
:: ( ABT Term abt
, typs ~ UnLCs args
, args ~ LCs typs)
=> PrimOp typs a
-> SArgs abt args
-> (CExpr -> CodeGen ())
flattenPrimOp Pi =
\End ->
\loc -> let piE = log1pE ((CVar . Ident $ "M_PI") .-. (intE 1)) in
putExprStat (loc .=. piE)
flattenPrimOp Not =
\(a :* End) ->
\loc ->
do aE <- flattenWithName a
bId <- genIdent
declare (typeOf a) bId
let datumIndex e = CMember e (Ident "index") True
bE = CVar bId
putExprStat $ datumIndex bE .=. (CCond (datumIndex aE .==. (intE 1))
(intE 0)
(intE 1))
putExprStat $ loc .=. bE
flattenPrimOp RealPow =
\(base :* power :* End) ->
\loc ->
do baseE <- flattenWithName base
powerE <- flattenWithName power
let realPow = CCall (CVar . Ident $ "pow")
[ expm1E baseE .+. (intE 1), powerE]
putExprStat $ loc .=. (log1pE (realPow .-. (intE 1)))
flattenPrimOp (NatPow baseTyp) =
\(base :* power :* End) ->
\loc ->
let sBase = sing_HSemiring baseTyp in
do baseId <- genIdent
declare sBase baseId
let baseE = CVar baseId
flattenABT base baseE
powerE <- flattenWithName power
let powerOf x y = CCall (CVar . Ident $ "pow") [x,y]
value = case sBase of
SProb -> log1pE $ (powerOf (expm1E baseE .+. (intE 1)) powerE)
.-. (intE 1)
_ -> powerOf baseE powerE
putExprStat $ loc .=. value
flattenPrimOp (NatRoot baseTyp) =
\(base :* root :* End) ->
\loc ->
let sBase = sing_HRadical baseTyp in
do baseId <- genIdent
declare sBase baseId
let baseE = CVar baseId
flattenABT base baseE
rootE <- flattenWithName root
let powerOf x y = CCall (CVar . Ident $ "pow") [x,y]
recipE = (floatE 1) ./. rootE
value = case sBase of
SProb -> log1pE $ (powerOf (expm1E baseE .+. (intE 1)) recipE)
.-. (intE 1)
_ -> powerOf baseE recipE
putExprStat $ loc .=. value
flattenPrimOp (Recip t) =
\(a :* End) ->
\loc ->
do aE <- flattenWithName a
case t of
HFractional_Real -> putExprStat $ loc .=. ((intE 1) ./. aE)
HFractional_Prob -> putExprStat $ loc .=. (CUnary CMinOp aE)
-- | exp : real -> prob, because of this we can just turn it into a prob without taking
-- its log, which would give us an exp in the log-domain
flattenPrimOp Exp = \(a :* End) -> flattenABT a
flattenPrimOp (Equal _) =
\(a :* b :* End) ->
\loc ->
do aE <- flattenWithName a
bE <- flattenWithName b
-- special case for booleans
let aE' = case (typeOf a) of
(SData _ (SPlus SDone (SPlus SDone SVoid))) -> (CMember aE (Ident "index") True)
_ -> aE
let bE' = case (typeOf b) of
(SData _ (SPlus SDone (SPlus SDone SVoid))) -> (CMember bE (Ident "index") True)
_ -> bE
putExprStat $ (CMember loc (Ident "index") True)
.=. (CCond (aE' .==. bE') (intE 0) (intE 1))
flattenPrimOp (Less _) =
\(a :* b :* End) ->
\loc ->
do aE <- flattenWithName a
bE <- flattenWithName b
putExprStat $ (CMember loc (Ident "index") True)
.=. (CCond (aE .<. bE) (intE 0) (intE 1))
flattenPrimOp (Negate _) =
\(a :* End) ->
\loc ->
do aE <- flattenWithName a
putExprStat $ loc .=. (CUnary CMinOp $ aE)
flattenPrimOp t = \_ -> error $ "TODO: flattenPrimOp: " ++ show t
--------------------------------------------------------------------------------
-- MeasureOps and Superpose --
--------------------------------------------------------------------------------
{-
The sections contains operations in the stochastic monad. See also
(Dirac, MBind, and Plate) found in SCon. Also see Reject found at the top level.
Remember in the C runtime. Measures are housed in a measure function, which
takes an `struct mdata` location. The MeasureOp attempts to store a value at
that location and returns 0 if it fails and 1 if it succeeds in that task.
The functions uniformCG, normalCG, and gammaCG are primitives that will generate
functions and call them (similar to logSumExpCG). The reduce code size and make
samplers a little more readable.
TODO: add inline pragmas to uniformCG, normalCG, and gammaCG
-}
uniformFun :: CFunDef
uniformFun = CFunDef (CTypeSpec <$> retTyp)
(CDeclr Nothing (CDDeclrIdent funcId))
[typeDeclaration SReal loId
,typeDeclaration SReal hiId]
(CCompound . concat
$ [ CBlockDecl <$> [declMD]
, CBlockStat <$> comment ++ [assW,assS,CReturn . Just $ mE]]
)
where r = castTo [CDouble] randE
rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX")
retTyp = buildType (SMeasure SReal)
(mId,mE) = let ident = Ident "mdata" in (ident,CVar ident)
(loId,loE) = let ident = Ident "lo" in (ident,CVar ident)
(hiId,hiE) = let ident = Ident "hi" in (ident,CVar ident)
value = (loE .+. ((r ./. rMax) .*. (hiE .-. loE)))
comment = fmap CComment
["uniform :: real -> real -> *(mdata real) -> ()"
,"------------------------------------------------"]
declMD = buildDeclaration (head retTyp) mId
assW = CExpr . Just $ mdataWeight mE .=. (floatE 0)
assS = CExpr . Just $ mdataSample mE .=. value
funcId = Ident "uniform"
uniformCG :: CExpr -> CExpr -> (CExpr -> CodeGen ())
uniformCG aE bE =
\loc -> do
uId <- reserveIdent "uniform"
extDeclareTypes (SMeasure SReal)
extDeclare . CFunDefExt $ uniformFun
putExprStat $ loc .=. CCall (CVar uId) [aE,bE]
{-
This is very cryptic, but I assure you it is only building an AST for the
Marsaglia Polar Method
-}
normalFun :: CFunDef
normalFun = CFunDef (CTypeSpec <$> retTyp)
(CDeclr Nothing (CDDeclrIdent (Ident "normal")))
[typeDeclaration SReal aId
,typeDeclaration SProb bId ]
( CCompound . concat
$ [[CBlockDecl declMD],comment,decls,stmts])
where r = castTo [CDouble] randE
rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX")
retTyp = buildType (SMeasure SReal)
(aId,aE) = let ident = Ident "a" in (ident,CVar ident)
(bId,bE) = let ident = Ident "b" in (ident,CVar ident)
(qId,qE) = let ident = Ident "q" in (ident,CVar ident)
(uId,uE) = let ident = Ident "u" in (ident,CVar ident)
(vId,vE) = let ident = Ident "v" in (ident,CVar ident)
(rId,rE) = let ident = Ident "r" in (ident,CVar ident)
(mId,mE) = let ident = Ident "mdata" in (ident,CVar ident)
declMD = buildDeclaration (head retTyp) mId
draw xE = CExpr . Just $ xE .=. (((r ./. rMax) .*. (floatE 2)) .-. (floatE 1))
body = seqCStat [draw uE
,draw vE
,CExpr . Just $ qE .=. ((uE .*. uE) .+. (vE .*. vE))]
polar = CWhile (qE .>. (floatE 1)) body True
setR = CExpr . Just $ rE .=. (sqrtE (((CUnary CMinOp (floatE 2)) .*. logE qE) ./. qE))
finalValue = aE .+. (uE .*. rE .*. bE)
comment = fmap (CBlockStat . CComment)
["normal :: real -> real -> *(mdata real) -> ()"
,"Marsaglia Polar Method"
,"-----------------------------------------------"]
decls = (CBlockDecl . typeDeclaration SReal) <$> [uId,vId,qId,rId]
stmts = CBlockStat <$> [polar,setR, assW, assS,CReturn . Just $ mE]
assW = CExpr . Just $ mdataWeight mE .=. (floatE 0)
assS = CExpr . Just $ mdataSample mE .=. finalValue
normalCG :: CExpr -> CExpr -> (CExpr -> CodeGen ())
normalCG aE bE =
\loc -> do
nId <- reserveIdent "normal"
extDeclareTypes (SMeasure SReal)
extDeclare . CFunDefExt $ normalFun
putExprStat $ loc .=. (CCall (CVar nId) [aE,bE])
{-
This method is from Marsaglia and Tsang "a simple method for generating gamma variables"
-}
gammaFun :: CFunDef
gammaFun = CFunDef (CTypeSpec <$> retTyp)
(CDeclr Nothing (CDDeclrIdent (Ident "gamma")))
[typeDeclaration SProb aId
,typeDeclaration SProb bId]
( CCompound . concat
$ [[CBlockDecl declMD],comment,decls,stmts])
where (aId,aE) = let ident = Ident "a" in (ident,CVar ident)
(bId,bE) = let ident = Ident "b" in (ident,CVar ident)
(cId,cE) = let ident = Ident "c" in (ident,CVar ident)
(dId,dE) = let ident = Ident "d" in (ident,CVar ident)
(xId,xE) = let ident = Ident "x" in (ident,CVar ident)
(vId,vE) = let ident = Ident "v" in (ident,CVar ident)
(uId,uE) = let ident = Ident "u" in (ident,CVar ident)
(mId,mE) = let ident = Ident "mdata" in (ident,CVar ident)
retTyp = buildType (SMeasure SProb)
declMD = buildDeclaration (head retTyp) mId
comment = fmap (CBlockStat . CComment)
["gamma :: real -> prob -> *(mdata prob) -> ()"
,"Marsaglia and Tsang 'a simple method for generating gamma variables'"
,"--------------------------------------------------------------------"]
decls = fmap CBlockDecl $ (fmap (typeDeclaration SReal) [dId,cId,vId])
++ (fmap (typeDeclaration (SMeasure SReal)) [uId,xId])
stmts = fmap CBlockStat $ [assD,assC,outerWhile]
xS = mdataSample xE
uS = mdataSample uE
assD = CExpr . Just $ dE .=. (aE .-. ((floatE 1) ./. (floatE 3)))
assC = CExpr . Just $ cE .=. ((floatE 1) ./. (sqrtE ((floatE 9) .*. dE)))
outerWhile = CWhile (intE 1) (seqCStat [innerWhile,assV,assU,exit]) False
innerWhile = CWhile (vE .<=. (floatE 0)) (seqCStat [assX,assVIn]) True
assX = CExpr . Just $ xE .=. (CCall (CVar . Ident $ "normal") [(floatE 0),(floatE 1)])
assVIn = CExpr . Just $ vE .=. ((floatE 1) .+. (cE .*. xS))
assV = CExpr . Just $ vE .=. (vE .*. vE .*. vE)
assU = CExpr . Just $ uE .=. (CCall (CVar . Ident $ "uniform") [(floatE 0),(floatE 1)])
exitC1 = uS .<. ((floatE 1) .-. ((floatE 0.331 .*. (xS .*. xS) .*. (xS .*. xS))))
exitC2 = (logE uS) .<. (((floatE 0.5) .*. (xS .*. xS)) .+. (dE .*. ((floatE 1.0) .-. vE .+. (logE vE))))
assW = CExpr . Just $ mdataWeight mE .=. (floatE 0)
assS = CExpr . Just $ mdataSample mE .=. (logE (dE .*. vE)) .+. bE
exit = CIf (exitC1 .||. exitC2) (seqCStat [assW,assS,CReturn . Just $ mE]) Nothing
gammaCG :: CExpr -> CExpr -> (CExpr -> CodeGen ())
gammaCG aE bE =
\loc -> do
extDeclareTypes (SMeasure SReal)
(_:_:gId:[]) <- mapM reserveIdent ["uniform","normal","gamma"]
mapM_ (extDeclare . CFunDefExt) [uniformFun,normalFun,gammaFun]
putExprStat $ loc .=. (CCall (CVar gId) [aE,bE])
flattenMeasureOp
:: forall abt typs args a .
( ABT Term abt
, typs ~ UnLCs args
, args ~ LCs typs )
=> MeasureOp typs a
-> SArgs abt args
-> (CExpr -> CodeGen ())
flattenMeasureOp Uniform =
\(a :* b :* End) ->
\loc ->
do aE <- flattenWithName a
bE <- flattenWithName b
uniformCG aE bE loc
flattenMeasureOp Normal =
\(a :* b :* End) ->
\loc ->
do aE <- flattenWithName a
bE <- flattenWithName b
normalCG aE (expE bE) loc
flattenMeasureOp Poisson =
\(lam :* End) ->
\loc ->
do lamE <- flattenWithName lam
(lId:kId:pId:[]) <- mapM genIdent' ["l","k","p"]
declare SProb lId
declare SNat kId
declare SProb pId
let (lE:kE:pE:[]) = fmap CVar [lId,kId,pId]
putExprStat $ lE .=. (expE (CUnary CMinOp $ expE lamE))
putExprStat $ kE .=. (intE 0)
putExprStat $ pE .=. (floatE 1)
doWhileCG (pE .>. lE) $
do uId <- genIdent' "u"
declare (SMeasure SReal) uId
let uE = CVar uId
uniformCG (intE 0) (intE 1) uE
putExprStat $ pE .*=. (mdataSample uE)
putExprStat $ kE .+=. (intE 1)
putExprStat $ mdataWeight loc .=. (intE 0)
putExprStat $ mdataSample loc .=. (kE .-. (intE 1))
flattenMeasureOp Gamma =
\(a :* b :* End) ->
\loc ->
do aE <- flattenWithName a
bE <- flattenWithName b
gammaCG (expE aE) bE loc
flattenMeasureOp Beta =
\(a :* b :* End) -> flattenABT (HKP.beta'' a b)
-- I ran into a bug here where sometime I recieved a location by reference and
-- others by value. Since measureOps assign a sample to mdata that they have a
-- reference to, we should enforce that when passing around mdata it is by
-- reference
flattenMeasureOp Categorical = \(arr :* End) ->
\loc ->
do arrE <- flattenWithName arr
itId <- genIdent' "it"
declare SNat itId
let itE = CVar itId
-- Accumulator for the total probability of the input array
wSumId <- genIdent' "ws"
declare SProb wSumId
let wSumE = CVar wSumId
assign wSumId (logE (intE 0))
-- Accumulator for the max value in the input array
wMaxId <- genIdent' "max"
declare SProb wMaxId
let wMaxE = CVar wMaxId
assign wMaxId (logE (floatE 0))
let currE = index (arrayData arrE) itE
isPar <- isParallel
mkSequential
-- Calculate the maximum value of the input array
-- And calculate the total weight
forCG (itE .=. (intE 0))
(itE .<. (arraySize arrE))
(CUnary CPostIncOp itE) $ do
ifCG (wMaxE .<. currE)
(putExprStat $ wMaxE .=. currE)
(return ())
logSumExpCG (S.fromList [wSumE, currE]) wSumE
putExprStat $ wSumE .=. (wSumE .-. wMaxE)
-- draw number from uniform(0, weightSum)
rId <- genIdent' "r"
declare SReal rId
let r = castTo [CDouble] randE
rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX")
rE = CVar rId
assign rId (logE (r ./. rMax) .+. wSumE)
assign wSumId (logE (intE 0))
assign itId (intE 0)
whileCG (intE 1) $
do ifCG (rE .<. wSumE)
(do putExprStat $ mdataWeight loc .=. (intE 0)
putExprStat $ mdataSample loc .=. (itE .-. (intE 1))
putStat CBreak)
(return ())
logSumExpCG (S.fromList [wSumE, currE .-. wMaxE]) wSumE
putExprStat $ CUnary CPostIncOp itE
when isPar mkParallel
flattenMeasureOp x = error $ "TODO: flattenMeasureOp: " ++ show x
---------------
-- Superpose --
---------------
flattenSuperpose
:: (ABT Term abt)
=> NE.NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
-> (CExpr -> CodeGen ())
-- do we need to normalize?
flattenSuperpose pairs =
let pairs' = NE.toList pairs in
if length pairs' == 1
then \loc -> let (w,m) = head pairs' in
do mE <- flattenWithName m
wE <- flattenWithName w
putExprStat $ mdataWeight loc .=. ((mdataWeight mE) .+. wE)
putExprStat $ mdataSample loc .=. (mdataSample mE)
else \loc ->
do wEs <- mapM (\(w,_) -> flattenWithName' w "w") pairs'
wSumId <- genIdent' "ws"
declare SProb wSumId
let wSumE = CVar wSumId
logSumExpCG (S.fromList wEs) wSumE
-- draw number from uniform(0, weightSum)
rId <- genIdent' "r"
declare SReal rId
let r = castTo [CDouble] randE
rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX")
rE = CVar rId
assign rId ((r ./. rMax) .*. (expE wSumE))
-- an iterator for picking a measure
itId <- genIdent' "it"
declare SProb itId
let itE = CVar itId
assign itId (logE (intE 0))
-- an output measure to assign to
outId <- genIdent' "out"
declare (typeOf . snd . head $ pairs') outId
let outE = CVar outId
outLabel <- genIdent' "exit"
forM_ (zip wEs pairs')
$ \(wE,(_,m)) ->
do logSumExpCG (S.fromList [itE,wE]) itE
ifCG (rE .<. (expE itE))
(flattenABT m outE >> putStat (CGoto outLabel))
(return ())
putStat $ CLabel outLabel (CExpr Nothing)
putExprStat $ mdataWeight loc .=. ((mdataWeight outE) .+. wSumE)
putExprStat $ mdataSample loc .=. (mdataSample outE)
--------------------------------------------------------------------------------
-- Specialized Arithmetic --
--------------------------------------------------------------------------------
--------------------------------------
-- LogSumExp for NaryOp Add [SProb] --
--------------------------------------
{-
Special for addition of probabilities we have a logSumExp. This will compute the
sum of the probabilities safely. Just adding the exp(a . prob) would make us
loose any of the safety from underflow that we got from storing prob in the log
domain
-}
-- the tree traversal is a depth first search
logSumExp :: S.Seq CExpr -> CExpr
logSumExp es = mkCompTree 0 1
where lastIndex = S.length es - 1
compIndices :: Int -> Int -> CExpr -> CExpr -> CExpr
compIndices i j = CCond ((S.index es i) .>. (S.index es j))
mkCompTree :: Int -> Int -> CExpr
mkCompTree i j
| j == lastIndex = compIndices i j (logSumExp' i) (logSumExp' j)
| otherwise = compIndices i j
(mkCompTree i (succ j))
(mkCompTree j (succ j))
diffExp :: Int -> Int -> CExpr
diffExp a b = expm1E ((S.index es a) .-. (S.index es b))
-- given the max index, produce a logSumExp expression
logSumExp' :: Int -> CExpr
logSumExp' 0 = S.index es 0
.+. (log1pE $ foldr (\x acc -> diffExp x 0 .+. acc)
(diffExp 1 0)
[2..S.length es - 1]
.+. (intE $ fromIntegral lastIndex))
logSumExp' i = S.index es i
.+. (log1pE $ foldr (\x acc -> if i == x
then acc
else diffExp x i .+. acc)
(diffExp 0 i)
[1..S.length es - 1]
.+. (intE $ fromIntegral lastIndex))
-- | logSumExpCG creates global functions for every n-ary logSumExp function
-- this reduces code size
logSumExpCG :: S.Seq CExpr -> (CExpr -> CodeGen ())
logSumExpCG seqE =
let size = S.length $ seqE
name = "logSumExp" ++ (show size)
funcId = Ident name
in \loc -> do -- reset the names so that the function is the same for each arity
let argIds = fmap Ident (take size cNameStream)
decls = fmap (typeDeclaration SProb) argIds
vars = fmap CVar argIds
funCG CDouble funcId decls
(putStat . CReturn . Just . logSumExp . S.fromList $ vars)
putExprStat $ loc .=. (CCall (CVar funcId) (F.toList seqE))
-------------------------------------
-- LogSumExp for Summation of Prob --
-------------------------------------
{-
For summation of SProb we need a new logSumExp function that will find the max
of an array and then sum it in a loop
-}
lseSummateArrayCG
:: ( ABT Term abt )
=> (abt '[ a ] b)
-> CExpr
-> (CExpr -> CodeGen ())
lseSummateArrayCG body arrayE =
caseBind body $ \v body' ->
\loc -> do
(maxVId:maxIId:sumId:[]) <- mapM genIdent' ["maxV","maxI","sum"]
itId <- createIdent v
mapM_ (declare SProb) [maxVId,sumId]
mapM_ (declare SNat) [maxIId,itId]
let (maxVE:maxIE:sumE:itE:[]) = fmap CVar [maxVId,maxIId,sumId,itId]
forCG (itE .=. intE 0)
(itE .<. arraySize arrayE)
(CUnary CPostIncOp itE)
(do tmpId <- genIdent
declare SProb tmpId
let tmpE = CVar tmpId
flattenABT body' tmpE
putExprStat $ derefIndex itE .=. tmpE
putStat $ CIf ((maxVE .<. tmpE) .||. (itE .==. (intE 0)))
(seqCStat . fmap (CExpr . Just) $
[ maxVE .=. tmpE
, maxIE .=. itE ])
Nothing)
putExprStat $ sumE .=. (floatE 0) -- the sum is actually in real space
forCG (itE .=. intE 0)
(itE .<. arraySize arrayE)
(CUnary CPostIncOp itE)
(putStat $ CIf (itE .!=. maxIE)
(CExpr . Just $ sumE .+=. (expE ((derefIndex itE) .-. (maxVE))))
Nothing)
putExprStat $ loc .=. (maxVE .+. (log1pE sumE))
where derefIndex xE = index (arrayData arrayE) xE
---------------------
-- Kahan Summation --
---------------------
-- | given a body and a size compute the kahan summation. This should work on
-- both probs and reals
kahanSummationCG
:: ( ABT Term abt )
=> (abt '[ a ] b)
-> CExpr
-> CExpr
-> (CExpr -> CodeGen ())
kahanSummationCG body loE hiE =
caseBind body $ \v body' ->
\loc -> do
(tId:cId:[]) <- mapM genIdent' ["t","c"]
itId <- createIdent v
declare SNat itId
mapM_ (declare SProb) [tId,cId]
let (tE:cE:itE:[]) = fmap CVar [tId,cId,itId]
putExprStat $ tE .=. (floatE 0)
putExprStat $ cE .=. (floatE 0)
forCG (itE .=. loE)
(itE .<. hiE)
(CUnary CPostIncOp itE)
(do (xId:yId:zId:[]) <- mapM genIdent' ["x","y","z"]
mapM_ (declare SProb) [xId,yId,zId]
let (xE:yE:zE:[]) = fmap CVar [xId,yId,zId]
flattenABT body' xE
putExprStat $ yE .=. (xE .-. cE)
putExprStat $ zE .=. (tE .+. yE)
putExprStat $ cE .=. ((zE .-. tE) .-. yE)
putExprStat $ tE .=. zE)
putExprStat $ loc .=. tE
--------------------------------------------------------------------------------
-- Coercion Helpers --
--------------------------------------------------------------------------------
-- instance PrimCoerce Value where
-- primCoerceTo c l =
-- case (c,l) of
-- (Signed HRing_Int, VNat a) -> VInt $ fromNat a
-- (Signed HRing_Real, VProb a) -> VReal $ LF.fromLogFloat a
-- (Continuous HContinuous_Prob, VNat a) ->
-- VProb $ LF.logFloat (fromIntegral (fromNat a) :: Double)
-- (Continuous HContinuous_Real, VInt a) -> VReal $ fromIntegral a
-- _ -> error "no a defined primitive coercion"
-- primCoerceFrom c l =
-- case (c,l) of
-- (Signed HRing_Int, VInt a) -> VNat $ unsafeNat a
-- (Signed HRing_Real, VReal a) -> VProb $ LF.logFloat a
-- (Continuous HContinuous_Prob, VProb a) ->
-- VNat $ unsafeNat $ floor (LF.fromLogFloat a :: Double)
-- (Continuous HContinuous_Real, VReal a) -> VInt $ floor a
-- _ -> error "no a defined primitive coercion"
coerceToCG
:: forall (a :: Hakaru) (b :: Hakaru)
. Coercion a b
-> CExpr
-> CodeGen CExpr
coerceToCG (CCons (Signed HRing_Int) cs) e = nat2int e >>= coerceToCG cs
coerceToCG (CCons (Signed HRing_Real) cs) e = prob2real e >>= coerceToCG cs
coerceToCG (CCons (Continuous HContinuous_Prob) cs) e = nat2prob e >>= coerceToCG cs
coerceToCG (CCons (Continuous HContinuous_Real) cs) e = int2real e >>= coerceToCG cs
coerceToCG CNil e = return e
coerceFromCG
:: forall (a :: Hakaru) (b :: Hakaru)
. Coercion a b
-> CExpr
-> CodeGen CExpr
coerceFromCG (CCons (Signed HRing_Int) cs) e = int2nat e >>= coerceFromCG cs
coerceFromCG (CCons (Signed HRing_Real) cs) e = real2prob e >>= coerceFromCG cs
coerceFromCG (CCons (Continuous HContinuous_Prob) cs) e = prob2nat e >>= coerceFromCG cs
coerceFromCG (CCons (Continuous HContinuous_Real) cs) e = real2int e >>= coerceFromCG cs
coerceFromCG CNil e = return e
-- safe
nat2int,nat2prob,prob2real,int2real
:: CExpr -> CodeGen CExpr
nat2int x = return x
nat2prob x = do x' <- localVar' SProb "n2p"
putExprStat $ x' .=. (logE x)
return x'
prob2real x = do x' <- localVar' SProb "p2r"
putExprStat $ x' .=. ((expm1E x) .+. (floatE 1))
return x'
int2real x = return (castTo [CDouble] x)
-- unsafe
{- Because of the hkc representation of reals and probs as doubles, (instead of
rationals). we will just silently truncate values for prob2nat and real2int
-}
int2nat,prob2nat,real2prob,real2int
:: CExpr -> CodeGen CExpr
int2nat x =
do x' <- localVar' SNat "i2n"
ifCG (x .<. (intE 0))
(do putExprStat $ printfE [ stringE "error: cannot coerce negative int to nat\n" ]
putExprStat $ mkCallE "abort" [] )
(putExprStat $ x' .=. (castTo [CUnsigned, CInt] x))
return x'
prob2nat x = return (castTo [CUnsigned, CInt] x)
real2prob x =
do x' <- localVar' SProb "r2p"
ifCG (x .<. (intE 0))
(do putExprStat $ printfE [ stringE "error: cannot coerce negative real to prob\n" ]
putExprStat $ mkCallE "abort" [] )
(putExprStat $ x' .=. (logE x))
return x'
real2int x = return (castTo [CInt] x)