{-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} module Data.Array.Knead.Expression where import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Control as C import qualified LLVM.Extra.Monad as LMonad import qualified LLVM.Core as LLVM import LLVM.Extra.Multi.Value (PatternTuple, Decomposed, Atom, atom, ) import qualified Control.Monad as Monad import qualified Data.Tuple.HT as TupleHT import qualified Data.Tuple as Tuple import Data.Complex (Complex((:+))) import Prelude hiding (fst, snd, min, max, zip, unzip, zip3, unzip3, curry, uncurry, pi, maybe) newtype Exp a = Exp {unExp :: forall r. LLVM.CodeGenFunction r (MultiValue.T a)} class Value val where lift0 :: MultiValue.T a -> val a lift1 :: (MultiValue.T a -> MultiValue.T b) -> val a -> val b lift2 :: (MultiValue.T a -> MultiValue.T b -> MultiValue.T c) -> val a -> val b -> val c lift3 :: (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d) -> val a -> val b -> val c -> val d lift4 :: (MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> MultiValue.T d -> MultiValue.T e) -> val a -> val b -> val c -> val d -> val e instance Value MultiValue.T where lift0 = id lift1 = id lift2 = id lift3 = id lift4 = id instance Value Exp where lift0 a = Exp (return a) lift1 f (Exp a) = Exp (Monad.liftM f a) lift2 f (Exp a) (Exp b) = Exp (Monad.liftM2 f a b) lift3 f (Exp a) (Exp b) (Exp c) = Exp (Monad.liftM3 f a b c) lift4 f (Exp a) (Exp b) (Exp c) (Exp d) = Exp (Monad.liftM4 f a b c d) liftM :: (forall r. MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b)) -> (Exp a -> Exp b) liftM f (Exp a) = Exp (f =<< a) liftM2 :: (forall r. MultiValue.T a -> MultiValue.T b -> LLVM.CodeGenFunction r (MultiValue.T c)) -> (Exp a -> Exp b -> Exp c) liftM2 f (Exp a) (Exp b) = Exp (LMonad.liftR2 f a b) liftM3 :: (forall r. MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> LLVM.CodeGenFunction r (MultiValue.T d)) -> (Exp a -> Exp b -> Exp c -> Exp d) liftM3 f (Exp a) (Exp b) (Exp c) = Exp (LMonad.liftR3 f a b c) unliftM1 :: (Exp a -> Exp b) -> MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T b) unliftM1 f ix = unExp (f (lift0 ix)) unliftM2 :: (Exp a -> Exp b -> Exp c) -> MultiValue.T a -> MultiValue.T b -> LLVM.CodeGenFunction r (MultiValue.T c) unliftM2 f ix jx = unExp (f (lift0 ix) (lift0 jx)) unliftM3 :: (Exp a -> Exp b -> Exp c -> Exp d) -> MultiValue.T a -> MultiValue.T b -> MultiValue.T c -> LLVM.CodeGenFunction r (MultiValue.T d) unliftM3 f ix jx kx = unExp (f (lift0 ix) (lift0 jx) (lift0 kx)) min :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a min = liftM2 A.min max :: (MultiValue.Real a) => Exp a -> Exp a -> Exp a max = liftM2 A.max zip :: (Value val) => val a -> val b -> val (a, b) zip = lift2 MultiValue.zip zip3 :: (Value val) => val a -> val b -> val c -> val (a, b, c) zip3 = lift3 MultiValue.zip3 zip4 :: (Value val) => val a -> val b -> val c -> val d -> val (a, b, c, d) zip4 = lift4 MultiValue.zip4 unzip :: (Value val) => val (a, b) -> (val a, val b) unzip ab = (lift1 MultiValue.fst ab, lift1 MultiValue.snd ab) unzip3 :: (Value val) => val (a, b, c) -> (val a, val b, val c) unzip3 abc = (lift1 MultiValue.fst3 abc, lift1 MultiValue.snd3 abc, lift1 MultiValue.thd3 abc) unzip4 :: (Value val) => val (a, b, c, d) -> (val a, val b, val c, val d) unzip4 abcd = (lift1 (\(MultiValue.Cons (a,_,_,_)) -> MultiValue.Cons a) abcd, lift1 (\(MultiValue.Cons (_,b,_,_)) -> MultiValue.Cons b) abcd, lift1 (\(MultiValue.Cons (_,_,c,_)) -> MultiValue.Cons c) abcd, lift1 (\(MultiValue.Cons (_,_,_,d)) -> MultiValue.Cons d) abcd) fst :: (Value val) => val (a, b) -> val a fst = lift1 MultiValue.fst snd :: (Value val) => val (a, b) -> val b snd = lift1 MultiValue.snd mapFst :: (Exp a -> Exp b) -> Exp (a, c) -> Exp (b, c) mapFst f = modify (atom, atom) $ TupleHT.mapFst f mapSnd :: (Exp b -> Exp c) -> Exp (a, b) -> Exp (a, c) mapSnd f = modify (atom, atom) $ TupleHT.mapSnd f swap :: (Value val) => val (a, b) -> val (b, a) swap = lift1 MultiValue.swap curry :: (Exp (a,b) -> c) -> (Exp a -> Exp b -> c) curry f = Tuple.curry (f . Tuple.uncurry zip) uncurry :: (Exp a -> Exp b -> c) -> (Exp (a,b) -> c) uncurry f = Tuple.uncurry f . unzip fst3 :: (Value val) => val (a,b,c) -> val a fst3 = lift1 MultiValue.fst3 snd3 :: (Value val) => val (a,b,c) -> val b snd3 = lift1 MultiValue.snd3 thd3 :: (Value val) => val (a,b,c) -> val c thd3 = lift1 MultiValue.thd3 mapFst3 :: (Exp a0 -> Exp a1) -> Exp (a0,b,c) -> Exp (a1,b,c) mapFst3 f = modify (atom, atom, atom) $ TupleHT.mapFst3 f mapSnd3 :: (Exp b0 -> Exp b1) -> Exp (a,b0,c) -> Exp (a,b1,c) mapSnd3 f = modify (atom, atom, atom) $ TupleHT.mapSnd3 f mapThd3 :: (Exp c0 -> Exp c1) -> Exp (a,b,c0) -> Exp (a,b,c1) mapThd3 f = modify (atom, atom, atom) $ TupleHT.mapThd3 f modifyMultiValue :: (Value val, MultiValue.Compose a, MultiValue.Decompose pattern, MultiValue.PatternTuple pattern ~ tuple) => pattern -> (Decomposed MultiValue.T pattern -> a) -> val tuple -> val (MultiValue.Composed a) modifyMultiValue p f = lift1 $ MultiValue.modify p f modifyMultiValue2 :: (Value val, MultiValue.Compose a, MultiValue.Decompose patternA, MultiValue.Decompose patternB, MultiValue.PatternTuple patternA ~ tupleA, MultiValue.PatternTuple patternB ~ tupleB) => patternA -> patternB -> (Decomposed MultiValue.T patternA -> Decomposed MultiValue.T patternB -> a) -> val tupleA -> val tupleB -> val (MultiValue.Composed a) modifyMultiValue2 pa pb f = lift2 $ MultiValue.modify2 pa pb f modifyMultiValueM :: (MultiValue.Compose a, MultiValue.Decompose pattern, MultiValue.PatternTuple pattern ~ tuple) => pattern -> (forall r. Decomposed MultiValue.T pattern -> LLVM.CodeGenFunction r a) -> Exp tuple -> Exp (MultiValue.Composed a) modifyMultiValueM p f = liftM (MultiValue.modifyF p f) modifyMultiValueM2 :: (MultiValue.Compose a, MultiValue.Decompose patternA, MultiValue.Decompose patternB, MultiValue.PatternTuple patternA ~ tupleA, MultiValue.PatternTuple patternB ~ tupleB) => patternA -> patternB -> (forall r. Decomposed MultiValue.T patternA -> Decomposed MultiValue.T patternB -> LLVM.CodeGenFunction r a) -> Exp tupleA -> Exp tupleB -> Exp (MultiValue.Composed a) modifyMultiValueM2 pa pb f = liftM2 (MultiValue.modifyF2 pa pb f) class Compose multituple where type Composed multituple {- | A nested 'zip'. -} compose :: multituple -> Exp (Composed multituple) class (Composed (Decomposed Exp pattern) ~ PatternTuple pattern) => Decompose pattern where {- | Analogous to 'MultiValue.decompose'. -} decompose :: pattern -> Exp (PatternTuple pattern) -> Decomposed Exp pattern {- | Analogus to 'MultiValue.modifyMultiValue'. -} modify :: (Compose a, Decompose pattern) => pattern -> (Decomposed Exp pattern -> a) -> Exp (PatternTuple pattern) -> Exp (Composed a) modify p f = compose . f . decompose p modify2 :: (Compose a, Decompose patternA, Decompose patternB) => patternA -> patternB -> (Decomposed Exp patternA -> Decomposed Exp patternB -> a) -> Exp (PatternTuple patternA) -> Exp (PatternTuple patternB) -> Exp (Composed a) modify2 pa pb f a b = compose $ f (decompose pa a) (decompose pb b) instance Compose (Exp a) where type Composed (Exp a) = a compose = id instance Decompose (Atom a) where decompose _ = id instance Compose () where type Composed () = () compose = cons instance Decompose () where decompose _ _ = () instance (Compose a, Compose b) => Compose (a,b) where type Composed (a,b) = (Composed a, Composed b) compose = Tuple.uncurry zip . TupleHT.mapPair (compose, compose) instance (Decompose pa, Decompose pb) => Decompose (pa,pb) where decompose (pa,pb) = TupleHT.mapPair (decompose pa, decompose pb) . unzip instance (Compose a, Compose b, Compose c) => Compose (a,b,c) where type Composed (a,b,c) = (Composed a, Composed b, Composed c) compose = TupleHT.uncurry3 zip3 . TupleHT.mapTriple (compose, compose, compose) instance (Decompose pa, Decompose pb, Decompose pc) => Decompose (pa,pb,pc) where decompose (pa,pb,pc) = TupleHT.mapTriple (decompose pa, decompose pb, decompose pc) . unzip3 instance (Compose a, Compose b, Compose c, Compose d) => Compose (a,b,c,d) where type Composed (a,b,c,d) = (Composed a, Composed b, Composed c, Composed d) compose (a,b,c,d) = zip4 (compose a) (compose b) (compose c) (compose d) instance (Decompose pa, Decompose pb, Decompose pc, Decompose pd) => Decompose (pa,pb,pc,pd) where decompose (pa,pb,pc,pd) x = case unzip4 x of (a,b,c,d) -> (decompose pa a, decompose pb b, decompose pc c, decompose pd d) instance (Compose a) => Compose (Complex a) where type Composed (Complex a) = Complex (Composed a) compose (r:+i) = consComplex (compose r) (compose i) instance (Decompose p) => Decompose (Complex p) where decompose (pr:+pi) = Tuple.uncurry (:+) . TupleHT.mapPair (decompose pr, decompose pi) . deconsComplex {- | You can construct complex numbers this way, but they will not make you happy, because the numeric operations require a RealFloat instance that we could only provide with lots of undefined methods (also in its superclasses). You may either define your own arithmetic or use the NumericPrelude type classes. -} consComplex :: Exp a -> Exp a -> Exp (Complex a) consComplex = lift2 MultiValue.consComplex deconsComplex :: Exp (Complex a) -> (Exp a, Exp a) deconsComplex c = (lift1 MultiValue.realPart c, lift1 MultiValue.imagPart c) cons :: (MultiValue.C a) => a -> Exp a cons = lift0 . MultiValue.cons unit :: Exp () unit = cons () zero :: (MultiValue.C a) => Exp a zero = lift0 MultiValue.zero add :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a add = liftM2 MultiValue.add sub :: (MultiValue.Additive a) => Exp a -> Exp a -> Exp a sub = liftM2 MultiValue.sub mul :: (MultiValue.PseudoRing a) => Exp a -> Exp a -> Exp a mul = liftM2 MultiValue.mul sqr :: (MultiValue.PseudoRing a) => Exp a -> Exp a sqr = liftM $ \x -> MultiValue.mul x x sqrt :: (MultiValue.Algebraic a) => Exp a -> Exp a sqrt = liftM MultiValue.sqrt idiv :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a idiv = liftM2 MultiValue.idiv irem :: (MultiValue.Integral a) => Exp a -> Exp a -> Exp a irem = liftM2 MultiValue.irem shl :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a shl = liftM2 MultiValue.shl shr :: (MultiValue.BitShift a) => Exp a -> Exp a -> Exp a shr = liftM2 MultiValue.shr fromInteger' :: (MultiValue.IntegerConstant a) => Integer -> Exp a fromInteger' = lift0 . MultiValue.fromInteger' fromRational' :: (MultiValue.RationalConstant a) => Rational -> Exp a fromRational' = lift0 . MultiValue.fromRational' cmp :: (MultiValue.Comparison a) => LLVM.CmpPredicate -> Exp a -> Exp a -> Exp Bool cmp ord = liftM2 $ MultiValue.cmp ord infix 4 ==*, /=*, <*, <=*, >*, >=* (==*), (/=*), (<*), (>=*), (>*), (<=*) :: (MultiValue.Comparison a) => Exp a -> Exp a -> Exp Bool (==*) = cmp LLVM.CmpEQ (/=*) = cmp LLVM.CmpNE (<*) = cmp LLVM.CmpLT (>=*) = cmp LLVM.CmpGE (>*) = cmp LLVM.CmpGT (<=*) = cmp LLVM.CmpLE true, false :: Exp Bool true = cons True false = cons False infixr 3 &&* (&&*) :: Exp Bool -> Exp Bool -> Exp Bool (&&*) = liftM2 MultiValue.and infixr 2 ||* (||*) :: Exp Bool -> Exp Bool -> Exp Bool (||*) = liftM2 MultiValue.or not :: Exp Bool -> Exp Bool not = liftM MultiValue.inv {- | Like 'ifThenElse' but computes both alternative expressions and then uses LLVM's efficient @select@ instruction. -} select :: (MultiValue.Select a) => Exp Bool -> Exp a -> Exp a -> Exp a select = liftM3 MultiValue.select ifThenElse :: (MultiValue.C a) => Exp Bool -> Exp a -> Exp a -> Exp a ifThenElse ec ex ey = Exp (do MultiValue.Cons c <- unExp ec C.ifThenElse c (unExp ex) (unExp ey)) complement :: (MultiValue.Logic a) => Exp a -> Exp a complement = liftM MultiValue.inv infixl 7 .&.* (.&.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a (.&.*) = liftM2 MultiValue.and infixl 5 .|.* (.|.*) :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a (.|.*) = liftM2 MultiValue.or infixl 6 `xor` xor :: (MultiValue.Logic a) => Exp a -> Exp a -> Exp a xor = liftM2 MultiValue.xor toMaybe :: Exp Bool -> Exp a -> Exp (Maybe a) toMaybe = lift2 MultiValue.toMaybe maybe :: (MultiValue.C b) => Exp b -> (Exp a -> Exp b) -> Exp (Maybe a) -> Exp b maybe n j = liftM $ \m -> do let (MultiValue.Cons b, a) = MultiValue.splitMaybe m C.ifThenElse b (unliftM1 j a) (unExp n) instance (MultiValue.PseudoRing a, MultiValue.Real a, MultiValue.IntegerConstant a) => Num (Exp a) where fromInteger n = lift0 (MultiValue.fromInteger' n) (+) = liftM2 MultiValue.add (-) = liftM2 MultiValue.sub negate = liftM MultiValue.neg (*) = liftM2 MultiValue.mul abs = liftM MultiValue.abs signum = liftM MultiValue.signum instance (MultiValue.Field a, MultiValue.Real a, MultiValue.RationalConstant a) => Fractional (Exp a) where fromRational n = lift0 (MultiValue.fromRational' n) (/) = liftM2 MultiValue.fdiv