{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
module LLVM.Extra.Vector (
   size, sizeInTuple,
   replicate, iterate, assemble,

   shuffle,
   rotateUp, rotateDown, reverse,
   shiftUp, shiftDown,
   shiftUpMultiZero, shiftDownMultiZero,
   ShuffleMatch (shuffleMatch),
   shuffleMatchTraversable,

   Access (insert, extract),
   insertTraversable,
   extractTraversable,

   insertChunk, modify,
   map, mapChunks, zipChunksWith,
   chop, concat, select,
   signedFraction,
   cumulate1, umul32to64,
   Arithmetic
      (sum, sumToPair, sumInterleavedToPair,
       cumulate, dotProduct, mul),
   Real
      (min, max, abs,
       truncate, floor, fraction),
   ) where

import qualified LLVM.Extra.Extension.X86 as X86
import qualified LLVM.Extra.Extension as Ext

import qualified LLVM.Extra.Monad as M
import qualified LLVM.Extra.Arithmetic as A

import qualified LLVM.Core as LLVM
import LLVM.Util.Loop (Phi, )
import LLVM.Core
   (Value, ConstValue, valueOf, value, constOf, undef,
    Vector, shufflevector, insertelement, extractelement, constVector,
    IsConst, IsArithmetic, IsFloating,
    IsPrimitive, IsPowerOf2,
    CodeGenFunction, )

import Data.TypeLevel.Num (D2, )
import qualified Data.TypeLevel.Num as TypeNum
import Control.Monad.HT ((<=<), )
import Control.Monad (liftM2, liftM3, foldM, )
import Data.Tuple.HT (uncurry3, )
import qualified Data.List.HT as ListHT
import qualified Data.List as List

import Control.Applicative (liftA2, )
import qualified Control.Applicative as App
import qualified Data.Traversable as Trav

-- import qualified Data.Bits as Bit
import Data.Int  (Int8, Int16, Int32, Int64, )
import Data.Word (Word8, Word16, Word32, Word64, )

import Prelude hiding
          (Real, truncate, floor, round,
           map, zipWith, iterate, replicate, reverse, concat, sum, )


-- * target independent functions

size ::
   (TypeNum.Nat n) =>
   Value (Vector n a) -> Int
size =
   let sz :: (TypeNum.Nat n) => n -> Value (Vector n a) -> Int
       sz n _ = TypeNum.toInt n
   in  sz undefined

{- |
Manually assemble a vector of equal values.
Better use ScalarOrVector.replicate.
-}
replicate ::
   (Access n a va) =>
   a -> CodeGenFunction r va
replicate = replicateCore undefined

replicateCore ::
   (Access n a va) =>
   n -> a -> CodeGenFunction r va
replicateCore n =
   assemble . List.replicate (TypeNum.toInt n)

{- |
construct a vector out of single elements

You must assert that the length of the list matches the vector size.
-}
assemble ::
   (Access n a va) =>
   [a] -> CodeGenFunction r va
assemble =
   foldM (\v (k,x) -> insert (valueOf k) x v) LLVM.undefTuple .
   List.zip [0..]
{- sends GHC into an infinite loop
   foldM (\(k,x) -> insert (valueOf k) x) LLVM.undefTuple .
   List.zip [0..]
-}

insertChunk ::
   (Access m a ca, Access n a va) =>
   Int -> ca ->
   va -> CodeGenFunction r va
insertChunk k x =
   M.chain $
   List.zipWith
      (\i j -> \v ->
          extract (valueOf i) x >>= \e ->
          insert (valueOf j) e v)
      (take (sizeInTuple x) [0..])
      [fromIntegral k ..]

iterate ::
   (Access n a va) =>
   (a -> CodeGenFunction r a) ->
   a -> CodeGenFunction r va
iterate f x =
   fmap snd $
   iterateCore f x LLVM.undefTuple

iterateCore ::
   (Access n a va) =>
   (a -> CodeGenFunction r a) ->
   a -> va ->
   CodeGenFunction r (a, va)
iterateCore f x0 v0 =
   foldM
      (\(x,v) k ->
         liftM2 (,) (f x)
            (insert (valueOf k) x v))
      (x0,v0)
      (take (sizeInTuple v0) [0..])

{- |
Manually implement vector shuffling using insertelement and extractelement.
In contrast to LLVM's built-in instruction it supports distinct vector sizes,
but it allows only one input vector
(or a tuple of vectors, but we cannot shuffle between them).
-}
shuffle ::
   (Access m a ca, Access n a va) =>
   va ->
   ConstValue (Vector m Word32) ->
   CodeGenFunction r ca
shuffle x i =
   assemble =<<
   mapM
      (flip extract x <=< extractelement (value i) . valueOf)
      (take (size (value i)) [0..])


sizeInTuple :: ShuffleMatch n v => v -> Int
sizeInTuple =
   let sz :: (ShuffleMatch n v) => n -> v -> Int
       sz n _ = TypeNum.toInt n
   in  sz undefined

{- |
Rotate one element towards the higher elements.

I don't want to call it rotateLeft or rotateRight,
because there is no prefered layout for the vector elements.
In Intel's instruction manual vector
elements are indexed like the bits,
that is from right to left.
However, when working with Haskell list and enumeration syntax,
the start index is left.
-}
rotateUp ::
   (ShuffleMatch n v) =>
   v -> CodeGenFunction r v
rotateUp x =
   shuffleMatch
      (constVector $ List.map constOf $
       (fromIntegral (sizeInTuple x) - 1) : [0..]) x

rotateDown ::
   (ShuffleMatch n v) =>
   v -> CodeGenFunction r v
rotateDown x =
   shuffleMatch
      (constVector $ List.map constOf $
       List.take (sizeInTuple x - 1) [1..] ++ [0]) x

reverse ::
   (ShuffleMatch n v) =>
   v -> CodeGenFunction r v
reverse x =
   shuffleMatch
      (constVector $ List.map constOf $
       List.reverse $
       List.take (sizeInTuple x) [0..]) x

shiftUp ::
   (Access n a v) =>
   a -> v -> CodeGenFunction r (a, v)
shiftUp x0 x = do
   y <-
      shuffleMatch
         (constVector $ undef : List.map constOf [0..]) x
   liftM2 (,)
      (extract (LLVM.valueOf (fromIntegral (sizeInTuple x) - 1)) x)
      (insert (value LLVM.zero) x0 y)

shiftDown ::
   (Access n a v) =>
   a -> v -> CodeGenFunction r (a, v)
shiftDown x0 x = do
   y <-
      shuffleMatch
         (constVector $
          List.map constOf (List.take (sizeInTuple x - 1) [1..]) ++ [undef]) x
   liftM2 (,)
      (extract (value LLVM.zero) x)
      (insert (LLVM.valueOf (fromIntegral (sizeInTuple x) - 1)) x0 y)

shiftUpMultiZero ::
   (IsPrimitive a, IsPowerOf2 n) =>
   Int ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
shiftUpMultiZero k x =
   LLVM.shufflevector (LLVM.value LLVM.zero) x
      (constVector $ List.map constOf $
       take k [0..] ++ [(fromIntegral (sizeInTuple x)) ..])

shiftDownMultiZero ::
   (IsPrimitive a, IsPowerOf2 n) =>
   Int ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
shiftDownMultiZero k x =
   LLVM.shufflevector x (LLVM.value LLVM.zero)
      (constVector $ List.map constOf $
       [(fromIntegral k) ..])


class
   (LLVM.IsPowerOf2 n, Phi v) =>
      ShuffleMatch n v | v -> n where
   shuffleMatch ::
      ConstValue (Vector n Word32) -> v -> CodeGenFunction r v

shuffleMatchTraversable ::
   (ShuffleMatch n v, Trav.Traversable f) =>
   ConstValue (Vector n Word32) -> f v -> CodeGenFunction r (f v)
shuffleMatchTraversable is v =
   Trav.mapM (shuffleMatch is) v


{- |
Allow to work on records of vectors as if they are vectors of records.
This is a reasonable approach for records of different element types
since processor vectors can only be built from elements of the same type.
But also say for chunked stereo signal this makes sense.
In this case we would work on @Stereo (Value a)@.
-}
class
   (ShuffleMatch n v) =>
      Access n a v | v -> a n, a n -> v where
   insert :: Value Word32 -> a -> v -> CodeGenFunction r v
   extract :: Value Word32 -> v -> CodeGenFunction r a

insertTraversable ::
   (Access n a v, Trav.Traversable f, App.Applicative f) =>
   Value Word32 -> f a -> f v -> CodeGenFunction r (f v)
insertTraversable n a v =
   Trav.sequence (liftA2 (insert n) a v)

extractTraversable ::
   (Access n a v, Trav.Traversable f) =>
   Value Word32 -> f v -> CodeGenFunction r (f a)
extractTraversable n v =
   Trav.mapM (extract n) v


instance
   (LLVM.IsPowerOf2 n, LLVM.IsPrimitive a) =>
      ShuffleMatch n (Value (Vector n a)) where
   shuffleMatch is v = shufflevector v (value undef) is

instance
   (LLVM.IsPowerOf2 n, LLVM.IsPrimitive a) =>
      Access n (Value a) (Value (Vector n a)) where
   insert  k a v = insertelement v a k
   extract k v   = extractelement v k


instance
   (ShuffleMatch n v0, ShuffleMatch n v1) =>
      ShuffleMatch n (v0, v1) where
   shuffleMatch is (v0,v1) =
      liftM2 (,)
         (shuffleMatch is v0)
         (shuffleMatch is v1)

instance
   (Access n a0 v0, Access n a1 v1) =>
      Access n (a0, a1) (v0, v1) where
   insert k (a0,a1) (v0,v1) =
      liftM2 (,)
         (insert k a0 v0)
         (insert k a1 v1)
   extract k (v0,v1) =
      liftM2 (,)
         (extract k v0)
         (extract k v1)


instance
   (ShuffleMatch n v0, ShuffleMatch n v1, ShuffleMatch n v2) =>
      ShuffleMatch n (v0, v1, v2) where
   shuffleMatch is (v0,v1,v2) =
      liftM3 (,,)
         (shuffleMatch is v0)
         (shuffleMatch is v1)
         (shuffleMatch is v2)

instance
   (Access n a0 v0, Access n a1 v1, Access n a2 v2) =>
      Access n (a0, a1, a2) (v0, v1, v2) where
   insert k (a0,a1,a2) (v0,v1,v2) =
      liftM3 (,,)
         (insert k a0 v0)
         (insert k a1 v1)
         (insert k a2 v2)
   extract k (v0,v1,v2) =
      liftM3 (,,)
         (extract k v0)
         (extract k v1)
         (extract k v2)


modify ::
   (Access n a va) =>
   Value Word32 ->
   (a -> CodeGenFunction r a) ->
   (va -> CodeGenFunction r va)
modify k f v =
   flip (insert k) v =<< f =<< extract k v

{- |
Like LLVM.Util.Loop.mapVector but the loop is unrolled,
which is faster since it can be packed by the code generator.
-}
map ::
   (Access n a va, Access n b vb) =>
   (a -> CodeGenFunction r b) ->
   (va -> CodeGenFunction r vb)
map f a =
   foldM
      (\b n ->
         extract (valueOf n) a >>=
         f >>=
         flip (insert (valueOf n)) b)
      LLVM.undefTuple
      (take (sizeInTuple a) [0..])

mapChunks ::
   (Access m a ca, Access m b cb,
    Access n a va, Access n b vb) =>
   (ca -> CodeGenFunction r cb) ->
   (va -> CodeGenFunction r vb)
mapChunks f a =
   foldM
      (\b (am,k) ->
         am >>= \ac ->
         f ac >>= \bc ->
         insertChunk (k * sizeInTuple ac) bc b)
      LLVM.undefTuple $
   List.zip (chop a) [0..]

zipChunksWith ::
   (Access m a ca, Access m b cb, Access m c cc,
    Access n a va, Access n b vb, Access n c vc) =>
   (ca -> cb -> CodeGenFunction r cc) ->
   (va -> vb -> CodeGenFunction r vc)
zipChunksWith f a b =
   mapChunks (uncurry f) (a,b)


mapAuto ::
   (Access m a ca, Access m b cb,
    Access n a va, Access n b vb) =>
   (a -> CodeGenFunction r b) ->
   Ext.T (ca -> CodeGenFunction r cb) ->
   (va -> CodeGenFunction r vb)
mapAuto f g a =
   Ext.run (map f a) $
   Ext.with g $ \op -> mapChunks op a

zipAutoWith ::
   (Access m a ca, Access m b cb, Access m c cc,
    Access n a va, Access n b vb, Access n c vc) =>
   (a -> b -> CodeGenFunction r c) ->
   Ext.T (ca -> cb -> CodeGenFunction r cc) ->
   (va -> vb -> CodeGenFunction r vc)
zipAutoWith f g a b =
   mapAuto (uncurry f) (fmap uncurry g) (a,b)


{- |
Ideally on ix86 with SSE41 this would be translated to 'dpps'.
-}
dotProductPartial ::
   (LLVM.IsPowerOf2 n, LLVM.IsPrimitive a, LLVM.IsArithmetic a) =>
   Int ->
   Value (Vector n a) ->
   Value (Vector n a) ->
   CodeGenFunction r (Value a)
dotProductPartial n x y =
   sumPartial n =<< A.mul x y

sumPartial ::
   (LLVM.IsPowerOf2 n, LLVM.IsPrimitive a, LLVM.IsArithmetic a) =>
   Int ->
   Value (Vector n a) ->
   CodeGenFunction r (Value a)
sumPartial n x =
   foldl1
      {- quite the same as (+) using LLVM.Arithmetic instances,
         but requires less type constraints -}
      (M.liftR2 A.add)
      (List.map (LLVM.extractelement x . valueOf) $ take n $ [0..])


{- |
If the target vector type is a native type
then the chop operation produces no actual machine instruction. (nop)
If the vector cannot be evenly divided into chunks
the last chunk will be padded with undefined values.
-}
chop ::
   (Access m a ca, Access n a va) =>
   va -> [CodeGenFunction r ca]
chop = chopCore undefined

chopCore ::
   (Access m a ca, Access n a va) =>
   m -> va -> [CodeGenFunction r ca]
chopCore m x =
   List.map (shuffle x . constVector) $
   ListHT.sliceVertical (TypeNum.toInt m) $
   List.map constOf $
   take (sizeInTuple x) [0..]

{- |
The target size is determined by the type.
If the chunk list provides more data, the exceeding data is dropped.
If the chunk list provides too few data,
the target vector is filled with undefined elements.
-}
concat ::
   (Access m a ca, Access n a va) =>
   [ca] -> CodeGenFunction r va
concat xs =
   foldM
      (\v0 (js,c) ->
         foldM
            (\v (i,j) -> do
               x <- extract (valueOf i) c
               insert (valueOf j) x v)
            v0 $
         List.zip [0..] js)
      LLVM.undefTuple $
   List.zip
      (ListHT.sliceVertical (sizeInTuple (head xs)) [0..])
      xs


getLowestPair ::
   Value (Vector n a) ->
   CodeGenFunction r (Value a, Value a)
getLowestPair x =
   liftM2 (,)
      (extractelement x (valueOf 0))
      (extractelement x (valueOf 1))


_reduceAddInterleaved ::
   (IsArithmetic a, IsPrimitive a,
    IsPowerOf2 n, IsPowerOf2 m, TypeNum.Mul D2 m n) =>
   m ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector m a))
