module LLVM.Extra.Multi.Value.Vector (
   cons,
   fst, snd,
   fst3, snd3, thd3,
   zip, zip3,
   unzip, unzip3,

   swap,
   mapFst, mapSnd,
   mapFst3, mapSnd3, mapThd3,

   extract, insert,
   replicate,
   dissect,
   select,
   cmp,
   take, takeRev,
   ) where

import qualified LLVM.Extra.Multi.Vector.Instance as Inst
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Multi.Value.Private as MultiValue

import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum

import qualified Data.Tuple.HT as TupleHT
import qualified Data.Tuple as Tuple
import Data.Word (Word32)

import Prelude (Bool, fmap, (.))


cons ::
   (TypeNum.Positive n, MultiVector.C a) =>
   LLVM.Vector n a -> MultiValue.T (LLVM.Vector n a)
cons = Inst.toMultiValue . MultiVector.cons

fst :: MultiValue.T (LLVM.Vector n (a,b)) -> MultiValue.T (LLVM.Vector n a)
fst = MultiValue.lift1 Tuple.fst

snd :: MultiValue.T (LLVM.Vector n (a,b)) -> MultiValue.T (LLVM.Vector n b)
snd = MultiValue.lift1 Tuple.snd

swap :: MultiValue.T (LLVM.Vector n (a,b)) -> MultiValue.T (LLVM.Vector n (b,a))
swap = MultiValue.lift1 TupleHT.swap

mapFst ::
   (MultiValue.T (LLVM.Vector n a0) -> MultiValue.T (LLVM.Vector n a1)) ->
   MultiValue.T (LLVM.Vector n (a0,b)) -> MultiValue.T (LLVM.Vector n (a1,b))
mapFst f = Tuple.uncurry zip . TupleHT.mapFst f . unzip

mapSnd ::
   (MultiValue.T (LLVM.Vector n b0) -> MultiValue.T (LLVM.Vector n b1)) ->
   MultiValue.T (LLVM.Vector n (a,b0)) -> MultiValue.T (LLVM.Vector n (a,b1))
mapSnd f = Tuple.uncurry zip . TupleHT.mapSnd f . unzip


fst3 :: MultiValue.T (LLVM.Vector n (a,b,c)) -> MultiValue.T (LLVM.Vector n a)
fst3 = MultiValue.lift1 TupleHT.fst3

snd3 :: MultiValue.T (LLVM.Vector n (a,b,c)) -> MultiValue.T (LLVM.Vector n b)
snd3 = MultiValue.lift1 TupleHT.snd3

thd3 :: MultiValue.T (LLVM.Vector n (a,b,c)) -> MultiValue.T (LLVM.Vector n c)
thd3 = MultiValue.lift1 TupleHT.thd3

mapFst3 ::
   (MultiValue.T (LLVM.Vector n a0) -> MultiValue.T (LLVM.Vector n a1)) ->
   MultiValue.T (LLVM.Vector n (a0,b,c)) ->
   MultiValue.T (LLVM.Vector n (a1,b,c))
mapFst3 f = TupleHT.uncurry3 zip3 . TupleHT.mapFst3 f . unzip3

mapSnd3 ::
   (MultiValue.T (LLVM.Vector n b0) -> MultiValue.T (LLVM.Vector n b1)) ->
   MultiValue.T (LLVM.Vector n (a,b0,c)) ->
   MultiValue.T (LLVM.Vector n (a,b1,c))
mapSnd3 f = TupleHT.uncurry3 zip3 . TupleHT.mapSnd3 f . unzip3

mapThd3 ::
   (MultiValue.T (LLVM.Vector n c0) -> MultiValue.T (LLVM.Vector n c1)) ->
   MultiValue.T (LLVM.Vector n (a,b,c0)) ->
   MultiValue.T (LLVM.Vector n (a,b,c1))
mapThd3 f = TupleHT.uncurry3 zip3 . TupleHT.mapThd3 f . unzip3


zip ::
   MultiValue.T (LLVM.Vector n a) ->
   MultiValue.T (LLVM.Vector n b) ->
   MultiValue.T (LLVM.Vector n (a,b))
zip (MultiValue.Cons a) (MultiValue.Cons b) = MultiValue.Cons (a,b)

zip3 ::
   MultiValue.T (LLVM.Vector n a) ->
   MultiValue.T (LLVM.Vector n b) ->
   MultiValue.T (LLVM.Vector n c) ->
   MultiValue.T (LLVM.Vector n (a,b,c))
zip3 (MultiValue.Cons a) (MultiValue.Cons b) (MultiValue.Cons c) =
   MultiValue.Cons (a,b,c)

unzip ::
   MultiValue.T (LLVM.Vector n (a,b)) ->
   (MultiValue.T (LLVM.Vector n a),
    MultiValue.T (LLVM.Vector n b))
unzip (MultiValue.Cons (a,b)) = (MultiValue.Cons a, MultiValue.Cons b)

unzip3 ::
   MultiValue.T (LLVM.Vector n (a,b,c)) ->
   (MultiValue.T (LLVM.Vector n a),
    MultiValue.T (LLVM.Vector n b),
    MultiValue.T (LLVM.Vector n c))
unzip3 (MultiValue.Cons (a,b,c)) =
   (MultiValue.Cons a, MultiValue.Cons b, MultiValue.Cons c)


extract ::
   (TypeNum.Positive n, MultiVector.C a) =>
   LLVM.Value Word32 -> MultiValue.T (LLVM.Vector n a) ->
   LLVM.CodeGenFunction r (MultiValue.T a)
extract k v = MultiVector.extract k (Inst.fromMultiValue v)

insert ::
   (TypeNum.Positive n, MultiVector.C a) =>
   LLVM.Value Word32 -> MultiValue.T a ->
   MultiValue.T (LLVM.Vector n a) ->
   LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector n a))
insert k a = Inst.liftMultiValueM (MultiVector.insert k a)


replicate ::
   (TypeNum.Positive n, MultiVector.C a) =>
   MultiValue.T a -> LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector n a))
replicate = fmap Inst.toMultiValue . MultiVector.replicate

take ::
   (TypeNum.Positive n, TypeNum.Positive m, MultiVector.C a) =>
   MultiValue.T (LLVM.Vector n a) ->
   LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector m a))
take = Inst.liftMultiValueM MultiVector.take

takeRev ::
   (TypeNum.Positive n, TypeNum.Positive m, MultiVector.C a) =>
   MultiValue.T (LLVM.Vector n a) ->
   LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector m a))
takeRev = Inst.liftMultiValueM MultiVector.takeRev


dissect ::
   (TypeNum.Positive n, MultiVector.C a) =>
   MultiValue.T (LLVM.Vector n a) -> LLVM.CodeGenFunction r [MultiValue.T a]
dissect = MultiVector.dissect . Inst.fromMultiValue

select ::
   (TypeNum.Positive n, MultiVector.Select a) =>
   MultiValue.T (LLVM.Vector n Bool) ->
   MultiValue.T (LLVM.Vector n a) -> MultiValue.T (LLVM.Vector n a) ->
   LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector n a))
select = Inst.liftMultiValueM3 MultiVector.select

cmp ::
   (TypeNum.Positive n, MultiVector.Comparison a) =>
   LLVM.CmpPredicate ->
   MultiValue.T (LLVM.Vector n a) -> MultiValue.T (LLVM.Vector n a) ->
   LLVM.CodeGenFunction r (MultiValue.T (LLVM.Vector n Bool))
cmp = Inst.liftMultiValueM2 . MultiVector.cmp