{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Test.Vector where

import qualified LLVM.Extra.ScalarOrVectorPrivate as SoVPriv
import qualified LLVM.Extra.ScalarOrVector as SoV
import qualified LLVM.Extra.VectorAlt as VectorAlt
import qualified LLVM.Extra.Vector as Vector
import qualified LLVM.Extra.Memory as Memory
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.ExecutionEngine as EE
import qualified LLVM.Core as LLVM

import qualified Type.Data.Num.Decimal as TypeNum
import Type.Base.Proxy (Proxy(Proxy))

import Foreign.Ptr (FunPtr)

import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.Bits as Bits
import Data.Word (Word8, Word16, Word32)
import Data.Int (Int8, Int32)

import qualified Test.QuickCheck as QC
import qualified Test.QuickCheck.Monadic as QCMon

import Control.Applicative (liftA2, pure)

import qualified Prelude as P
import Prelude hiding (min, max)


type V4 = LLVM.Vector TypeNum.D4
type V5 = LLVM.Vector TypeNum.D5
type V4Word32 = V4 Word32
type V4Int32 = V4 Int32
type V4Float = V4 Float

type Importer func = FunPtr func -> func

generateFunction ::
   EE.ExecutionFunction f =>
   Importer f -> LLVM.CodeGenModule (LLVM.Function f) -> IO f
generateFunction imprt code = do
   m <- LLVM.newModule
   fn <- do
      func <- LLVM.defineModule m $ LLVM.setTarget LLVM.hostTriple >> code
      EE.runEngineAccessWithModule m $ EE.getExecutionFunction imprt func
   LLVM.writeBitcodeToFile "test-vector.bc" m
   return fn


foreign import ccall safe "dynamic" derefTestCasePtr ::
   Importer (LLVM.Ptr inp -> LLVM.Ptr out -> IO ())

modul ::
   (Memory.C linp, Memory.Struct linp ~ minp, LLVM.IsType minp,
    Memory.C lout, Memory.Struct lout ~ mout, LLVM.IsType mout) =>
   (linp -> LLVM.CodeGenFunction () lout) ->
   LLVM.CodeGenModule (LLVM.Function (LLVM.Ptr minp -> LLVM.Ptr mout -> IO ()))
modul codegen =
   LLVM.createFunction LLVM.ExternalLinkage $ \xPtr yPtr -> do
      flip Memory.store yPtr =<< codegen =<< Memory.load xPtr
      LLVM.ret ()

run ::
   (Marshal.C inp, Marshal.Struct inp ~ minp, LLVM.IsType minp,
    Marshal.C out, Marshal.Struct out ~ mout, LLVM.IsType mout,
    Tuple.ValueOf inp ~ linp, Tuple.ValueOf out ~ lout) =>
   (Show inp, QC.Arbitrary inp) =>
   (linp -> LLVM.CodeGenFunction () lout) ->
   (inp -> out -> Bool) ->
   IO QC.Property
run codegen predicate = do
   funIO <- generateFunction derefTestCasePtr $ modul codegen
   return $ QC.property $ \x ->
      QCMon.monadicIO $ do
         y <-
            QCMon.run $
               Marshal.with x $ \xPtr ->
               Marshal.alloca $ \yPtr -> do
                  funIO xPtr yPtr
                  Marshal.peek yPtr
         QCMon.assert $ predicate x y


vec4 :: V4 a -> V4 a
vec4 = id


unop ::
   (LLVM.Value V4Int32 -> LLVM.CodeGenFunction () (LLVM.Value V4Int32)) ->
   (Int32 -> Int32) ->
   IO QC.Property
unop codegen fun =
   run codegen (\x y -> fmap fun (vec4 x) == vec4 y)

unopFloat ::
   (LLVM.Value V4Float -> LLVM.CodeGenFunction () (LLVM.Value V4Float)) ->
   (Float -> Float) ->
   IO QC.Property
unopFloat codegen fun =
   run codegen (\x y -> fmap fun (vec4 x) == vec4 y)


binop ::
   ((TypeNum.D4 TypeNum.:*: LLVM.SizeOf a) ~ size, TypeNum.Natural size,
    QC.Arbitrary a, Show a, Eq a,
    Marshal.Vector TypeNum.D4 a, Tuple.VectorValueOf TypeNum.D4 a ~ v) =>
   (v -> v -> LLVM.CodeGenFunction () v) ->
   (a -> a -> a) ->
   IO QC.Property
binop codegen fun =
   run (uncurry codegen)
      (\(x,y) z -> liftA2 fun (vec4 x) (vec4 y)  ==  vec4 z)

binopInt ::
   (LLVM.Value V4Int32 ~ v) =>
   (v -> v -> LLVM.CodeGenFunction () v) ->
   (Int32 -> Int32 -> Int32) ->
   IO QC.Property
binopInt = binop


type Int2 = LLVM.IntN TypeNum.D2
type Int3 = LLVM.IntN TypeNum.D3
type Word2 = LLVM.WordN TypeNum.D2
type Word3 = LLVM.WordN TypeNum.D3

vectorise ::
   (TypeNum.Positive n, Integral a) =>
   Integer -> a -> LLVM.Vector n Integer
vectorise modu x =
   snd $ Trav.mapAccumL (\xi f -> f xi) (toInteger x) $
   pure (\xi -> divMod xi modu)

unpackInts ::
   (TypeNum.Positive n, TypeNum.Positive d, Integral a) =>
   Integer -> a -> LLVM.Vector n (LLVM.IntN d)
unpackInts modu =
   fmap
      (\x ->
         LLVM.IntN $
         if Bits.shiftR modu 1 Bits..&. x /= 0
            then toInteger x - modu
            else toInteger x) .
   vectorise modu

unpackWords ::
   (TypeNum.Positive n, TypeNum.Positive d, Integral a) =>
   Integer -> a -> LLVM.Vector n (LLVM.WordN d)
unpackWords modu = fmap LLVM.WordN . vectorise modu

unpackInt2 :: Word8 -> V4 Int2
unpackInt2 = unpackInts 4

unpackWord2 :: Word8 -> V4 Word2
unpackWord2 = unpackWords 4

unpackInt3 :: Word16 -> V5 Int3
unpackInt3 = unpackInts 8

unpackWord3 :: Word16 -> V5 Word3
unpackWord3 = unpackWords 8

binopV4I2 ::
   (Eq a, LLVM.IsPrimitive a, LLVM.IsSized a, LLVM.SizeOf a ~ TypeNum.D2,
    LLVM.Value (V4 a) ~ v) =>
   (Word8 -> V4 a) ->
   (v -> v -> LLVM.CodeGenFunction () v) ->
   (a -> a -> a) ->
   IO QC.Property
binopV4I2 unpackBits codegen fun =
   run
      (\(x,y) -> do
         vx <- LLVM.bitcast x
         vy <- LLVM.bitcast y
         vz <- codegen vx vy
         LLVM.bitcast vz)
      (\(x,y) z ->
         liftA2 fun (unpackBits x) (unpackBits y)  ==  unpackBits z)

type Code15 r = LLVM.CodeGenFunction r (LLVM.Value (LLVM.WordN TypeNum.D15))

binopV5I3 ::
   (Eq a, LLVM.IsPrimitive a, LLVM.IsSized a, LLVM.SizeOf a ~ TypeNum.D3,
    LLVM.Value (V5 a) ~ v) =>
   (Word16 -> V5 a) ->
   (v -> v -> LLVM.CodeGenFunction () v) ->
   (a -> a -> a) ->
   IO QC.Property
binopV5I3 unpackBits codegen fun =
   run
      (\(x,y) -> do
         vx <- LLVM.bitcast =<< (LLVM.trunc x :: Code15 r)
         vy <- LLVM.bitcast =<< (LLVM.trunc y :: Code15 r)
         vz <- codegen vx vy
         LLVM.zext =<< (LLVM.bitcast vz :: Code15 r))
      (\(x,y) z ->
         liftA2 fun (unpackBits x) (unpackBits y)  ==  unpackBits z)

binopInt8 ::
   (LLVM.Value (V4 Int8) ~ v) =>
   (v -> v -> LLVM.CodeGenFunction () v) ->
   (Int8 -> Int8 -> Int8) ->
   IO QC.Property
binopInt8 = binop

binopWord8 ::
   (LLVM.Value (V4 Word8) ~ v) =>
   (v -> v -> LLVM.CodeGenFunction () v) ->
   (Word8 -> Word8 -> Word8) ->
   IO QC.Property
binopWord8 = binop


addSat, subSat :: (Bounded a, Integral a) => a -> a -> a
addSat = addSatMan (toInteger, fromInteger)
subSat = subSatMan (toInteger, fromInteger)

addSatMan, subSatMan ::
   (Bounded a) => (a -> Integer, Integer -> a) -> a -> a -> a
addSatMan = opSat (+)
subSatMan = opSat (-)

convertIntN :: Proxy d -> (LLVM.IntN d -> Integer, Integer -> LLVM.IntN d)
convertIntN Proxy = (\(LLVM.IntN n) -> n, LLVM.IntN)

convertWordN :: Proxy d -> (LLVM.WordN d -> Integer, Integer -> LLVM.WordN d)
convertWordN Proxy = (\(LLVM.WordN n) -> n, LLVM.WordN)

opSat ::
   (Bounded a) =>
   (Integer -> Integer -> Integer) ->
   (a -> Integer, Integer -> a) ->
   a -> a -> a
opSat op (toIntg, fromIntg) x y =
   fromIntg $
   P.max (toIntg $ minBound `asTypeOf` x) $
   P.min (toIntg $ maxBound `asTypeOf` x) $
   op (toIntg x) (toIntg y)


fraction :: RealFrac a => a -> a
fraction x = x - fromInteger (floor x)


split :: String -> (a -> b -> c) -> (a,a) -> b -> [(String, c)]
split name driver (intrinsic, fallback) f =
   (name ++ ".intrinsic", driver intrinsic f) :
   (name ++ ".fallback",  driver fallback  f) :
   []

tests :: [(String, IO QC.Property)]
tests =
   ("abs", unop Vector.abs P.abs) :
   ("signum", unop Vector.signum P.signum) :
   ("Alt.abs", unop VectorAlt.abs P.abs) :

   ("min", binopInt Vector.min P.min) :
   ("max", binopInt Vector.max P.max) :
   ("Alt.min", binopInt VectorAlt.min P.min) :
   ("Alt.max", binopInt VectorAlt.max P.max) :

   split "addSat.Word8" binopWord8 (SoV.addSat, SoVPriv.uaddSat) addSat ++
   split "subSat.Word8" binopWord8 (SoV.subSat, SoVPriv.usubSat) subSat ++
   split "addSat.Int8"  binopInt8  (SoV.addSat, SoVPriv.saddSat) addSat ++
   split "subSat.Int8"  binopInt8  (SoV.subSat, SoVPriv.ssubSat) subSat ++

   split "addSat.Word3"
      (binopV5I3 unpackWord3) (SoV.addSat, SoVPriv.uaddSat)
      (addSatMan $ convertWordN TypeNum.d3) ++
   split "subSat.Word3"
      (binopV5I3 unpackWord3) (SoV.subSat, SoVPriv.usubSat)
      (subSatMan $ convertWordN TypeNum.d3) ++
   split "addSat.Int3"
      (binopV5I3 unpackInt3) (SoV.addSat, SoVPriv.saddSat)
      (addSatMan $ convertIntN TypeNum.d3) ++
   split "subSat.Int3"
      (binopV5I3 unpackInt3) (SoV.subSat, SoVPriv.ssubSat)
      (subSatMan $ convertIntN TypeNum.d3) ++

   split "addSat.Word2"
      (binopV4I2 unpackWord2) (SoV.addSat, SoVPriv.uaddSat)
      (addSatMan $ convertWordN TypeNum.d2) ++
   split "subSat.Word2"
      (binopV4I2 unpackWord2) (SoV.subSat, SoVPriv.usubSat)
      (subSatMan $ convertWordN TypeNum.d2) ++
   split "addSat.Int2"
      (binopV4I2 unpackInt2) (SoV.addSat, SoVPriv.saddSat)
      (addSatMan $ convertIntN TypeNum.d2) ++
   split "subSat.Int2"
      (binopV4I2 unpackInt2) (SoV.subSat, SoVPriv.ssubSat)
      (subSatMan $ convertIntN TypeNum.d2) ++

   ("sum",
      run Vector.sum (\x y -> Fold.sum (vec4 x) == (y::Int32))) :
   ("cumulate",
      run
         (uncurry Vector.cumulate)
         (\(x0,xv) (y0,yv) ->
            scanl (+) x0 (Fold.toList (vec4 xv))
            ==
            Fold.toList (vec4 yv) ++ [y0::Int32])) :
   ("dot",
      run
         (uncurry Vector.dotProduct)
         (\(x,y) z ->
            Fold.sum (liftA2 (*) (vec4 x) (vec4 y))  ==  (z::Int32))) :

   ("truncate", unopFloat Vector.truncate (fromInteger . P.truncate)) :
   ("floor", unopFloat Vector.floor (fromInteger . P.floor)) :
   ("fraction", unopFloat Vector.fraction fraction) :

   ("floorLogical", unopFloat VectorAlt.floor (fromInteger . P.floor)) :
   ("fractionLogical", unopFloat VectorAlt.fraction fraction) :
   ("floorSelect", unopFloat VectorAlt.floorSelect (fromInteger . P.floor)) :
   ("fractionSelect", unopFloat VectorAlt.fractionSelect fraction) :
   []
