{-# LANGUAGE FlexibleInstances #-}

module HaskellWorks.Data.Simd.Comparison.Stock
  ( CmpEqWord8s(..)
  ) where

import Data.Word
import HaskellWorks.Data.AtIndex
import HaskellWorks.Data.Bits.BitWise
import HaskellWorks.Data.Simd.Internal.Bits
import HaskellWorks.Data.Simd.Internal.Broadword

import qualified Data.Vector.Storable as DVS

class CmpEqWord8s a where
  cmpEqWord8s :: Word8 -> a -> a

instance CmpEqWord8s (DVS.Vector Word8) where
  cmpEqWord8s w8 v = DVS.unsafeCast (cmpEqWord8s w8 (DVS.unsafeCast v :: DVS.Vector Word64))

instance CmpEqWord8s (DVS.Vector Word64) where
  cmpEqWord8s w8 v = DVS.constructN ((DVS.length v + 7) `div` 8) go
    where iw = fillWord64 w8
          go :: DVS.Vector Word64 -> Word64
          go u = let ui = end u in
            if ui * 8 + 8 < end v
              then  let vi  = ui * 8
                        w0  = testWord8s ((v !!! (vi + 0)) .^. iw)
                        w1  = testWord8s ((v !!! (vi + 1)) .^. iw)
                        w2  = testWord8s ((v !!! (vi + 2)) .^. iw)
                        w3  = testWord8s ((v !!! (vi + 3)) .^. iw)
                        w4  = testWord8s ((v !!! (vi + 4)) .^. iw)
                        w5  = testWord8s ((v !!! (vi + 5)) .^. iw)
                        w6  = testWord8s ((v !!! (vi + 6)) .^. iw)
                        w7  = testWord8s ((v !!! (vi + 7)) .^. iw)
                        w   = (w7 .<. 56) .|.
                              (w6 .<. 48) .|.
                              (w5 .<. 40) .|.
                              (w4 .<. 32) .|.
                              (w3 .<. 24) .|.
                              (w2 .<. 16) .|.
                              (w1 .<.  8) .|.
                                w0
                    in comp w
              else  let vi  = ui * 8
                        w0  = testWord8s (atIndexOr 0 v (vi + 0) .^. iw)
                        w1  = testWord8s (atIndexOr 0 v (vi + 1) .^. iw)
                        w2  = testWord8s (atIndexOr 0 v (vi + 2) .^. iw)
                        w3  = testWord8s (atIndexOr 0 v (vi + 3) .^. iw)
                        w4  = testWord8s (atIndexOr 0 v (vi + 4) .^. iw)
                        w5  = testWord8s (atIndexOr 0 v (vi + 5) .^. iw)
                        w6  = testWord8s (atIndexOr 0 v (vi + 6) .^. iw)
                        w7  = testWord8s (atIndexOr 0 v (vi + 7) .^. iw)
                        w   = (w7 .<. 56) .|.
                              (w6 .<. 48) .|.
                              (w5 .<. 40) .|.
                              (w4 .<. 32) .|.
                              (w3 .<. 24) .|.
                              (w2 .<. 16) .|.
                              (w1 .<.  8) .|.
                              w0
                    in comp w
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word64] where
  cmpEqWord8s w8 vs = cmpEqWord8s w8 <$> vs
  {-# INLINE cmpEqWord8s #-}

instance CmpEqWord8s [DVS.Vector Word8] where
  cmpEqWord8s w8 vs = cmpEqWord8s w8 <$> vs
  {-# INLINE cmpEqWord8s #-}