{-# language BangPatterns #-}
{-# language BlockArguments #-}
{-# language DuplicateRecordFields #-}
{-# language MagicHash #-}
{-# language NamedFieldPuns #-}
{-# language RankNTypes #-}
{-# language TupleSections #-}
{-# language TypeApplications #-}
{-# language UnboxedTuples #-}

-- This is broken out into a separate module to make it easier
-- to dump core and investigate performance issues.
module Data.Bytes.Search
  ( findIndices
  , replace
  , isInfixOf
  ) where

import Prelude hiding (length,takeWhile,dropWhile,null,foldl,foldr,elem,replicate,any,all,readFile,map)

import Control.Monad.ST.Run (runByteArrayST,runPrimArrayST)
import Data.Bits((.&.),(.|.),shiftL,finiteBitSize)
import Data.Bytes.Pure (length,unsafeIndex,unsafeHead)
import Data.Bytes.Types (Bytes(Bytes,array,offset))
import Data.Primitive (ByteArray,PrimArray)
import GHC.Exts (Int(I#))
import GHC.Word (Word32)

import qualified Data.Bytes.Byte as Byte
import qualified Data.Bytes.Pure as Pure
import qualified Data.Bytes.Types as Types
import qualified Data.Primitive as PM

-- Implementation Notes
-- =====================
-- For karp rabin, there are some easy performance improvements
-- left on the table. The main optimization that has been done is making
-- sure that there is no unnecessary boxing of Int, Word32, or Bytes
-- going on. Here are some other things that have not been done:
--
-- * The hash is currently a Word32. It would be better to use either
--   Word or Word64 for this. We would need for hashKey to be different.
-- * In several places, we track an index into a Bytes. This index gets
--   repeatedly added to the base offset as we loop over the bytes. We
--   could instead track the true offset instead of repeatedly
--   recalculating it.

-- | Replace every non-overlapping occurrence of @needle@ in
-- @haystack@ with @replacement@.
replace ::
     Bytes -- ^ needle, must not be empty
  -> Bytes -- ^ replacement
  -> Bytes -- ^ haystack
  -> Bytes
{-# noinline replace #-}
-- Implementation note: there is a lot of room to improve the performance
-- of this function.
replace :: Bytes -> Bytes -> Bytes -> Bytes
replace !Bytes
needle !Bytes
replacement !haystack :: Bytes
haystack@Bytes{$sel:array:Bytes :: Bytes -> ByteArray
array=ByteArray
haystackArray,$sel:offset:Bytes :: Bytes -> Int
offset=Int
haystackIndex,$sel:length:Bytes :: Bytes -> Int
length=Int
haystackLength}
  | Bytes -> Int
Pure.length Bytes
needle forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. [Char] -> a
errorWithoutStackTrace [Char]
"Data.Bytes.replace: needle of length zero"
  | Bytes -> Int
Pure.length Bytes
haystack forall a. Eq a => a -> a -> Bool
== Int
0 = Bytes
Pure.empty
  | Bytes -> Int
Pure.length Bytes
needle forall a. Eq a => a -> a -> Bool
== Int
1, Bytes -> Int
Pure.length Bytes
replacement forall a. Eq a => a -> a -> Bool
== Int
1 =
      let !needle0 :: Word8
needle0 = Bytes -> Int -> Word8
unsafeIndex Bytes
needle Int
0
          !replacement0 :: Word8
replacement0 = Bytes -> Int -> Word8
unsafeIndex Bytes
replacement Int
0
       in (Word8 -> Word8) -> Bytes -> Bytes
Pure.map (\Word8
w -> if Word8
w forall a. Eq a => a -> a -> Bool
== Word8
needle0 then Word8
replacement0 else Word8
w) Bytes
haystack
  | Bool
otherwise =
      let !hp :: Word32
hp = Bytes -> Word32
rollingHash Bytes
needle
          !ixs :: PrimArray Int
ixs = Int -> Word32 -> Bytes -> ByteArray -> Int -> Int -> PrimArray Int
findIndicesKarpRabin Int
0 Word32
hp Bytes
needle ByteArray
haystackArray Int
haystackIndex Int
haystackLength
       in ByteArray -> Bytes
Pure.fromByteArray (PrimArray Int
-> Bytes -> Int -> ByteArray -> Int -> Int -> ByteArray
replaceIndices PrimArray Int
ixs Bytes
replacement (Bytes -> Int
Pure.length Bytes
needle) ByteArray
haystackArray Int
haystackIndex Int
haystackLength)


-- This is an internal function because it deals explicitly with
-- an offset into a byte array.
--
-- Example:
-- * haystack len: 39
-- * ixs: 7, 19, 33
-- * patLen: 5
-- * replacment: foo (len 3)
-- We want to perform these copies:
-- * src[0,7] -> dst[0,7]
-- * foo -> dst[7,3]
-- * src[12,7] -> dst[10,7]
-- * foo -> dst[17,3]
-- * src[24,9] -> dst[20,9]
-- * foo -> dst[29,3]
-- * src[38,1] -> dst[32,1]
replaceIndices :: PrimArray Int -> Bytes -> Int -> ByteArray -> Int -> Int -> ByteArray
replaceIndices :: PrimArray Int
-> Bytes -> Int -> ByteArray -> Int -> Int -> ByteArray
replaceIndices !PrimArray Int
ixs !Bytes
replacement !Int
patLen !ByteArray
haystack !Int
ix0 !Int
len0 = (forall s. ST s ByteArray) -> ByteArray
runByteArrayST forall a b. (a -> b) -> a -> b
$ do
  let !ixsLen :: Int
ixsLen = forall a. Prim a => PrimArray a -> Int
PM.sizeofPrimArray PrimArray Int
ixs
  let !delta :: Int
delta = Bytes -> Int
Pure.length Bytes
replacement forall a. Num a => a -> a -> a
- Int
patLen
  MutableByteArray (PrimState (ST s))
dst <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
len0 forall a. Num a => a -> a -> a
+ Int
ixsLen forall a. Num a => a -> a -> a
* Int
delta)
  let applyReplacement :: Int -> Int -> ST s ByteArray
applyReplacement !Int
ixIx !Int
prevSrcIx = if Int
ixIx forall a. Ord a => a -> a -> Bool
< Int
ixsLen
        then do
          let !srcMatchIx :: Int
srcMatchIx = forall a. Prim a => PrimArray a -> Int -> a
PM.indexPrimArray PrimArray Int
ixs Int
ixIx
          let !offset :: Int
offset = Int
ixIx forall a. Num a => a -> a -> a
* Int
delta
          let !dstIx :: Int
dstIx = Int
srcMatchIx forall a. Num a => a -> a -> a
+ Int
offset forall a. Num a => a -> a -> a
- Int
ix0
          forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> Bytes -> m ()
Pure.unsafeCopy MutableByteArray (PrimState (ST s))
dst (Int
prevSrcIx forall a. Num a => a -> a -> a
+ Int
offset forall a. Num a => a -> a -> a
- Int
ix0)
            Bytes{$sel:array:Bytes :: ByteArray
array=ByteArray
haystack,$sel:offset:Bytes :: Int
offset=Int
prevSrcIx,$sel:length:Bytes :: Int
length=Int
srcMatchIx forall a. Num a => a -> a -> a
- Int
prevSrcIx}
          forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> Bytes -> m ()
Pure.unsafeCopy MutableByteArray (PrimState (ST s))
dst Int
dstIx Bytes
replacement
          Int -> Int -> ST s ByteArray
applyReplacement (Int
ixIx forall a. Num a => a -> a -> a
+ Int
1) (Int
srcMatchIx forall a. Num a => a -> a -> a
+ Int
patLen)
        else do 
          let !offset :: Int
offset = Int
ixIx forall a. Num a => a -> a -> a
* Int
delta
          forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> Bytes -> m ()
Pure.unsafeCopy MutableByteArray (PrimState (ST s))
dst (Int
prevSrcIx forall a. Num a => a -> a -> a
+ Int
offset forall a. Num a => a -> a -> a
- Int
ix0)
            Bytes{$sel:array:Bytes :: ByteArray
array=ByteArray
haystack,$sel:offset:Bytes :: Int
offset=Int
prevSrcIx,$sel:length:Bytes :: Int
length=(Int
len0 forall a. Num a => a -> a -> a
+ Int
ix0) forall a. Num a => a -> a -> a
- Int
prevSrcIx}
          forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray (PrimState (ST s))
dst
  Int -> Int -> ST s ByteArray
applyReplacement Int
0 Int
ix0

-- | Find locations of non-overlapping instances of @needle@ within @haystack@.
findIndices ::
     Bytes -- ^ needle
  -> Bytes -- ^ haystack
  -> PrimArray Int
findIndices :: Bytes -> Bytes -> PrimArray Int
findIndices Bytes
needle Bytes{ByteArray
array :: ByteArray
$sel:array:Bytes :: Bytes -> ByteArray
array,$sel:offset:Bytes :: Bytes -> Int
offset=Int
off,$sel:length:Bytes :: Bytes -> Int
length=Int
len}
  | Int
needleLen forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. [Char] -> a
errorWithoutStackTrace [Char]
"Data.Bytes.findIndices: needle with length zero"
  | Int
len forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. Monoid a => a
mempty
  | Bool
otherwise = 
      let !hp :: Word32
hp = Bytes -> Word32
rollingHash Bytes
needle
       in Int -> Word32 -> Bytes -> ByteArray -> Int -> Int -> PrimArray Int
findIndicesKarpRabin (forall a. Num a => a -> a
negate Int
off) Word32
hp Bytes
needle ByteArray
array Int
off Int
len
  where
  needleLen :: Int
needleLen = Bytes -> Int
Pure.length Bytes
needle

-- Precondition: Haystack has non-zero length
-- Precondition: Pattern has non-zero length
-- Uses karp rabin to search. 
-- Easy opportunity to improve implementation. Instead of having karpRabin
-- return two slices, we could have it just return a single index.
findIndicesKarpRabin ::
     Int -- Output index modifier. Set to negated initial index to make slicing invisible in results.
  -> Word32 -- Hash to search for (must agree with pattern)
  -> Bytes -- Pattern to search for
  -> ByteArray
  -> Int -- initial index
  -> Int -- length
  -> PrimArray Int
findIndicesKarpRabin :: Int -> Word32 -> Bytes -> ByteArray -> Int -> Int -> PrimArray Int
findIndicesKarpRabin !Int
ixModifier !Word32
hp !Bytes
pat !ByteArray
haystack !Int
ix0 !Int
len0 = forall a. (forall s. ST s (PrimArray a)) -> PrimArray a
runPrimArrayST forall a b. (a -> b) -> a -> b
$ do
  let dstLen :: Int
dstLen = Int
1 forall a. Num a => a -> a -> a
+ forall a. Integral a => a -> a -> a
quot Int
len0 (Bytes -> Int
Pure.length Bytes
pat)
  MutablePrimArray (PrimState (ST s)) Int
dst <- forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
PM.newPrimArray Int
dstLen
  let go :: Int -> Int -> Int -> ST s (PrimArray Int)
go !Int
ix !Int
len !Int
ixIx = case Word32 -> Bytes -> Bytes -> Int
karpRabin Word32
hp Bytes
pat Bytes{$sel:array:Bytes :: ByteArray
array=ByteArray
haystack,$sel:offset:Bytes :: Int
offset=Int
ix,$sel:length:Bytes :: Int
length=Int
len} of
        (-1) -> do
          forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a -> Int -> m ()
PM.shrinkMutablePrimArray MutablePrimArray (PrimState (ST s)) Int
dst Int
ixIx
          forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
PM.unsafeFreezePrimArray MutablePrimArray (PrimState (ST s)) Int
dst
        Int
skipCount -> do
          let !advancement :: Int
advancement = Int
skipCount forall a. Num a => a -> a -> a
- Bytes -> Int
Pure.length Bytes
pat
          let !advancement' :: Int
advancement' = Int
advancement forall a. Num a => a -> a -> a
+ Bytes -> Int
Pure.length Bytes
pat
          forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
PM.writePrimArray MutablePrimArray (PrimState (ST s)) Int
dst Int
ixIx (Int
ix forall a. Num a => a -> a -> a
+ Int
advancement forall a. Num a => a -> a -> a
+ Int
ixModifier)
          let !ix' :: Int
ix' = Int
ix forall a. Num a => a -> a -> a
+ Int
advancement'
          Int -> Int -> Int -> ST s (PrimArray Int)
go Int
ix' (Int
len forall a. Num a => a -> a -> a
- Int
advancement') (Int
ixIx forall a. Num a => a -> a -> a
+ Int
1)
  Int -> Int -> Int -> ST s (PrimArray Int)
go Int
ix0 Int
len0 Int
0

-- Output: Negative one means match not found. Other negative
-- numbers should not occur. Zero may occur. Positive number
-- means the number of bytes to skip to make it past the match.
breakSubstring :: Bytes -- ^ String to search for
               -> Bytes -- ^ String to search in
               -> Int
breakSubstring :: Bytes -> Bytes -> Int
breakSubstring !Bytes
pat !haystack :: Bytes
haystack@(Bytes ByteArray
_ Int
off0 Int
_) =
  case Int
lp of
    Int
0 -> Int
0
    Int
1 -> case Word8 -> Bytes -> Int#
Byte.elemIndexLoop# (Bytes -> Word8
unsafeHead Bytes
pat) Bytes
haystack of
      (Int#
-1#) -> (-Int
1)
      Int#
off -> Int
1 forall a. Num a => a -> a -> a
+ (Int# -> Int
I# Int#
off) forall a. Num a => a -> a -> a
- Int
off0
    Int
_ -> if Int
lp forall a. Num a => a -> a -> a
* Int
8 forall a. Ord a => a -> a -> Bool
<= forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word)
      then Bytes -> Int
shift Bytes
haystack
      else Word32 -> Bytes -> Bytes -> Int
karpRabin (Bytes -> Word32
rollingHash Bytes
pat) Bytes
pat Bytes
haystack
  where
  lp :: Int
lp                = Bytes -> Int
length Bytes
pat
  {-# INLINE shift #-}
  shift :: Bytes -> Int
  shift :: Bytes -> Int
shift !Bytes
src
      | Bytes -> Int
length Bytes
src forall a. Ord a => a -> a -> Bool
< Int
lp = (-Int
1)
      | Bool
otherwise       = Word -> Int -> Int
search (Bytes -> Word
intoWord forall a b. (a -> b) -> a -> b
$ Int -> Bytes -> Bytes
Pure.unsafeTake Int
lp Bytes
src) Int
lp
    where
    intoWord :: Bytes -> Word
    intoWord :: Bytes -> Word
intoWord = forall a. (a -> Word8 -> a) -> a -> Bytes -> a
Pure.foldl' (\Word
w Word8
b -> (Word
w forall a. Bits a => a -> Int -> a
`shiftL` Int
8) forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Word
0
    wp :: Word
wp   = Bytes -> Word
intoWord Bytes
pat
    mask :: Word
mask = (Word
1 forall a. Bits a => a -> Int -> a
`shiftL` (Int
8 forall a. Num a => a -> a -> a
* Int
lp)) forall a. Num a => a -> a -> a
- Word
1
    search :: Word -> Int -> Int
    search :: Word -> Int -> Int
search !Word
w !Int
i
        | Word
w forall a. Eq a => a -> a -> Bool
== Word
wp         = Int
i
        | Bytes -> Int
length Bytes
src forall a. Ord a => a -> a -> Bool
<= Int
i = (-Int
1)
        | Bool
otherwise       = Word -> Int -> Int
search Word
w' (Int
i forall a. Num a => a -> a -> a
+ Int
1)
      where
      b :: Word
b  = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Bytes -> Int -> Word8
Pure.unsafeIndex Bytes
src Int
i)
      w' :: Word
w' = Word
mask forall a. Bits a => a -> a -> a
.&. ((Word
w forall a. Bits a => a -> Int -> a
`shiftL` Int
8) forall a. Bits a => a -> a -> a
.|. Word
b)

-- Only used for karp rabin
rollingHash :: Bytes -> Word32
{-# inline rollingHash #-}
rollingHash :: Bytes -> Word32
rollingHash = forall a. (a -> Word8 -> a) -> a -> Bytes -> a
Pure.foldl' (\Word32
h Word8
b -> Word32
h forall a. Num a => a -> a -> a
* Word32
hashKey forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Word32
0

hashKey :: Word32 
{-# inline hashKey #-}
hashKey :: Word32
hashKey = Word32
2891336453

-- Precondition: Length of bytes is greater than or equal to 1.
-- Precondition: Rolling hash agrees with pattern.
-- Output: Negative one means match not found. Other negative
-- numbers should not occur. Zero should not occur. Positive number
-- means the number of bytes to skip to make it past the match.
karpRabin :: Word32 -> Bytes -> Bytes -> Int
karpRabin :: Word32 -> Bytes -> Bytes -> Int
karpRabin !Word32
hp !Bytes
pat !Bytes
src
    | Bytes -> Int
length Bytes
src forall a. Ord a => a -> a -> Bool
< Int
lp = (-Int
1)
    | Bool
otherwise = Word32 -> Int -> Int
search (Bytes -> Word32
rollingHash forall a b. (a -> b) -> a -> b
$ Int -> Bytes -> Bytes
Pure.unsafeTake Int
lp Bytes
src) Int
lp
  where
  lp :: Int
  !lp :: Int
lp = Bytes -> Int
Pure.length Bytes
pat
  m :: Word32
  !m :: Word32
m = Word32
hashKey forall a b. (Num a, Integral b) => a -> b -> a
^ Int
lp
  get :: Int -> Word32
  get :: Int -> Word32
get !Int
ix = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Bytes -> Int -> Word8
Pure.unsafeIndex Bytes
src Int
ix)
  search :: Word32 -> Int -> Int
search !Word32
hs !Int
i
      | Word32
hp forall a. Eq a => a -> a -> Bool
== Word32
hs Bool -> Bool -> Bool
&& Bytes -> Bytes -> Bool
eqBytesNoShortCut Bytes
pat (Int -> Bytes -> Bytes
Pure.unsafeTake Int
lp (Int -> Bytes -> Bytes
Pure.unsafeDrop (Int
i forall a. Num a => a -> a -> a
- Int
lp) Bytes
src)) = Int
i
      | Bytes -> Int
length Bytes
src forall a. Ord a => a -> a -> Bool
<= Int
i                    = (-Int
1)
      | Bool
otherwise                          = Word32 -> Int -> Int
search Word32
hs' (Int
i forall a. Num a => a -> a -> a
+ Int
1)
    where
    hs' :: Word32
hs' = Word32
hs forall a. Num a => a -> a -> a
* Word32
hashKey forall a. Num a => a -> a -> a
+
          Int -> Word32
get Int
i forall a. Num a => a -> a -> a
-
          Word32
m forall a. Num a => a -> a -> a
* Int -> Word32
get (Int
i forall a. Num a => a -> a -> a
- Int
lp)

-- | Is the first argument an infix of the second argument?
-- 
-- Uses the Rabin-Karp algorithm: expected time @O(n+m)@, worst-case @O(nm)@.
isInfixOf :: Bytes -- ^ String to search for
          -> Bytes -- ^ String to search in
          -> Bool
isInfixOf :: Bytes -> Bytes -> Bool
isInfixOf Bytes
p Bytes
s = Bytes -> Bool
Pure.null Bytes
p Bool -> Bool -> Bool
|| Bytes -> Bytes -> Int
breakSubstring Bytes
p Bytes
s forall a. Ord a => a -> a -> Bool
>= Int
0


-- Precondition: both arguments have the same length
-- Skips the pointer equality check and the length check.
eqBytesNoShortCut :: Bytes -> Bytes -> Bool
{-# inline eqBytesNoShortCut #-}
eqBytesNoShortCut :: Bytes -> Bytes -> Bool
eqBytesNoShortCut (Bytes ByteArray
arr1 Int
off1 Int
len1) (Bytes ByteArray
arr2 Int
off2 Int
_) =
  ByteArray -> Int -> ByteArray -> Int -> Int -> Ordering
PM.compareByteArrays ByteArray
arr1 Int
off1 ByteArray
arr2 Int
off2 Int
len1 forall a. Eq a => a -> a -> Bool
== Ordering
EQ