{-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} module Synthesizer.LLVM.Simple.Value ( T, decons, twoPi, square, sqrt, max, min, limit, fraction, (%==), (%/=), (%<), (%<=), (%>), (%>=), not, (%&&), (%||), (?), (??), lift0, lift1, lift2, lift3, unlift0, unlift1, unlift2, unlift3, unlift4, unlift5, constantValue, constant, fromInteger', fromRational', Flatten(flattenCode, unfoldCode), Registers, flatten, unfold, flattenCodeTraversable, unfoldCodeTraversable, flattenFunction, ) where import qualified LLVM.Extra.Control as C import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Tuple as Tuple import LLVM.Core (CodeGenFunction) import qualified LLVM.Core as LLVM import qualified Synthesizer.Basic.Phase as Phase import qualified Data.Vault.Lazy as Vault import qualified Control.Monad.Trans.Class as MT import qualified Control.Monad.Trans.State as MS import Control.Monad (liftM2, liftM3) import Control.Applicative (Applicative, pure, (<*>)) import Control.Functor.HT (unzip, unzip3) import qualified Synthesizer.LLVM.Frame.Stereo as Stereo -- import qualified Algebra.NormedSpace.Maximum as NormedMax import qualified Algebra.NormedSpace.Euclidean as NormedEuc import qualified Algebra.NormedSpace.Sum as NormedSum import qualified Algebra.Transcendental as Trans import qualified Algebra.Algebraic as Algebraic import qualified Algebra.RealRing as RealRing import qualified Algebra.Absolute as Absolute import qualified Algebra.Module as Module import qualified Algebra.Field as Field import qualified Algebra.Ring as Ring import qualified Algebra.Additive as Additive import qualified Number.Complex as Complex import qualified Data.Traversable as Trav import qualified Data.Foldable as Fold import qualified System.Unsafe as Unsafe import qualified Prelude as P import NumericPrelude.Numeric hiding (pi, sqrt, fromRational', fraction) import NumericPrelude.Base hiding (min, max, unzip, unzip3, not) {- The @r@ type parameter must be hidden and forall-quantified because otherwise we would need an impossible type where we have to quantify for @r@ and @t@ in different scopes while having a class constraint that involves both of them. > osci :: > (RealRing.C (Value.T r t), > IsFirstClass t, IsFloating t, > IsPrimitive t, IsConst t) => > (forall r. Wave.T (Value.T r t) (Value.T r y)) -> > t -> t -> T (Value y) -} newtype T a = Cons {code :: forall r. Compute r a} decons :: T a -> (forall r. LLVM.CodeGenFunction r a) decons value = MS.evalStateT (code value) Vault.empty instance Functor T where fmap f x = consUnique (fmap f (code x)) instance Applicative T where pure = constantValue f <*> x = consUnique (code f <*> code x) type Compute r a = MS.StateT Vault.Vault (LLVM.CodeGenFunction r) a consUnique :: (forall r. Compute r a) -> T a consUnique code0 = Unsafe.performIO $ fmap (consKey code0) Vault.newKey consKey :: (forall r. Compute r a) -> Vault.Key a -> T a consKey code0 key = Cons (do ma <- MS.gets (Vault.lookup key) case ma of Just a -> return a Nothing -> do a <- code0 MS.modify (Vault.insert key a) return a) {- | We do not require a numeric prelude superclass, thus also LLVM only types like vectors are instances. -} instance (A.Additive a) => Additive.C (T a) where zero = constantValue A.zero (+) = lift2 A.add (-) = lift2 A.sub negate = lift1 A.neg instance (A.PseudoRing a, A.IntegerConstant a) => Ring.C (T a) where one = constantValue A.one (*) = lift2 A.mul fromInteger = fromInteger' {- This instance is enough for Module here. The difference to Module instances on Haskell tuples is, that LLVM vectors cannot be nested. -} instance (a ~ A.Scalar v, A.PseudoModule v, A.IntegerConstant a) => Module.C (T a) (T v) where (*>) = lift2 A.scale instance (A.Additive a, A.IntegerConstant a) => Enum (T a) where succ x = x + constantValue A.one pred x = x - constantValue A.one fromEnum _ = error "CodeGenFunction Value: fromEnum" toEnum = constantValue . A.fromInteger' . fromIntegral {- instance (IsArithmetic a, Cmp a b, Num a, IsConst a) => Real (T a) where toRational _ = error "CodeGenFunction Value: toRational" instance (Cmp a b, Num a, IsConst a, IsInteger a) => Integral (T a) where quot = lift2 idiv rem = lift2 irem quotRem x y = (quot x y, rem x y) toInteger _ = error "CodeGenFunction Value: toInteger" -} instance (A.Field a, A.RationalConstant a) => Field.C (T a) where (/) = lift2 A.fdiv fromRational' = fromRational' . Field.fromRational' {- instance (Cmp a b, Fractional a, IsConst a, IsFloating a) => RealFrac (T a) where properFraction _ = error "CodeGenFunction Value: properFraction" -} instance (A.Transcendental a, A.RationalConstant a) => Algebraic.C (T a) where sqrt = lift1 A.sqrt root n x = lift2 A.pow x (1 / fromInteger n) x^/r = lift2 A.pow x (Field.fromRational' r) instance (A.Transcendental a, A.RationalConstant a) => Trans.C (T a) where pi = lift0 A.pi sin = lift1 A.sin cos = lift1 A.cos (**) = lift2 A.pow exp = lift1 A.exp log = lift1 A.log asin _ = error "LLVM missing intrinsic: asin" acos _ = error "LLVM missing intrinsic: acos" atan _ = error "LLVM missing intrinsic: atan" instance (A.PseudoRing a, A.Real a, A.IntegerConstant a) => P.Num (T a) where fromInteger = fromInteger' (+) = lift2 A.add (-) = lift2 A.sub (*) = lift2 A.mul negate = lift1 A.neg abs = lift1 A.abs signum = lift1 A.signum instance (A.Field a, A.Real a, A.RationalConstant a) => P.Fractional (T a) where fromRational = fromRational' (/) = lift2 A.fdiv instance (A.Transcendental a, A.Real a, A.RationalConstant a) => P.Floating (T a) where pi = lift0 A.pi sin = lift1 A.sin cos = lift1 A.cos (**) = lift2 A.pow exp = lift1 A.exp log = lift1 A.log asin _ = error "LLVM missing intrinsic: asin" acos _ = error "LLVM missing intrinsic: acos" atan _ = error "LLVM missing intrinsic: atan" sinh x = (exp x - exp (-x)) / 2 cosh x = (exp x + exp (-x)) / 2 asinh x = log (x + sqrt (x*x + 1)) acosh x = log (x + sqrt (x*x - 1)) atanh x = (log (1 + x) - log (1 - x)) / 2 twoPi :: (A.Transcendental a, A.RationalConstant a) => T a twoPi = 2 * Trans.pi square :: (A.PseudoRing a) => T a -> T a square = lift1 A.square {- | The same as 'Algebraic.sqrt', but needs only Algebraic constraint, not Transcendental. -} sqrt :: (A.Algebraic a) => T a -> T a sqrt = lift1 A.sqrt min, max :: (A.Real a) => T a -> T a -> T a min = lift2 A.min max = lift2 A.max limit :: (A.Real a) => (T a, T a) -> T a -> T a limit (l,u) = max l . min u fraction :: (A.Fraction a) => T a -> T a fraction = lift1 A.fraction instance (A.Real a, A.PseudoRing a, A.IntegerConstant a) => Absolute.C (T a) where abs = lift1 A.abs signum = lift1 A.signum {- For useful instances with different scalar and vector type, we would need a more flexible superclass. -} instance (A.Real a, A.IntegerConstant a, a ~ A.Scalar a, A.PseudoModule a) => NormedSum.C (T a) (T a) where norm = lift1 A.abs instance (A.Real a, A.IntegerConstant a, a ~ A.Scalar a, A.PseudoModule a) => NormedEuc.Sqr (T a) (T a) where normSqr = lift1 A.square instance (NormedEuc.Sqr (T a) (T v), A.RationalConstant a, A.Algebraic a) => NormedEuc.C (T a) (T v) where norm = lift1 A.sqrt . NormedEuc.normSqr {- instance (A.Real a, A.IntegerConstant a, A.PseudoModule a a) => NormedMax.C (T a) (T a) where norm = lift1 A.abs -} infix 4 %==, %/=, %<, %<=, %>=, %> (%==), (%/=), (%<), (%<=), (%>), (%>=) :: (LLVM.CmpRet a) => T (LLVM.Value a) -> T (LLVM.Value a) -> T (LLVM.Value (LLVM.CmpResult a)) (%==) = lift2 $ LLVM.cmp LLVM.CmpEQ (%/=) = lift2 $ LLVM.cmp LLVM.CmpNE (%>) = lift2 $ LLVM.cmp LLVM.CmpGT (%>=) = lift2 $ LLVM.cmp LLVM.CmpGE (%<) = lift2 $ LLVM.cmp LLVM.CmpLT (%<=) = lift2 $ LLVM.cmp LLVM.CmpLE infixr 3 %&& infixr 2 %|| -- | Lazy AND (%&&) :: T (LLVM.Value Bool) -> T (LLVM.Value Bool) -> T (LLVM.Value Bool) a %&& b = a ? (b, constant False) -- | Lazy OR (%||) :: T (LLVM.Value Bool) -> T (LLVM.Value Bool) -> T (LLVM.Value Bool) a %|| b = a ? (constant True, b) not :: T (LLVM.Value Bool) -> T (LLVM.Value Bool) not = lift1 LLVM.inv infix 0 ? {- | @true ? (t,f)@ evaluates @t@, @false ? (t,f)@ evaluates @f@. @t@ and @f@ can reuse interim results, but they cannot contribute shared results, since only one of them will be run. Cf. '(??)' -} (?) :: (Flatten value, Registers value ~ a, Tuple.Phi a) => T (LLVM.Value Bool) -> (value, value) -> value c ? (t, f) = unfoldCode $ consUnique $ do b <- code c shared <- MS.get MT.lift $ C.ifThenElse b (MS.evalStateT (flattenCode t) shared) (MS.evalStateT (flattenCode f) shared) infix 0 ?? {- | The expression @c ?? (t,f)@ evaluates both @t@ and @f@ and selects components from @t@ and @f@ according to @c@. It is useful for vector values and for sharing @t@ or @f@ with other branches of an expression. -} (??) :: (LLVM.IsFirstClass a, LLVM.CmpRet a) => T (LLVM.Value (LLVM.CmpResult a)) -> (T (LLVM.Value a), T (LLVM.Value a)) -> T (LLVM.Value a) c ?? (t, f) = lift3 LLVM.select c t f lift0 :: (forall r. CodeGenFunction r a) -> T a lift0 f = consUnique $ MT.lift $ f lift1 :: (forall r. a -> CodeGenFunction r b) -> T a -> T b lift1 f x = consUnique $ MT.lift . f =<< code x lift2 :: (forall r. a -> b -> CodeGenFunction r c) -> T a -> T b -> T c lift2 f x y = consUnique $ do xv <- code x yv <- code y MT.lift $ f xv yv lift3 :: (forall r. a -> b -> c -> CodeGenFunction r d) -> T a -> T b -> T c -> T d lift3 f x y z = consUnique $ do xv <- code x yv <- code y zv <- code z MT.lift $ f xv yv zv _unlift0 :: T a -> (forall r. CodeGenFunction r a) _unlift0 = decons unlift0 :: (Flatten value) => value -> (forall r. CodeGenFunction r (Registers value)) unlift0 = flatten _unlift1 :: (T a -> T b) -> (forall r. a -> CodeGenFunction r b) _unlift1 = unlift1 {- Better type inference than flattenFunction. -} unlift1 :: (Flatten value) => (T a -> value) -> (forall r. a -> CodeGenFunction r (Registers value)) unlift1 f a = flatten (f (constantValue a)) _unlift2 :: (T a -> T b -> T c) -> (forall r. a -> b -> CodeGenFunction r c) _unlift2 = unlift2 unlift2 :: (Flatten value) => (T a -> T b -> value) -> (forall r. a -> b -> CodeGenFunction r (Registers value)) unlift2 f a b = flatten (f (constantValue a) (constantValue b)) unlift3 :: (Flatten value) => (T a -> T b -> T c -> value) -> (forall r. a -> b -> c -> CodeGenFunction r (Registers value)) unlift3 f a b c = flatten (f (constantValue a) (constantValue b) (constantValue c)) unlift4 :: (Flatten value) => (T a -> T b -> T c -> T d -> value) -> (forall r. a -> b -> c -> d -> CodeGenFunction r (Registers value)) unlift4 f a b c d = flatten $ f (constantValue a) (constantValue b) (constantValue c) (constantValue d) unlift5 :: (Flatten value) => (T a -> T b -> T c -> T d -> T e -> value) -> (forall r. a -> b -> c -> d -> e -> CodeGenFunction r (Registers value)) unlift5 f a b c d e = flatten $ f (constantValue a) (constantValue b) (constantValue c) (constantValue d) (constantValue e) constantValue :: a -> T a constantValue x = consUnique (return x) constant :: (LLVM.IsConst a) => a -> T (LLVM.Value a) constant = constantValue . LLVM.valueOf fromInteger' :: (A.IntegerConstant a) => Integer -> T a fromInteger' = constantValue . A.fromInteger' fromRational' :: (A.RationalConstant a) => P.Rational -> T a fromRational' = constantValue . A.fromRational' class Flatten value where type Registers value :: * flattenCode :: value -> Compute r (Registers value) unfoldCode :: T (Registers value) -> value flatten :: (Flatten value) => value -> CodeGenFunction r (Registers value) flatten x = MS.evalStateT (flattenCode x) Vault.empty unfold :: (Flatten value) => (Registers value) -> value unfold x = unfoldCode $ pure x flattenCodeTraversable :: (Flatten value, Trav.Traversable f) => f value -> Compute r (f (Registers value)) flattenCodeTraversable = Trav.mapM flattenCode unfoldCodeTraversable :: (Flatten value, Trav.Traversable f, Applicative f) => T (f (Registers value)) -> f value unfoldCodeTraversable = unfoldFromGetters getters unfoldFromGetters :: (Functor f, Flatten b) => f (a -> Registers b) -> T a -> f b unfoldFromGetters g x = fmap (unfoldCode . flip fmap x) g getters :: (Trav.Traversable f, Applicative f) => f (f a -> a) getters = fmap (\n x -> Fold.toList x !! n) $ MS.evalState (Trav.sequenceA (pure (MS.state $ \n -> (n, succ n)))) 0 flattenFunction :: (Flatten a, Flatten b) => (a -> b) -> (Registers a -> CodeGenFunction r (Registers b)) flattenFunction f = flatten . f . unfold {- This function is hardly useful, since most functions are not of type @(Registers a -> (forall r. CodeGenFunction r (Registers b)))@ but of type @(forall r. Registers a -> CodeGenFunction r (Registers b))@. We would also need a method unfoldF. See ValueUnfoldF for some implementations. unfoldFunction :: (Flatten a, Flatten b) => (Registers a -> (forall r. CodeGenFunction r (Registers b))) -> (a -> b) unfoldFunction f x = unfoldF (f =<< flatten x) -} instance (Flatten a, Flatten b) => Flatten (a,b) where type Registers (a,b) = (Registers a, Registers b) flattenCode (a,b) = liftM2 (,) (flattenCode a) (flattenCode b) unfoldCode x = case unzip x of (a,b) -> (unfoldCode a, unfoldCode b) instance (Flatten a, Flatten b, Flatten c) => Flatten (a,b,c) where type Registers (a,b,c) = (Registers a, Registers b, Registers c) flattenCode (a,b,c) = liftM3 (,,) (flattenCode a) (flattenCode b) (flattenCode c) unfoldCode x = case unzip3 x of (a,b,c) -> (unfoldCode a, unfoldCode b, unfoldCode c) instance Flatten a => Flatten (Stereo.T a) where type Registers (Stereo.T a) = Stereo.T (Registers a) flattenCode = flattenCodeTraversable unfoldCode = unfoldCodeTraversable instance Flatten a => Flatten (Complex.T a) where type Registers (Complex.T a) = Complex.T (Registers a) flattenCode s = liftM2 (Complex.+:) (flattenCode $ Complex.real s) (flattenCode $ Complex.imag s) unfoldCode = unfoldFromGetters $ Complex.real Complex.+: Complex.imag instance (RealRing.C a, Flatten a) => Flatten (Phase.T a) where type Registers (Phase.T a) = Registers a flattenCode s = flattenCode $ Phase.toRepresentative s unfoldCode s = -- could also be unsafeFromRepresentative Phase.fromRepresentative $ unfoldCode s instance Flatten (T a) where type Registers (T a) = a flattenCode = code unfoldCode = id instance Flatten () where type Registers () = () flattenCode = return unfoldCode _ = ()