module Synthesizer.LLVM.Simple.Value (
T, decons,
twoPi, square, sqrt,
max, min, limit, fraction,
(%==), (%/=), (%<), (%<=), (%>), (%>=), not,
(%&&), (%||),
(?), (??),
lift0, lift1, lift2, lift3,
unlift0, unlift1, unlift2, unlift3, unlift4, unlift5,
constantValue, constant,
fromInteger', fromRational',
Flatten(flattenCode, unfoldCode), Registers,
flatten, unfold,
flattenCodeTraversable, unfoldCodeTraversable,
flattenFunction,
) where
import qualified LLVM.Extra.Control as C
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Util.Loop (Phi, )
import LLVM.Core (CodeGenFunction, )
import qualified LLVM.Core as LLVM
import qualified Synthesizer.Basic.Phase as Phase
import qualified Data.Vault.Lazy as Vault
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.State as MS
import Control.Monad (liftM2, liftM3, )
import Control.Applicative (Applicative, pure, (<*>), )
import Control.Functor.HT (unzip, unzip3, )
import qualified Synthesizer.LLVM.Frame.Stereo as Stereo
import qualified Algebra.NormedSpace.Euclidean as NormedEuc
import qualified Algebra.NormedSpace.Sum as NormedSum
import qualified Algebra.Transcendental as Trans
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.RealRing as RealRing
import qualified Algebra.Absolute as Absolute
import qualified Algebra.Module as Module
import qualified Algebra.Field as Field
import qualified Algebra.Ring as Ring
import qualified Algebra.Additive as Additive
import qualified Number.Complex as Complex
import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified System.Unsafe as Unsafe
import qualified Prelude as P
import NumericPrelude.Numeric hiding (pi, sqrt, fromRational', fraction, )
import NumericPrelude.Base hiding (min, max, unzip, unzip3, not, )
newtype T a = Cons {code :: forall r. Compute r a}
decons :: T a -> (forall r. LLVM.CodeGenFunction r a)
decons value =
MS.evalStateT (code value) Vault.empty
instance Functor T where
fmap f x = consUnique (fmap f (code x))
instance Applicative T where
pure = constantValue
f <*> x = consUnique (code f <*> code x)
type Compute r a =
MS.StateT Vault.Vault (LLVM.CodeGenFunction r) a
consUnique :: (forall r. Compute r a) -> T a
consUnique code0 =
Unsafe.performIO $
fmap (consKey code0) Vault.newKey
consKey :: (forall r. Compute r a) -> Vault.Key a -> T a
consKey code0 key =
Cons (do
ma <- MS.gets (Vault.lookup key)
case ma of
Just a -> return a
Nothing -> do
a <- code0
MS.modify (Vault.insert key a)
return a)
instance (A.Additive a) => Additive.C (T a) where
zero = constantValue A.zero
(+) = lift2 A.add
() = lift2 A.sub
negate = lift1 A.neg
instance (A.PseudoRing a, A.IntegerConstant a) =>
Ring.C (T a) where
one = constantValue A.one
(*) = lift2 A.mul
fromInteger = fromInteger'
instance (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) =>
Module.C (T a) (T v) where
(*>) = lift2 A.scale
instance (A.Additive a, A.IntegerConstant a) => Enum (T a) where
succ x = x + constantValue A.one
pred x = x constantValue A.one
fromEnum _ = error "CodeGenFunction Value: fromEnum"
toEnum = constantValue . A.fromInteger' . fromIntegral
instance (A.Field a, A.RationalConstant a) => Field.C (T a) where
(/) = lift2 A.fdiv
fromRational' = fromRational' . Field.fromRational'
instance (A.Transcendental a, A.RationalConstant a) => Algebraic.C (T a) where
sqrt = lift1 A.sqrt
root n x = lift2 A.pow x (1 / fromInteger n)
x^/r = lift2 A.pow x (Field.fromRational' r)
instance (A.Transcendental a, A.RationalConstant a) => Trans.C (T a) where
pi = lift0 A.pi
sin = lift1 A.sin
cos = lift1 A.cos
(**) = lift2 A.pow
exp = lift1 A.exp
log = lift1 A.log
asin _ = error "LLVM missing intrinsic: asin"
acos _ = error "LLVM missing intrinsic: acos"
atan _ = error "LLVM missing intrinsic: atan"
instance
(A.PseudoRing a, A.Real a, A.IntegerConstant a) =>
P.Num (T a) where
fromInteger = fromInteger'
(+) = lift2 A.add
() = lift2 A.sub
(*) = lift2 A.mul
negate = lift1 A.neg
abs = lift1 A.abs
signum = lift1 A.signum
instance
(A.Field a, A.Real a, A.RationalConstant a) =>
P.Fractional (T a) where
fromRational = fromRational'
(/) = lift2 A.fdiv
instance
(A.Transcendental a, A.Real a, A.RationalConstant a) =>
P.Floating (T a) where
pi = lift0 A.pi
sin = lift1 A.sin
cos = lift1 A.cos
(**) = lift2 A.pow
exp = lift1 A.exp
log = lift1 A.log
asin _ = error "LLVM missing intrinsic: asin"
acos _ = error "LLVM missing intrinsic: acos"
atan _ = error "LLVM missing intrinsic: atan"
sinh x = (exp x exp (x)) / 2
cosh x = (exp x + exp (x)) / 2
asinh x = log (x + sqrt (x*x + 1))
acosh x = log (x + sqrt (x*x 1))
atanh x = (log (1 + x) log (1 x)) / 2
twoPi ::
(A.Transcendental a, A.RationalConstant a) =>
T a
twoPi = 2 * Trans.pi
square ::
(A.PseudoRing a) =>
T a -> T a
square = lift1 A.square
sqrt ::
(A.Algebraic a) =>
T a -> T a
sqrt = lift1 A.sqrt
min, max :: (A.Real a) => T a -> T a -> T a
min = lift2 A.min
max = lift2 A.max
limit :: (A.Real a) => (T a, T a) -> T a -> T a
limit (l,u) = max l . min u
fraction :: (A.Fraction a) => T a -> T a
fraction = lift1 A.fraction
instance (A.Real a, A.PseudoRing a, A.IntegerConstant a) =>
Absolute.C (T a) where
abs = lift1 A.abs
signum = lift1 A.signum
instance (A.Real a, A.IntegerConstant a, a ~ A.Scalar a, A.PseudoModule a) =>
NormedSum.C (T a) (T a) where
norm = lift1 A.abs
instance (A.Real a, A.IntegerConstant a, a ~ A.Scalar a, A.PseudoModule a) =>
NormedEuc.Sqr (T a) (T a) where
normSqr = lift1 A.square
instance
(NormedEuc.Sqr (T a) (T v),
A.RationalConstant a, A.Algebraic a) =>
NormedEuc.C (T a) (T v) where
norm = lift1 A.sqrt . NormedEuc.normSqr
infix 4 %==, %/=, %<, %<=, %>=, %>
(%==), (%/=), (%<), (%<=), (%>), (%>=) ::
(LLVM.CmpRet a) =>
T (LLVM.Value a) -> T (LLVM.Value a) -> T (LLVM.Value (LLVM.CmpResult a))
(%==) = lift2 $ LLVM.cmp LLVM.CmpEQ
(%/=) = lift2 $ LLVM.cmp LLVM.CmpNE
(%>) = lift2 $ LLVM.cmp LLVM.CmpGT
(%>=) = lift2 $ LLVM.cmp LLVM.CmpGE
(%<) = lift2 $ LLVM.cmp LLVM.CmpLT
(%<=) = lift2 $ LLVM.cmp LLVM.CmpLE
infixr 3 %&&
infixr 2 %||
(%&&) :: T (LLVM.Value Bool) -> T (LLVM.Value Bool) -> T (LLVM.Value Bool)
a %&& b = a ? (b, constant False)
(%||) :: T (LLVM.Value Bool) -> T (LLVM.Value Bool) -> T (LLVM.Value Bool)
a %|| b = a ? (constant True, b)
not :: T (LLVM.Value Bool) -> T (LLVM.Value Bool)
not = lift1 LLVM.inv
infix 0 ?
(?) ::
(Flatten value, Registers value ~ a, Phi a) =>
T (LLVM.Value Bool) -> (value, value) -> value
c ? (t, f) =
unfoldCode $ consUnique $ do
b <- code c
shared <- MS.get
MT.lift $
C.ifThenElse b
(MS.evalStateT (flattenCode t) shared)
(MS.evalStateT (flattenCode f) shared)
infix 0 ??
(??) ::
(LLVM.IsFirstClass a, LLVM.CmpRet a) =>
T (LLVM.Value (LLVM.CmpResult a)) ->
(T (LLVM.Value a), T (LLVM.Value a)) ->
T (LLVM.Value a)
c ?? (t, f) = lift3 LLVM.select c t f
lift0 ::
(forall r. CodeGenFunction r a) ->
T a
lift0 f =
consUnique $ MT.lift $ f
lift1 ::
(forall r. a -> CodeGenFunction r b) ->
T a -> T b
lift1 f x =
consUnique $ MT.lift . f =<< code x
lift2 ::
(forall r. a -> b -> CodeGenFunction r c) ->
T a -> T b -> T c
lift2 f x y =
consUnique $ do
xv <- code x
yv <- code y
MT.lift $ f xv yv
lift3 ::
(forall r. a -> b -> c -> CodeGenFunction r d) ->
T a -> T b -> T c -> T d
lift3 f x y z =
consUnique $ do
xv <- code x
yv <- code y
zv <- code z
MT.lift $ f xv yv zv
_unlift0 ::
T a ->
(forall r. CodeGenFunction r a)
_unlift0 = decons
unlift0 ::
(Flatten value) =>
value ->
(forall r. CodeGenFunction r (Registers value))
unlift0 = flatten
_unlift1 ::
(T a -> T b) ->
(forall r. a -> CodeGenFunction r b)
_unlift1 = unlift1
unlift1 ::
(Flatten value) =>
(T a -> value) ->
(forall r. a -> CodeGenFunction r (Registers value))
unlift1 f a =
flatten (f (constantValue a))
_unlift2 ::
(T a -> T b -> T c) ->
(forall r. a -> b -> CodeGenFunction r c)
_unlift2 = unlift2
unlift2 ::
(Flatten value) =>
(T a -> T b -> value) ->
(forall r. a -> b -> CodeGenFunction r (Registers value))
unlift2 f a b =
flatten (f (constantValue a) (constantValue b))
unlift3 ::
(Flatten value) =>
(T a -> T b -> T c -> value) ->
(forall r. a -> b -> c -> CodeGenFunction r (Registers value))
unlift3 f a b c =
flatten (f (constantValue a) (constantValue b) (constantValue c))
unlift4 ::
(Flatten value) =>
(T a -> T b -> T c -> T d -> value) ->
(forall r. a -> b -> c -> d -> CodeGenFunction r (Registers value))
unlift4 f a b c d =
flatten $
f (constantValue a) (constantValue b) (constantValue c) (constantValue d)
unlift5 ::
(Flatten value) =>
(T a -> T b -> T c -> T d -> T e -> value) ->
(forall r. a -> b -> c -> d -> e -> CodeGenFunction r (Registers value))
unlift5 f a b c d e =
flatten $
f (constantValue a) (constantValue b) (constantValue c)
(constantValue d) (constantValue e)
constantValue :: a -> T a
constantValue x =
consUnique (return x)
constant :: (LLVM.IsConst a) => a -> T (LLVM.Value a)
constant = constantValue . LLVM.valueOf
fromInteger' :: (A.IntegerConstant a) => Integer -> T a
fromInteger' = constantValue . A.fromInteger'
fromRational' :: (A.RationalConstant a) => P.Rational -> T a
fromRational' = constantValue . A.fromRational'
class Flatten value where
type Registers value :: *
flattenCode :: value -> Compute r (Registers value)
unfoldCode :: T (Registers value) -> value
flatten ::
(Flatten value) =>
value -> CodeGenFunction r (Registers value)
flatten x = MS.evalStateT (flattenCode x) Vault.empty
unfold ::
(Flatten value) =>
(Registers value) -> value
unfold x = unfoldCode $ pure x
flattenCodeTraversable ::
(Flatten value, Trav.Traversable f) =>
f value -> Compute r (f (Registers value))
flattenCodeTraversable =
Trav.mapM flattenCode
unfoldCodeTraversable ::
(Flatten value, Trav.Traversable f, Applicative f) =>
T (f (Registers value)) -> f value
unfoldCodeTraversable =
unfoldFromGetters getters
unfoldFromGetters ::
(Functor f, Flatten b) =>
f (a -> Registers b) -> T a -> f b
unfoldFromGetters g x =
fmap (unfoldCode . flip fmap x) g
getters ::
(Trav.Traversable f, Applicative f) =>
f (f a -> a)
getters =
fmap (\n x -> Fold.toList x !! n) $
MS.evalState (Trav.sequenceA (pure (MS.state $ \n -> (n, succ n)))) 0
flattenFunction ::
(Flatten a, Flatten b) =>
(a -> b) -> (Registers a -> CodeGenFunction r (Registers b))
flattenFunction f =
flatten . f . unfold
instance (Flatten a, Flatten b) => Flatten (a,b) where
type Registers (a,b) = (Registers a, Registers b)
flattenCode (a,b) =
liftM2 (,) (flattenCode a) (flattenCode b)
unfoldCode x =
case unzip x of
(a,b) -> (unfoldCode a, unfoldCode b)
instance (Flatten a, Flatten b, Flatten c) => Flatten (a,b,c) where
type Registers (a,b,c) = (Registers a, Registers b, Registers c)
flattenCode (a,b,c) =
liftM3 (,,) (flattenCode a) (flattenCode b) (flattenCode c)
unfoldCode x =
case unzip3 x of
(a,b,c) -> (unfoldCode a, unfoldCode b, unfoldCode c)
instance Flatten a => Flatten (Stereo.T a) where
type Registers (Stereo.T a) = Stereo.T (Registers a)
flattenCode = flattenCodeTraversable
unfoldCode = unfoldCodeTraversable
instance Flatten a => Flatten (Complex.T a) where
type Registers (Complex.T a) = Complex.T (Registers a)
flattenCode s =
liftM2 (Complex.+:)
(flattenCode $ Complex.real s)
(flattenCode $ Complex.imag s)
unfoldCode =
unfoldFromGetters $ Complex.real Complex.+: Complex.imag
instance (RealRing.C a, Flatten a) => Flatten (Phase.T a) where
type Registers (Phase.T a) = Registers a
flattenCode s =
flattenCode $ Phase.toRepresentative s
unfoldCode s =
Phase.fromRepresentative $ unfoldCode s
instance Flatten (T a) where
type Registers (T a) = a
flattenCode = code
unfoldCode = id
instance Flatten () where
type Registers () = ()
flattenCode = return
unfoldCode _ = ()