{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module LLVM.DSL.Expression.Vector where import qualified LLVM.DSL.Expression as Expr import LLVM.DSL.Expression (Exp) import qualified LLVM.Extra.Multi.Value.Vector as MultiValueVec import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Multi.Vector.Instance as MultiVectorInst import qualified LLVM.Extra.Multi.Vector as MultiVector import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Core as LLVM import qualified Data.Tuple.HT as Tuple import Prelude hiding (replicate, take, zip, fst, snd, min, max) cons :: (LLVM.Positive n, MultiVector.C a) => LLVM.Vector n a -> Exp (LLVM.Vector n a) cons = Expr.lift0 . MultiValueVec.cons fst :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b) => Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n a) fst = Expr.lift1 MultiValueVec.fst snd :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b) => Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n b) snd = Expr.lift1 MultiValueVec.snd swap :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b) => Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n (b,a)) swap = Expr.lift1 MultiValueVec.swap mapFst :: (Exp (LLVM.Vector n a0) -> Exp (LLVM.Vector n a1)) -> Exp (LLVM.Vector n (a0,b)) -> Exp (LLVM.Vector n (a1,b)) mapFst f = Expr.liftReprM (\(a0,b) -> do MultiValue.Cons a1 <- Expr.unliftM1 f $ MultiValue.Cons a0 return (a1,b)) mapSnd :: (Exp (LLVM.Vector n b0) -> Exp (LLVM.Vector n b1)) -> Exp (LLVM.Vector n (a,b0)) -> Exp (LLVM.Vector n (a,b1)) mapSnd f = Expr.liftReprM (\(a,b0) -> do MultiValue.Cons b1 <- Expr.unliftM1 f $ MultiValue.Cons b0 return (a,b1)) fst3 :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) => Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n a) fst3 = Expr.lift1 MultiValueVec.fst3 snd3 :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) => Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n b) snd3 = Expr.lift1 MultiValueVec.snd3 thd3 :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) => Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n c) thd3 = Expr.lift1 MultiValueVec.thd3 zip :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b) => Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n b) -> Exp (LLVM.Vector n (a,b)) zip = Expr.lift2 MultiValueVec.zip zip3 :: (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) => Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n b) -> Exp (LLVM.Vector n c) -> Exp (LLVM.Vector n (a,b,c)) zip3 = Expr.lift3 MultiValueVec.zip3 replicate :: (LLVM.Positive n, MultiVector.C a) => Exp a -> Exp (LLVM.Vector n a) replicate = Expr.liftM MultiValueVec.replicate iterate :: (LLVM.Positive n, MultiVector.C a) => (Exp a -> Exp a) -> Exp a -> Exp (LLVM.Vector n a) iterate f = Expr.liftM (MultiValueVec.iterate (Expr.unliftM1 f)) take :: (LLVM.Positive n, LLVM.Positive m, MultiVector.Select a) => Exp (LLVM.Vector n a) -> Exp (LLVM.Vector m a) take = Expr.liftM MultiValueVec.take takeRev :: (LLVM.Positive n, LLVM.Positive m, MultiVector.Select a) => Exp (LLVM.Vector n a) -> Exp (LLVM.Vector m a) takeRev = Expr.liftM MultiValueVec.takeRev cumulate :: (LLVM.Positive n, MultiVector.Additive a) => Exp a -> Exp (LLVM.Vector n a) -> (Exp a, Exp (LLVM.Vector n a)) cumulate a0 v0 = Expr.unzip $ Expr.liftM2 (\a v -> fmap (uncurry MultiValue.zip . Tuple.mapSnd MultiVectorInst.toMultiValue) $ MultiVector.cumulate a $ MultiVectorInst.fromMultiValue v) a0 v0 cmp :: (LLVM.Positive n, MultiVector.Comparison a) => LLVM.CmpPredicate -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n Bool) cmp ord = Expr.liftM2 (MultiValueVec.cmp ord) select :: (LLVM.Positive n, MultiVector.Select a) => Exp (LLVM.Vector n Bool) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) select = Expr.liftM3 MultiValueVec.select min, max :: (LLVM.Positive n, MultiVector.Real a) => Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) min = Expr.liftM2 A.min max = Expr.liftM2 A.max limit :: (LLVM.Positive n, MultiVector.Real a) => (Exp (LLVM.Vector n a), Exp (LLVM.Vector n a)) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) limit (l,u) = max l . min u fromIntegral :: (MultiValueVec.NativeInteger i ir, MultiValueVec.NativeFloating a ar, LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) => Exp i -> Exp a fromIntegral = Expr.liftM MultiValueVec.fromIntegral truncateToInt :: (MultiValueVec.NativeInteger i ir, MultiValueVec.NativeFloating a ar, LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) => Exp a -> Exp i truncateToInt = Expr.liftM MultiValueVec.truncateToInt splitFractionToInt :: (MultiValueVec.NativeInteger i ir, MultiValueVec.NativeFloating a ar, LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) => Exp a -> (Exp i, Exp a) splitFractionToInt = Expr.unzip . Expr.liftM MultiValueVec.splitFractionToInt