{-# LANGUAGE FlexibleInstances #-}

module HaskellWorks.Data.Simd.Comparison.Stock
  ( CmpEqWord8s(..)
  ) where

import Data.Word
import HaskellWorks.Data.AtIndex
import HaskellWorks.Data.Bits.BitWise
import HaskellWorks.Data.Simd.Internal.Bits
import HaskellWorks.Data.Simd.Internal.Broadword

import qualified Data.ByteString                    as BS
import qualified Data.Vector.Storable               as DVS
import qualified HaskellWorks.Data.ByteString       as BS
import qualified HaskellWorks.Data.Simd.ChunkString as CS
import qualified HaskellWorks.Data.Vector.AsVector8 as V

class CmpEqWord8s a where
  cmpEqWord8s :: Word8 -> a -> a

instance CmpEqWord8s (DVS.Vector Word8) where
  cmpEqWord8s :: Word8 -> Vector Word8 -> Vector Word8
cmpEqWord8s Word8
w8 Vector Word8
v = Vector Word64 -> Vector Word8
forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast (Word8 -> Vector Word64 -> Vector Word64
forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 (Vector Word8 -> Vector Word64
forall a b. (Storable a, Storable b) => Vector a -> Vector b
DVS.unsafeCast Vector Word8
v :: DVS.Vector Word64))

instance CmpEqWord8s (DVS.Vector Word64) where
  cmpEqWord8s :: Word8 -> Vector Word64 -> Vector Word64
cmpEqWord8s Word8
w8 Vector Word64
v = Int -> (Vector Word64 -> Word64) -> Vector Word64
forall a. Storable a => Int -> (Vector a -> a) -> Vector a
DVS.constructN ((Vector Word64 -> Int
forall a. Storable a => Vector a -> Int
DVS.length Vector Word64
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8) Vector Word64 -> Word64
go
    where iw :: Word64
iw = Word8 -> Word64
forall a. FillWord64 a => a -> Word64
fillWord64 Word8
w8
          go :: DVS.Vector Word64 -> Word64
          go :: Vector Word64 -> Word64
go Vector Word64
u = let ui :: Position
ui = Vector Word64 -> Position
forall v. Length v => v -> Position
end Vector Word64
u in
            if Position
ui Position -> Position -> Position
forall a. Num a => a -> a -> a
* Position
8 Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
8 Position -> Position -> Bool
forall a. Ord a => a -> a -> Bool
< Vector Word64 -> Position
forall v. Length v => v -> Position
end Vector Word64
v
              then  let vi :: Position
vi  = Position
ui Position -> Position -> Position
forall a. Num a => a -> a -> a
* Position
8
                        w0 :: Word64
w0  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
0)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w1 :: Word64
w1  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
1)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w2 :: Word64
w2  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
2)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w3 :: Word64
w3  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
3)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w4 :: Word64
w4  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
4)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w5 :: Word64
w5  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
5)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w6 :: Word64
w6  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
6)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w7 :: Word64
w7  = Word64 -> Word64
testWord8s ((Vector Word64
v Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => v -> Position -> Elem v
!!! (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
7)) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w :: Word64
w   = (Word64
w7 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
56) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w6 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
48) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w5 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
40) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w4 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
32) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w3 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
24) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w2 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
16) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w1 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<.  Word64
8) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                                Word64
w0
                    in Word64 -> Word64
forall a. BitWise a => a -> a
comp Word64
w
              else  let vi :: Position
vi  = Position
ui Position -> Position -> Position
forall a. Num a => a -> a -> a
* Position
8
                        w0 :: Word64
w0  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
0) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w1 :: Word64
w1  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
1) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w2 :: Word64
w2  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
2) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w3 :: Word64
w3  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
3) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w4 :: Word64
w4  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
4) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w5 :: Word64
w5  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
5) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w6 :: Word64
w6  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
6) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w7 :: Word64
w7  = Word64 -> Word64
testWord8s (Elem (Vector Word64)
-> Vector Word64 -> Position -> Elem (Vector Word64)
forall v. AtIndex v => Elem v -> v -> Position -> Elem v
atIndexOr Elem (Vector Word64)
0 Vector Word64
v (Position
vi Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position
7) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.^. Word64
iw)
                        w :: Word64
w   = (Word64
w7 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
56) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w6 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
48) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w5 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
40) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w4 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
32) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w3 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
24) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w2 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<. Word64
16) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              (Word64
w1 Word64 -> Word64 -> Word64
forall a. Shift a => a -> Word64 -> a
.<.  Word64
8) Word64 -> Word64 -> Word64
forall a. BitWise a => a -> a -> a
.|.
                              Word64
w0
                    in Word64 -> Word64
forall a. BitWise a => a -> a
comp Word64
w
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word64] where
  cmpEqWord8s :: Word8 -> [Vector Word64] -> [Vector Word64]
cmpEqWord8s Word8
w8 [Vector Word64]
vs = Word8 -> Vector Word64 -> Vector Word64
forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 (Vector Word64 -> Vector Word64)
-> [Vector Word64] -> [Vector Word64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word64]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word8] where
  cmpEqWord8s :: Word8 -> [Vector Word8] -> [Vector Word8]
cmpEqWord8s Word8
w8 [Vector Word8]
vs = Word8 -> Vector Word8 -> Vector Word8
forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 (Vector Word8 -> Vector Word8) -> [Vector Word8] -> [Vector Word8]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Vector Word8]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [BS.ByteString] where
  cmpEqWord8s :: Word8 -> [ByteString] -> [ByteString]
cmpEqWord8s Word8
w8 [ByteString]
vs = Vector Word8 -> ByteString
forall a. ToByteString a => a -> ByteString
BS.toByteString (Vector Word8 -> ByteString)
-> (ByteString -> Vector Word8) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Vector Word8 -> Vector Word8
forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 (Vector Word8 -> Vector Word8)
-> (ByteString -> Vector Word8) -> ByteString -> Vector Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Vector Word8
forall a. AsVector8 a => a -> Vector Word8
V.asVector8 (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString]
vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s CS.ChunkString where
  cmpEqWord8s :: Word8 -> ChunkString -> ChunkString
cmpEqWord8s Word8
w8 = [ByteString] -> ChunkString
forall a. ToChunkString a => a -> ChunkString
CS.toChunkString ([ByteString] -> ChunkString)
-> (ChunkString -> [ByteString]) -> ChunkString -> ChunkString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> [ByteString] -> [ByteString]
forall a. CmpEqWord8s a => Word8 -> a -> a
cmpEqWord8s Word8
w8 ([ByteString] -> [ByteString])
-> (ChunkString -> [ByteString]) -> ChunkString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ChunkString -> [ByteString]
forall a. ToByteStrings a => a -> [ByteString]
BS.toByteStrings
  {-# INLINE cmpEqWord8s #-}