{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

module HaskellWorks.Data.Simd.Comparison.Avx2 where

import Control.Monad
import Data.Word

import qualified Data.ByteString                         as BS
import qualified Data.Vector                             as DV
import qualified Data.Vector.Storable                    as DVS
import qualified Foreign.ForeignPtr                      as F
import qualified Foreign.Marshal.Unsafe                  as F
import qualified Foreign.Ptr                             as F
import qualified HaskellWorks.Data.ByteString            as BS
import qualified HaskellWorks.Data.Simd.ChunkString      as CS
import qualified HaskellWorks.Data.Simd.Internal.Foreign as F
import qualified HaskellWorks.Data.Vector.AsVector8      as V
import qualified HaskellWorks.Data.Vector.Storable       as DVS

{- HLINT ignore "Redundant do"        -}

class CmpEqWord8s a where
  type Target a
  cmpEqWord8s :: Word8 -> a -> Target a

instance CmpEqWord8s (DVS.Vector Word8) where
  type Target (DVS.Vector Word8) = DVS.Vector Word8
  cmpEqWord8s :: Word8 -> Vector Word8 -> Target (Vector Word8)
cmpEqWord8s Word8
w8 Vector Word8
v = forall a. IO a -> a
F.unsafeLocalState forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr Word8
tgtFptr <- forall a. Int -> IO (ForeignPtr a)
F.mallocForeignPtrBytes Int
bufLen
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
srcFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
      forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
tgtFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
tgtPtr -> do
        ()
_ <- UInt8 -> Ptr UInt8 -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
tgtPtr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w64sLen) (Ptr Word8
srcPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` Int
srcOffset)
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
disalignment forall a. Eq a => a -> a -> Bool
/= Int
0) forall a b. (a -> b) -> a -> b
$ do
          let ending :: Vector Word8
ending = Int -> Vector Word8 -> Vector Word8
DVS.padded ((forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
v forall a. Num a => a -> a -> a
+ Int
63) forall a. Integral a => a -> a -> a
`div` Int
64)  (forall a. Storable a => Int -> Vector a -> Vector a
DVS.drop Int
alignment Vector Word8
v)
          let (ForeignPtr Word8
endFptr, Int
_, Int
_) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
ending
          forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
endFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
endPtr -> do
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ UInt8 -> Ptr UInt8 -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8) (Ptr Word8
tgtPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` (Int
w64sLen forall a. Num a => a -> a -> a
* Int
8)) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
disalignment) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
endPtr)

        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Storable a => ForeignPtr a -> Int -> Int -> Vector a
DVS.unsafeFromForeignPtr ForeignPtr Word8
tgtFptr Int
0 Int
tgtLen
    where (ForeignPtr Word8
srcFptr, Int
srcOffset, Int
srcLen) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
v
          bufLen :: Int
bufLen        = (Int
srcLen forall a. Num a => a -> a -> a
+ Int
63) forall a. Integral a => a -> a -> a
`div` Int
8
          tgtLen :: Int
