{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module HaskellWorks.Data.Simd.Comparison.Avx2 where
import Control.Monad
import Data.Word
import qualified Data.ByteString as BS
import qualified Data.Vector as DV
import qualified Data.Vector.Storable as DVS
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Marshal.Unsafe as F
import qualified Foreign.Ptr as F
import qualified HaskellWorks.Data.ByteString as BS
import qualified HaskellWorks.Data.Simd.ChunkString as CS
import qualified HaskellWorks.Data.Simd.Internal.Foreign as F
import qualified HaskellWorks.Data.Vector.AsVector8 as V
import qualified HaskellWorks.Data.Vector.Storable as DVS
class CmpEqWord8s a where
type Target a
cmpEqWord8s :: Word8 -> a -> Target a
instance CmpEqWord8s (DVS.Vector Word8) where
type Target (DVS.Vector Word8) = DVS.Vector Word8
cmpEqWord8s :: Word8 -> Vector Word8 -> Target (Vector Word8)
cmpEqWord8s Word8
w8 Vector Word8
v = forall a. IO a -> a
F.unsafeLocalState forall a b. (a -> b) -> a -> b
$ do
ForeignPtr Word8
tgtFptr <- forall a. Int -> IO (ForeignPtr a)
F.mallocForeignPtrBytes Int
bufLen
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
srcFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
tgtFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
tgtPtr -> do
()
_ <- UInt8 -> Ptr UInt8 -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
tgtPtr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w64sLen) (Ptr Word8
srcPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` Int
srcOffset)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
disalignment forall a. Eq a => a -> a -> Bool
/= Int
0) forall a b. (a -> b) -> a -> b
$ do
let ending :: Vector Word8
ending = Int -> Vector Word8 -> Vector Word8
DVS.padded ((forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
v forall a. Num a => a -> a -> a
+ Int
63) forall a. Integral a => a -> a -> a
`div` Int
64) (forall a. Storable a => Int -> Vector a -> Vector a
DVS.drop Int
alignment Vector Word8
v)
let (ForeignPtr Word8
endFptr, Int
_, Int
_) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
ending
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
endFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
endPtr -> do
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ UInt8 -> Ptr UInt8 -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8) (Ptr Word8
tgtPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` (Int
w64sLen forall a. Num a => a -> a -> a
* Int
8)) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
disalignment) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
endPtr)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Storable a => ForeignPtr a -> Int -> Int -> Vector a
DVS.unsafeFromForeignPtr ForeignPtr Word8
tgtFptr Int
0 Int
tgtLen
where (ForeignPtr Word8
srcFptr, Int
srcOffset, Int
srcLen) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
v
bufLen :: Int
bufLen = (Int
srcLen forall a. Num a => a -> a -> a
+ Int
63) forall a. Integral a => a -> a -> a
`div` Int
8
tgtLen :: Int
tgtLen = (Int
srcLen forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8
w64sLen :: Int
w64sLen = Int
srcLen forall a. Integral a => a -> a -> a
`div` Int
64
alignment :: Int
alignment = Int
w64sLen forall a. Num a => a -> a -> a
* Int
64
disalignment :: Int
disalignment = Int
srcLen forall a. Num a => a -> a -> a
- Int
alignment
{-# INLINE cmpEqWord8s #-}
instance CmpEqWord8s (DVS.Vector Word64) where
type Target (DVS.Vector Word64) = DVS.Vector Word64
cmpEqWord8s :: Word8 -> Vector Word64 -> Target (Vector Word64)
cmpEqWord8s Word8
w8 Vector Word64
v = case forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word64
v :: DVS.Vector Word8 of
Vector Word8
u -> case forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
u of
(ForeignPtr Word8
srcFptr, Int
srcOffset, Int
srcLength) -> if Int
disalignment forall a. Eq a => a -> a -> Bool
== Int
0
then forall a. IO a -> a
F.unsafeLocalState forall a b. (a -> b) -> a -> b
$ do
ForeignPtr Word64
targetFptr <- forall a. Int -> IO (ForeignPtr a)
F.mallocForeignPtrBytes Int
srcLength
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
srcFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word64
targetFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word64
targetPtr -> do
()
_ <- UInt8 -> Ptr UInt8 -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8
(forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8)
(forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word64
targetPtr)
(forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w64sLen)
(forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
srcPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` Int
srcOffset)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Storable a => ForeignPtr a -> Int -> Int -> Vector a
DVS.unsafeFromForeignPtr ForeignPtr Word64
targetFptr Int
0 Int
w64sLen
else forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unaligned byte string: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
disalignment
where w64sLen :: Int
w64sLen = Int
srcLength forall a. Integral a => a -> a -> a
`div` Int
64
disalignment :: Int
disalignment = Int
srcLength forall a. Num a => a -> a -> a
- Int
w64sLen forall a. Num a => a -> a -> a
* Int
64
{-# INLINE cmpEqWord8s #-}
instance CmpEqWord8s [DVS.Vector Word64] where
type Target [DVS.Vector Word64] = [DVS.Vector Word64]
cmpEqWord8s :: Word8 -> [Vector Word64] -> Target [Vector Word64]
cmpEqWord8s Word8
w8 [Vector Word64]
vs = forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word64]
vs
{-# INLINE cmpEqWord8s #-}
instance CmpEqWord8s [DVS.Vector Word8] where
type Target [DVS.Vector Word8] = [DVS.Vector Word8]
cmpEqWord8s :: Word8 -> [Vector Word8] -> Target [Vector Word8]
cmpEqWord8s Word8
w8 [Vector Word8]
vs = forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word8]
vs
{-# INLINE cmpEqWord8s #-}
instance CmpEqWord8s [BS.ByteString] where
type Target [BS.ByteString] = [BS.ByteString]
cmpEqWord8s :: Word8 -> [ByteString] -> Target [ByteString]
cmpEqWord8s Word8
w8 [ByteString]
vs = forall a. ToByteString a => a -> ByteString
BS.toByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. AsVector8 a => a -> Vector Word8
V.asVector8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString]
vs
{-# INLINE cmpEqWord8s #-}
instance CmpEqWord8s CS.ChunkString where
type Target CS.ChunkString = CS.ChunkString
cmpEqWord8s :: Word8 -> ChunkString -> Target ChunkString
cmpEqWord8s Word8
w8 = forall a. ToChunkString a => a -> ChunkString
CS.toChunkString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToByteStrings a => a -> [ByteString]
BS.toByteStrings
{-# INLINE cmpEqWord8s #-}
class CmpEqWord8sPara a where
type CmpEqWord8sParaTarget a
cmpEqWord8sPara :: DVS.Vector Word8 -> a -> CmpEqWord8sParaTarget a
instance CmpEqWord8sPara (DVS.Vector Word64) where
type CmpEqWord8sParaTarget (DVS.Vector Word64) = DV.Vector (DVS.Vector Word64)
cmpEqWord8sPara :: Vector Word8
-> Vector Word64 -> CmpEqWord8sParaTarget (Vector Word64)
cmpEqWord8sPara Vector Word8
w8s Vector Word64
v = case forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word64
v :: DVS.Vector Word8 of
Vector Word8
u -> case forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
u of
(ForeignPtr Word8
srcFptr, Int
srcOffset, Int
srcLength) -> if Int
disalignment forall a. Eq a => a -> a -> Bool
== Int
0
then forall a. IO a -> a
F.unsafeLocalState forall a b. (a -> b) -> a -> b
$ do
ForeignPtr Word64
tgtFptr <- forall a. Int -> IO (ForeignPtr a)
F.mallocForeignPtrBytes (Int
srcLength forall a. Num a => a -> a -> a
* forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
srcFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word64
tgtFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word64
tgtPtr -> do
let Vector (Ptr Word8)
tgtsPtrsV :: DVS.Vector (F.Ptr Word8) = forall a. Storable a => Int -> (Vector a -> a) -> Vector a
DVS.constructN (forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s) forall a b. (a -> b) -> a -> b
$ \Vector (Ptr Word8)
t ->
Ptr Word64
tgtPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` (forall a. Storable a => Vector a -> Int
DVS.length Vector (Ptr Word8)
t forall a. Num a => a -> a -> a
* forall a. Storable a => Vector a -> Int
DVS.length Vector Word64
v)
let (ForeignPtr Word8
w8sFptr, Int
_, Int
w8sLen) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
w8s
let (ForeignPtr (Ptr Word8)
tgtsPtrsFptr, Int
_, Int
_) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector (Ptr Word8)
tgtsPtrsV
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
w8sFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
w8sPtr -> do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr (Ptr Word8)
tgtsPtrsFptr forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr Word8)
tgtsPtrsPtr -> do
Ptr UInt8 -> Size -> Ptr (Ptr UInt8) -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8Para (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
w8sPtr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w8sLen) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr (Ptr Word8)
tgtsPtrsPtr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w64sLen) (Ptr Word8
srcPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` Int
srcOffset)
let tgtV :: Vector Word64
tgtV = forall a. Storable a => ForeignPtr a -> Int -> Int -> Vector a
DVS.unsafeFromForeignPtr ForeignPtr Word64
tgtFptr Int
0 (Int
w64sLen forall a. Num a => a -> a -> a
* forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Int -> (Vector a -> a) -> Vector a
DV.constructN (forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s) forall a b. (a -> b) -> a -> b
$ \Vector (Vector Word64)
t -> forall a. Storable a => Int -> Vector a -> Vector a
DVS.take Int
w64sLen (forall a. Storable a => Int -> Vector a -> Vector a
DVS.drop (forall a. Vector a -> Int
DV.length Vector (Vector Word64)
t forall a. Num a => a -> a -> a
* Int
w64sLen) Vector Word64
tgtV)
else forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unaligned byte string: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
disalignment
where w64sLen :: Int
w64sLen = Int
srcLength forall a. Integral a => a -> a -> a
`div` Int
64
disalignment :: Int
disalignment = Int
srcLength forall a. Num a => a -> a -> a
- Int
w64sLen forall a. Num a => a -> a -> a
* Int
64
{-# INLINE cmpEqWord8sPara #-}