{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} module LLVM.Extra.Multi.Value where import qualified LLVM.Extra.ScalarOrVector as SoV import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Control as C import qualified LLVM.Extra.Class as Class import qualified LLVM.Core as LLVM import qualified LLVM.Util.Loop as Loop import LLVM.Util.Loop (Phi, ) import Foreign.StablePtr (StablePtr, ) import Foreign.Ptr (Ptr, FunPtr, ) import qualified Control.Monad.HT as Monad import Control.Monad (Monad, return, fmap, (>>), ) import Data.Functor (Functor, ) import qualified Data.Tuple.HT as TupleHT import qualified Data.Tuple as Tuple import Data.Function (id, (.), ($), ) import Data.Tuple.HT (uncurry3, ) import Data.Bool (Bool, ) import Data.Word (Word8, Word16, Word32, Word64, ) import Data.Int (Int8, Int16, Int32, Int64, ) import Prelude (Float, Double, Integer, Rational, ) newtype T a = Cons (Repr LLVM.Value a) class C a where type Repr (f :: * -> *) a :: * cons :: a -> T a undef :: T a zero :: T a phis :: LLVM.BasicBlock -> T a -> LLVM.CodeGenFunction r (T a) addPhis :: LLVM.BasicBlock -> T a -> T a -> LLVM.CodeGenFunction r () instance C Bool where type Repr f Bool = f Bool cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Float where type Repr f Float = f Float cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Double where type Repr f Double = f Double cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Word8 where type Repr f Word8 = f Word8 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Word16 where type Repr f Word16 = f Word16 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Word32 where type Repr f Word32 = f Word32 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Word64 where type Repr f Word64 = f Word64 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Int8 where type Repr f Int8 = f Int8 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Int16 where type Repr f Int16 = f Int16 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Int32 where type Repr f Int32 = f Int32 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C Int64 where type Repr f Int64 = f Int64 cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance (LLVM.IsType a) => C (Ptr a) where -- Do we also have to convert the pointer target type? type Repr f (Ptr a) = f (Ptr a) cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance (LLVM.IsFunction a) => C (FunPtr a) where type Repr f (FunPtr a) = f (FunPtr a) cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive instance C (StablePtr a) where type Repr f (StablePtr a) = f (StablePtr a) cons = consPrimitive undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive consPrimitive :: (LLVM.IsConst al, LLVM.Value al ~ Repr LLVM.Value a) => al -> T a consPrimitive = Cons . LLVM.valueOf undefPrimitive, zeroPrimitive :: (LLVM.IsType al, LLVM.Value al ~ Repr LLVM.Value a) => T a undefPrimitive = Cons $ LLVM.value LLVM.undef zeroPrimitive = Cons $ LLVM.value LLVM.zero phisPrimitive :: (LLVM.IsFirstClass al, LLVM.Value al ~ Repr LLVM.Value a) => LLVM.BasicBlock -> T a -> LLVM.CodeGenFunction r (T a) phisPrimitive bb (Cons a) = fmap Cons $ Loop.phis bb a addPhisPrimitive :: (LLVM.IsFirstClass al, LLVM.Value al ~ Repr LLVM.Value a) => LLVM.BasicBlock -> T a -> T a -> LLVM.CodeGenFunction r () addPhisPrimitive bb (Cons a) (Cons b) = Loop.addPhis bb a b instance C () where type Repr f () = () cons = consUnit undef = undefUnit zero = zeroUnit phis = phisUnit addPhis = addPhisUnit consUnit :: (Repr LLVM.Value a ~ ()) => a -> T a consUnit _ = Cons () undefUnit :: (Repr LLVM.Value a ~ ()) => T a undefUnit = Cons () zeroUnit :: (Repr LLVM.Value a ~ ()) => T a zeroUnit = Cons () phisUnit :: (Repr LLVM.Value a ~ ()) => LLVM.BasicBlock -> T a -> LLVM.CodeGenFunction r (T a) phisUnit _bb (Cons ()) = return $ Cons () addPhisUnit :: (Repr LLVM.Value a ~ ()) => LLVM.BasicBlock -> T a -> T a -> LLVM.CodeGenFunction r () addPhisUnit _bb (Cons ()) (Cons ()) = return () instance (C a, C b) => C (a,b) where type Repr f (a, b) = (Repr f a, Repr f b) cons (a,b) = zip (cons a) (cons b) undef = zip undef undef zero = zip zero zero phis bb a = case unzip a of (a0,a1) -> Monad.lift2 zip (phis bb a0) (phis bb a1) addPhis bb a b = case (unzip a, unzip b) of ((a0,a1), (b0,b1)) -> addPhis bb a0 b0 >> addPhis bb a1 b1 instance (C a, C b, C c) => C (a,b,c) where type Repr f (a, b, c) = (Repr f a, Repr f b, Repr f c) cons (a,b,c) = zip3 (cons a) (cons b) (cons c) undef = zip3 undef undef undef zero = zip3 zero zero zero phis bb a = case unzip3 a of (a0,a1,a2) -> Monad.lift3 zip3 (phis bb a0) (phis bb a1) (phis bb a2) addPhis bb a b = case (unzip3 a, unzip3 b) of ((a0,a1,a2), (b0,b1,b2)) -> addPhis bb a0 b0 >> addPhis bb a1 b1 >> addPhis bb a2 b2 instance (C a, C b, C c, C d) => C (a,b,c,d) where type Repr f (a, b, c, d) = (Repr f a, Repr f b, Repr f c, Repr f d) cons (a,b,c,d) = zip4 (cons a) (cons b) (cons c) (cons d) undef = zip4 undef undef undef undef zero = zip4 zero zero zero zero phis bb a = case unzip4 a of (a0,a1,a2,a3) -> Monad.lift4 zip4 (phis bb a0) (phis bb a1) (phis bb a2) (phis bb a3) addPhis bb a b = case (unzip4 a, unzip4 b) of ((a0,a1,a2,a3), (b0,b1,b2,b3)) -> addPhis bb a0 b0 >> addPhis bb a1 b1 >> addPhis bb a2 b2 >> addPhis bb a3 b3 fst :: T (a,b) -> T a fst (Cons (a,_b)) = Cons a snd :: T (a,b) -> T b snd (Cons (_a,b)) = Cons b curry :: (T (a,b) -> c) -> (T a -> T b -> c) curry f a b = f $ zip a b uncurry :: (T a -> T b -> c) -> (T (a,b) -> c) uncurry f = Tuple.uncurry f . unzip mapFst :: (T a0 -> T a1) -> T (a0,b) -> T (a1,b) mapFst f = Tuple.uncurry zip . TupleHT.mapFst f . unzip mapSnd :: (T b0 -> T b1) -> T (a,b0) -> T (a,b1) mapSnd f = Tuple.uncurry zip . TupleHT.mapSnd f . unzip swap :: T (a,b) -> T (b,a) swap = Tuple.uncurry zip . TupleHT.swap . unzip fst3 :: T (a,b,c) -> T a fst3 (Cons (a,_b,_c)) = Cons a snd3 :: T (a,b,c) -> T b snd3 (Cons (_a,b,_c)) = Cons b thd3 :: T (a,b,c) -> T c thd3 (Cons (_a,_b,c)) = Cons c mapFst3 :: (T a0 -> T a1) -> T (a0,b,c) -> T (a1,b,c) mapFst3 f = uncurry3 zip3 . TupleHT.mapFst3 f . unzip3 mapSnd3 :: (T b0 -> T b1) -> T (a,b0,c) -> T (a,b1,c) mapSnd3 f = uncurry3 zip3 . TupleHT.mapSnd3 f . unzip3 mapThd3 :: (T c0 -> T c1) -> T (a,b,c0) -> T (a,b,c1) mapThd3 f = uncurry3 zip3 . TupleHT.mapThd3 f . unzip3 zip :: T a -> T b -> T (a,b) zip (Cons a) (Cons b) = Cons (a,b) zip3 :: T a -> T b -> T c -> T (a,b,c) zip3 (Cons a) (Cons b) (Cons c) = Cons (a,b,c) zip4 :: T a -> T b -> T c -> T d -> T (a,b,c,d) zip4 (Cons a) (Cons b) (Cons c) (Cons d) = Cons (a,b,c,d) unzip :: T (a,b) -> (T a, T b) unzip (Cons (a,b)) = (Cons a, Cons b) unzip3 :: T (a,b,c) -> (T a, T b, T c) unzip3 (Cons (a,b,c)) = (Cons a, Cons b, Cons c) unzip4 :: T (a,b,c,d) -> (T a, T b, T c, T d) unzip4 (Cons (a,b,c,d)) = (Cons a, Cons b, Cons c, Cons d) class Compose multituple where type Composed multituple {- | A nested 'zip'. -} compose :: multituple -> T (Composed multituple) class (Composed (Decomposed T pattern) ~ PatternTuple pattern) => Decompose pattern where {- | A nested 'unzip'. Since it is not obvious how deep to decompose nested tuples, you must provide a pattern of the decomposed tuple. E.g. > f :: MultiValue ((a,b),(c,d)) -> > ((MultiValue a, MultiValue b), MultiValue (c,d)) > f = decompose ((atom,atom),atom) -} decompose :: pattern -> T (PatternTuple pattern) -> Decomposed T pattern type family Decomposed (f :: * -> *) pattern type family PatternTuple pattern {- | A combination of 'compose' and 'decompose' that let you operate on tuple multivalues as Haskell tuples. -} modify :: (Compose a, Decompose pattern) => pattern -> (Decomposed T pattern -> a) -> T (PatternTuple pattern) -> T (Composed a) modify p f = compose . f . decompose p modify2 :: (Compose a, Decompose patternA, Decompose patternB) => patternA -> patternB -> (Decomposed T patternA -> Decomposed T patternB -> a) -> T (PatternTuple patternA) -> T (PatternTuple patternB) -> T (Composed a) modify2 pa pb f a b = compose $ f (decompose pa a) (decompose pb b) modifyF :: (Compose a, Decompose pattern, Functor f) => pattern -> (Decomposed T pattern -> f a) -> T (PatternTuple pattern) -> f (T (Composed a)) modifyF p f = fmap compose . f . decompose p modifyF2 :: (Compose a, Decompose patternA, Decompose patternB, Functor f) => patternA -> patternB -> (Decomposed T patternA -> Decomposed T patternB -> f a) -> T (PatternTuple patternA) -> T (PatternTuple patternB) -> f (T (Composed a)) modifyF2 pa pb f a b = fmap compose $ f (decompose pa a) (decompose pb b) instance Compose (T a) where type Composed (T a) = a compose = id instance Decompose (Atom a) where decompose _ = id type instance Decomposed f (Atom a) = f a type instance PatternTuple (Atom a) = a data Atom a = Atom atom :: Atom a atom = Atom instance (Compose a, Compose b) => Compose (a,b) where type Composed (a,b) = (Composed a, Composed b) compose = Tuple.uncurry zip . TupleHT.mapPair (compose, compose) instance (Decompose pa, Decompose pb) => Decompose (pa,pb) where decompose (pa,pb) = TupleHT.mapPair (decompose pa, decompose pb) . unzip type instance Decomposed f (pa,pb) = (Decomposed f pa, Decomposed f pb) type instance PatternTuple (pa,pb) = (PatternTuple pa, PatternTuple pb) instance (Compose a, Compose b, Compose c) => Compose (a,b,c) where type Composed (a,b,c) = (Composed a, Composed b, Composed c) compose = uncurry3 zip3 . TupleHT.mapTriple (compose, compose, compose) instance (Decompose pa, Decompose pb, Decompose pc) => Decompose (pa,pb,pc) where decompose (pa,pb,pc) = TupleHT.mapTriple (decompose pa, decompose pb, decompose pc) . unzip3 type instance Decomposed f (pa,pb,pc) = (Decomposed f pa, Decomposed f pb, Decomposed f pc) type instance PatternTuple (pa,pb,pc) = (PatternTuple pa, PatternTuple pb, PatternTuple pc) instance (Compose a, Compose b, Compose c, Compose d) => Compose (a,b,c,d) where type Composed (a,b,c,d) = (Composed a, Composed b, Composed c, Composed d) compose (a,b,c,d) = zip4 (compose a) (compose b) (compose c) (compose d) instance (Decompose pa, Decompose pb, Decompose pc, Decompose pd) => Decompose (pa,pb,pc,pd) where decompose (pa,pb,pc,pd) x = case unzip4 x of (a,b,c,d) -> (decompose pa a, decompose pb b, decompose pc c, decompose pd d) type instance Decomposed f (pa,pb,pc,pd) = (Decomposed f pa, Decomposed f pb, Decomposed f pc, Decomposed f pd) type instance PatternTuple (pa,pb,pc,pd) = (PatternTuple pa, PatternTuple pb, PatternTuple pc, PatternTuple pd) lift1 :: (Repr LLVM.Value a -> Repr LLVM.Value b) -> T a -> T b lift1 f (Cons a) = Cons $ f a liftM0 :: (Monad m) => m (Repr LLVM.Value a) -> m (T a) liftM0 f = Monad.lift Cons f liftM :: (Monad m) => (Repr LLVM.Value a -> m (Repr LLVM.Value b)) -> T a -> m (T b) liftM f (Cons a) = Monad.lift Cons $ f a liftM2 :: (Monad m) => (Repr LLVM.Value a -> Repr LLVM.Value b -> m (Repr LLVM.Value c)) -> T a -> T b -> m (T c) liftM2 f (Cons a) (Cons b) = Monad.lift Cons $ f a b liftM3 :: (Monad m) => (Repr LLVM.Value a -> Repr LLVM.Value b -> Repr LLVM.Value c -> m (Repr LLVM.Value d)) -> T a -> T b -> T c -> m (T d) liftM3 f (Cons a) (Cons b) (Cons c) = Monad.lift Cons $ f a b c instance (C a) => Class.Zero (T a) where zeroTuple = zero instance (C a) => Class.Undefined (T a) where undefTuple = undef instance (C a) => Phi (T a) where phis = phis addPhis = addPhis class (C a) => IntegerConstant a where fromInteger' :: Integer -> T a class (IntegerConstant a) => RationalConstant a where fromRational' :: Rational -> T a instance IntegerConstant Float where fromInteger' = Cons . LLVM.value . SoV.constFromInteger instance IntegerConstant Double where fromInteger' = Cons . LLVM.value . SoV.constFromInteger instance RationalConstant Float where fromRational' = Cons . LLVM.value . SoV.constFromRational instance RationalConstant Double where fromRational' = Cons . LLVM.value . SoV.constFromRational instance (IntegerConstant a) => A.IntegerConstant (T a) where fromInteger' = fromInteger' instance (RationalConstant a) => A.RationalConstant (T a) where fromRational' = fromRational' class (C a) => Additive a where add :: T a -> T a -> LLVM.CodeGenFunction r (T a) sub :: T a -> T a -> LLVM.CodeGenFunction r (T a) neg :: T a -> LLVM.CodeGenFunction r (T a) instance Additive Float where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance Additive Double where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance Additive Word32 where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance Additive Word64 where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance Additive Int32 where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance Additive Int64 where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance (Additive a) => A.Additive (T a) where zero = zero add = add sub = sub neg = neg class (Additive a) => PseudoRing a where mul :: T a -> T a -> LLVM.CodeGenFunction r (T a) instance PseudoRing Float where mul = liftM2 LLVM.mul instance PseudoRing Double where mul = liftM2 LLVM.mul instance PseudoRing Word32 where mul = liftM2 LLVM.mul instance PseudoRing Word64 where mul = liftM2 LLVM.mul instance PseudoRing Int32 where mul = liftM2 LLVM.mul instance PseudoRing Int64 where mul = liftM2 LLVM.mul instance (PseudoRing a) => A.PseudoRing (T a) where mul = mul class (PseudoRing a) => Field a where fdiv :: T a -> T a -> LLVM.CodeGenFunction r (T a) instance Field Float where fdiv = liftM2 LLVM.fdiv instance Field Double where fdiv = liftM2 LLVM.fdiv instance (Field a) => A.Field (T a) where fdiv = fdiv type family Scalar vector :: * type instance Scalar Float = Float type instance Scalar Double = Double type instance A.Scalar (T a) = T (Scalar a) class (PseudoRing (Scalar v), Additive v) => PseudoModule v where scale :: T (Scalar v) -> T v -> LLVM.CodeGenFunction r (T v) instance PseudoModule Float where scale = liftM2 A.mul instance PseudoModule Double where scale = liftM2 A.mul instance (PseudoModule a) => A.PseudoModule (T a) where scale = scale class (Additive a) => Real a where min :: T a -> T a -> LLVM.CodeGenFunction r (T a) max :: T a -> T a -> LLVM.CodeGenFunction r (T a) abs :: T a -> LLVM.CodeGenFunction r (T a) signum :: T a -> LLVM.CodeGenFunction r (T a) instance Real Float where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance Real Double where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance Real Word32 where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance Real Word64 where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance Real Int32 where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance Real Int64 where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance (Real a) => A.Real (T a) where min = min max = max abs = abs signum = signum class (Real a) => Fraction a where truncate :: T a -> LLVM.CodeGenFunction r (T a) fraction :: T a -> LLVM.CodeGenFunction r (T a) instance Fraction Float where truncate = liftM A.truncate fraction = liftM A.fraction instance Fraction Double where truncate = liftM A.truncate fraction = liftM A.fraction instance (Fraction a) => A.Fraction (T a) where truncate = truncate fraction = fraction class Field a => Algebraic a where sqrt :: T a -> LLVM.CodeGenFunction r (T a) instance Algebraic Float where sqrt = liftM A.sqrt instance Algebraic Double where sqrt = liftM A.sqrt instance (Algebraic a) => A.Algebraic (T a) where sqrt = sqrt class Algebraic a => Transcendental a where pi :: LLVM.CodeGenFunction r (T a) sin, cos, exp, log :: T a -> LLVM.CodeGenFunction r (T a) pow :: T a -> T a -> LLVM.CodeGenFunction r (T a) instance Transcendental Float where pi = liftM0 A.pi sin = liftM A.sin cos = liftM A.cos exp = liftM A.exp log = liftM A.log pow = liftM2 A.pow instance Transcendental Double where pi = liftM0 A.pi sin = liftM A.sin cos = liftM A.cos exp = liftM A.exp log = liftM A.log pow = liftM2 A.pow instance (Transcendental a) => A.Transcendental (T a) where pi = pi sin = sin cos = cos exp = exp log = log pow = pow class (C a) => Select a where select :: T Bool -> T a -> T a -> LLVM.CodeGenFunction r (T a) instance Select Float where select = liftM3 LLVM.select instance Select Double where select = liftM3 LLVM.select instance Select Word32 where select = liftM3 LLVM.select instance Select Word64 where select = liftM3 LLVM.select instance Select Int32 where select = liftM3 LLVM.select instance Select Int64 where select = liftM3 LLVM.select instance (Select a, Select b) => Select (a,b) where select b = modifyF2 (atom,atom) (atom,atom) $ \(a0,b0) (a1,b1) -> Monad.lift2 (,) (select b a0 a1) (select b b0 b1) instance (Select a, Select b, Select c) => Select (a,b,c) where select b = modifyF2 (atom,atom,atom) (atom,atom,atom) $ \(a0,b0,c0) (a1,b1,c1) -> Monad.lift3 (,,) (select b a0 a1) (select b b0 b1) (select b c0 c1) instance (Select a) => C.Select (T a) where select b = select (Cons b) class (Real a) => Comparison a where {- | It must hold > max x y == do gt <- cmp CmpGT x y; select gt x y -} cmp :: LLVM.CmpPredicate -> T a -> T a -> LLVM.CodeGenFunction r (T Bool) instance Comparison Float where cmp = liftM2 . LLVM.cmp instance Comparison Double where cmp = liftM2 . LLVM.cmp instance (Comparison a) => A.Comparison (T a) where type CmpResult (T a) = T Bool cmp = cmp class (Comparison a) => FloatingComparison a where fcmp :: LLVM.FPPredicate -> T a -> T a -> LLVM.CodeGenFunction r (T Bool) instance FloatingComparison Float where fcmp = liftM2 . LLVM.fcmp instance (FloatingComparison a) => A.FloatingComparison (T a) where fcmp = fcmp class Logic a where and :: T a -> T a -> LLVM.CodeGenFunction r (T a) or :: T a -> T a -> LLVM.CodeGenFunction r (T a) xor :: T a -> T a -> LLVM.CodeGenFunction r (T a) inv :: T a -> LLVM.CodeGenFunction r (T a) instance Logic Bool where and = liftM2 LLVM.and or = liftM2 LLVM.or xor = liftM2 LLVM.xor inv = liftM LLVM.inv instance Logic a => A.Logic (T a) where and = and or = or xor = xor inv = inv class (PseudoRing a) => Integral a where idiv :: T a -> T a -> LLVM.CodeGenFunction r (T a) irem :: T a -> T a -> LLVM.CodeGenFunction r (T a) instance Integral Word32 where idiv = liftM2 LLVM.idiv irem = liftM2 LLVM.irem instance Integral Word64 where idiv = liftM2 LLVM.idiv irem = liftM2 LLVM.irem instance Integral Int32 where idiv = liftM2 LLVM.idiv irem = liftM2 LLVM.irem instance Integral Int64 where idiv = liftM2 LLVM.idiv irem = liftM2 LLVM.irem