{-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} module LLVM.DSL.Expression where import qualified LLVM.Extra.ScalarOrVector as SoV import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Tuple as LLTuple import qualified LLVM.Extra.FastMath as FastMath import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Control as C import qualified LLVM.Core as LLVM import LLVM.Extra.Multi.Value (PatternTuple, Decomposed, Atom) import qualified Control.Monad.HT as Monad import Control.Monad.IO.Class (liftIO) import qualified Data.Enum.Storable as Enum import qualified Data.Tuple.HT as TupleHT import qualified Data.Tuple as Tuple import Data.IORef (IORef, newIORef, readIORef, writeIORef) import Data.Complex (Complex((:+))) import Data.Bool8 (Bool8) import qualified Foreign.Storable.Record.Tuple as StTuple import qualified Algebra.Transcendental as Trans import qualified Algebra.Algebraic as Algebraic import qualified Algebra.Absolute as Absolute import qualified Algebra.Module as Module import qualified Algebra.Field as Field import qualified Algebra.Ring as Ring import qualified Algebra.Additive as Additive import System.IO.Unsafe (unsafePerformIO) import qualified Prelude as P import Prelude hiding (fst, snd, min, max, zip, unzip, zip3, unzip3, curry, uncurry, recip, pi, sqrt, maybe, toEnum, fromEnum, pred, succ) newtype Exp a = Exp {unExp :: forall r. LLVM.CodeGenFunction r (MultiValue.T a)} {- Using IORef should be thread-safe here, because you cannot fork within CodeGenFunction. -} unique :: (forall r. LLVM.CodeGenFunction r (MultiValue.T a)) -> Exp a unique = Exp _unique :: (forall r. LLVM.CodeGenFunction r (MultiValue.T a)) -> Exp a _unique code = unsafePerformIO $ fmap (withKey code) $ newIORef Nothing withKey :: (forall r. LLVM.CodeGenFunction r (MultiValue.T a)) -> IORef (Maybe (MultiValue.T a)) -> Exp a withKey code ref = Exp (do ma <- liftIO $ readIORef ref case ma of Just a -> return a Nothing -> do a <- code liftIO $ writeIORef ref $ Just a return a) class Value val where lift0 :: MultiValue.T a -> val a lift1 :: (MultiValue.T a -> MultiValue.T b) -> val a -> val b lift2 :: (MultiValue.T a -> MultiValue.T b -> MultiValue.T c) -> val a -> val b -> val c instance Value MultiValue.T where lift0 = id lift1 = id lift2 = id instance Value Exp where lift0 a = unique (return a) lift1 f (Exp a) = unique (Monad.lift f a) lift2 f (Exp a) (Exp b) = unique (Monad.lift2 f a b) lift3 :: (Value val) => (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d) -> val a -> val b -> val c -> val d lift3 f a b = lift2 (MultiValue.uncurry f) (zip a b) lift4 :: (Value val) => (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d -> MultiValue.T e) -> val a -> val b -> val c -> val d -> val e lift4 f a b = lift3 (MultiValue.uncurry f) (zip a b) liftM :: (forall r. MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b)) -> (Exp a -> Exp b) liftM f (Exp a) = unique (f =<< a) liftM2 :: (forall r. MultiValue.T a -> MultiValue.T b -> LLVM.CodeGenFunction r (MultiValue.T c)) -> (Exp a -> Exp b -> Exp c) liftM2 f (Exp a) (Exp b) = unique (Monad.liftJoin2 f a b) liftM3 :: (forall r. MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> LLVM.CodeGenFunction r (MultiValue.T d)) -> (Exp a -> Exp b -> Exp c -> Exp d) liftM3 f (Exp a) (Exp b) (Exp c) = unique (Monad.liftJoin3 f a b c) unliftM1 :: (Exp a -> Exp b) -> MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b) unliftM1 f ix = unExp (f (lift0 ix)) unliftM2 :: (Exp a -> Exp b -> Exp c) -> MultiValue.T a -> MultiValue.T b -> LLVM.CodeGenFunction r (MultiValue.T c) unliftM2 f ix jx = unExp (f (lift0 ix) (lift0 jx)) unliftM3 :: (Exp a -> Exp b -> Exp c -> Exp d) -> MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> LLVM.CodeGenFunction r (MultiValue.T d) unliftM3 f ix jx kx = unExp (f (lift0 ix) (lift0 jx) (lift0 kx)) liftTupleM :: (forall r. LLTuple.ValueOf a -> LLVM.CodeGenFunction r (LLTuple.ValueOf b)) -> (Exp a -> Exp b) liftTupleM f = liftM (MultiValue.liftM f) liftTupleM2 :: (forall r. LLTuple.ValueOf a -> LLTuple.ValueOf b -> LLVM.CodeGenFunction r (LLTuple.ValueOf c)) -> (Exp a -> Exp b -> Exp c) liftTupleM2 f = liftM2 (MultiValue.liftM2 f) liftTupleM3 :: (forall r. LLTuple.ValueOf a -> LLTuple.ValueOf b -> LLTuple.ValueOf c -> LLVM.CodeGenFunction r (LLTuple.ValueOf d)) -> (Exp a -> Exp b -> Exp c -> Exp d) liftTupleM3 f = liftM3 (MultiValue.liftM3 f) zip :: (Value val) => val a -> val b -> val (a, b) zip = lift2 MultiValue.zip zip3 :: (Value val) => val a -> val b -> val c -> val (a, b, c) zip3 = lift3 MultiValue.zip3 zip4 :: (Value val) => val a -> val b -> val c -> val d -> val (a, b, c, d) zip4 = lift4 MultiValue.zip4 unzip :: (Value val) => val (a, b) -> (val a, val b) unzip ab = (fst ab, snd ab) unzip3 :: (Value val) => val (a, b, c) -> (val a, val b, val c) unzip3 abc = (fst3 abc, snd3 abc, thd3 abc) unzip4 :: (Value val) => val (a, b, c, d) -> (val a, val b, val c, val d) unzip4 abcd = (lift1 (\(MultiValue.Cons (a,_,_,_)) -> MultiValue.Cons a) abcd, lift1 (\(MultiValue.Cons (_,b,_,_)) -> MultiValue.Cons b) abcd, lift1 (\(MultiValue.Cons (_,_,c,_)) -> MultiValue.Cons c) abcd, lift1 (\(MultiValue.Cons (_,_,_,d)) -> MultiValue.Cons d) abcd) fst :: (Value val) => val (a, b) -> val a fst = lift1 MultiValue.fst snd :: (Value val) => val (a, b) -> val b snd = lift1 MultiValue.snd mapFst :: (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c) mapFst f = liftM (MultiValue.mapFstF (unliftM1 f)) mapSnd :: (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c) mapSnd f = liftM (MultiValue.mapSndF (unliftM1 f)) mapPair :: (Exp a0 -> Exp a1, Exp b0 -> Exp b1) -> Exp (a0, b0) -> Exp (a1, b1) mapPair (f,g) = mapFst f . mapSnd g swap :: (Value val) => val (a, b) -> val (b, a) swap = lift1 MultiValue.swap curry :: (Exp (a,b) -> c) -> (Exp a -> Exp b -> c) curry f = Tuple.curry (f . Tuple.uncurry zip) uncurry :: (Exp a -> Exp b -> c) -> (Exp (a,b) -> c) uncurry f = Tuple.uncurry f . unzip fst3 :: (Value val) => val (a,b,c) -> val a fst3 = lift1 MultiValue.fst3 snd3 :: (Value val) => val (a,b,c) -> val b snd3 = lift1 MultiValue.snd3 thd3 :: (Value val) => val (a,b,c) -> val c thd3 = lift1 MultiValue.thd3 mapFst3 :: (Exp a0 -> Exp a1) -> Exp (a0,b,c) -> Exp (a1,b,c) mapFst3 f = liftM (MultiValue.mapFst3F (unliftM1 f)) mapSnd3 :: (Exp b0 -> Exp b1) -> Exp (a,b0,c) -> Exp (a,b1,c) mapSnd3 f = liftM (MultiValue.mapSnd3F (unliftM1 f)) mapThd3 :: (Exp c0 -> Exp c1) -> Exp (a,b,c0) -> Exp (a,b,c1) mapThd3 f = liftM (MultiValue.mapThd3F (unliftM1 f)) mapTriple :: (Exp a0 -> Exp a1, Exp b0 -> Exp b1, Exp c0 -> Exp c1) -> Exp (a0,b0,c0) -> Exp (a1,b1,c1) mapTriple (f,g,h) = mapFst3 f . mapSnd3 g . mapThd3 h tuple :: Exp tuple -> Exp (StTuple.Tuple tuple) tuple = lift1 MultiValue.tuple untuple :: Exp (StTuple.Tuple tuple) -> Exp tuple untuple = lift1 MultiValue.untuple modifyMultiValue :: (Value val, MultiValue.Compose a, MultiValue.Decompose pattern, MultiValue.PatternTuple pattern ~ tuple) => pattern -> (Decomposed MultiValue.T pattern -> a) -> val tuple -> val (MultiValue.Composed a) modifyMultiValue p f = lift1 $ MultiValue.modify p f modifyMultiValue2 :: (Value val, MultiValue.Compose a, MultiValue.Decompose patternA, MultiValue.Decompose patternB, MultiValue.PatternTuple patternA ~ tupleA, MultiValue.PatternTuple patternB ~ tupleB) => patternA -> patternB -> (Decomposed MultiValue.T patternA -> Decomposed MultiValue.T patternB -> a) -> val tupleA -> val tupleB -> val (MultiValue.Composed a) modifyMultiValue2 pa pb f = lift2 $ MultiValue.modify2 pa pb f modifyMultiValueM :: (MultiValue.Compose a, MultiValue.Decompose pattern, MultiValue.PatternTuple pattern ~ tuple) => pattern -> (forall r. Decomposed MultiValue.T pattern -> LLVM.CodeGenFunction r a) -> Exp tuple -> Exp (MultiValue.Composed a) modifyMultiValueM p f = liftM (MultiValue.modifyF p f) modifyMultiValueM2 :: (MultiValue.Compose a, MultiValue.Decompose patternA, MultiValue.Decompose patternB, MultiValue.PatternTuple patternA ~ tupleA, MultiValue.PatternTuple patternB ~ tupleB) => patternA -> patternB -> (forall r. Decomposed MultiValue.T patternA -> Decomposed MultiValue.T patternB -> LLVM.CodeGenFunction r a) -> Exp tupleA -> Exp tupleB -> Exp (MultiValue.Composed a) modifyMultiValueM2 pa pb f = liftM2 (MultiValue.modifyF2 pa pb f) class Compose multituple where type Composed multituple {- | A nested 'zip'. -} compose :: multituple -> Exp (Composed multituple) class (Composed (Decomposed Exp pattern) ~ PatternTuple pattern) => Decompose pattern where {- | Analogous to 'MultiValue.decompose'. -} decompose :: pattern -> Exp (PatternTuple pattern) -> Decomposed Exp pattern {- | Analogus to 'MultiValue.modifyMultiValue'. -} modify :: (Compose a, Decompose pattern) => pattern -> (Decomposed Exp pattern -> a) -> Exp (PatternTuple pattern) -> Exp (Composed a) modify p f = compose . f . decompose p modify2 :: (Compose a, Decompose patternA, Decompose patternB) => patternA -> patternB -> (Decomposed Exp patternA -> Decomposed Exp patternB -> a) -> Exp (PatternTuple patternA) -> Exp (PatternTuple patternB) -> Exp (Composed a) modify2 pa pb f a b = compose $ f (decompose pa a) (decompose pb b) instance Compose (Exp a) where type Composed (Exp a) = a compose = id instance Decompose (Atom a) where decompose _ = id instance Compose () where type Composed () = () compose = cons instance Decompose () where decompose _ _ = () instance (Compose a, Compose b) => Compose (a,b) where type Composed (a,b) = (Composed a, Composed b) compose = Tuple.uncurry zip . TupleHT.mapPair (compose, compose) instance (Decompose pa, Decompose pb) => Decompose (pa,pb) where decompose (pa,pb) = TupleHT.mapPair (decompose pa, decompose pb) . unzip instance (Compose a, Compose b, Compose c) => Compose (a,b,c) where type Composed (a,b,c) = (Composed a, Composed b, Composed c) compose = TupleHT.uncurry3 zip3 . TupleHT.mapTriple (compose, compose, compose) instance (Decompose pa, Decompose pb, Decompose pc) => Decompose (pa,pb,pc) where decompose (pa,pb,pc) = TupleHT.mapTriple (decompose pa, decompose pb, decompose pc) . unzip3 instance (Compose a, Compose b, Compose c, Compose d) => Compose (a,b,c,d) where type Composed (a,b,c,d) = (Composed a, Composed b, Composed c, Composed d) compose (a,b,c,d) = zip4 (compose a) (compose b) (compose c) (compose d) instance (Decompose pa, Decompose pb, Decompose pc, Decompose pd) => Decompose (pa,pb,pc,pd) where decompose (pa,pb,pc,pd) x = case unzip4 x of (a,b,c,d) -> (decompose pa a, decompose pb b, decompose pc c, decompose pd d) instance (Compose tuple) => Compose (StTuple.Tuple tuple) where type Composed (StTuple.Tuple tuple) = StTuple.Tuple (Composed tuple) compose (StTuple.Tuple tup) = tuple $ compose tup instance (Decompose p) => Decompose (StTuple.Tuple p) where decompose (StTuple.Tuple p) = StTuple.Tuple . decompose p . untuple instance (Compose a) => Compose (Complex a) where type Composed (Complex a) = Complex (Composed a) compose (r:+i) = consComplex (compose r) (compose i) instance (Decompose p) => Decompose (Complex p) where decompose (pr:+pi) = Tuple.uncurry (:+) . TupleHT.mapPair (decompose pr, decompose pi) . deconsComplex {- | You can construct complex numbers this way, but they will not make you happy, because the numeric operations require a RealFloat instance that we could only provide with lots of undefined methods (also in its superclasses). You may either define your own arithmetic or use the NumericPrelude type classes. -} consComplex :: Exp a -> Exp a -> Exp (Complex a) consComplex = lift2 MultiValue.consComplex deconsComplex :: Exp (Complex a) -> (Exp a, Exp a) deconsComplex c = (lift1 MultiValue.realPart c, lift1 MultiValue.imagPart c) cons :: (MultiValue.C a) => a -> Exp a cons = lift0 . MultiValue.cons unit :: Exp () unit = cons () zero :: (MultiValue.C a) => Exp a zero = lift0 MultiValue.zero add :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a add = liftM2 MultiValue.add sub :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a sub = liftM2 MultiValue.sub neg :: (MultiValue.Additive a) => Exp a -> Exp a neg = liftM MultiValue.neg one :: (MultiValue.IntegerConstant a) => Exp a one = fromInteger' 1 mul :: (MultiValue.PseudoRing a) => Exp a -> Exp a -> Exp a mul = liftM2 MultiValue.mul sqr :: (MultiValue.PseudoRing a) => Exp a -> Exp a sqr = liftM $ \x -> MultiValue.mul x x recip :: (MultiValue.Field a, MultiValue.IntegerConstant a) => Exp a -> Exp a recip = fdiv one fdiv :: (MultiValue.Field a) => Exp a -> Exp a -> Exp a fdiv = liftM2 MultiValue.fdiv sqrt :: (MultiValue.Algebraic a) => Exp a -> Exp a sqrt = liftM MultiValue.sqrt pow :: (MultiValue.Transcendental a) => Exp a -> Exp a -> Exp a pow = liftM2 MultiValue.pow idiv :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a idiv = liftM2 MultiValue.idiv irem :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a irem = liftM2 MultiValue.irem shl :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a shl = liftM2 MultiValue.shl shr :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a shr = liftM2 MultiValue.shr fromInteger' :: (MultiValue.IntegerConstant a) => Integer -> Exp a fromInteger' = lift0 . MultiValue.fromInteger' fromRational' :: (MultiValue.RationalConstant a) => Rational -> Exp a fromRational' = lift0 . MultiValue.fromRational' boolPFrom8 :: Exp Bool8 -> Exp Bool boolPFrom8 = lift1 MultiValue.boolPFrom8 bool8FromP :: Exp Bool -> Exp Bool8 bool8FromP = lift1 MultiValue.bool8FromP intFromBool8 :: (MultiValue.NativeInteger i ir) => Exp Bool8 -> Exp i intFromBool8 = liftM MultiValue.intFromBool8 floatFromBool8 :: (MultiValue.NativeFloating a ar) => Exp Bool8 -> Exp a floatFromBool8 = liftM MultiValue.floatFromBool8 toEnum :: (LLTuple.ValueOf w ~ LLVM.Value w) => Exp w -> Exp (Enum.T w e) toEnum = lift1 MultiValue.toEnum fromEnum :: (LLTuple.ValueOf w ~ LLVM.Value w) => Exp (Enum.T w e) -> Exp w fromEnum = lift1 MultiValue.fromEnum succ, pred :: (LLVM.IsArithmetic w, SoV.IntegerConstant w) => Exp (Enum.T w e) -> Exp (Enum.T w e) succ = liftM MultiValue.succ pred = liftM MultiValue.pred fromFastMath :: Exp (FastMath.Number flags a) -> Exp a fromFastMath = lift1 FastMath.mvDenumber toFastMath :: Exp a -> Exp (FastMath.Number flags a) toFastMath = lift1 FastMath.mvNumber minBound, maxBound :: (MultiValue.Bounded a) => Exp a minBound = lift0 MultiValue.minBound maxBound = lift0 MultiValue.maxBound cmp :: (MultiValue.Comparison a) => LLVM.CmpPredicate -> Exp a -> Exp a -> Exp Bool cmp ord = liftM2 (MultiValue.cmp ord) infix 4 ==*, /=*, <*, <=*, >*, >=* (==*), (/=*), (<*), (>=*), (>*), (<=*) :: (MultiValue.Comparison a) => Exp a -> Exp a -> Exp Bool (==*) = cmp LLVM.CmpEQ (/=*) = cmp LLVM.CmpNE (<*) = cmp LLVM.CmpLT (>=*) = cmp LLVM.CmpGE (>*) = cmp LLVM.CmpGT (<=*) = cmp LLVM.CmpLE min, max :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a min = liftM2 A.min max = liftM2 A.max limit :: (MultiValue.Real a) => (Exp a, Exp a) -> Exp a -> Exp a limit (l,u) = max l . min u fraction :: (MultiValue.Fraction a) => Exp a -> Exp a fraction = liftM MultiValue.fraction true, false :: Exp Bool true = cons True false = cons False infixr 3 &&* (&&*) :: Exp Bool -> Exp Bool -> Exp Bool (&&*) = liftM2 MultiValue.and infixr 2 ||* (||*) :: Exp Bool -> Exp Bool -> Exp Bool (||*) = liftM2 MultiValue.or not :: Exp Bool -> Exp Bool not = liftM MultiValue.inv {- | Like 'ifThenElse' but computes both alternative expressions and then uses LLVM's efficient @select@ instruction. -} select :: (MultiValue.Select a) => Exp Bool -> Exp a -> Exp a -> Exp a select = liftM3 MultiValue.select ifThenElse :: (MultiValue.C a) => Exp Bool -> Exp a -> Exp a -> Exp a ifThenElse ec ex ey = unique (do MultiValue.Cons c <- unExp ec C.ifThenElse c (unExp ex) (unExp ey)) complement :: (MultiValue.Logic a) => Exp a -> Exp a complement = liftM MultiValue.inv infixl 7 .&.* (.&.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a (.&.*) = liftM2 MultiValue.and infixl 5 .|.* (.|.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a (.|.*) = liftM2 MultiValue.or infixl 6 `xor` xor :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a xor = liftM2 MultiValue.xor toMaybe :: Exp Bool -> Exp a -> Exp (Maybe a) toMaybe = lift2 MultiValue.toMaybe maybe :: (MultiValue.C b) => Exp b -> (Exp a -> Exp b) -> Exp (Maybe a) -> Exp b maybe n j = liftM $ \m -> do let (MultiValue.Cons b, a) = MultiValue.splitMaybe m C.ifThenElse b (unliftM1 j a) (unExp n) instance (MultiValue.PseudoRing a, MultiValue.Real a, MultiValue.IntegerConstant a) => Num (Exp a) where fromInteger = fromInteger' (+) = add (-) = sub negate = neg (*) = mul abs = liftM MultiValue.abs signum = liftM MultiValue.signum instance (MultiValue.Field a, MultiValue.Real a, MultiValue.RationalConstant a) => Fractional (Exp a) where fromRational = fromRational' (/) = fdiv instance (MultiValue.Transcendental a, MultiValue.Real a, MultiValue.RationalConstant a) => Floating (Exp a) where pi = unique MultiValue.pi sin = liftM MultiValue.sin cos = liftM MultiValue.cos sqrt = sqrt (**) = pow exp = liftM MultiValue.exp log = liftM MultiValue.log asin _ = error "LLVM missing intrinsic: asin" acos _ = error "LLVM missing intrinsic: acos" atan _ = error "LLVM missing intrinsic: atan" sinh x = (exp x - exp (-x)) / 2 cosh x = (exp x + exp (-x)) / 2 asinh x = log (x + sqrt (x*x + 1)) acosh x = log (x + sqrt (x*x - 1)) atanh x = (log (1 + x) - log (1 - x)) / 2 {- | We do not require a numeric prelude superclass, thus also LLVM only types like vectors are instances. -} instance (MultiValue.Additive a) => Additive.C (Exp a) where zero = zero (+) = add (-) = sub negate = neg instance (MultiValue.PseudoRing a, MultiValue.IntegerConstant a) => Ring.C (Exp a) where one = one (*) = mul fromInteger = fromInteger' {- This instance is enough for Module here. The difference to Module instances on Haskell tuples is, that LLVM vectors cannot be nested. -} instance (a ~ MultiValue.Scalar v, MultiValue.PseudoModule v, MultiValue.IntegerConstant a) => Module.C (Exp a) (Exp v) where (*>) = liftM2 MultiValue.scale instance (MultiValue.Field a, MultiValue.RationalConstant a) => Field.C (Exp a) where (/) = fdiv fromRational' = fromRational' . Field.fromRational' instance (MultiValue.Transcendental a, MultiValue.RationalConstant a) => Algebraic.C (Exp a) where sqrt = sqrt root n x = pow x (recip $ fromInteger' n) x^/r = pow x (Field.fromRational' r) tau :: (MultiValue.Transcendental a, MultiValue.RationalConstant a) => Exp a tau = mul (fromInteger' 2) Trans.pi instance (MultiValue.Transcendental a, MultiValue.RationalConstant a) => Trans.C (Exp a) where pi = unique MultiValue.pi sin = liftM MultiValue.sin cos = liftM MultiValue.cos (**) = pow exp = liftM MultiValue.exp log = liftM MultiValue.log asin _ = error "LLVM missing intrinsic: asin" acos _ = error "LLVM missing intrinsic: acos" atan _ = error "LLVM missing intrinsic: atan" instance (MultiValue.Real a, MultiValue.PseudoRing a, MultiValue.IntegerConstant a) => Absolute.C (Exp a) where abs = liftM MultiValue.abs signum = liftM MultiValue.signum