{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE Rank2Types #-}
module Synthesizer.LLVM.Causal.Private where
import qualified Synthesizer.LLVM.Generator.Private as Sig
import Synthesizer.LLVM.Private (getPairPtrs, noLocalPtr, unbool)
import qualified Synthesizer.Causal.Class as CausalClass
import qualified Synthesizer.Causal.Utility as ArrowUtil
import Synthesizer.Causal.Class (($>))
import qualified LLVM.DSL.Expression as Expr
import LLVM.DSL.Expression (Exp)
import qualified LLVM.Extra.Multi.Value.Marshal as Marshal
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.MaybeContinuation as MaybeCont
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Core as LLVM
import LLVM.Core (CodeGenFunction)
import qualified Type.Data.Num.Decimal as TypeNum
import qualified Control.Category as Cat
import Control.Arrow (Arrow, arr, first, (&&&), (<<<))
import Control.Category (Category)
import Control.Applicative (Applicative, pure, liftA2, (<*>), (<$>))
import Data.Tuple.Strict (mapFst, zipPair)
import Data.Word (Word)
import qualified Number.Ratio as Ratio
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive
import NumericPrelude.Base hiding (map, zip, zipWith, init)
import qualified Prelude as P
data T a b =
forall global local state.
(Memory.C global, LLVM.IsSized local, Memory.C state) =>
Cons (forall r c.
(Tuple.Phi c) =>
global -> LLVM.Value (LLVM.Ptr local) ->
a -> state -> MaybeCont.T r c (b, state))
(forall r. CodeGenFunction r (global, state))
(forall r. global -> CodeGenFunction r ())
type instance CausalClass.ProcessOf Sig.T = T
instance CausalClass.C T where
type SignalOf T = Sig.T
toSignal (Cons next start stop) = Sig.Cons
(\global local -> next global local ())
start
stop
fromSignal (Sig.Cons next start stop) = Cons
(\global local _ -> next global local)
start
stop
noGlobal ::
(LLVM.IsSized local, Memory.C state) =>
(forall r c.
(Tuple.Phi c) =>
LLVM.Value (LLVM.Ptr local) -> a -> state -> MaybeCont.T r c (b, state)) ->
(forall r. CodeGenFunction r state) ->
T a b
noGlobal next start =
Cons (const next) (fmap ((,) ()) start) return
simple ::
(Memory.C state) =>
(forall r c. (Tuple.Phi c) => a -> state -> MaybeCont.T r c (b, state)) ->
(forall r. CodeGenFunction r state) ->
T a b
simple next start = noGlobal (noLocalPtr next) start
mapAccum ::
(Memory.C state) =>
(forall r. a -> state -> CodeGenFunction r (b, state)) ->
(forall r. CodeGenFunction r state) ->
T a b
mapAccum next =
simple (\a s -> MaybeCont.lift $ next a s)
map ::
(forall r. a -> CodeGenFunction r b) ->
T a b
map f =
mapAccum (\a s -> fmap (flip (,) s) $ f a) (return ())
zipWith ::
(forall r. a -> b -> CodeGenFunction r c) ->
T (a,b) c
zipWith f = map (uncurry f)
instance Category T where
id = map return
Cons nextB startB stopB . Cons nextA startA stopA = Cons
(\(globalA, globalB) local a (sa0,sb0) -> do
(localA,localB) <- getPairPtrs local
(b,sa1) <- nextA globalA localA a sa0
(c,sb1) <- nextB globalB localB b sb0
return (c, (sa1,sb1)))
(liftA2 zipPair startA startB)
(\(globalA, globalB) -> stopA globalA >> stopB globalB)
instance Arrow T where
arr f = map (return . f)
first (Cons next start stop) = Cons (firstNext next) start stop
firstNext ::
(Functor m) =>
(global -> local -> a -> s -> m (b, s)) ->
global -> local -> (a, c) -> s -> m ((b, c), s)
firstNext next global local (b,d) s0 =
fmap
(\(c,s1) -> ((c,d), s1))
(next global local b s0)
instance Functor (T a) where
fmap = flip (>>^)
instance Applicative (T a) where
pure = ArrowUtil.pure
(<*>) = ArrowUtil.apply
infixr 1 >>^, ^>>
(>>^) :: T a b -> (b -> c) -> T a c
Cons next start stop >>^ f =
Cons
(\global local a state -> mapFst f <$> next global local a state)
start stop
(^>>) :: (a -> b) -> T b c -> T a c
f ^>> Cons next start stop =
Cons
(\global local -> next global local . f)
start stop
mapProc ::
(forall r. b -> CodeGenFunction r c) ->
T a b -> T a c
mapProc f x = map f <<< x
zipProcWith ::
(forall r. b -> c -> CodeGenFunction r d) ->
T a b -> T a c -> T a d
zipProcWith f x y = zipWith f <<< x&&&y
instance (A.Additive b) => Additive.C (T a b) where
zero = pure A.zero
negate = mapProc A.neg
(+) = zipProcWith A.add
(-) = zipProcWith A.sub
instance (A.PseudoRing b, A.IntegerConstant b) => Ring.C (T a b) where
one = pure A.one
fromInteger n = pure (A.fromInteger' n)
(*) = zipProcWith A.mul
instance (A.Field b, A.RationalConstant b) => Field.C (T a b) where
fromRational' x = pure (A.fromRational' $ Ratio.toRational98 x)
(/) = zipProcWith A.fdiv
instance (A.PseudoRing b, A.Real b, A.IntegerConstant b) => P.Num (T a b) where
fromInteger n = pure (A.fromInteger' n)
negate = mapProc A.neg
(+) = zipProcWith A.add
(-) = zipProcWith A.sub
(*) = zipProcWith A.mul
abs = mapProc A.abs
signum = mapProc A.signum
instance
(A.Field b, A.Real b, A.RationalConstant b) => P.Fractional (T a b) where
fromRational x = pure (A.fromRational' x)
(/) = zipProcWith A.fdiv
loop ::
(Memory.C c) =>
(forall r. CodeGenFunction r c) -> T (a,c) (b,c) -> T a b
loop initial (Cons next start stop) = Cons
(\global local a0 (c0,s0) -> do
((b1,c1), s1) <- next global local (a0,c0) s0
return (b1,(c1,s1)))
(liftA2 (\ini (global,s) -> (global,(ini,s))) initial start)
stop
replicateSerial ::
(Tuple.Undefined a, Tuple.Phi a) =>
Exp Word -> T a a -> T a a
replicateSerial n proc =
(\a -> ((),a)) ^>> replicateControlled n (snd^>>proc)
replicateControlled ::
(Tuple.Undefined a, Tuple.Phi a) =>
Exp Word -> T (c,a) a -> T (c,a) a
replicateControlled n (Cons next start stop) = Cons
(\(len,globalStates) local (c,a) () ->
MaybeCont.fromMaybe $ fmap (\(_,ms) -> flip (,) () <$> ms) $
MaybeCont.arrayLoop len globalStates a $
\globalStatePtr a0 -> do
(global, s0) <- MaybeCont.lift $ Memory.load globalStatePtr
(a1,s1) <- next global local (c,a0) s0
MaybeCont.lift $
Memory.store s1 =<<
LLVM.getElementPtr0 globalStatePtr (TypeNum.d1, ())
return a1)
(do
MultiValue.Cons len <- Expr.unExp n
globalStates <- LLVM.arrayMalloc len
C.arrayLoop len globalStates () $ \globalStatePtr () ->
flip Memory.store globalStatePtr =<< start
return ((len,globalStates), ()))
(\(len,globalStates) -> do
C.arrayLoop len globalStates () $ \globalStatePtr () ->
stop =<< Memory.load
=<< LLVM.getElementPtr0 globalStatePtr (TypeNum.d0, ())
LLVM.free globalStates)
replicateControlledAlt ::
(Tuple.Undefined a, Tuple.Phi a) =>
(Tuple.Undefined c, Tuple.Phi c) =>
Exp Word -> T (c,a) a -> T (c,a) a
replicateControlledAlt n proc =
replicateSerial n (arr fst &&& proc) >>^ snd
replicateParallel ::
(Tuple.Undefined b, Tuple.Phi b) =>
Exp Word -> Sig.T b -> T (b,b) b -> T a b -> T a b
replicateParallel n z cum p =
replicateControlled n (cum <<< first p) $> z
quantizeLift ::
(Memory.C b, Marshal.C c, MultiValue.IntegerConstant c,
MultiValue.Additive c, MultiValue.Comparison c) =>
T a b -> T (MultiValue.T c, a) b
quantizeLift (Cons next start stop) = Cons
(\global local (k, a0) yState0 -> do
(yState1, frac1) <-
MaybeCont.fromBool $
C.whileLoop
(LLVM.valueOf True, yState0)
(\(cont1, (_, frac0)) ->
LLVM.and cont1 . unbool
=<< MultiValue.cmp LLVM.CmpLE frac0 A.zero)
(\(_,((_,state01), frac0)) ->
MaybeCont.toBool $ liftA2 (,)
(next global local a0 state01)
(MaybeCont.lift $ A.add frac0 k))
frac2 <- MaybeCont.lift $ A.sub frac1 A.one
return (fst yState1, (yState1, frac2)))
(do
(global,s) <- start
return (global, ((Tuple.undef, s), A.zero)))
stop