_reduceAddInterleaved tm v = do
   let m = TypeNum.toInt tm
   x <- shuffle v (constVector $ List.map constOf $ take m [0..])
   y <- shuffle v (constVector $ List.map constOf $ take m [fromIntegral m ..])
   A.add x y

sumGeneric ::
   (IsArithmetic a, IsPrimitive a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value a)
sumGeneric =
   flip extractelement (valueOf 0) <=<
   reduceSumInterleaved 1

sumToPairGeneric ::
   (Arithmetic a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value a, Value a)
sumToPairGeneric v =
   let n2 = div (size v) 2
   in  sumInterleavedToPair =<<
       shufflevector v (value undef)
          (constVector $
           List.map (constOf . fromIntegral) $
           concatMap (\k -> [k, k+n2]) $
           take n2 [0..])

{- |
We partition a vector of size n into chunks of size m
and add these chunks using vector additions.
We do this by repeated halving of the vector,
since this way we do not need assumptions about the native vector size.

We reduce the vector size only virtually,
that is we maintain the vector size and fill with undefined values.
This is reasonable
since LLVM-2.5 and LLVM-2.6 does not allow shuffling between vectors of different size
and because it likes to do computations on Vector D2 Float
in MMX registers on ix86 CPU's,
which interacts badly with FPU usage.
Since we fill the vector with undefined values,
LLVM actually treats the vectors like vectors of smaller size.
-}
reduceSumInterleaved ::
   (IsArithmetic a, IsPrimitive a, IsPowerOf2 n) =>
   Int ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
