{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module LLVM.DSL.Expression.Vector where

import qualified LLVM.DSL.Expression as Expr
import LLVM.DSL.Expression (Exp)

import qualified LLVM.Extra.Multi.Value.Vector as MultiValueVec
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Vector.Instance as MultiVectorInst
import qualified LLVM.Extra.Multi.Vector as MultiVector
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Core as LLVM

import qualified Data.Tuple.HT as Tuple

import Prelude hiding (replicate, take, zip, fst, snd, min, max)


cons ::
   (LLVM.Positive n, MultiVector.C a) =>
   LLVM.Vector n a -> Exp (LLVM.Vector n a)
cons = Expr.lift0 . MultiValueVec.cons

fst ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n a)
fst = Expr.lift1 MultiValueVec.fst

snd ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n b)
snd = Expr.lift1 MultiValueVec.snd

swap ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n (a,b)) -> Exp (LLVM.Vector n (b,a))
swap = Expr.lift1 MultiValueVec.swap

mapFst ::
   (Exp (LLVM.Vector n a0) -> Exp (LLVM.Vector n a1)) ->
   Exp (LLVM.Vector n (a0,b)) -> Exp (LLVM.Vector n (a1,b))
mapFst f =
   Expr.liftReprM
      (\(a0,b) -> do
         MultiValue.Cons a1 <- Expr.unliftM1 f $ MultiValue.Cons a0
         return (a1,b))

mapSnd ::
   (Exp (LLVM.Vector n b0) -> Exp (LLVM.Vector n b1)) ->
   Exp (LLVM.Vector n (a,b0)) -> Exp (LLVM.Vector n (a,b1))
mapSnd f =
   Expr.liftReprM
      (\(a,b0) -> do
         MultiValue.Cons b1 <- Expr.unliftM1 f $ MultiValue.Cons b0
         return (a,b1))


fst3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n a)
fst3 = Expr.lift1 MultiValueVec.fst3

snd3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n b)
snd3 = Expr.lift1 MultiValueVec.snd3

thd3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n (a,b,c)) -> Exp (LLVM.Vector n c)
thd3 = Expr.lift1 MultiValueVec.thd3


zip ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n b) ->
   Exp (LLVM.Vector n (a,b))
zip = Expr.lift2 MultiValueVec.zip

zip3 ::
   (LLVM.Positive n, MultiVector.C a, MultiVector.C b, MultiVector.C c) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n b) -> Exp (LLVM.Vector n c) ->
   Exp (LLVM.Vector n (a,b,c))
zip3 = Expr.lift3 MultiValueVec.zip3


replicate ::
   (LLVM.Positive n, MultiVector.C a) =>
   Exp a -> Exp (LLVM.Vector n a)
replicate = Expr.liftM MultiValueVec.replicate

iterate ::
   (LLVM.Positive n, MultiVector.C a) =>
   (Exp a -> Exp a) -> Exp a -> Exp (LLVM.Vector n a)
iterate f = Expr.liftM (MultiValueVec.iterate (Expr.unliftM1 f))

take ::
   (LLVM.Positive n, LLVM.Positive m, MultiVector.Select a) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector m a)
take = Expr.liftM MultiValueVec.take

takeRev ::
   (LLVM.Positive n, LLVM.Positive m, MultiVector.Select a) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector m a)
takeRev = Expr.liftM MultiValueVec.takeRev


cumulate ::
   (LLVM.Positive n, MultiVector.Additive a) =>
   Exp a -> Exp (LLVM.Vector n a) -> (Exp a, Exp (LLVM.Vector n a))
cumulate a0 v0 =
   Expr.unzip $
   Expr.liftM2
      (\a v ->
         fmap (uncurry MultiValue.zip .
               Tuple.mapSnd MultiVectorInst.toMultiValue) $
         MultiVector.cumulate a $ MultiVectorInst.fromMultiValue v)
      a0 v0


cmp ::
   (LLVM.Positive n, MultiVector.Comparison a) =>
   LLVM.CmpPredicate ->
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n Bool)
cmp ord = Expr.liftM2 (MultiValueVec.cmp ord)

select ::
   (LLVM.Positive n, MultiVector.Select a) =>
   Exp (LLVM.Vector n Bool) ->
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a)
select = Expr.liftM3 MultiValueVec.select


min, max ::
   (LLVM.Positive n, MultiVector.Real a) =>
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a)
min = Expr.liftM2 A.min
max = Expr.liftM2 A.max

limit ::
   (LLVM.Positive n, MultiVector.Real a) =>
   (Exp (LLVM.Vector n a), Exp (LLVM.Vector n a)) ->
   Exp (LLVM.Vector n a) -> Exp (LLVM.Vector n a)
limit (l,u) = max l . min u


fromIntegral ::
   (MultiValueVec.NativeInteger i ir, MultiValueVec.NativeFloating a ar,
    LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) =>
   Exp i -> Exp a
fromIntegral = Expr.liftM MultiValueVec.fromIntegral

truncateToInt ::
   (MultiValueVec.NativeInteger i ir, MultiValueVec.NativeFloating a ar,
    LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) =>
   Exp a -> Exp i
truncateToInt = Expr.liftM MultiValueVec.truncateToInt

splitFractionToInt ::
   (MultiValueVec.NativeInteger i ir, MultiValueVec.NativeFloating a ar,
    LLVM.ShapeOf ir ~ LLVM.ShapeOf ar) =>
   Exp a -> (Exp i, Exp a)
splitFractionToInt = Expr.unzip . Expr.liftM MultiValueVec.splitFractionToInt