tgtLen        = (Int
srcLen forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8
          w64sLen :: Int
w64sLen       = Int
srcLen forall a. Integral a => a -> a -> a
`div` Int
64
          alignment :: Int
alignment     = Int
w64sLen forall a. Num a => a -> a -> a
* Int
64
          disalignment :: Int
disalignment  = Int
srcLen forall a. Num a => a -> a -> a
- Int
alignment
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s (DVS.Vector Word64) where
  type Target (DVS.Vector Word64) = DVS.Vector Word64
  cmpEqWord8s :: Word8 -> Vector Word64 -> Target (Vector Word64)
cmpEqWord8s Word8
w8 Vector Word64
v = case forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word64
v :: DVS.Vector Word8 of
    Vector Word8
u -> case forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
u of
      (ForeignPtr Word8
srcFptr, Int
srcOffset, Int
srcLength) -> if Int
disalignment forall a. Eq a => a -> a -> Bool
== Int
0
        then forall a. IO a -> a
F.unsafeLocalState forall a b. (a -> b) -> a -> b
$ do
          ForeignPtr Word64
targetFptr <- forall a. Int -> IO (ForeignPtr a)
F.mallocForeignPtrBytes Int
srcLength
          forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
srcFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
            forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word64
targetFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word64
targetPtr -> do
              ()
_ <- UInt8 -> Ptr UInt8 -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8
                (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w8)
                (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word64
targetPtr)
                (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w64sLen)
                (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
srcPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` Int
srcOffset)
              forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Storable a => ForeignPtr a -> Int -> Int -> Vector a
DVS.unsafeFromForeignPtr ForeignPtr Word64
targetFptr Int
0 Int
w64sLen
        else forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unaligned byte string: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
disalignment
        where w64sLen :: Int
w64sLen       = Int
srcLength forall a. Integral a => a -> a -> a
`div` Int
64
              disalignment :: Int
disalignment  = Int
srcLength forall a. Num a => a -> a -> a
- Int
w64sLen forall a. Num a => a -> a -> a
* Int
64
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word64] where
  type Target [DVS.Vector Word64] = [DVS.Vector Word64]
  cmpEqWord8s :: Word8 -> [Vector Word64] -> Target [Vector Word64]
cmpEqWord8s Word8
w8 [Vector Word64]
vs = forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word64]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word8] where
  type Target [DVS.Vector Word8] = [DVS.Vector Word8]
  cmpEqWord8s :: Word8 -> [Vector Word8] -> Target [Vector Word8]
cmpEqWord8s Word8
w8 [Vector Word8]
vs = forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word8]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [BS.ByteString] where
  type Target [BS.ByteString] = [BS.ByteString]
  cmpEqWord8s :: Word8 -> [ByteString] -> Target [ByteString]
cmpEqWord8s Word8
w8 [ByteString]
vs = forall a. ToByteString a => a -> ByteString
BS.toByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. AsVector8 a => a -> Vector Word8
V.asVector8 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s CS.ChunkString where
  type Target CS.ChunkString = CS.ChunkString
  cmpEqWord8s :: Word8 -> ChunkString -> Target ChunkString
cmpEqWord8s Word8
w8 = forall a. ToChunkString a => a -> ChunkString
CS.toChunkString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. CmpEqWord8s a => Word8 -> a -> Target a
cmpEqWord8s Word8
w8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToByteStrings a => a -> [ByteString]
BS.toByteStrings
  {-# INLINE cmpEqWord8s #-}

class CmpEqWord8sPara a where
  type CmpEqWord8sParaTarget a
  cmpEqWord8sPara :: DVS.Vector Word8 -> a -> CmpEqWord8sParaTarget a

instance CmpEqWord8sPara (DVS.Vector Word64) where
  type CmpEqWord8sParaTarget (DVS.Vector Word64) = DV.Vector (DVS.Vector Word64)
  cmpEqWord8sPara :: Vector Word8
-> Vector Word64 -> CmpEqWord8sParaTarget (Vector Word64)
cmpEqWord8sPara Vector Word8
w8s Vector Word64
v = case forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word64
v :: DVS.Vector Word8 of
    Vector Word8
u -> case forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
u of
      (ForeignPtr Word8
srcFptr, Int
srcOffset, Int
srcLength) -> if Int
disalignment forall a. Eq a => a -> a -> Bool
== Int
0
        then forall a. IO a -> a
F.unsafeLocalState forall a b. (a -> b) -> a -> b
$ do
          ForeignPtr Word64
tgtFptr <- forall a. Int -> IO (ForeignPtr a)
F.mallocForeignPtrBytes (Int
srcLength forall a. Num a => a -> a -> a
* forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s)
          forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
srcFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> do
            forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word64
tgtFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word64
tgtPtr -> do
              let Vector (Ptr Word8)
tgtsPtrsV :: DVS.Vector (F.Ptr Word8) = forall a. Storable a => Int -> (Vector a -> a) -> Vector a
DVS.constructN (forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s) forall a b. (a -> b) -> a -> b
$ \Vector (Ptr Word8)
t ->
                    Ptr Word64
tgtPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` (forall a. Storable a => Vector a -> Int
DVS.length Vector (Ptr Word8)
t forall a. Num a => a -> a -> a
* forall a. Storable a => Vector a -> Int
DVS.length Vector Word64
v)
              let (ForeignPtr Word8
w8sFptr, Int
_, Int
w8sLen) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector Word8
w8s
              let (ForeignPtr (Ptr Word8)
tgtsPtrsFptr, Int
_, Int
_) = forall a. Vector a -> (ForeignPtr a, Int, Int)
DVS.unsafeToForeignPtr Vector (Ptr Word8)
tgtsPtrsV
              forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
w8sFptr forall a b. (a -> b) -> a -> b
$ \Ptr Word8
w8sPtr -> do
                forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr (Ptr Word8)
tgtsPtrsFptr forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr Word8)
tgtsPtrsPtr -> do
                  Ptr UInt8 -> Size -> Ptr (Ptr UInt8) -> Size -> Ptr UInt8 -> IO ()
F.avx2Cmpeq8Para (forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
w8sPtr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w8sLen) (forall a b. Ptr a -> Ptr b
F.castPtr Ptr (Ptr Word8)
tgtsPtrsPtr) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
w64sLen) (Ptr Word8
srcPtr forall a b. Ptr a -> Int -> Ptr b
`F.plusPtr` Int
srcOffset)

              let tgtV :: Vector Word64
tgtV = forall a. Storable a => ForeignPtr a -> Int -> Int -> Vector a
DVS.unsafeFromForeignPtr ForeignPtr Word64
tgtFptr Int
0 (Int
w64sLen forall a. Num a => a -> a -> a
* forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s)
              forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Int -> (Vector a -> a) -> Vector a
DV.constructN (forall a. Storable a => Vector a -> Int
DVS.length Vector Word8
w8s) forall a b. (a -> b) -> a -> b
$ \Vector (Vector Word64)
t -> forall a. Storable a => Int -> Vector a -> Vector a
DVS.take Int
w64sLen (forall a. Storable a => Int -> Vector a -> Vector a
DVS.drop (forall a. Vector a -> Int
DV.length Vector (Vector Word64)
t forall a. Num a => a -> a -> a
* Int
w64sLen) Vector Word64
tgtV)

        else forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unaligned byte string: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
disalignment
        where w64sLen :: Int
w64sLen       = Int
srcLength forall a. Integral a => a -> a -> a
`div` Int
64
              disalignment :: Int
disalignment  = Int
srcLength forall a. Num a => a -> a -> a
- Int
w64sLen forall a. Num a => a -> a -> a
* Int
64
  {-# INLINE cmpEqWord8sPara #-}