reduceSumInterleaved m x0 =
   let go ::
          (IsArithmetic a, IsPrimitive a, IsPowerOf2 n) =>
          Int ->
          Value (Vector n a) ->
          CodeGenFunction r (Value (Vector n a))
       go n x =
          if m==n
            then return x
            else
               let n2 = div n 2
               in  go n2
                      =<< A.add x
                      =<< shufflevector x (value undef)
                             (constVector $ List.map constOf (take n2 [fromIntegral n2 ..])
                                 ++ List.repeat undef)
   in  go (size x0) x0

cumulateGeneric, _cumulateSimple ::
   (IsArithmetic a, IsPrimitive a, IsPowerOf2 n) =>
   Value a -> Value (Vector n a) ->
   CodeGenFunction r (Value a, Value (Vector n a))
_cumulateSimple a x =
   foldM
      (\(a0,y0) k -> do
         a1 <- A.add a0 =<< extract (valueOf k) x
         y1 <- insert (valueOf k) a0 y0
         return (a1,y1))
      (a, LLVM.undefTuple)
      (take (sizeInTuple x) $ [0..])

cumulateGeneric =
   cumulateFrom1 cumulate1

cumulateFrom1 ::
   (IsArithmetic a, IsPrimitive a, IsPowerOf2 n) =>
   (Value (Vector n a) ->
    CodeGenFunction r (Value (Vector n a))) ->
   Value a -> Value (Vector n a) ->
   CodeGenFunction r (Value a, Value (Vector n a))
