{-# LANGUAGE CPP, BangPatterns, DataKinds, FlexibleContexts, GADTs, KindSignatures, ScopedTypeVariables, RankNTypes, TypeOperators #-} ---------------------------------------------------------------- -- 2016.06.23 -- | -- Module : Language.Hakaru.CodeGen.Flatten -- Copyright : Copyright (c) 2016 the Hakaru team -- License : BSD3 -- Maintainer : zsulliva@indiana.edu -- Stability : experimental -- Portability : GHC-only -- -- Flatten takes Hakaru ABTs and C vars and returns a CStatement -- assigning the var to the flattened ABT. -- ---------------------------------------------------------------- module Language.Hakaru.CodeGen.Flatten ( flattenABT , flattenVar , flattenTerm , flattenWithName , flattenWithName' , localVar , localVar' , opComment ) where import Language.Hakaru.CodeGen.CodeGenMonad import Language.Hakaru.CodeGen.AST import Language.Hakaru.CodeGen.Libs import Language.Hakaru.CodeGen.Types import Language.Hakaru.Syntax.AST import Language.Hakaru.Syntax.ABT import Language.Hakaru.Syntax.TypeOf import Language.Hakaru.Syntax.Datum hiding (Ident) import Language.Hakaru.Syntax.Reducer import qualified Language.Hakaru.Syntax.Prelude as HKP import Language.Hakaru.Types.DataKind import Language.Hakaru.Types.HClasses import Language.Hakaru.Syntax.IClasses import Language.Hakaru.Types.Coercion import Language.Hakaru.Types.Sing import Control.Monad.State.Strict import Control.Monad (replicateM) import Control.Applicative (pure) import Data.Number.Natural import Data.Monoid hiding (Product,Sum) import Data.Ratio import qualified Data.List.NonEmpty as NE import qualified Data.Sequence as S import qualified Data.Foldable as F import qualified Data.Traversable as T #if __GLASGOW_HASKELL__ < 710 import Data.Functor #endif opComment :: String -> CStat opComment opStr = CComment $ concat [space," ",opStr," ",space] where size = (50 - (length opStr)) `div` 2 - 8 space = replicate size '-' -------------------------------------------------------------------------------- -- Top Level -- -------------------------------------------------------------------------------- {- flattening an ABT will produce a continuation that takes a CExpr representing a location where the value of the ABT should be stored. Return type of the the continuation is CodeGen Bool, where the computed bool is whether or not there is a Reject inside the ABT. Therefore it is only needed when computing mochastic values -} localVar :: Sing (a :: Hakaru) -> CodeGen CExpr localVar typ = localVar' typ "" localVar' :: Sing (a :: Hakaru) -> String -> CodeGen CExpr localVar' typ s = do eId <- genIdent' s declare typ eId return (CVar eId) flattenWithName' :: ABT Term abt => abt '[] a -> String -> CodeGen CExpr flattenWithName' abt hint = do ident <- genIdent' hint declare (typeOf abt) ident let cvar = CVar ident flattenABT abt cvar return cvar flattenWithName :: ABT Term abt => abt '[] a -> CodeGen CExpr flattenWithName abt = flattenWithName' abt "" flattenABT :: ABT Term abt => abt '[] a -> (CExpr -> CodeGen ()) flattenABT abt = caseVarSyn abt flattenVar flattenTerm -- note that variables will find their values in the state of the CodeGen monad flattenVar :: Variable (a :: Hakaru) -> (CExpr -> CodeGen ()) flattenVar v = \loc -> do v' <- CVar <$> lookupIdent v putExprStat $ loc .=. v' flattenTerm :: ABT Term abt => Term abt a -> (CExpr -> CodeGen ()) -- SCon can contain mochastic terms flattenTerm (x :$ ys) = flattenSCon x ys flattenTerm (NaryOp_ t s) = flattenNAryOp t s flattenTerm (Literal_ x) = flattenLit x flattenTerm (Empty_ _) = error "TODO: flattenTerm{Empty}" flattenTerm (Datum_ d) = flattenDatum d flattenTerm (Case_ c bs) = flattenCase c bs flattenTerm (Bucket b e rs) = flattenBucket b e rs flattenTerm (Array_ s e) = flattenArray s e flattenTerm (ArrayLiteral_ s) = flattenArrayLiteral s --------------------- -- Mochastic Terms -- --------------------- flattenTerm (Superpose_ wes) = flattenSuperpose wes flattenTerm (Reject_ _) = \loc -> putExprStat (mdataWeight loc .=. (intE 0)) -------------------------------------------------------------------------------- -- SCon -- -------------------------------------------------------------------------------- flattenSCon :: ( ABT Term abt ) => SCon args a -> SArgs abt args -> (CExpr -> CodeGen ()) flattenSCon Let_ = \(expr :* body :* End) -> \loc -> do caseBind body $ \v@(Variable _ _ typ) body'-> do ident <- createIdent v case typ of (SFun _ _) -> return () _ -> declare typ ident flattenABT expr (CVar ident) flattenABT body' loc -- Lambdas produce functions and then return a function label exprssion flattenSCon Lam_ = \(body :* End) -> \loc -> do -- externally declare closure and function closureTypeSpec <- coalesceLambda body extDeclClosure -- declare local closure var closureId <- genIdent' "closure" declare' (buildDeclaration closureTypeSpec closureId) -- capture environment in closure putExprStat $ loc .=. (CVar closureId) where coalesceLambda :: ( ABT Term abt ) => abt '[x] a -> (forall (ys :: [Hakaru]) b. List1 Variable ys -> abt '[] b -> r) -> r coalesceLambda abt k = caseBind abt $ \v abt' -> caseVarSyn abt' (const (k (Cons1 v Nil1) abt')) $ \term -> case term of (Lam_ :$ body :* End) -> coalesceLambda body $ \vars abt'' -> k (Cons1 v vars) abt'' _ -> k (Cons1 v Nil1) abt' -- given a parameter, create identifiers corresponding to Hakaru vars, -- and return a CTypeSpec for the param -- Will this fail if the parameter is a SFun? mkVarIdandSpec :: Variable (a :: Hakaru) -> CodeGen (Ident,[CTypeSpec]) mkVarIdandSpec v@(Variable _ _ typ) = do extDeclareTypes typ vId <- createIdent v return (vId,buildType typ) extDeclClosure :: ( ABT Term abt ) => List1 Variable (ys :: [Hakaru]) -> abt '[] b -> CodeGen CTypeSpec extDeclClosure vars body'= do funcId <- genIdent' "fn" idAndSpecs <- sequence $ foldMap11 (\v -> [mkVarIdandSpec v]) vars let fVars = freeVars body' typ = typeOf body' sId@(Ident sname) <- extDeclClosureStruct typ (fmap snd idAndSpecs) fVars funCG (head . buildType $ typ) funcId ([buildDeclaration (callStruct sname) (Ident "env")] ++ (fmap (\(vId,specs) -> buildDeclaration' specs vId) idAndSpecs)) ((putStat . CReturn . Just) =<< flattenWithName body') return (callStruct sname) extDeclClosureStruct :: forall (a :: Hakaru) (ys :: [Hakaru]) . Sing a -> [[CTypeSpec]] -> VarSet (KindOf a) -> CodeGen Ident extDeclClosureStruct retTyp paramTypeSpecs freeVars = do sId@(Ident sname) <- genIdent' "clos" freeVarDecls <- mapM (\(SomeVariable v@(Variable _ _ typ)) -> do extDeclareTypes typ vId <- createIdent v return (typeDeclaration typ vId) ) (fromVarSet freeVars) let funPtrDecl = CDecl (fmap CTypeSpec $ buildType retTyp) [( CDeclr Nothing (CDDeclrFun (CDDeclrRec (CDeclr (Just $ CPtrDeclr []) (CDDeclrIdent . Ident $ "fn"))) ([callStruct sname]++(concat paramTypeSpecs))) , Nothing)] extDeclare $ CDeclExt $ CDecl [ CTypeSpec $ buildStruct (Just sId) ([funPtrDecl]++freeVarDecls) ] [] return sId flattenSCon App_ = \(fun :* arg :* End) -> \loc -> do closE <- flattenWithName' fun "clos" paramE <- flattenWithName' fun "param" putExprStat $ loc .=. (CCall (CMember closE (Ident "fn") True) [paramE]) flattenSCon (PrimOp_ op) = flattenPrimOp op flattenSCon (ArrayOp_ op) = flattenArrayOp op flattenSCon (Summate _ sr) = \(lo :* hi :* body :* End) -> \loc -> let semiTyp = sing_HSemiring sr in do loE <- flattenWithName' lo "lo" hiE <- flattenWithName' hi "hi" putStat $ opComment "Begin Summate" case semiTyp of -- special prob branch SProb -> do summateArrId <- genIdent' "summate_arr" declare (SArray SProb) summateArrId let summateArrE = CVar summateArrId putExprStat $ arraySize summateArrE .=. (hiE .-. loE) putExprStat $ arrayData summateArrE .=. (castToPtrOf [CDouble] (mallocE ((arraySize summateArrE) .*. (CSizeOfType (CTypeName [CDouble] False))))) lseSummateArrayCG body summateArrE loc putExprStat $ freeE (arrayData summateArrE) _ -> caseBind body $ \v body' -> do iterI <- createIdent v declare SNat iterI accI <- genIdent' "acc" declare semiTyp accI assign accI (case semiTyp of SReal -> floatE 0 _ -> intE 0) let accVar = CVar accI iterVar = CVar iterI reductionCG (Left CAddOp) accI (iterVar .=. loE) (iterVar .<. hiE) (CUnary CPostIncOp iterVar) $ (putExprStat . (accVar .+=.) =<< flattenWithName body') putExprStat $ loc .=. accVar putStat $ opComment "End Summate" flattenSCon (Product _ sr) = \(lo :* hi :* body :* End) -> \loc -> let semiTyp = sing_HSemiring sr in do loE <- flattenWithName' lo "lo" hiE <- flattenWithName' hi "hi" putStat $ opComment "Begin Product" case semiTyp of -- special prob branch SProb -> kahanSummationCG body loE hiE loc _ -> caseBind body $ \v body' -> do iterI <- createIdent v declare SNat iterI accI <- genIdent' "acc" declare semiTyp accI assign accI (case semiTyp of SReal -> floatE 1 _ -> intE 1) let accVar = CVar accI iterVar = CVar iterI reductionCG (Left CMulOp) accI (iterVar .=. loE) (iterVar .<. hiE) (CUnary CPostIncOp iterVar) $ (putExprStat . (accVar .*=.) =<< flattenWithName body') putExprStat (loc .=. accVar) putStat $ opComment "End Product" -------------------- -- SCon Coercions -- -------------------- {- Helpers found by searching "Coercion Helpers" -} flattenSCon (CoerceTo_ ctyp) = \(e :* End) -> \loc -> do eE <- flattenWithName e cE <- coerceToCG ctyp eE putExprStat $ loc .=. cE flattenSCon (UnsafeFrom_ ctyp) = \(e :* End) -> \loc -> do eE <- flattenWithName e cE <- coerceFromCG ctyp eE putExprStat $ loc .=. cE ----------------------------------- -- SCons in the Stochastic Monad -- ----------------------------------- flattenSCon (MeasureOp_ op) = flattenMeasureOp op flattenSCon Dirac = \(e :* End) -> \loc -> do sE <- flattenWithName' e "samp" putExprStat $ mdataWeight loc .=. (floatE 0) putExprStat $ mdataSample loc .=. sE flattenSCon MBind = \(ma :* b :* End) -> \loc -> caseBind b $ \v@(Variable _ _ typ) mb -> do -- first mE <- flattenWithName' ma "m" -- assign that sample to var vId <- createIdent v declare typ vId assign vId (mdataSample mE) flattenABT mb loc putExprStat $ mdataWeight loc .+=. (mdataWeight mE) -- for now plats make use of a global sample flattenSCon Plate = \(size :* b :* End) -> \loc -> caseBind b $ \v@(Variable _ _ typ) body -> do sizeE <- flattenWithName' size "s" isMM <- managedMem <$> get when (not isMM) (error "plate will leak memory without the '-g' flag and boehm-gc") putExprStat $ (arraySize . mdataSample $ loc) .=. sizeE putMallocStat (arrayData . mdataSample $ loc) sizeE (typeOf body) weightId <- genIdent' "w" declare SProb weightId let weightE = CVar weightId assign weightId (floatE 0) itId <- createIdent v declare SNat itId let itE = CVar itId currInd = index (arrayData . mdataSample $ loc) itE sampId <- genIdent' "samp" declare (typeOf $ body) sampId let sampE = CVar sampId reductionCG (Left CAddOp) weightId (itE .=. (intE 0)) (itE .<. sizeE) (CUnary CPostIncOp itE) (do flattenABT body sampE putExprStat (currInd .=. (mdataSample sampE)) putExprStat (weightE .+=. (mdataWeight sampE))) putExprStat $ mdataWeight loc .=. weightE ----------------------------------- -- SCon's that arent implemented -- ----------------------------------- flattenSCon x = \_ -> \_ -> error $ "TODO: flattenSCon: " ++ show x -------------------------------------------------------------------------------- -- NaryOps -- -------------------------------------------------------------------------------- flattenNAryOp :: ABT Term abt => NaryOp a -> S.Seq (abt '[] a) -> (CExpr -> CodeGen ()) flattenNAryOp op args = \loc -> do es <- T.mapM flattenWithName args case op of And -> boolNaryOp op es loc Or -> boolNaryOp op es loc Xor -> boolNaryOp op es loc Iff -> boolNaryOp op es loc (Sum HSemiring_Prob) -> logSumExpCG es loc _ -> let opE = F.foldr (binaryOp op) (S.index es 0) (S.drop 1 es) in putExprStat (loc .=. opE) where boolNaryOp op' es' loc' = let indexOf x = CMember x (Ident "index") True es'' = fmap indexOf es' expr = F.foldr (binaryOp op') (S.index es'' 0) (S.drop 1 es'') in putExprStat ((indexOf loc') .=. expr) -------------------------------------------------------------------------------- -- Literals -- -------------------------------------------------------------------------------- flattenLit :: Literal a -> (CExpr -> CodeGen ()) flattenLit lit = \loc -> case lit of (LNat x) -> putExprStat $ loc .=. (intE $ fromIntegral x) (LInt x) -> putExprStat $ loc .=. (intE x) (LReal x) -> putExprStat $ loc .=. (floatE $ fromRational x) (LProb x) -> let rat = fromNonNegativeRational x x' = (fromIntegral $ numerator rat) / (fromIntegral $ denominator rat) xE = log1pE (floatE x' .-. intE 1) in putExprStat (loc .=. xE) -------------------------------------------------------------------------------- -- Array and ArrayOps -- -------------------------------------------------------------------------------- flattenArray :: (ABT Term abt) => (abt '[] 'HNat) -> (abt '[ 'HNat ] a) -> (CExpr -> CodeGen ()) flattenArray arity body = \loc -> caseBind body $ \v body' -> do let arityE = arraySize loc dataE = arrayData loc typ = typeOf body' flattenABT arity arityE isManagedMem <- managedMem <$> get let malloc' = if isManagedMem then gcMalloc else mallocE putExprStat $ dataE .=. (CCast (CTypeName (buildType typ) True) (malloc' (arityE .*. (CSizeOfType (CTypeName (buildType typ) False))))) itId <- createIdent v declare SNat itId let itE = CVar itId currInd = index dataE itE putStat $ opComment "Begin Array" forCG (itE .=. (intE 0)) (itE .<. arityE) (CUnary CPostIncOp itE) (flattenABT body' currInd) putStat $ opComment "End Array" flattenArrayLiteral :: ( ABT Term abt ) => [abt '[] a] -> (CExpr -> CodeGen ()) flattenArrayLiteral es = \loc -> do arrId <- genIdent isManagedMem <- managedMem <$> get let arity = fromIntegral . length $ es typ = typeOf . head $ es arrE = CVar arrId malloc' = if isManagedMem then gcMalloc else mallocE declare (SArray typ) arrId putExprStat $ (arrayData arrE) .=. (CCast (CTypeName (buildType typ) True) (malloc' ((intE arity) .*. (CSizeOfType (CTypeName (buildType typ) False))))) putExprStat $ arraySize arrE .=. (intE arity) sequence_ . snd $ foldl (\(i,acc) e -> (succ i,(assignIndex e i arrE):acc)) (0,[]) es putExprStat $ loc .=. arrE where assignIndex :: ( ABT Term abt ) => abt '[] a -> Integer -> (CExpr -> CodeGen ()) assignIndex e index loc = do eE <- flattenWithName e putExprStat $ indirect ((arrayData loc) .+. (intE index)) .=. eE -------------- -- ArrayOps -- -------------- flattenArrayOp :: ( ABT Term abt , typs ~ UnLCs args , args ~ LCs typs ) => ArrayOp typs a -> SArgs abt args -> (CExpr -> CodeGen ()) flattenArrayOp (Index _) = \(arr :* ind :* End) -> \loc -> do indE <- flattenWithName ind arrE <- flattenWithName arr let valE = index (CMember arrE (Ident "data") True) indE putExprStat (loc .=. valE) flattenArrayOp (Size _) = \(arr :* End) -> \loc -> do arrE <- flattenWithName arr putExprStat (loc .=. (CMember arrE (Ident "size") True)) flattenArrayOp (Reduce _) = error "TODO: flattenArrayOp" -- \(fun :* base :* arr :* End) -> -- do funE <- flattenABT fun -- baseE <- flattenABT base -- arrE <- flattenABT arr -- accI <- genIdent' "acc" -- iterI <- genIdent' "iter" -- let sizeE = CMember arrE (Ident "size") True -- iterE = CVar iterI -- accE = CVar accI -- cond = iterE .<. sizeE -- inc = CUnary CPostIncOp iterE -- declare (typeOf base) accI -- declare SInt iterI -- assign accI baseE -- forCG (iterE .=. (intE 0)) cond inc $ -- assign accI $ CCall funE [accE] -- return accE -------------------------------------------------------------------------------- -- Bucket and Reducers -- -------------------------------------------------------------------------------- {- Declarations for buckets - since we will have some product of monoids we need unique names for each one. Ex: bucket i from 0 to 100: fanout(add(\_ -> 1),add(\_ -> 2)) we will need to keep track of two ints: int x; int y; for (i = 0; i < 100; i++) { x += 1; y += 2; } Summary objects are nested pairs, e.g. pair(nat,pair(real,array(nat))) -} flattenBucket :: (ABT Term abt) => abt '[] 'HNat -> abt '[] 'HNat -> Reducer abt '[] a -> (CExpr -> CodeGen ()) flattenBucket lo hi red = \loc -> do putStat $ opComment "Begin Bucket" loE <- flattenWithName' lo "lo" hiE <- flattenWithName' hi "hi" itId <- genIdent' "it" declare SNat itId let itE = CVar itId initRed red loc forCG (itE .=. loE) (itE .<. hiE) (CUnary CPostIncOp itE) (accumRed red itE loc) putStat $ opComment "End Bucket" where initRed :: (ABT Term abt) => Reducer abt xs a -> (CExpr -> CodeGen ()) initRed mr = \loc -> case mr of (Red_Fanout mr1 mr2) -> initRed mr1 (datumFst loc) >> initRed mr2 (datumSnd loc) (Red_Split _ mr1 mr2) -> initRed mr1 (datumFst loc) >> initRed mr2 (datumSnd loc) (Red_Index s _ body) -> let (vs,s') = caseBinds s btyp = typeOfReducer body in do sequence_ . foldMap11 (\v' -> case v' of (Variable _ _ typ') -> [declare typ' =<< createIdent v']) $ vs sE <- flattenWithName s' putExprStat $ arraySize loc .=. sE putMallocStat (arrayData loc) sE btyp itId <- genIdent declare SNat itId let itE = CVar itId forCG (itE .=. (intE 0)) (itE .<. sE) (CUnary CPostIncOp itE) (initRed body (index (arrayData loc) itE)) Red_Nop -> return () (Red_Add sr _) -> putExprStat $ loc .=. (addMonoidIdentity . sing_HSemiring $ sr) accumRed :: (ABT Term abt) => Reducer abt xs a -> CExpr -> (CExpr -> CodeGen ()) accumRed mr itE = \loc -> case mr of (Red_Index _ a body) -> caseBind a $ \v@(Variable _ _ typ) a' -> let (vs,a'') = caseBinds a' in do vId <- createIdent v declare typ vId putExprStat $ (CVar vId) .=. itE sequence_ . foldMap11 (\v' -> case v' of (Variable _ _ typ') -> [declare typ' =<< createIdent v']) $ vs aE <- flattenWithName' a'' "index" accumRed body itE (index (arrayData loc) aE) (Red_Fanout mr1 mr2) -> accumRed mr1 itE (datumFst loc) >> accumRed mr2 itE (datumSnd loc) (Red_Split b mr1 mr2) -> caseBind b $ \v@(Variable _ _ typ) b' -> let (vs,b'') = caseBinds b' in do vId <- createIdent v declare typ vId putExprStat $ (CVar vId) .=. itE sequence_ . foldMap11 (\v' -> case v' of (Variable _ _ typ') -> [declare typ' =<< createIdent v']) $ vs bE <- flattenWithName' b'' "cond" ifCG (bE ... "index" .==. (intE 0)) (accumRed mr1 itE (datumFst loc)) (accumRed mr2 itE (datumSnd loc)) Red_Nop -> return () (Red_Add sr e) -> caseBind e $ \v@(Variable _ _ typ) e' -> let (vs,e'') = caseBinds e' in do vId <- createIdent v declare typ vId putExprStat $ (CVar vId) .=. itE sequence_ . foldMap11 (\v' -> case v' of (Variable _ _ typ') -> [declare typ' =<< createIdent v']) $ vs eE <- flattenWithName e'' case sing_HSemiring sr of SProb -> logSumExpCG (S.fromList [loc,eE]) loc _ -> putExprStat $ loc .+=. eE addMonoidIdentity :: Sing (a :: Hakaru) -> CExpr addMonoidIdentity s = case s of SNat -> intE 0 SInt -> intE 0 SReal -> floatE 0 SProb -> logE (floatE 0) SArray x -> addMonoidIdentity x x -> error $ "addMonoidIdentity{" ++ show x ++ "}" -------------------------------------------------------------------------------- -- Datum and Case -- -------------------------------------------------------------------------------- {- Datum are sums of products of types. This maps to a C structure. flattenDatum will produce a literal of some datum type. This will also produce a global struct representing that datum which will be needed for the C compiler. -} flattenDatum :: (ABT Term abt) => Datum (abt '[]) (HData' a) -> (CExpr -> CodeGen ()) flattenDatum (Datum _ typ code) = \loc -> do extDeclareTypes typ assignDatum code loc assignDatum :: (ABT Term abt) => DatumCode xss (abt '[]) c -> CExpr -> CodeGen () assignDatum code ident = let index = getIndex code indexExpr = CMember ident (Ident "index") True in do putExprStat (indexExpr .=. (intE index)) sequence_ $ assignSum code ident where getIndex :: DatumCode xss b c -> Integer getIndex (Inl _) = 0 getIndex (Inr rest) = succ (getIndex rest) assignSum :: (ABT Term abt) => DatumCode xs (abt '[]) c -> CExpr -> [CodeGen ()] assignSum code ident = fst $ runState (assignSum' code ident) cNameStream assignSum' :: (ABT Term abt) => DatumCode xs (abt '[]) c -> CExpr -> State [String] [CodeGen ()] assignSum' (Inr rest) topIdent = do (_:names) <- get put names assignSum' rest topIdent assignSum' (Inl prod) topIdent = do (name:_) <- get return $ assignProd prod topIdent (CVar . Ident $ name) assignProd :: (ABT Term abt) => DatumStruct xs (abt '[]) c -> CExpr -> CExpr -> [CodeGen ()] assignProd dstruct topIdent sumIdent = fst $ runState (assignProd' dstruct topIdent sumIdent) cNameStream assignProd' :: (ABT Term abt) => DatumStruct xs (abt '[]) c -> CExpr -> CExpr -> State [String] [CodeGen ()] assignProd' Done _ _ = return [] assignProd' (Et (Konst d) rest) topIdent (CVar sumIdent) = do (name:names) <- get put names let varName = CMember (CMember (CMember topIdent (Ident "sum") True) sumIdent True) (Ident name) True rest' <- assignProd' rest topIdent (CVar sumIdent) return $ [flattenABT d varName] ++ rest' assignProd' _ _ _ = error $ "TODO: assignProd Ident" ---------- -- Case -- ---------- -- currently we can only match on boolean values flattenCase :: forall abt a b . (ABT Term abt) => abt '[] a -> [Branch a abt b] -> (CExpr -> CodeGen ()) flattenCase c [ Branch (PDatum _ (PInl PDone)) trueB , Branch (PDatum _ (PInr (PInl PDone))) falseB ] = \loc -> do cE <- flattenWithName c ifCG ((cE ... "index") .==. (intE 0)) (flattenABT trueB loc) (flattenABT falseB loc) flattenCase e [ Branch (PDatum _ (PInl (PEt (PKonst PVar) (PEt (PKonst PVar) PDone)))) b ] = \loc -> do eE <- flattenWithName e caseBind b $ \vfst@(Variable _ _ fstTyp) b' -> caseBind b' $ \vsnd@(Variable _ _ sndTyp) b'' -> do fstId <- createIdent vfst sndId <- createIdent vsnd declare fstTyp fstId declare sndTyp sndId putExprStat $ (CVar fstId) .=. (datumFst eE) putExprStat $ (CVar sndId) .=. (datumSnd eE) flattenABT b'' loc flattenCase e _ = error $ "TODO: flattenCase{" ++ show (typeOf e) ++ "}" -------------------------------------------------------------------------------- -- PrimOp -- -------------------------------------------------------------------------------- flattenPrimOp :: ( ABT Term abt , typs ~ UnLCs args , args ~ LCs typs) => PrimOp typs a -> SArgs abt args -> (CExpr -> CodeGen ()) flattenPrimOp Pi = \End -> \loc -> let piE = log1pE ((CVar . Ident $ "M_PI") .-. (intE 1)) in putExprStat (loc .=. piE) flattenPrimOp Not = \(a :* End) -> \loc -> do aE <- flattenWithName a bId <- genIdent declare (typeOf a) bId let datumIndex e = CMember e (Ident "index") True bE = CVar bId putExprStat $ datumIndex bE .=. (CCond (datumIndex aE .==. (intE 1)) (intE 0) (intE 1)) putExprStat $ loc .=. bE flattenPrimOp RealPow = \(base :* power :* End) -> \loc -> do baseE <- flattenWithName base powerE <- flattenWithName power let realPow = CCall (CVar . Ident $ "pow") [ expm1E baseE .+. (intE 1), powerE] putExprStat $ loc .=. (log1pE (realPow .-. (intE 1))) flattenPrimOp (NatPow baseTyp) = \(base :* power :* End) -> \loc -> let sBase = sing_HSemiring baseTyp in do baseId <- genIdent declare sBase baseId let baseE = CVar baseId flattenABT base baseE powerE <- flattenWithName power let powerOf x y = CCall (CVar . Ident $ "pow") [x,y] value = case sBase of SProb -> log1pE $ (powerOf (expm1E baseE .+. (intE 1)) powerE) .-. (intE 1) _ -> powerOf baseE powerE putExprStat $ loc .=. value flattenPrimOp (NatRoot baseTyp) = \(base :* root :* End) -> \loc -> let sBase = sing_HRadical baseTyp in do baseId <- genIdent declare sBase baseId let baseE = CVar baseId flattenABT base baseE rootE <- flattenWithName root let powerOf x y = CCall (CVar . Ident $ "pow") [x,y] recipE = (floatE 1) ./. rootE value = case sBase of SProb -> log1pE $ (powerOf (expm1E baseE .+. (intE 1)) recipE) .-. (intE 1) _ -> powerOf baseE recipE putExprStat $ loc .=. value flattenPrimOp (Recip t) = \(a :* End) -> \loc -> do aE <- flattenWithName a case t of HFractional_Real -> putExprStat $ loc .=. ((intE 1) ./. aE) HFractional_Prob -> putExprStat $ loc .=. (CUnary CMinOp aE) -- | exp : real -> prob, because of this we can just turn it into a prob without taking -- its log, which would give us an exp in the log-domain flattenPrimOp Exp = \(a :* End) -> flattenABT a flattenPrimOp (Equal _) = \(a :* b :* End) -> \loc -> do aE <- flattenWithName a bE <- flattenWithName b -- special case for booleans let aE' = case (typeOf a) of (SData _ (SPlus SDone (SPlus SDone SVoid))) -> (CMember aE (Ident "index") True) _ -> aE let bE' = case (typeOf b) of (SData _ (SPlus SDone (SPlus SDone SVoid))) -> (CMember bE (Ident "index") True) _ -> bE putExprStat $ (CMember loc (Ident "index") True) .=. (CCond (aE' .==. bE') (intE 0) (intE 1)) flattenPrimOp (Less _) = \(a :* b :* End) -> \loc -> do aE <- flattenWithName a bE <- flattenWithName b putExprStat $ (CMember loc (Ident "index") True) .=. (CCond (aE .<. bE) (intE 0) (intE 1)) flattenPrimOp (Negate _) = \(a :* End) -> \loc -> do aE <- flattenWithName a putExprStat $ loc .=. (CUnary CMinOp $ aE) flattenPrimOp t = \_ -> error $ "TODO: flattenPrimOp: " ++ show t -------------------------------------------------------------------------------- -- MeasureOps and Superpose -- -------------------------------------------------------------------------------- {- The sections contains operations in the stochastic monad. See also (Dirac, MBind, and Plate) found in SCon. Also see Reject found at the top level. Remember in the C runtime. Measures are housed in a measure function, which takes an `struct mdata` location. The MeasureOp attempts to store a value at that location and returns 0 if it fails and 1 if it succeeds in that task. The functions uniformCG, normalCG, and gammaCG are primitives that will generate functions and call them (similar to logSumExpCG). The reduce code size and make samplers a little more readable. TODO: add inline pragmas to uniformCG, normalCG, and gammaCG -} uniformFun :: CFunDef uniformFun = CFunDef (CTypeSpec <$> retTyp) (CDeclr Nothing (CDDeclrIdent funcId)) [typeDeclaration SReal loId ,typeDeclaration SReal hiId] (CCompound . concat $ [ CBlockDecl <$> [declMD] , CBlockStat <$> comment ++ [assW,assS,CReturn . Just $ mE]] ) where r = castTo [CDouble] randE rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX") retTyp = buildType (SMeasure SReal) (mId,mE) = let ident = Ident "mdata" in (ident,CVar ident) (loId,loE) = let ident = Ident "lo" in (ident,CVar ident) (hiId,hiE) = let ident = Ident "hi" in (ident,CVar ident) value = (loE .+. ((r ./. rMax) .*. (hiE .-. loE))) comment = fmap CComment ["uniform :: real -> real -> *(mdata real) -> ()" ,"------------------------------------------------"] declMD = buildDeclaration (head retTyp) mId assW = CExpr . Just $ mdataWeight mE .=. (floatE 0) assS = CExpr . Just $ mdataSample mE .=. value funcId = Ident "uniform" uniformCG :: CExpr -> CExpr -> (CExpr -> CodeGen ()) uniformCG aE bE = \loc -> do uId <- reserveIdent "uniform" extDeclareTypes (SMeasure SReal) extDeclare . CFunDefExt $ uniformFun putExprStat $ loc .=. CCall (CVar uId) [aE,bE] {- This is very cryptic, but I assure you it is only building an AST for the Marsaglia Polar Method -} normalFun :: CFunDef normalFun = CFunDef (CTypeSpec <$> retTyp) (CDeclr Nothing (CDDeclrIdent (Ident "normal"))) [typeDeclaration SReal aId ,typeDeclaration SProb bId ] ( CCompound . concat $ [[CBlockDecl declMD],comment,decls,stmts]) where r = castTo [CDouble] randE rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX") retTyp = buildType (SMeasure SReal) (aId,aE) = let ident = Ident "a" in (ident,CVar ident) (bId,bE) = let ident = Ident "b" in (ident,CVar ident) (qId,qE) = let ident = Ident "q" in (ident,CVar ident) (uId,uE) = let ident = Ident "u" in (ident,CVar ident) (vId,vE) = let ident = Ident "v" in (ident,CVar ident) (rId,rE) = let ident = Ident "r" in (ident,CVar ident) (mId,mE) = let ident = Ident "mdata" in (ident,CVar ident) declMD = buildDeclaration (head retTyp) mId draw xE = CExpr . Just $ xE .=. (((r ./. rMax) .*. (floatE 2)) .-. (floatE 1)) body = seqCStat [draw uE ,draw vE ,CExpr . Just $ qE .=. ((uE .*. uE) .+. (vE .*. vE))] polar = CWhile (qE .>. (floatE 1)) body True setR = CExpr . Just $ rE .=. (sqrtE (((CUnary CMinOp (floatE 2)) .*. logE qE) ./. qE)) finalValue = aE .+. (uE .*. rE .*. bE) comment = fmap (CBlockStat . CComment) ["normal :: real -> real -> *(mdata real) -> ()" ,"Marsaglia Polar Method" ,"-----------------------------------------------"] decls = (CBlockDecl . typeDeclaration SReal) <$> [uId,vId,qId,rId] stmts = CBlockStat <$> [polar,setR, assW, assS,CReturn . Just $ mE] assW = CExpr . Just $ mdataWeight mE .=. (floatE 0) assS = CExpr . Just $ mdataSample mE .=. finalValue normalCG :: CExpr -> CExpr -> (CExpr -> CodeGen ()) normalCG aE bE = \loc -> do nId <- reserveIdent "normal" extDeclareTypes (SMeasure SReal) extDeclare . CFunDefExt $ normalFun putExprStat $ loc .=. (CCall (CVar nId) [aE,bE]) {- This method is from Marsaglia and Tsang "a simple method for generating gamma variables" -} gammaFun :: CFunDef gammaFun = CFunDef (CTypeSpec <$> retTyp) (CDeclr Nothing (CDDeclrIdent (Ident "gamma"))) [typeDeclaration SProb aId ,typeDeclaration SProb bId] ( CCompound . concat $ [[CBlockDecl declMD],comment,decls,stmts]) where (aId,aE) = let ident = Ident "a" in (ident,CVar ident) (bId,bE) = let ident = Ident "b" in (ident,CVar ident) (cId,cE) = let ident = Ident "c" in (ident,CVar ident) (dId,dE) = let ident = Ident "d" in (ident,CVar ident) (xId,xE) = let ident = Ident "x" in (ident,CVar ident) (vId,vE) = let ident = Ident "v" in (ident,CVar ident) (uId,uE) = let ident = Ident "u" in (ident,CVar ident) (mId,mE) = let ident = Ident "mdata" in (ident,CVar ident) retTyp = buildType (SMeasure SProb) declMD = buildDeclaration (head retTyp) mId comment = fmap (CBlockStat . CComment) ["gamma :: real -> prob -> *(mdata prob) -> ()" ,"Marsaglia and Tsang 'a simple method for generating gamma variables'" ,"--------------------------------------------------------------------"] decls = fmap CBlockDecl $ (fmap (typeDeclaration SReal) [dId,cId,vId]) ++ (fmap (typeDeclaration (SMeasure SReal)) [uId,xId]) stmts = fmap CBlockStat $ [assD,assC,outerWhile] xS = mdataSample xE uS = mdataSample uE assD = CExpr . Just $ dE .=. (aE .-. ((floatE 1) ./. (floatE 3))) assC = CExpr . Just $ cE .=. ((floatE 1) ./. (sqrtE ((floatE 9) .*. dE))) outerWhile = CWhile (intE 1) (seqCStat [innerWhile,assV,assU,exit]) False innerWhile = CWhile (vE .<=. (floatE 0)) (seqCStat [assX,assVIn]) True assX = CExpr . Just $ xE .=. (CCall (CVar . Ident $ "normal") [(floatE 0),(floatE 1)]) assVIn = CExpr . Just $ vE .=. ((floatE 1) .+. (cE .*. xS)) assV = CExpr . Just $ vE .=. (vE .*. vE .*. vE) assU = CExpr . Just $ uE .=. (CCall (CVar . Ident $ "uniform") [(floatE 0),(floatE 1)]) exitC1 = uS .<. ((floatE 1) .-. ((floatE 0.331 .*. (xS .*. xS) .*. (xS .*. xS)))) exitC2 = (logE uS) .<. (((floatE 0.5) .*. (xS .*. xS)) .+. (dE .*. ((floatE 1.0) .-. vE .+. (logE vE)))) assW = CExpr . Just $ mdataWeight mE .=. (floatE 0) assS = CExpr . Just $ mdataSample mE .=. (logE (dE .*. vE)) .+. bE exit = CIf (exitC1 .||. exitC2) (seqCStat [assW,assS,CReturn . Just $ mE]) Nothing gammaCG :: CExpr -> CExpr -> (CExpr -> CodeGen ()) gammaCG aE bE = \loc -> do extDeclareTypes (SMeasure SReal) (_:_:gId:[]) <- mapM reserveIdent ["uniform","normal","gamma"] mapM_ (extDeclare . CFunDefExt) [uniformFun,normalFun,gammaFun] putExprStat $ loc .=. (CCall (CVar gId) [aE,bE]) flattenMeasureOp :: forall abt typs args a . ( ABT Term abt , typs ~ UnLCs args , args ~ LCs typs ) => MeasureOp typs a -> SArgs abt args -> (CExpr -> CodeGen ()) flattenMeasureOp Uniform = \(a :* b :* End) -> \loc -> do aE <- flattenWithName a bE <- flattenWithName b uniformCG aE bE loc flattenMeasureOp Normal = \(a :* b :* End) -> \loc -> do aE <- flattenWithName a bE <- flattenWithName b normalCG aE (expE bE) loc flattenMeasureOp Poisson = \(lam :* End) -> \loc -> do lamE <- flattenWithName lam (lId:kId:pId:[]) <- mapM genIdent' ["l","k","p"] declare SProb lId declare SNat kId declare SProb pId let (lE:kE:pE:[]) = fmap CVar [lId,kId,pId] putExprStat $ lE .=. (expE (CUnary CMinOp $ expE lamE)) putExprStat $ kE .=. (intE 0) putExprStat $ pE .=. (floatE 1) doWhileCG (pE .>. lE) $ do uId <- genIdent' "u" declare (SMeasure SReal) uId let uE = CVar uId uniformCG (intE 0) (intE 1) uE putExprStat $ pE .*=. (mdataSample uE) putExprStat $ kE .+=. (intE 1) putExprStat $ mdataWeight loc .=. (intE 0) putExprStat $ mdataSample loc .=. (kE .-. (intE 1)) flattenMeasureOp Gamma = \(a :* b :* End) -> \loc -> do aE <- flattenWithName a bE <- flattenWithName b gammaCG (expE aE) bE loc flattenMeasureOp Beta = \(a :* b :* End) -> flattenABT (HKP.beta'' a b) -- I ran into a bug here where sometime I recieved a location by reference and -- others by value. Since measureOps assign a sample to mdata that they have a -- reference to, we should enforce that when passing around mdata it is by -- reference flattenMeasureOp Categorical = \(arr :* End) -> \loc -> do arrE <- flattenWithName arr itId <- genIdent' "it" declare SNat itId let itE = CVar itId -- Accumulator for the total probability of the input array wSumId <- genIdent' "ws" declare SProb wSumId let wSumE = CVar wSumId assign wSumId (logE (intE 0)) -- Accumulator for the max value in the input array wMaxId <- genIdent' "max" declare SProb wMaxId let wMaxE = CVar wMaxId assign wMaxId (logE (floatE 0)) let currE = index (arrayData arrE) itE isPar <- isParallel mkSequential -- Calculate the maximum value of the input array -- And calculate the total weight forCG (itE .=. (intE 0)) (itE .<. (arraySize arrE)) (CUnary CPostIncOp itE) $ do ifCG (wMaxE .<. currE) (putExprStat $ wMaxE .=. currE) (return ()) logSumExpCG (S.fromList [wSumE, currE]) wSumE putExprStat $ wSumE .=. (wSumE .-. wMaxE) -- draw number from uniform(0, weightSum) rId <- genIdent' "r" declare SReal rId let r = castTo [CDouble] randE rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX") rE = CVar rId assign rId (logE (r ./. rMax) .+. wSumE) assign wSumId (logE (intE 0)) assign itId (intE 0) whileCG (intE 1) $ do ifCG (rE .<. wSumE) (do putExprStat $ mdataWeight loc .=. (intE 0) putExprStat $ mdataSample loc .=. (itE .-. (intE 1)) putStat CBreak) (return ()) logSumExpCG (S.fromList [wSumE, currE .-. wMaxE]) wSumE putExprStat $ CUnary CPostIncOp itE when isPar mkParallel flattenMeasureOp x = error $ "TODO: flattenMeasureOp: " ++ show x --------------- -- Superpose -- --------------- flattenSuperpose :: (ABT Term abt) => NE.NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a)) -> (CExpr -> CodeGen ()) -- do we need to normalize? flattenSuperpose pairs = let pairs' = NE.toList pairs in if length pairs' == 1 then \loc -> let (w,m) = head pairs' in do mE <- flattenWithName m wE <- flattenWithName w putExprStat $ mdataWeight loc .=. ((mdataWeight mE) .+. wE) putExprStat $ mdataSample loc .=. (mdataSample mE) else \loc -> do wEs <- mapM (\(w,_) -> flattenWithName' w "w") pairs' wSumId <- genIdent' "ws" declare SProb wSumId let wSumE = CVar wSumId logSumExpCG (S.fromList wEs) wSumE -- draw number from uniform(0, weightSum) rId <- genIdent' "r" declare SReal rId let r = castTo [CDouble] randE rMax = castTo [CDouble] (CVar . Ident $ "RAND_MAX") rE = CVar rId assign rId ((r ./. rMax) .*. (expE wSumE)) -- an iterator for picking a measure itId <- genIdent' "it" declare SProb itId let itE = CVar itId assign itId (logE (intE 0)) -- an output measure to assign to outId <- genIdent' "out" declare (typeOf . snd . head $ pairs') outId let outE = CVar outId outLabel <- genIdent' "exit" forM_ (zip wEs pairs') $ \(wE,(_,m)) -> do logSumExpCG (S.fromList [itE,wE]) itE ifCG (rE .<. (expE itE)) (flattenABT m outE >> putStat (CGoto outLabel)) (return ()) putStat $ CLabel outLabel (CExpr Nothing) putExprStat $ mdataWeight loc .=. ((mdataWeight outE) .+. wSumE) putExprStat $ mdataSample loc .=. (mdataSample outE) -------------------------------------------------------------------------------- -- Specialized Arithmetic -- -------------------------------------------------------------------------------- -------------------------------------- -- LogSumExp for NaryOp Add [SProb] -- -------------------------------------- {- Special for addition of probabilities we have a logSumExp. This will compute the sum of the probabilities safely. Just adding the exp(a . prob) would make us loose any of the safety from underflow that we got from storing prob in the log domain -} -- the tree traversal is a depth first search logSumExp :: S.Seq CExpr -> CExpr logSumExp es = mkCompTree 0 1 where lastIndex = S.length es - 1 compIndices :: Int -> Int -> CExpr -> CExpr -> CExpr compIndices i j = CCond ((S.index es i) .>. (S.index es j)) mkCompTree :: Int -> Int -> CExpr mkCompTree i j | j == lastIndex = compIndices i j (logSumExp' i) (logSumExp' j) | otherwise = compIndices i j (mkCompTree i (succ j)) (mkCompTree j (succ j)) diffExp :: Int -> Int -> CExpr diffExp a b = expm1E ((S.index es a) .-. (S.index es b)) -- given the max index, produce a logSumExp expression logSumExp' :: Int -> CExpr logSumExp' 0 = S.index es 0 .+. (log1pE $ foldr (\x acc -> diffExp x 0 .+. acc) (diffExp 1 0) [2..S.length es - 1] .+. (intE $ fromIntegral lastIndex)) logSumExp' i = S.index es i .+. (log1pE $ foldr (\x acc -> if i == x then acc else diffExp x i .+. acc) (diffExp 0 i) [1..S.length es - 1] .+. (intE $ fromIntegral lastIndex)) -- | logSumExpCG creates global functions for every n-ary logSumExp function -- this reduces code size logSumExpCG :: S.Seq CExpr -> (CExpr -> CodeGen ()) logSumExpCG seqE = let size = S.length $ seqE name = "logSumExp" ++ (show size) funcId = Ident name in \loc -> do -- reset the names so that the function is the same for each arity let argIds = fmap Ident (take size cNameStream) decls = fmap (typeDeclaration SProb) argIds vars = fmap CVar argIds funCG CDouble funcId decls (putStat . CReturn . Just . logSumExp . S.fromList $ vars) putExprStat $ loc .=. (CCall (CVar funcId) (F.toList seqE)) ------------------------------------- -- LogSumExp for Summation of Prob -- ------------------------------------- {- For summation of SProb we need a new logSumExp function that will find the max of an array and then sum it in a loop -} lseSummateArrayCG :: ( ABT Term abt ) => (abt '[ a ] b) -> CExpr -> (CExpr -> CodeGen ()) lseSummateArrayCG body arrayE = caseBind body $ \v body' -> \loc -> do (maxVId:maxIId:sumId:[]) <- mapM genIdent' ["maxV","maxI","sum"] itId <- createIdent v mapM_ (declare SProb) [maxVId,sumId] mapM_ (declare SNat) [maxIId,itId] let (maxVE:maxIE:sumE:itE:[]) = fmap CVar [maxVId,maxIId,sumId,itId] forCG (itE .=. intE 0) (itE .<. arraySize arrayE) (CUnary CPostIncOp itE) (do tmpId <- genIdent declare SProb tmpId let tmpE = CVar tmpId flattenABT body' tmpE putExprStat $ derefIndex itE .=. tmpE putStat $ CIf ((maxVE .<. tmpE) .||. (itE .==. (intE 0))) (seqCStat . fmap (CExpr . Just) $ [ maxVE .=. tmpE , maxIE .=. itE ]) Nothing) putExprStat $ sumE .=. (floatE 0) -- the sum is actually in real space forCG (itE .=. intE 0) (itE .<. arraySize arrayE) (CUnary CPostIncOp itE) (putStat $ CIf (itE .!=. maxIE) (CExpr . Just $ sumE .+=. (expE ((derefIndex itE) .-. (maxVE)))) Nothing) putExprStat $ loc .=. (maxVE .+. (log1pE sumE)) where derefIndex xE = index (arrayData arrayE) xE --------------------- -- Kahan Summation -- --------------------- -- | given a body and a size compute the kahan summation. This should work on -- both probs and reals kahanSummationCG :: ( ABT Term abt ) => (abt '[ a ] b) -> CExpr -> CExpr -> (CExpr -> CodeGen ()) kahanSummationCG body loE hiE = caseBind body $ \v body' -> \loc -> do (tId:cId:[]) <- mapM genIdent' ["t","c"] itId <- createIdent v declare SNat itId mapM_ (declare SProb) [tId,cId] let (tE:cE:itE:[]) = fmap CVar [tId,cId,itId] putExprStat $ tE .=. (floatE 0) putExprStat $ cE .=. (floatE 0) forCG (itE .=. loE) (itE .<. hiE) (CUnary CPostIncOp itE) (do (xId:yId:zId:[]) <- mapM genIdent' ["x","y","z"] mapM_ (declare SProb) [xId,yId,zId] let (xE:yE:zE:[]) = fmap CVar [xId,yId,zId] flattenABT body' xE putExprStat $ yE .=. (xE .-. cE) putExprStat $ zE .=. (tE .+. yE) putExprStat $ cE .=. ((zE .-. tE) .-. yE) putExprStat $ tE .=. zE) putExprStat $ loc .=. tE -------------------------------------------------------------------------------- -- Coercion Helpers -- -------------------------------------------------------------------------------- -- instance PrimCoerce Value where -- primCoerceTo c l = -- case (c,l) of -- (Signed HRing_Int, VNat a) -> VInt $ fromNat a -- (Signed HRing_Real, VProb a) -> VReal $ LF.fromLogFloat a -- (Continuous HContinuous_Prob, VNat a) -> -- VProb $ LF.logFloat (fromIntegral (fromNat a) :: Double) -- (Continuous HContinuous_Real, VInt a) -> VReal $ fromIntegral a -- _ -> error "no a defined primitive coercion" -- primCoerceFrom c l = -- case (c,l) of -- (Signed HRing_Int, VInt a) -> VNat $ unsafeNat a -- (Signed HRing_Real, VReal a) -> VProb $ LF.logFloat a -- (Continuous HContinuous_Prob, VProb a) -> -- VNat $ unsafeNat $ floor (LF.fromLogFloat a :: Double) -- (Continuous HContinuous_Real, VReal a) -> VInt $ floor a -- _ -> error "no a defined primitive coercion" coerceToCG :: forall (a :: Hakaru) (b :: Hakaru) . Coercion a b -> CExpr -> CodeGen CExpr coerceToCG (CCons (Signed HRing_Int) cs) e = nat2int e >>= coerceToCG cs coerceToCG (CCons (Signed HRing_Real) cs) e = prob2real e >>= coerceToCG cs coerceToCG (CCons (Continuous HContinuous_Prob) cs) e = nat2prob e >>= coerceToCG cs coerceToCG (CCons (Continuous HContinuous_Real) cs) e = int2real e >>= coerceToCG cs coerceToCG CNil e = return e coerceFromCG :: forall (a :: Hakaru) (b :: Hakaru) . Coercion a b -> CExpr -> CodeGen CExpr coerceFromCG (CCons (Signed HRing_Int) cs) e = int2nat e >>= coerceFromCG cs coerceFromCG (CCons (Signed HRing_Real) cs) e = real2prob e >>= coerceFromCG cs coerceFromCG (CCons (Continuous HContinuous_Prob) cs) e = prob2nat e >>= coerceFromCG cs coerceFromCG (CCons (Continuous HContinuous_Real) cs) e = real2int e >>= coerceFromCG cs coerceFromCG CNil e = return e -- safe nat2int,nat2prob,prob2real,int2real :: CExpr -> CodeGen CExpr nat2int x = return x nat2prob x = do x' <- localVar' SProb "n2p" putExprStat $ x' .=. (logE x) return x' prob2real x = do x' <- localVar' SProb "p2r" putExprStat $ x' .=. ((expm1E x) .+. (floatE 1)) return x' int2real x = return (castTo [CDouble] x) -- unsafe {- Because of the hkc representation of reals and probs as doubles, (instead of rationals). we will just silently truncate values for prob2nat and real2int -} int2nat,prob2nat,real2prob,real2int :: CExpr -> CodeGen CExpr int2nat x = do x' <- localVar' SNat "i2n" ifCG (x .<. (intE 0)) (do putExprStat $ printfE [ stringE "error: cannot coerce negative int to nat\n" ] putExprStat $ mkCallE "abort" [] ) (putExprStat $ x' .=. (castTo [CUnsigned, CInt] x)) return x' prob2nat x = return (castTo [CUnsigned, CInt] x) real2prob x = do x' <- localVar' SProb "r2p" ifCG (x .<. (intE 0)) (do putExprStat $ printfE [ stringE "error: cannot coerce negative real to prob\n" ] putExprStat $ mkCallE "abort" [] ) (putExprStat $ x' .=. (logE x)) return x' real2int x = return (castTo [CInt] x)