{-# LANGUAGE ForeignFunctionInterface #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} module Test.Vector where import qualified LLVM.Extra.ScalarOrVectorPrivate as SoVPriv import qualified LLVM.Extra.ScalarOrVector as SoV import qualified LLVM.Extra.VectorAlt as VectorAlt import qualified LLVM.Extra.Vector as Vector import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Extra.Marshal as Marshal import qualified LLVM.Extra.Tuple as Tuple import qualified LLVM.ExecutionEngine as EE import qualified LLVM.Core as LLVM import qualified Type.Data.Num.Decimal as TypeNum import Type.Base.Proxy (Proxy(Proxy)) import Foreign.Ptr (FunPtr) import qualified Data.Traversable as Trav import qualified Data.Foldable as Fold import qualified Data.Bits as Bits import Data.Word (Word8, Word16, Word32) import Data.Int (Int8, Int32) import qualified Test.QuickCheck as QC import qualified Test.QuickCheck.Monadic as QCMon import Control.Applicative (liftA2, pure) import qualified Prelude as P import Prelude hiding (min, max) type V4 = LLVM.Vector TypeNum.D4 type V5 = LLVM.Vector TypeNum.D5 type V4Word32 = V4 Word32 type V4Int32 = V4 Int32 type V4Float = V4 Float type Importer func = FunPtr func -> func generateFunction :: EE.ExecutionFunction f => Importer f -> LLVM.CodeGenModule (LLVM.Function f) -> IO f generateFunction imprt code = do m <- LLVM.newModule fn <- do func <- LLVM.defineModule m $ LLVM.setTarget LLVM.hostTriple >> code EE.runEngineAccessWithModule m $ EE.getExecutionFunction imprt func LLVM.writeBitcodeToFile "test-vector.bc" m return fn foreign import ccall safe "dynamic" derefTestCasePtr :: Importer (LLVM.Ptr inp -> LLVM.Ptr out -> IO ()) modul :: (Memory.C linp, Memory.Struct linp ~ minp, LLVM.IsType minp, Memory.C lout, Memory.Struct lout ~ mout, LLVM.IsType mout) => (linp -> LLVM.CodeGenFunction () lout) -> LLVM.CodeGenModule (LLVM.Function (LLVM.Ptr minp -> LLVM.Ptr mout -> IO ())) modul codegen = LLVM.createFunction LLVM.ExternalLinkage $ \xPtr yPtr -> do flip Memory.store yPtr =<< codegen =<< Memory.load xPtr LLVM.ret () run :: (Marshal.C inp, Marshal.Struct inp ~ minp, LLVM.IsType minp, Marshal.C out, Marshal.Struct out ~ mout, LLVM.IsType mout, Tuple.ValueOf inp ~ linp, Tuple.ValueOf out ~ lout) => (Show inp, QC.Arbitrary inp) => (linp -> LLVM.CodeGenFunction () lout) -> (inp -> out -> Bool) -> IO QC.Property run codegen predicate = do funIO <- generateFunction derefTestCasePtr $ modul codegen return $ QC.property $ \x -> QCMon.monadicIO $ do y <- QCMon.run $ Marshal.with x $ \xPtr -> Marshal.alloca $ \yPtr -> do funIO xPtr yPtr Marshal.peek yPtr QCMon.assert $ predicate x y vec4 :: V4 a -> V4 a vec4 = id unop :: (LLVM.Value V4Int32 -> LLVM.CodeGenFunction () (LLVM.Value V4Int32)) -> (Int32 -> Int32) -> IO QC.Property unop codegen fun = run codegen (\x y -> fmap fun (vec4 x) == vec4 y) unopFloat :: (LLVM.Value V4Float -> LLVM.CodeGenFunction () (LLVM.Value V4Float)) -> (Float -> Float) -> IO QC.Property unopFloat codegen fun = run codegen (\x y -> fmap fun (vec4 x) == vec4 y) binop :: ((TypeNum.D4 TypeNum.:*: LLVM.SizeOf a) ~ size, TypeNum.Natural size, QC.Arbitrary a, Show a, Eq a, Marshal.Vector TypeNum.D4 a, Tuple.VectorValueOf TypeNum.D4 a ~ v) => (v -> v -> LLVM.CodeGenFunction () v) -> (a -> a -> a) -> IO QC.Property binop codegen fun = run (uncurry codegen) (\(x,y) z -> liftA2 fun (vec4 x) (vec4 y) == vec4 z) binopInt :: (LLVM.Value V4Int32 ~ v) => (v -> v -> LLVM.CodeGenFunction () v) -> (Int32 -> Int32 -> Int32) -> IO QC.Property binopInt = binop type Int2 = LLVM.IntN TypeNum.D2 type Int3 = LLVM.IntN TypeNum.D3 type Word2 = LLVM.WordN TypeNum.D2 type Word3 = LLVM.WordN TypeNum.D3 vectorise :: (TypeNum.Positive n, Integral a) => Integer -> a -> LLVM.Vector n Integer vectorise modu x = snd $ Trav.mapAccumL (\xi f -> f xi) (toInteger x) $ pure (\xi -> divMod xi modu) unpackInts :: (TypeNum.Positive n, TypeNum.Positive d, Integral a) => Integer -> a -> LLVM.Vector n (LLVM.IntN d) unpackInts modu = fmap (\x -> LLVM.IntN $ if Bits.shiftR modu 1 Bits..&. x /= 0 then toInteger x - modu else toInteger x) . vectorise modu unpackWords :: (TypeNum.Positive n, TypeNum.Positive d, Integral a) => Integer -> a -> LLVM.Vector n (LLVM.WordN d) unpackWords modu = fmap LLVM.WordN . vectorise modu unpackInt2 :: Word8 -> V4 Int2 unpackInt2 = unpackInts 4 unpackWord2 :: Word8 -> V4 Word2 unpackWord2 = unpackWords 4 unpackInt3 :: Word16 -> V5 Int3 unpackInt3 = unpackInts 8 unpackWord3 :: Word16 -> V5 Word3 unpackWord3 = unpackWords 8 binopV4I2 :: (Eq a, LLVM.IsPrimitive a, LLVM.IsSized a, LLVM.SizeOf a ~ TypeNum.D2, LLVM.Value (V4 a) ~ v) => (Word8 -> V4 a) -> (v -> v -> LLVM.CodeGenFunction () v) -> (a -> a -> a) -> IO QC.Property binopV4I2 unpackBits codegen fun = run (\(x,y) -> do vx <- LLVM.bitcast x vy <- LLVM.bitcast y vz <- codegen vx vy LLVM.bitcast vz) (\(x,y) z -> liftA2 fun (unpackBits x) (unpackBits y) == unpackBits z) type Code15 r = LLVM.CodeGenFunction r (LLVM.Value (LLVM.WordN TypeNum.D15)) binopV5I3 :: (Eq a, LLVM.IsPrimitive a, LLVM.IsSized a, LLVM.SizeOf a ~ TypeNum.D3, LLVM.Value (V5 a) ~ v) => (Word16 -> V5 a) -> (v -> v -> LLVM.CodeGenFunction () v) -> (a -> a -> a) -> IO QC.Property binopV5I3 unpackBits codegen fun = run (\(x,y) -> do vx <- LLVM.bitcast =<< (LLVM.trunc x :: Code15 r) vy <- LLVM.bitcast =<< (LLVM.trunc y :: Code15 r) vz <- codegen vx vy LLVM.zext =<< (LLVM.bitcast vz :: Code15 r)) (\(x,y) z -> liftA2 fun (unpackBits x) (unpackBits y) == unpackBits z) binopInt8 :: (LLVM.Value (V4 Int8) ~ v) => (v -> v -> LLVM.CodeGenFunction () v) -> (Int8 -> Int8 -> Int8) -> IO QC.Property binopInt8 = binop binopWord8 :: (LLVM.Value (V4 Word8) ~ v) => (v -> v -> LLVM.CodeGenFunction () v) -> (Word8 -> Word8 -> Word8) -> IO QC.Property binopWord8 = binop addSat, subSat :: (Bounded a, Integral a) => a -> a -> a addSat = addSatMan (toInteger, fromInteger) subSat = subSatMan (toInteger, fromInteger) addSatMan, subSatMan :: (Bounded a) => (a -> Integer, Integer -> a) -> a -> a -> a addSatMan = opSat (+) subSatMan = opSat (-) convertIntN :: Proxy d -> (LLVM.IntN d -> Integer, Integer -> LLVM.IntN d) convertIntN Proxy = (\(LLVM.IntN n) -> n, LLVM.IntN) convertWordN :: Proxy d -> (LLVM.WordN d -> Integer, Integer -> LLVM.WordN d) convertWordN Proxy = (\(LLVM.WordN n) -> n, LLVM.WordN) opSat :: (Bounded a) => (Integer -> Integer -> Integer) -> (a -> Integer, Integer -> a) -> a -> a -> a opSat op (toIntg, fromIntg) x y = fromIntg $ P.max (toIntg $ minBound `asTypeOf` x) $ P.min (toIntg $ maxBound `asTypeOf` x) $ op (toIntg x) (toIntg y) fraction :: RealFrac a => a -> a fraction x = x - fromInteger (floor x) split :: String -> (a -> b -> c) -> (a,a) -> b -> [(String, c)] split name driver (intrinsic, fallback) f = (name ++ ".intrinsic", driver intrinsic f) : (name ++ ".fallback", driver fallback f) : [] tests :: [(String, IO QC.Property)] tests = ("abs", unop Vector.abs P.abs) : ("signum", unop Vector.signum P.signum) : ("Alt.abs", unop VectorAlt.abs P.abs) : ("min", binopInt Vector.min P.min) : ("max", binopInt Vector.max P.max) : ("Alt.min", binopInt VectorAlt.min P.min) : ("Alt.max", binopInt VectorAlt.max P.max) : split "addSat.Word8" binopWord8 (SoV.addSat, SoVPriv.uaddSat) addSat ++ split "subSat.Word8" binopWord8 (SoV.subSat, SoVPriv.usubSat) subSat ++ split "addSat.Int8" binopInt8 (SoV.addSat, SoVPriv.saddSat) addSat ++ split "subSat.Int8" binopInt8 (SoV.subSat, SoVPriv.ssubSat) subSat ++ split "addSat.Word3" (binopV5I3 unpackWord3) (SoV.addSat, SoVPriv.uaddSat) (addSatMan $ convertWordN TypeNum.d3) ++ split "subSat.Word3" (binopV5I3 unpackWord3) (SoV.subSat, SoVPriv.usubSat) (subSatMan $ convertWordN TypeNum.d3) ++ split "addSat.Int3" (binopV5I3 unpackInt3) (SoV.addSat, SoVPriv.saddSat) (addSatMan $ convertIntN TypeNum.d3) ++ split "subSat.Int3" (binopV5I3 unpackInt3) (SoV.subSat, SoVPriv.ssubSat) (subSatMan $ convertIntN TypeNum.d3) ++ split "addSat.Word2" (binopV4I2 unpackWord2) (SoV.addSat, SoVPriv.uaddSat) (addSatMan $ convertWordN TypeNum.d2) ++ split "subSat.Word2" (binopV4I2 unpackWord2) (SoV.subSat, SoVPriv.usubSat) (subSatMan $ convertWordN TypeNum.d2) ++ split "addSat.Int2" (binopV4I2 unpackInt2) (SoV.addSat, SoVPriv.saddSat) (addSatMan $ convertIntN TypeNum.d2) ++ split "subSat.Int2" (binopV4I2 unpackInt2) (SoV.subSat, SoVPriv.ssubSat) (subSatMan $ convertIntN TypeNum.d2) ++ ("sum", run Vector.sum (\x y -> Fold.sum (vec4 x) == (y::Int32))) : ("cumulate", run (uncurry Vector.cumulate) (\(x0,xv) (y0,yv) -> scanl (+) x0 (Fold.toList (vec4 xv)) == Fold.toList (vec4 yv) ++ [y0::Int32])) : ("dot", run (uncurry Vector.dotProduct) (\(x,y) z -> Fold.sum (liftA2 (*) (vec4 x) (vec4 y)) == (z::Int32))) : ("truncate", unopFloat Vector.truncate (fromInteger . P.truncate)) : ("floor", unopFloat Vector.floor (fromInteger . P.floor)) : ("fraction", unopFloat Vector.fraction fraction) : ("floorLogical", unopFloat VectorAlt.floor (fromInteger . P.floor)) : ("fractionLogical", unopFloat VectorAlt.fraction fraction) : ("floorSelect", unopFloat VectorAlt.floorSelect (fromInteger . P.floor)) : ("fractionSelect", unopFloat VectorAlt.fractionSelect fraction) : []