cumulateFrom1 cum a x0 = do
   (b,x1) <- shiftUp a x0
   y <- cum x1
   z <- A.add b =<< extract (valueOf (fromIntegral (sizeInTuple x0) - 1)) y
   return (z,y)


{- |
Needs (log n) vector additions
-}
cumulate1 ::
   (IsArithmetic a, IsPrimitive a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
cumulate1 x =
   foldM
      (\y k -> A.add y =<< shiftUpMultiZero k y)
      x
      (takeWhile (<sizeInTuple x) $ List.iterate (2*) 1)


signedFraction ::
   (IsFloating a, IsConst a, Real a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
signedFraction x =
   A.sub x =<< truncate x

floorGeneric ::
   (IsFloating a, IsConst a, Real a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
floorGeneric = floorLogical A.fcmp

{- |
On LLVM-2.6 and X86 this produces branch-free
but even slower code than 'fractionSelect',
since the comparison to booleans and
back to a floating point number is translated literally
to elementwise comparison, conversion to a 0 or -1 byte
and then to a floating point number.
-}
fractionGeneric ::
   (IsFloating a, IsConst a, Real a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
fractionGeneric = fractionLogical A.fcmp


{- |
LLVM.select on boolean vectors cannot be translated to X86 code in LLVM-2.6,
thus I code my own version that calls select on all elements.
This is slow but works.
When this issue is fixed, this function will be replaced by LLVM.select.
-}
select ::
   (LLVM.IsFirstClass a, IsPrimitive a, IsPowerOf2 n,
    LLVM.CmpRet a Bool) =>
   Value (Vector n Bool) ->
   Value (Vector n a) ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
select b x y =
   map (uncurry3 LLVM.select) (b, x, y)

{- |
'floor' implemented using 'select'.
This will need jumps.
-}
_floorSelect ::
   (Num a, IsFloating a, IsConst a, Real a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
_floorSelect x =
   do xr <- truncate x
      b <- A.fcmp LLVM.FPOLE xr x
      select b xr =<< A.sub xr =<< replicate (valueOf 1)

{- |
'fraction' implemented using 'select'.
This will need jumps.
-}
_fractionSelect ::
   (Num a, IsFloating a, IsConst a, Real a, IsPowerOf2 n) =>
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
_fractionSelect x =
   do xf <- signedFraction x
      b <- A.fcmp LLVM.FPOGE xf (value LLVM.zero)
      select b xf =<< A.add xf =<< replicate (valueOf 1)


{- |
Another implementation of 'select',
this time in terms of binary logical operations.
The selecting integers must be
(-1) for selecting an element from the first operand
and 0 for selecting an element from the second operand.
This leads to optimal code.

On SSE41 this could be done with blendvps or blendvpd.
-}
selectLogical ::
   (LLVM.IsFirstClass a, IsPrimitive a,
    LLVM.IsInteger i, IsPrimitive i,
--    LLVM.IsSized a sa, LLVM.IsSized i si, sa :==: si, si :==: sa,
--    LLVM.IsSized a s, LLVM.IsSized i s,
    LLVM.IsSized (Vector n a) s, LLVM.IsSized (Vector n i) s,
    IsPowerOf2 n) =>
   Value (Vector n i) ->
   Value (Vector n a) ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
selectLogical b x y = do
--   bneg <- A.xor b
   bneg <- LLVM.inv b
   xm <- A.and b    =<< LLVM.bitcastUnify x
   ym <- A.and bneg =<< LLVM.bitcastUnify y
   LLVM.bitcastUnify =<< A.or xm ym


floorLogical ::
   (IsFloating a, IsConst a, Real a,
    IsPrimitive i, LLVM.IsInteger i, IsPowerOf2 n) =>
   (LLVM.FPPredicate ->
    Value (Vector n a) ->
    Value (Vector n a) ->
    CodeGenFunction r (Value (Vector n i))) ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
floorLogical cmp x =
   do xr <- truncate x
      b <- cmp LLVM.FPOGT xr x
      A.add xr =<< LLVM.sitofp b

fractionLogical ::
   (IsFloating a, IsConst a, Real a,
    IsPrimitive i, LLVM.IsInteger i, IsPowerOf2 n) =>
   (LLVM.FPPredicate ->
    Value (Vector n a) ->
    Value (Vector n a) ->
    CodeGenFunction r (Value (Vector n i))) ->
   Value (Vector n a) ->
   CodeGenFunction r (Value (Vector n a))
fractionLogical cmp x =
   do xf <- signedFraction x
      b <- cmp LLVM.FPOLT xf (value LLVM.zero)
      A.sub xf =<< LLVM.sitofp b


orderBy ::
   (IsPowerOf2 m,
    LLVM.IsFirstClass a, IsPrimitive a,
    LLVM.IsInteger i, IsPrimitive i,
    LLVM.IsSized (Vector m a) s, LLVM.IsSized (Vector m i) s) =>
   Ext.T (Value (Vector m a) -> Value (Vector m a) -> CodeGenFunction r (Value (Vector m i))) ->
   Ext.T (Value (Vector m a) -> Value (Vector m a) -> CodeGenFunction r (Value (Vector m a)))
orderBy cmp =
   Ext.with cmp $ \pcmpgt x y ->
      pcmpgt x y >>= \b -> selectLogical b y x

order ::
   (IsPowerOf2 n, IsPowerOf2 m,
    LLVM.IsFirstClass a, IsPrimitive a,
    LLVM.IsInteger i, IsPrimitive i,
    LLVM.IsSized (Vector m a) s, LLVM.IsSized (Vector m i) s) =>
   (Value a -> Value a -> CodeGenFunction r (Value a)) ->
   Ext.T (Value (Vector m a) -> Value (Vector m a) -> CodeGenFunction r (Value (Vector m i))) ->
   Ext.T (Value (Vector m a) -> Value (Vector m a) -> CodeGenFunction r (Value (Vector m a))) ->
   (Value (Vector n a) -> Value (Vector n a) -> CodeGenFunction r (Value (Vector n a)))
order byScalar byCmp byChunk x y =
   map (uncurry byScalar) (x,y)
   `Ext.run`
   (Ext.with byCmp $ \pcmpgt ->
      mapChunks (\(cx,cy) ->
         pcmpgt cx cy >>= \b -> selectLogical b cy cx) (x,y))
{-
This is not nice, because selectLogical uses bitcast
and bitcast requires ugly type constraints for equal vector sizes.
Thus we restrict selectLogical to chunks and thus monomorphic types.
   (Ext.with byCmp $ \pcmpgt -> do
       b <- mapChunks (uncurry pcmpgt) (x,y)
       selectLogical b y x)
-}
   `Ext.run`
   (Ext.with byChunk $ \psel ->
       zipChunksWith psel x y)


-- * target independent functions with target dependent optimizations

{- |
The order of addition is chosen for maximum efficiency.
We do not try to prevent cancelations.
-}
class (IsArithmetic a, IsPrimitive a) => Arithmetic a where
   sum ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      CodeGenFunction r (Value a)
   sum = sumGeneric

   {- |
   The first result value is the sum of all vector elements from 0 to @div n 2 + 1@
   and the second result value is the sum of vector elements from @div n 2@ to @n-1@.
   n must be at least D2.
   -}
   sumToPair ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      CodeGenFunction r (Value a, Value a)
   sumToPair = sumToPairGeneric

   {- |
   Treat the vector as concatenation of pairs and all these pairs are added.
   Useful for stereo signal processing.
   n must be at least D2.
   -}
   sumInterleavedToPair ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      CodeGenFunction r (Value a, Value a)
   sumInterleavedToPair v =
      getLowestPair =<< reduceSumInterleaved 2 v

   cumulate ::
      (IsPowerOf2 n) =>
      Value a -> Value (Vector n a) ->
      CodeGenFunction r (Value a, Value (Vector n a))
   cumulate = cumulateGeneric

   dotProduct ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      Value (Vector n a) ->
      CodeGenFunction r (Value a)
   dotProduct x y =
      dotProductPartial (size x) x y

   mul ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      Value (Vector n a) ->
      CodeGenFunction r (Value (Vector n a))
   mul = A.mul

instance Arithmetic Float where
   sum x =
      Ext.runWhen (size x >= 4) (sumGeneric x) $
      Ext.with X86.haddps $ \haddp ->
          {-
          We can make use of the following facts:
          SSE3 has Float vectors of size 4,
          there is an instruction for horizontal add.
          -}
          do chunkSum <-
                foldl1 (M.liftR2 A.add) $ chop x
             y <- haddp chunkSum (value undef)
             z <- haddp y        (value undef)
{-
             y <- haddp chunkSum chunkSum
             z <- haddp y y
-}
             extractelement z (valueOf 0)

   sumToPair x =
      Ext.runWhen (size x >= 4) (getLowestPair x) $
      Ext.with X86.haddps $ \haddp ->
          let {-
              reduce ::
                 [CodeGenFunction r (Value (Vector D4 Float))] ->
                 [CodeGenFunction r (Value (Vector D4 Float))]
              -}
              reduce [] = []
              reduce [_] = error "vector must have size power of two"
              reduce (x0:x1:xs) =
                 M.liftR2 haddp x0 x1 : reduce xs
              go []  = error "vector must not be empty"
              go [c] =
                 getLowestPair
                    =<< flip haddp (value undef)
                    =<< c
              go cs  = go (reduce cs)
          in  go $ chop x

{-
The haddps based implementation cumulate is slower than the generic one.
However, one day the x86 processors may implement a cumulative sum
which we could employ with this frame.

   cumulate a x =
      Ext.runWhen (size x >= 4) (cumulateGeneric a x) $
      Ext.with X86.cumulate1s $ \cumulate1s -> do
         (b,ys) <-
            foldr
               (\chunk0 cont a0 -> do
                  (a1,chunk1) <- cumulateFrom1 cumulate1s a0 =<< chunk0
                  fmap (mapSnd (chunk1:)) (cont a1))
               (\a0 -> return (a0,[]))
               (chop x)
               a
         y <- concat ys
         return (b,y)
-}

   dotProduct x y =
      Ext.run (sum =<< A.mul x y) $
      Ext.with X86.dpps $ \dpp ->
         foldl1 (M.liftR2 A.add) $
         List.zipWith
            (\mx my -> do
               cx <- mx
               cy <- my
               flip extractelement (valueOf 0)
                =<< dpp cx cy (valueOf 0xF1))
            (chop x)
            (chop y)

instance Arithmetic Double where

instance Arithmetic Int8   where
instance Arithmetic Int16  where
instance Arithmetic Int32  where
instance Arithmetic Int64  where
instance Arithmetic Word8  where
instance Arithmetic Word16 where
instance Arithmetic Word64 where

instance Arithmetic Word32 where
   mul x y =
      A.mul x y
      `Ext.run`
      (Ext.with X86.pmuludq $ \pmul ->
         zipChunksWith
            (\cx cy -> do
               evenX <- LLVM.shufflevector cx (value undef)
                  (constVector [constOf 0, undef, constOf 2, undef])
               evenY <- LLVM.shufflevector cy (value undef)
                  (constVector [constOf 0, undef, constOf 2, undef])
               evenZ64 <- pmul evenX evenY
               evenZ <- LLVM.bitcastUnify evenZ64
               oddX <- LLVM.shufflevector cx (value undef)
                  (constVector [constOf 1, undef, constOf 3, undef])
               oddY <- LLVM.shufflevector cy (value undef)
                  (constVector [constOf 1, undef, constOf 3, undef])
               oddZ64 <- pmul oddX oddY
               oddZ <- LLVM.bitcastUnify oddZ64
               LLVM.shufflevector evenZ oddZ
                  (constVector [constOf 0, constOf 4, constOf 2, constOf 6]))
            x y)
      `Ext.run`
      (Ext.with X86.pmulld $ \pmul ->
         zipChunksWith pmul x y)


umul32to64 ::
   (IsPowerOf2 n) =>
   Value (Vector n Word32) ->
   Value (Vector n Word32) ->
   CodeGenFunction r (Value (Vector n Word64))
umul32to64 x y =
   (do x64 <- map LLVM.zext x
       y64 <- map LLVM.zext y
       A.mul x64 y64)
   `Ext.run`
   (Ext.with X86.pmuludq $ \pmul ->
      zipChunksWith
         -- save an initial shuffle
         (\cx cy -> do
            evenX <- LLVM.shufflevector cx (value undef)
               (constVector [constOf 0, undef, constOf 2, undef])
            evenY <- LLVM.shufflevector cy (value undef)
               (constVector [constOf 0, undef, constOf 2, undef])
            evenZ <- pmul evenX evenY
            oddX <- LLVM.shufflevector cx (value undef)
               (constVector [constOf 1, undef, constOf 3, undef])
            oddY <- LLVM.shufflevector cy (value undef)
               (constVector [constOf 1, undef, constOf 3, undef])
            oddZ <- pmul oddX oddY
{-
            LLVM.shufflevector evenZ oddZ
               (constVector [constOf 0, constOf 2, constOf 1, constOf 3])
-}
            assemble =<< (sequence $
               extract (valueOf 0) evenZ :
               extract (valueOf 0) oddZ :
               extract (valueOf 1) evenZ :
               extract (valueOf 1) oddZ :
               []))
{-
         -- save the final shuffle
         (\cx cy -> do
            lowerX <- LLVM.shufflevector cx (value undef)
               (constVector [constOf 0, undef, constOf 1, undef])
            lowerY <- LLVM.shufflevector cy (value undef)
               (constVector [constOf 0, undef, constOf 1, undef])
            lowerZ <- pmul lowerX lowerY
            upperX <- LLVM.shufflevector cx (value undef)
               (constVector [constOf 2, undef, constOf 3, undef])
            upperY <- LLVM.shufflevector cy (value undef)
               (constVector [constOf 2, undef, constOf 3, undef])
            upperZ <- pmul upperX upperY
{-
            LLVM.shufflevector lowerZ upperZ
               (constVector [constOf 0, constOf 1, constOf 2, constOf 3])
-}
            concat [lowerZ, upperZ])
-}
         x y)


{- |
Attention:
The rounding and fraction functions only work
for floating point values with maximum magnitude of @maxBound :: Int32@.
This way we safe expensive handling of possibly seldom cases.
-}
class (Arithmetic a, LLVM.CmpRet a Bool, IsConst a) =>
         Real a where
   min, max ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      Value (Vector n a) ->
      CodeGenFunction r (Value (Vector n a))

   abs ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      CodeGenFunction r (Value (Vector n a))

   truncate, floor, fraction ::
      (IsPowerOf2 n) =>
      Value (Vector n a) ->
      CodeGenFunction r (Value (Vector n a))

instance Real Float where
   min = zipAutoWith A.fmin X86.minps
   max = zipAutoWith A.fmax X86.maxps
   abs = mapAuto A.fabs X86.absps
   {-
   An IEEE specific implementation could do some bit manipulation:
   s eeeeeeee mmmmmmmmmmmmmmmmmmmmmmm
   Generate a pure power of two by clearing mantissa:
   s eeeeeeee 00000000000000000000000
   Now subtract 1 in order to get the required bit mask for the mantissa
   s eeeeeeee 11111111110000000000000
   multiply with 2 in order to correct exponent
   and then do bitwise AND of the mask with the original number.
   This method only works for numbers from 1 to 2^23-1,
   that is the range is even more smaller
   than that for the rounding via Int32.
   -}
   truncate x =
      (LLVM.sitofp .
       (id :: Value (Vector n Int32) -> Value (Vector n Int32))
       <=< LLVM.fptosi) x
      `Ext.run`
      (Ext.with X86.roundps $ \round ->
          mapChunks (flip round (valueOf 3)) x)
   floor x =
      floorGeneric x
      `Ext.run`
      (Ext.with X86.cmpps $ \cmp ->
          mapChunks (floorLogical cmp) x)
{- LLVM-2.6 rearranges the MXCSR manipulations in an invalid way
      `Ext.run`
      (Ext.with2 (X86.withMXCSR (Bit.shiftL 1 13)) X86.cvtps2dq $
          \ with cvtps2dq -> with $
             LLVM.sitofp =<< mapChunks cvtps2dq x)
-}
      `Ext.run`
      (Ext.with X86.roundps $ \round ->
          mapChunks (flip round (valueOf 1)) x)
   fraction x =
      fractionGeneric x
      `Ext.run`
      (Ext.with X86.cmpps $ \cmp ->
          mapChunks (fractionLogical cmp) x)
{-
      `Ext.run`
      (Ext.with2 (X86.withMXCSR (Bit.shiftL 1 13)) X86.cvtps2dq $
          \ with cvtps2dq -> with $
             A.sub x =<< LLVM.sitofp =<< mapChunks cvtps2dq x)
-}
      `Ext.run`
      (Ext.with X86.roundps $ \round ->
          mapChunks (\c -> A.sub c =<< flip round (valueOf 1) c) x)

instance Real Double where
   min = zipAutoWith A.fmin X86.minpd
   max = zipAutoWith A.fmax X86.maxpd
   abs = mapAuto A.fabs X86.abspd
   truncate x =
      (LLVM.sitofp .
       (id :: Value (Vector n Int64) -> Value (Vector n Int64))
       <=< LLVM.fptosi) x
      `Ext.run`
      (Ext.with X86.roundpd $ \round ->
          mapChunks (flip round (valueOf 3)) x)
   floor x =
      floorGeneric x
      `Ext.run`
      (Ext.with X86.cmppd $ \cmp ->
          mapChunks (floorLogical cmp) x)
      `Ext.run`
      (Ext.with X86.roundpd $ \round ->
          mapChunks (flip round (valueOf 1)) x)
   fraction x =
      fractionGeneric x
      `Ext.run`
      (Ext.with X86.cmppd $ \cmp ->
          mapChunks (fractionLogical cmp) x)
      `Ext.run`
      (Ext.with X86.roundpd $ \round ->
          mapChunks (\c -> A.sub c =<< flip round (valueOf 1) c) x)

instance Real Int8 where
   min = order A.smin X86.pcmpgtb X86.pminsb
   max = order A.smax (fmap flip X86.pcmpgtb) X86.pmaxsb
   abs = mapAuto A.sabs X86.pabsb
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)

instance Real Int16 where
   min = order A.smin X86.pcmpgtw X86.pminsw
   max = order A.smax (fmap flip X86.pcmpgtw) X86.pmaxsw
   abs = mapAuto A.sabs X86.pabsw
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)

instance Real Int32 where
   min = order A.smin X86.pcmpgtd X86.pminsd
   max = order A.smax (fmap flip X86.pcmpgtd) X86.pmaxsd
   abs = mapAuto A.sabs X86.pabsd
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)

instance Real Int64 where
   min = zipAutoWith A.smin (orderBy X86.pcmpgtq)
   max = zipAutoWith A.smax (orderBy (fmap flip X86.pcmpgtq))
   abs = mapAuto A.sabs $
      Ext.with (orderBy (fmap flip X86.pcmpgtq)) $
         \smax x -> smax x =<< LLVM.neg x
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)

instance Real Word8 where
   min = order A.umin X86.pcmpugtb X86.pminub
   max = order A.umax (fmap flip X86.pcmpugtb) X86.pmaxub
   abs = return
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)

instance Real Word16 where
   min = order A.umin X86.pcmpugtw X86.pminuw
   max = order A.umax (fmap flip X86.pcmpugtw) X86.pmaxuw
   abs = return
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)

instance Real Word32 where
   min = order A.umin X86.pcmpugtd X86.pminud
   max = order A.umax (fmap flip X86.pcmpugtd) X86.pmaxud
   abs = return
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)

instance Real Word64 where
   min = zipAutoWith A.umin (orderBy X86.pcmpugtq)
   max = zipAutoWith A.umax (orderBy (fmap flip X86.pcmpugtq))
   abs = return
   truncate = return
   floor = return
   fraction = const $ return (value LLVM.zero)