module Data.Array.Knead.Expression.Vector where

import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp)

import qualified LLVM.Extra.Multi.Value.Vector as MultiValueVec
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Core as LLVM

import Prelude hiding (replicate, zip, fst, snd)


cons ::
   (LLVM.Positive n, MultiVector.C a) =>
   LLVM.Vector n a -> Exp (LLVM.Vector n a)
cons = Expr.lift0 . MultiValueVec.cons

fst ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n a)
fst = Expr.lift1 MultiValueVec.fst

snd ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n b)
snd = Expr.lift1 MultiValueVec.snd

swap ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n (b,a))
swap = Expr.lift1 MultiValueVec.swap

mapFst ::
   (Exp (LLVM.Vector n a0) -> Exp (LLVM.Vector n a1)) ->
   Exp (LLVM.Vector n (a0,b)) -> Exp (LLVM.Vector n (a1,b))
mapFst f =
   Expr.liftM
      (MultiValue.liftM
         (\(a0,b) -> do
            MultiValue.Cons a1 <- Expr.unliftM1 f $ MultiValue.Cons a0
            return (a1,b)))

mapSnd ::
   (Exp (LLVM.Vector n b0) -> Exp (LLVM.Vector n b1)) ->
   Exp (LLVM.Vector n (a,b0)) -> Exp (LLVM.Vector n (a,b1))
mapSnd f =
   Expr.liftM
      (MultiValue.liftM
         (\(a,b0) -> do
            MultiValue.Cons b1 <- Expr.unliftM1 f $ MultiValue.Cons b0
            return (a,b1)))


fst3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n a)
fst3 = Expr.lift1 MultiValueVec.fst3

snd3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n b)
snd3 = Expr.lift1 MultiValueVec.snd3

thd3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n c)
thd3 = Expr.lift1 MultiValueVec.thd3


zip ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n b) ->
   Exp (LLVM.Vector n (a,b))
zip = Expr.lift2 MultiValueVec.zip

zip3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n b) -> Exp (LLVM.Vector n c) ->
   Exp (LLVM.Vector n (a,b,c))
zip3 = Expr.lift3 MultiValueVec.zip3


replicate ::
   (LLVM.Positive n, MultiVector.C a) =>
   Exp a -> Exp (LLVM.Vector n a)
replicate = Expr.liftM MultiValueVec.replicate

take ::
   (LLVM.Positive n, LLVM.Positive m, MultiVector.Select a) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector m a)
take = Expr.liftM MultiValueVec.take

takeRev ::
   (LLVM.Positive n, LLVM.Positive m, MultiVector.Select a) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector m a)
takeRev = Expr.liftM MultiValueVec.takeRev


cmp ::
   (LLVM.Positive n, MultiVector.Comparison a) =>
   LLVM.CmpPredicate ->
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n Bool)
cmp ord = Expr.liftM2 (MultiValueVec.cmp ord)

select ::
   (LLVM.Positive n, MultiVector.Select a) =>
   Exp (LLVM.Vector n Bool) ->
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a)
select = Expr.liftM3 MultiValueVec.select