{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} module LLVM.Extra.Multi.Vector ( T(Cons), consPrim, deconsPrim, C(..), Value(Value), map, zip, zip3, unzip, unzip3, replicate, iterate, lift1, modify, assemble, dissect, dissectList, reverse, rotateUp, rotateDown, shiftUp, shiftDown, shiftUpMultiZero, shiftDownMultiZero, undefPrimitive, shuffleMatchPrimitive, extractPrimitive, insertPrimitive, shuffleMatchTraversable, insertTraversable, extractTraversable, Additive(..), PseudoRing(..), Field(..), PseudoModule(..), Real(..), Fraction(..), Algebraic(..), Transcendental(..), FloatingComparison(..), Comparison(..), Logic(..), ) where import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.ScalarOrVector as SoV import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Class as Class import LLVM.Extra.Multi.Value (Repr, ) import qualified LLVM.Util.Loop as Loop import qualified LLVM.Core as LLVM import LLVM.Util.Loop (Phi, ) import LLVM.Core (valueOf, value, IsPrimitive, CodeGenFunction, ) import qualified Type.Data.Num.Decimal as TypeNum import qualified Data.Traversable as Trav import qualified Data.NonEmpty as NonEmpty import qualified Data.List as List import Data.Traversable (mapM, sequence, ) import Data.NonEmpty ((!:), ) import Data.Function (flip, (.), ($), ) import Data.Tuple (snd, ) import Data.Maybe (maybe, ) import Data.List (take, (++), ) import Data.Word (Word32, ) import Data.Bool (Bool, ) import qualified Control.Applicative as App import qualified Control.Monad.HT as Monad import Control.Monad.HT ((<=<), ) import Control.Monad (Monad, foldM, fmap, (>>), (=<<), ) import Control.Applicative (liftA2, ) import Prelude (Float, Double, Integer, Int, Rational, fromIntegral, (-), error, ) newtype T n a = Cons (Repr (Value n) a) newtype Value n a = Value (PrimValue n a) consPrim :: (Repr (Value n) a ~ Value n a) => LLVM.Value (LLVM.Vector n a) -> T n a consPrim = Cons . Value deconsPrim :: (Repr (Value n) a ~ Value n a) => T n a -> LLVM.Value (LLVM.Vector n a) deconsPrim (Cons (Value a)) = a instance (TypeNum.Positive n, C a) => Class.Undefined (T n a) where undefTuple = undef instance (TypeNum.Positive n, C a) => Class.Zero (T n a) where zeroTuple = zero instance (TypeNum.Positive n, C a) => Phi (T n a) where phis = phis addPhis = addPhis size :: TypeNum.Positive n => T n a -> Int size = let sz :: TypeNum.Positive n => TypeNum.Singleton n -> T n a -> Int sz n _ = TypeNum.integralFromSingleton n in sz TypeNum.singleton zip :: T n a -> T n b -> T n (a,b) zip (Cons a) (Cons b) = Cons (a,b) zip3 :: T n a -> T n b -> T n c -> T n (a,b,c) zip3 (Cons a) (Cons b) (Cons c) = Cons (a,b,c) unzip :: T n (a,b) -> (T n a, T n b) unzip (Cons (a,b)) = (Cons a, Cons b) unzip3 :: T n (a,b,c) -> (T n a, T n b, T n c) unzip3 (Cons (a,b,c)) = (Cons a, Cons b, Cons c) class (MultiValue.C a) => C a where undef :: (TypeNum.Positive n) => T n a zero :: (TypeNum.Positive n) => T n a phis :: (TypeNum.Positive n) => LLVM.BasicBlock -> T n a -> LLVM.CodeGenFunction r (T n a) addPhis :: (TypeNum.Positive n) => LLVM.BasicBlock -> T n a -> T n a -> LLVM.CodeGenFunction r () shuffleMatch :: (TypeNum.Positive n) => LLVM.ConstValue (LLVM.Vector n Word32) -> T n a -> CodeGenFunction r (T n a) extract :: (TypeNum.Positive n) => LLVM.Value Word32 -> T n a -> CodeGenFunction r (MultiValue.T a) insert :: (TypeNum.Positive n) => LLVM.Value Word32 -> MultiValue.T a -> T n a -> CodeGenFunction r (T n a) instance C Bool where undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive shuffleMatch = shuffleMatchPrimitive extract = extractPrimitive insert = insertPrimitive instance C Float where undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive shuffleMatch = shuffleMatchPrimitive extract = extractPrimitive insert = insertPrimitive instance C Double where undef = undefPrimitive zero = zeroPrimitive phis = phisPrimitive addPhis = addPhisPrimitive shuffleMatch = shuffleMatchPrimitive extract = extractPrimitive insert = insertPrimitive undefPrimitive :: (TypeNum.Positive n, IsPrimitive a, Repr (Value n) a ~ Value n a) => T n a undefPrimitive = Cons $ Value $ LLVM.value LLVM.undef zeroPrimitive :: (TypeNum.Positive n, IsPrimitive a, Repr (Value n) a ~ Value n a) => T n a zeroPrimitive = Cons $ Value $ LLVM.value LLVM.zero phisPrimitive :: (TypeNum.Positive n, IsPrimitive a, Repr (Value n) a ~ Value n a) => LLVM.BasicBlock -> T n a -> LLVM.CodeGenFunction r (T n a) phisPrimitive bb (Cons (Value a)) = fmap (Cons . Value) $ Loop.phis bb a addPhisPrimitive :: (TypeNum.Positive n, IsPrimitive a, Repr (Value n) a ~ Value n a) => LLVM.BasicBlock -> T n a -> T n a -> LLVM.CodeGenFunction r () addPhisPrimitive bb (Cons (Value a)) (Cons (Value b)) = Loop.addPhis bb a b shuffleMatchPrimitive :: (TypeNum.Positive n, IsPrimitive a, Repr LLVM.Value a ~ LLVM.Value a, Repr (Value n) a ~ Value n a) => LLVM.ConstValue (LLVM.Vector n Word32) -> T n a -> CodeGenFunction r (T n a) shuffleMatchPrimitive k (Cons (Value v)) = fmap (Cons . Value) $ LLVM.shufflevector v (value LLVM.undef) k extractPrimitive :: (TypeNum.Positive n, IsPrimitive a, Repr LLVM.Value a ~ LLVM.Value a, Repr (Value n) a ~ Value n a) => LLVM.Value Word32 -> T n a -> CodeGenFunction r (MultiValue.T a) extractPrimitive k (Cons (Value v)) = fmap MultiValue.Cons $ LLVM.extractelement v k insertPrimitive :: (TypeNum.Positive n, IsPrimitive a, -- this constraint is accepted, but does not help -- Repr f a ~ f a, Repr LLVM.Value a ~ LLVM.Value a, Repr (Value n) a ~ Value n a) => LLVM.Value Word32 -> MultiValue.T a -> T n a -> CodeGenFunction r (T n a) insertPrimitive k (MultiValue.Cons a) (Cons (Value v)) = fmap (Cons . Value) $ LLVM.insertelement v a k instance (C a, C b) => C (a,b) where undef = zip undef undef zero = zip zero zero phis bb a = case unzip a of (a0,a1) -> Monad.lift2 zip (phis bb a0) (phis bb a1) addPhis bb a b = case (unzip a, unzip b) of ((a0,a1), (b0,b1)) -> addPhis bb a0 b0 >> addPhis bb a1 b1 shuffleMatch is v = case unzip v of (v0,v1) -> Monad.lift2 zip (shuffleMatch is v0) (shuffleMatch is v1) extract k v = case unzip v of (v0,v1) -> Monad.lift2 MultiValue.zip (extract k v0) (extract k v1) insert k a v = case (MultiValue.unzip a, unzip v) of ((a0,a1), (v0,v1)) -> Monad.lift2 zip (insert k a0 v0) (insert k a1 v1) instance (C a, C b, C c) => C (a,b,c) where undef = zip3 undef undef undef zero = zip3 zero zero zero phis bb a = case unzip3 a of (a0,a1,a2) -> Monad.lift3 zip3 (phis bb a0) (phis bb a1) (phis bb a2) addPhis bb a b = case (unzip3 a, unzip3 b) of ((a0,a1,a2), (b0,b1,b2)) -> addPhis bb a0 b0 >> addPhis bb a1 b1 >> addPhis bb a2 b2 shuffleMatch is v = case unzip3 v of (v0,v1,v2) -> Monad.lift3 zip3 (shuffleMatch is v0) (shuffleMatch is v1) (shuffleMatch is v2) extract k v = case unzip3 v of (v0,v1,v2) -> Monad.lift3 MultiValue.zip3 (extract k v0) (extract k v1) (extract k v2) insert k a v = case (MultiValue.unzip3 a, unzip3 v) of ((a0,a1,a2), (v0,v1,v2)) -> Monad.lift3 zip3 (insert k a0 v0) (insert k a1 v1) (insert k a2 v2) class (C a) => IntegerConstant a where fromInteger' :: (TypeNum.Positive n) => Integer -> T n a class (IntegerConstant a) => RationalConstant a where fromRational' :: (TypeNum.Positive n) => Rational -> T n a instance IntegerConstant Float where fromInteger' = Cons . Value . LLVM.value . SoV.constFromInteger instance IntegerConstant Double where fromInteger' = Cons . Value . LLVM.value . SoV.constFromInteger instance RationalConstant Float where fromRational' = Cons . Value . LLVM.value . SoV.constFromRational instance RationalConstant Double where fromRational' = Cons . Value . LLVM.value . SoV.constFromRational instance (TypeNum.Positive n, IntegerConstant a) => A.IntegerConstant (T n a) where fromInteger' = fromInteger' instance (TypeNum.Positive n, RationalConstant a) => A.RationalConstant (T n a) where fromRational' = fromRational' modify :: (TypeNum.Positive n, C a) => LLVM.Value Word32 -> (MultiValue.T a -> CodeGenFunction r (MultiValue.T a)) -> (T n a -> CodeGenFunction r (T n a)) modify k f v = flip (insert k) v =<< f =<< extract k v assemble :: (TypeNum.Positive n, C a) => [MultiValue.T a] -> CodeGenFunction r (T n a) assemble = foldM (\v (k,x) -> insert (valueOf k) x v) undef . List.zip [0..] dissect :: (TypeNum.Positive n, C a) => T n a -> LLVM.CodeGenFunction r [MultiValue.T a] dissect = sequence . dissectList dissectList :: (TypeNum.Positive n, C a) => T n a -> [LLVM.CodeGenFunction r (MultiValue.T a)] dissectList x = List.map (flip extract x . LLVM.valueOf) (take (size x) [0..]) map :: (TypeNum.Positive n, C a, C b) => (MultiValue.T a -> CodeGenFunction r (MultiValue.T b)) -> (T n a -> CodeGenFunction r (T n b)) map f = assemble <=< mapM f <=< dissect replicate :: (TypeNum.Positive n, C a) => MultiValue.T a -> CodeGenFunction r (T n a) replicate = replicateCore TypeNum.singleton replicateCore :: (TypeNum.Positive n, C a) => TypeNum.Singleton n -> MultiValue.T a -> CodeGenFunction r (T n a) replicateCore n = assemble . List.replicate (TypeNum.integralFromSingleton n) iterate :: (TypeNum.Positive n, C a) => (MultiValue.T a -> CodeGenFunction r (MultiValue.T a)) -> MultiValue.T a -> CodeGenFunction r (T n a) iterate f x = fmap snd $ iterateCore f x Class.undefTuple iterateCore :: (TypeNum.Positive n, C a) => (MultiValue.T a -> CodeGenFunction r (MultiValue.T a)) -> MultiValue.T a -> T n a -> CodeGenFunction r (MultiValue.T a, T n a) iterateCore f x0 v0 = foldM (\(x,v) k -> Monad.lift2 (,) (f x) (insert (valueOf k) x v)) (x0,v0) (take (size v0) [0..]) -- * re-ordering of elements constCyclicVector :: (LLVM.IsConst a, TypeNum.Positive n) => NonEmpty.T [] a -> LLVM.ConstValue (LLVM.Vector n a) constCyclicVector = LLVM.constCyclicVector . fmap LLVM.constOf {- | Rotate one element towards the higher elements. I don't want to call it rotateLeft or rotateRight, because there is no prefered layout for the vector elements. In Intel's instruction manual vector elements are indexed like the bits, that is from right to left. However, when working with Haskell list and enumeration syntax, the start index is left. -} rotateUp :: (TypeNum.Positive n, C a) => T n a -> CodeGenFunction r (T n a) rotateUp x = shuffleMatch (constCyclicVector $ (fromIntegral (size x) - 1) !: [0..]) x rotateDown :: (TypeNum.Positive n, C a) => T n a -> CodeGenFunction r (T n a) rotateDown x = shuffleMatch (constCyclicVector $ NonEmpty.snoc (List.take (size x - 1) [1..]) 0) x reverse :: (TypeNum.Positive n, C a) => T n a -> CodeGenFunction r (T n a) reverse x = shuffleMatch (constCyclicVector $ maybe (error "vector size must be positive") NonEmpty.reverse $ NonEmpty.fetch $ List.take (size x) [0..]) x shiftUp :: (TypeNum.Positive n, C a) => MultiValue.T a -> T n a -> CodeGenFunction r (MultiValue.T a, T n a) shiftUp x0 x = do y <- shuffleMatch (LLVM.constCyclicVector $ LLVM.undef !: List.map LLVM.constOf [0..]) x Monad.lift2 (,) (extract (LLVM.valueOf (fromIntegral (size x) - 1)) x) (insert (value LLVM.zero) x0 y) shiftDown :: (TypeNum.Positive n, C a) => MultiValue.T a -> T n a -> CodeGenFunction r (MultiValue.T a, T n a) shiftDown x0 x = do y <- shuffleMatch (LLVM.constCyclicVector $ NonEmpty.snoc (List.map LLVM.constOf $ List.take (size x - 1) [1..]) LLVM.undef) x Monad.lift2 (,) (extract (value LLVM.zero) x) (insert (LLVM.valueOf (fromIntegral (size x) - 1)) x0 y) shiftUpMultiZero :: (TypeNum.Positive n, C a, Class.ValueTuple a ~ al, Class.Zero al) => Int -> T n a -> LLVM.CodeGenFunction r (T n a) shiftUpMultiZero n v = assemble . take (size v) . (List.replicate n MultiValue.zero ++) =<< dissect v shiftDownMultiZero :: (TypeNum.Positive n, C a, Class.ValueTuple a ~ al, Class.Zero al) => Int -> T n a -> LLVM.CodeGenFunction r (T n a) shiftDownMultiZero n v = assemble . take (size v) . (++ List.repeat MultiValue.zero) . List.drop n =<< dissect v -- * method implementations based on Traversable shuffleMatchTraversable :: (TypeNum.Positive n, C a, Trav.Traversable f) => LLVM.ConstValue (LLVM.Vector n Word32) -> f (T n a) -> CodeGenFunction r (f (T n a)) shuffleMatchTraversable is v = Trav.mapM (shuffleMatch is) v insertTraversable :: (TypeNum.Positive n, C a, Trav.Traversable f, App.Applicative f) => LLVM.Value Word32 -> f (MultiValue.T a) -> f (T n a) -> CodeGenFunction r (f (T n a)) insertTraversable n a v = Trav.sequence (liftA2 (insert n) a v) extractTraversable :: (TypeNum.Positive n, C a, Trav.Traversable f) => LLVM.Value Word32 -> f (T n a) -> CodeGenFunction r (f (MultiValue.T a)) extractTraversable n v = Trav.mapM (extract n) v type PrimValue n a = LLVM.Value (LLVM.Vector n a) lift1 :: (Repr (Value n) a -> Repr (Value n) b) -> T n a -> T n b lift1 f (Cons a) = Cons $ f a _liftM0 :: (Monad m) => m (Repr (Value n) a) -> m (T n a) _liftM0 f = Monad.lift Cons f liftM0 :: (Monad m, Repr (Value n) a ~ Value n a) => m (PrimValue n a) -> m (T n a) liftM0 f = Monad.lift consPrim f liftM :: (Monad m, Repr (Value n) a ~ Value n a, Repr (Value n) b ~ Value n b) => (PrimValue n a -> m (PrimValue n b)) -> T n a -> m (T n b) liftM f a = Monad.lift consPrim $ f (deconsPrim a) liftM2 :: (Monad m, Repr (Value n) a ~ Value n a, Repr (Value n) b ~ Value n b, Repr (Value n) c ~ Value n c) => (PrimValue n a -> PrimValue n b -> m (PrimValue n c)) -> T n a -> T n b -> m (T n c) liftM2 f a b = Monad.lift consPrim $ f (deconsPrim a) (deconsPrim b) class (MultiValue.Additive a, C a) => Additive a where add :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) sub :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) neg :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) instance Additive Float where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance Additive Double where add = liftM2 LLVM.add sub = liftM2 LLVM.sub neg = liftM LLVM.neg instance (TypeNum.Positive n, Additive a) => A.Additive (T n a) where zero = zero add = add sub = sub neg = neg class (MultiValue.PseudoRing a, Additive a) => PseudoRing a where mul :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) instance PseudoRing Float where mul = liftM2 LLVM.mul instance PseudoRing Double where mul = liftM2 LLVM.mul instance (TypeNum.Positive n, PseudoRing a) => A.PseudoRing (T n a) where mul = mul class (MultiValue.Field a, PseudoRing a) => Field a where fdiv :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) instance Field Float where fdiv = liftM2 LLVM.fdiv instance Field Double where fdiv = liftM2 LLVM.fdiv instance (TypeNum.Positive n, Field a) => A.Field (T n a) where fdiv = fdiv type instance A.Scalar (T n a) = T n (MultiValue.Scalar a) class (MultiValue.PseudoModule v, PseudoRing (MultiValue.Scalar v), Additive v) => PseudoModule v where scale :: (TypeNum.Positive n) => T n (MultiValue.Scalar v) -> T n v -> LLVM.CodeGenFunction r (T n v) instance PseudoModule Float where scale = liftM2 A.mul instance PseudoModule Double where scale = liftM2 A.mul instance (TypeNum.Positive n, PseudoModule a) => A.PseudoModule (T n a) where scale = scale class (MultiValue.Real a, Additive a) => Real a where min :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) max :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) abs :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) signum :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) instance Real Float where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance Real Double where min = liftM2 A.min max = liftM2 A.max abs = liftM A.abs signum = liftM A.signum instance (TypeNum.Positive n, Real a) => A.Real (T n a) where min = min max = max abs = abs signum = signum class (MultiValue.Fraction a, Real a) => Fraction a where truncate :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) fraction :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) instance Fraction Float where truncate = liftM A.truncate fraction = liftM A.fraction instance Fraction Double where truncate = liftM A.truncate fraction = liftM A.fraction instance (TypeNum.Positive n, Fraction a) => A.Fraction (T n a) where truncate = truncate fraction = fraction class (MultiValue.Algebraic a, Field a) => Algebraic a where sqrt :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) instance Algebraic Float where sqrt = liftM A.sqrt instance Algebraic Double where sqrt = liftM A.sqrt instance (TypeNum.Positive n, Algebraic a) => A.Algebraic (T n a) where sqrt = sqrt class (MultiValue.Transcendental a, Algebraic a) => Transcendental a where pi :: (TypeNum.Positive n) => LLVM.CodeGenFunction r (T n a) sin, cos, exp, log :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) pow :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) instance Transcendental Float where pi = liftM0 A.pi sin = liftM A.sin cos = liftM A.cos exp = liftM A.exp log = liftM A.log pow = liftM2 A.pow instance Transcendental Double where pi = liftM0 A.pi sin = liftM A.sin cos = liftM A.cos exp = liftM A.exp log = liftM A.log pow = liftM2 A.pow instance (TypeNum.Positive n, Transcendental a) => A.Transcendental (T n a) where pi = pi sin = sin cos = cos exp = exp log = log pow = pow class (MultiValue.Comparison a, C a) => Comparison a where cmp :: (TypeNum.Positive n) => LLVM.CmpPredicate -> T n a -> T n a -> LLVM.CodeGenFunction r (T n Bool) instance Comparison Float where cmp = liftM2 . LLVM.cmp instance Comparison Double where cmp = liftM2 . LLVM.cmp instance (TypeNum.Positive n, Comparison a) => A.Comparison (T n a) where type CmpResult (T n a) = T n Bool cmp = cmp class (MultiValue.FloatingComparison a, Comparison a) => FloatingComparison a where fcmp :: (TypeNum.Positive n) => LLVM.FPPredicate -> T n a -> T n a -> LLVM.CodeGenFunction r (T n Bool) instance FloatingComparison Float where fcmp = liftM2 . LLVM.fcmp instance (TypeNum.Positive n, FloatingComparison a) => A.FloatingComparison (T n a) where fcmp = fcmp class (MultiValue.Logic a, C a) => Logic a where and :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) or :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) xor :: (TypeNum.Positive n) => T n a -> T n a -> LLVM.CodeGenFunction r (T n a) inv :: (TypeNum.Positive n) => T n a -> LLVM.CodeGenFunction r (T n a) instance Logic Bool where and = liftM2 LLVM.and or = liftM2 LLVM.or xor = liftM2 LLVM.xor inv = liftM LLVM.inv instance (TypeNum.Positive n, Logic a) => A.Logic (T n a) where and = and or = or xor = xor inv = inv