-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2019 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An efficient implementation of the Boyer-Moore string search algorithm.
-- http://www-igm.univ-mlv.fr/~lecroq/string/node14.html#SECTION00140
-- https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string-search_algorithm
--
-- This is case insensitive variant of the algorithm which, unlike the case
-- sensitive variant, has to be aware of the unicode code points that the bytes
-- represent.
--
module Data.Text.BoyerMooreCI.Automaton
    ( Automaton
    , CaseSensitivity (..)
    , CodeUnitIndex (..)
    , Next (..)
    , buildAutomaton
    , patternLength
    , patternText
    , runText

      -- Exposed for testing
    , minimumSkipForCodePoint
    ) where

import Control.DeepSeq (NFData)
import Control.Monad.ST (runST)
import Data.Hashable (Hashable (..))
import Data.Text.Internal (Text (..))
import GHC.Generics (Generic)

#if defined(HAS_AESON)
import qualified Data.Aeson as AE
#endif

import Data.Text.CaseSensitivity (CaseSensitivity (..))
import Data.Text.Utf8 (BackwardsIter (..), CodePoint, CodeUnitIndex (..))
import Data.TypedByteArray (Prim, TypedByteArray)

import qualified Data.Char as Char
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Text as Text
import qualified Data.Text.Utf8 as Utf8
import qualified Data.TypedByteArray as TBA

data Next a
  = Done !a
  | Step !a


-- | A Boyer-Moore automaton is based on lookup-tables that allow skipping through the haystack.
-- This allows for sub-linear matching in some cases, as we do not have to look at every input
-- character.
--
-- NOTE: Unlike the AcMachine, a Boyer-Moore automaton only returns non-overlapping matches.
-- This means that a Boyer-Moore automaton is not a 100% drop-in replacement for Aho-Corasick.
--
-- Returning overlapping matches would degrade the performance to /O(nm)/ in pathological cases like
-- finding @aaaa@ in @aaaaa....aaaaaa@ as for each match it would scan back the whole /m/ characters
-- of the pattern.
data Automaton = Automaton
  { Automaton -> TypedByteArray CodePoint
automatonPattern :: !(TypedByteArray CodePoint)
  , Automaton -> Int
automatonPatternHash :: !Int
  , Automaton -> SuffixTable
automatonSuffixTable :: !SuffixTable
  , Automaton -> BadCharLookup
automatonBadCharLookup :: !BadCharLookup
  , Automaton -> CodeUnitIndex
automatonMinPatternSkip :: !CodeUnitIndex
  }
  deriving stock ((forall x. Automaton -> Rep Automaton x)
-> (forall x. Rep Automaton x -> Automaton) -> Generic Automaton
forall x. Rep Automaton x -> Automaton
forall x. Automaton -> Rep Automaton x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Automaton -> Rep Automaton x
from :: forall x. Automaton -> Rep Automaton x
$cto :: forall x. Rep Automaton x -> Automaton
to :: forall x. Rep Automaton x -> Automaton
Generic, Int -> Automaton -> ShowS
[Automaton] -> ShowS
Automaton -> String
(Int -> Automaton -> ShowS)
-> (Automaton -> String)
-> ([Automaton] -> ShowS)
-> Show Automaton
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Automaton -> ShowS
showsPrec :: Int -> Automaton -> ShowS
$cshow :: Automaton -> String
show :: Automaton -> String
$cshowList :: [Automaton] -> ShowS
showList :: [Automaton] -> ShowS
Show)
  deriving anyclass (Automaton -> ()
(Automaton -> ()) -> NFData Automaton
forall a. (a -> ()) -> NFData a
$crnf :: Automaton -> ()
rnf :: Automaton -> ()
NFData)

instance Hashable Automaton where
  hashWithSalt :: Int -> Automaton -> Int
hashWithSalt Int
salt = Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (Int -> Int) -> (Automaton -> Int) -> Automaton -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Int
automatonPatternHash

instance Eq Automaton where
  Automaton
x == :: Automaton -> Automaton -> Bool
== Automaton
y = Automaton -> TypedByteArray CodePoint
automatonPattern Automaton
x TypedByteArray CodePoint -> TypedByteArray CodePoint -> Bool
forall a. Eq a => a -> a -> Bool
== Automaton -> TypedByteArray CodePoint
automatonPattern Automaton
y

#if defined(HAS_AESON)
instance AE.FromJSON Automaton where
  parseJSON :: Value -> Parser Automaton
parseJSON Value
v = Text -> Automaton
buildAutomaton (Text -> Automaton) -> Parser Text -> Parser Automaton
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser Text
forall a. FromJSON a => Value -> Parser a
AE.parseJSON Value
v

instance AE.ToJSON Automaton where
  toJSON :: Automaton -> Value
toJSON = Text -> Value
forall a. ToJSON a => a -> Value
AE.toJSON (Text -> Value) -> (Automaton -> Text) -> Automaton -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Text
patternText
#endif

buildAutomaton :: Text -> Automaton
buildAutomaton :: Text -> Automaton
buildAutomaton Text
pattern_ =
  Automaton
    { automatonPattern :: TypedByteArray CodePoint
automatonPattern = TypedByteArray CodePoint
patternVec
    , automatonPatternHash :: Int
automatonPatternHash = Text -> Int
forall a. Hashable a => a -> Int
hash Text
pattern_
    , automatonSuffixTable :: SuffixTable
automatonSuffixTable = TypedByteArray CodePoint -> SuffixTable
buildSuffixTable TypedByteArray CodePoint
patternVec
    , automatonBadCharLookup :: BadCharLookup
automatonBadCharLookup = TypedByteArray CodePoint -> BadCharLookup
buildBadCharLookup TypedByteArray CodePoint
patternVec
    , automatonMinPatternSkip :: CodeUnitIndex
automatonMinPatternSkip = TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector TypedByteArray CodePoint
patternVec
    }
  where
    patternVec :: TypedByteArray CodePoint
patternVec = String -> TypedByteArray CodePoint
forall a. Prim a => [a] -> TypedByteArray a
TBA.fromList (Text -> String
Text.unpack Text
pattern_)

-- | Finds all matches in the text, calling the match callback with the first and last byte index of
-- each match of the pattern.
runText  :: forall a
  . a
  -> (a -> CodeUnitIndex -> CodeUnitIndex -> Next a)
  -> Automaton
  -> Text
  -> a
{-# INLINE runText #-}
runText :: forall a.
a
-> (a -> CodeUnitIndex -> CodeUnitIndex -> Next a)
-> Automaton
-> Text
-> a
runText a
seed a -> CodeUnitIndex -> CodeUnitIndex -> Next a
f Automaton
automaton !Text
text
  | TypedByteArray CodePoint -> Bool
forall a. TypedByteArray a -> Bool
TBA.null TypedByteArray CodePoint
pattern_ = a
seed
  | Bool
otherwise = a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
seed CodeUnitIndex
initialHaystackMin (CodeUnitIndex
initialHaystackMin CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
minPatternSkip CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
  where
    Automaton TypedByteArray CodePoint
pattern_ Int
_ SuffixTable
suffixTable BadCharLookup
badCharTable CodeUnitIndex
minPatternSkip = Automaton
automaton

    -- In the pattern we always count codepoints,
    -- in the haystack we always count code units

    -- Highest index that we're allowed to use in the text
    haystackMax :: CodeUnitIndex
haystackMax = case Text
text of Text Array
_ Int
offset Int
len -> Int -> CodeUnitIndex
CodeUnitIndex (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

    -- How far we can look back in the text data is first limited by the text
    -- offset, and later by what we matched before.
    initialHaystackMin :: CodeUnitIndex
initialHaystackMin = case Text
text of Text Array
_ Int
offset Int
_ -> Int -> CodeUnitIndex
CodeUnitIndex Int
offset

    -- This is our _outer_ loop, called when the pattern is moved
    alignPattern
      :: a
      -> CodeUnitIndex  -- Don't read before this point in the haystack
      -> CodeUnitIndex  -- End of pattern is aligned at this point in the haystack
      -> a
    {-# INLINE alignPattern #-}
    alignPattern :: a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern !a
result !CodeUnitIndex
haystackMin !CodeUnitIndex
alignmentEnd
      | CodeUnitIndex
alignmentEnd CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
> CodeUnitIndex
haystackMax = a
result
      | Bool
otherwise =
          let
            !iter :: BackwardsIter
iter = Array -> CodeUnitIndex -> BackwardsIter
Utf8.unsafeIndexAnywhereInCodePoint' (case Text
text of Text Array
d Int
_ Int
_ -> Array
d) CodeUnitIndex
alignmentEnd
            !patternIndex :: Int
patternIndex = TypedByteArray CodePoint -> Int
forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_ Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
            -- End of char may be somewhere different than where we started looking
            !alignmentEnd' :: CodeUnitIndex
alignmentEnd' = BackwardsIter -> CodeUnitIndex
backwardsIterEndOfChar BackwardsIter
iter
          in
            a -> CodeUnitIndex -> CodeUnitIndex -> BackwardsIter -> Int -> a
matchLoop a
result CodeUnitIndex
haystackMin CodeUnitIndex
alignmentEnd' BackwardsIter
iter Int
patternIndex

    -- The _inner_ loop, called for every pattern character back to front within a pattern alignment.
    matchLoop
      :: a
      -> CodeUnitIndex  -- haystackMin, don't read before this point in the haystack
      -> CodeUnitIndex  -- (adjusted) alignmentEnd, end of pattern is aligned at this point in the haystack
      -> BackwardsIter
      -> Int            -- index in the pattern
      -> a
    matchLoop :: a -> CodeUnitIndex -> CodeUnitIndex -> BackwardsIter -> Int -> a
matchLoop !a
result !CodeUnitIndex
haystackMin !CodeUnitIndex
alignmentEnd !BackwardsIter
iter !Int
patternIndex =
      let
        !haystackCodePointLower :: CodePoint
haystackCodePointLower = CodePoint -> CodePoint
Utf8.lowerCodePoint (BackwardsIter -> CodePoint
backwardsIterChar BackwardsIter
iter)
      in
        case CodePoint
haystackCodePointLower CodePoint -> CodePoint -> Bool
forall a. Eq a => a -> a -> Bool
== TypedByteArray CodePoint -> Int -> CodePoint
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ Int
patternIndex of

          Bool
True | Int
patternIndex Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
            -- We found a complete match (all pattern characters matched)
            let !from :: CodeUnitIndex
from = BackwardsIter -> CodeUnitIndex
backwardsIterNext BackwardsIter
iter CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
initialHaystackMin
                !to :: CodeUnitIndex
to = CodeUnitIndex
alignmentEnd CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
initialHaystackMin
            in
              case a -> CodeUnitIndex -> CodeUnitIndex -> Next a
f a
result CodeUnitIndex
from CodeUnitIndex
to of
                Done a
final -> a
final
                Step a
intermediate ->
                  let haystackMin' :: CodeUnitIndex
haystackMin' = CodeUnitIndex
alignmentEnd CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1  -- Disallow overlapping matches
                      alignmentEnd' :: CodeUnitIndex
alignmentEnd' = CodeUnitIndex
alignmentEnd CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
minPatternSkip
                  in a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
intermediate CodeUnitIndex
haystackMin' CodeUnitIndex
alignmentEnd'

          -- The pattern may be aligned in such a way that the start is before the start of the
          -- haystack. This _only_ happens when ⱥ and ⱦ characters occur (due to how minPatternSkip
          -- is calculated).
          Bool
True | BackwardsIter -> CodeUnitIndex
backwardsIterNext BackwardsIter
iter CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
haystackMin ->
            a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
result CodeUnitIndex
haystackMin (CodeUnitIndex
alignmentEnd CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)

          -- We continue by comparing the next character
          Bool
True ->
            let
              next :: CodeUnitIndex
next = BackwardsIter -> CodeUnitIndex
backwardsIterNext BackwardsIter
iter
              !iter' :: BackwardsIter
iter' = Array -> CodeUnitIndex -> BackwardsIter
Utf8.unsafeIndexEndOfCodePoint' (case Text
text of Text Array
d Int
_ Int
_ -> Array
d) CodeUnitIndex
next
            in
              a -> CodeUnitIndex -> CodeUnitIndex -> BackwardsIter -> Int -> a
matchLoop a
result CodeUnitIndex
haystackMin CodeUnitIndex
alignmentEnd BackwardsIter
iter' (Int
patternIndex Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

          -- Character did _not_ match at current position. Check how far the pattern has to move.
          Bool
False ->
            let
              -- The bad character table tells us how far we can advance to the right so that the
              -- character at the current position in the input string, where matching failed,
              -- is lined up with it's rightmost occurrence in the pattern.
              !fromBadChar :: CodeUnitIndex
fromBadChar =
                BackwardsIter -> CodeUnitIndex
backwardsIterEndOfChar BackwardsIter
iter CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ BadCharLookup -> CodePoint -> CodeUnitIndex
badCharLookup BadCharLookup
badCharTable CodePoint
haystackCodePointLower

              -- This is always at least 1, ensuring that we make progress
              -- Suffixlookup tells us how far we can move the pattern
              !fromSuffixLookup :: CodeUnitIndex
fromSuffixLookup =
                CodeUnitIndex
alignmentEnd CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ SuffixTable -> Int -> CodeUnitIndex
suffixLookup SuffixTable
suffixTable Int
patternIndex

              !alignmentEnd' :: CodeUnitIndex
alignmentEnd' = CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Ord a => a -> a -> a
max CodeUnitIndex
fromBadChar CodeUnitIndex
fromSuffixLookup

            in
              -- Minimum stays the same
              a -> CodeUnitIndex -> CodeUnitIndex -> a
alignPattern a
result CodeUnitIndex
haystackMin CodeUnitIndex
alignmentEnd'

-- | Length of the matched pattern measured in UTF-8 code units (bytes).
patternLength :: Automaton -> CodeUnitIndex
patternLength :: Automaton -> CodeUnitIndex
patternLength = Text -> CodeUnitIndex
Utf8.lengthUtf8 (Text -> CodeUnitIndex)
-> (Automaton -> Text) -> Automaton -> CodeUnitIndex
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Text
patternText

-- | Return the pattern that was used to construct the automaton, O(n).
patternText :: Automaton -> Text
patternText :: Automaton -> Text
patternText = String -> Text
Text.pack (String -> Text) -> (Automaton -> String) -> Automaton -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedByteArray CodePoint -> String
forall a. Prim a => TypedByteArray a -> [a]
TBA.toList (TypedByteArray CodePoint -> String)
-> (Automaton -> TypedByteArray CodePoint) -> Automaton -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> TypedByteArray CodePoint
automatonPattern


-- | Number of bytes that we can skip in the haystack if we want to skip no more
-- than 1 pattern codepoint.
--
-- It must always be a low (safe) estimate, otherwise the algorithm can miss
-- matches. It must account for any variation of upper/lower case characters
-- that may occur in the haystack. In most cases, this is the same number of
-- bytes as for the given codepoint
--
--     minimumSkipForCodePoint 'a' == 1
--     minimumSkipForCodePoint 'д' == 2
--     minimumSkipForCodePoint 'ⓟ' == 3
--     minimumSkipForCodePoint '🎄' == 4
--
minimumSkipForCodePoint :: CodePoint -> CodeUnitIndex
minimumSkipForCodePoint :: CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
cp =
  case CodePoint -> Int
Char.ord CodePoint
cp of
    Int
c | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0x80     -> CodeUnitIndex
1
    Int
c | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0x800    -> CodeUnitIndex
2
    -- The letters ⱥ and ⱦ are 3 UTF8 bytes, but have unlowerings Ⱥ and Ⱦ of 2 bytes
    Int
0x2C65           -> CodeUnitIndex
2  -- ⱥ
    Int
0x2C66           -> CodeUnitIndex
2  -- ⱦ
    Int
c | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0x10000  -> CodeUnitIndex
3
    Int
_                -> CodeUnitIndex
4


-- | Number of bytes of the shortest case variation of the given needle. Needles
-- are assumed to be lower case.
--
--     minimumSkipForVector (TBA.fromList "ab..cd") == 6
--     minimumSkipForVector (TBA.fromList "aⱥ💩") == 7
--
minimumSkipForVector :: TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector :: TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector = (CodePoint -> CodeUnitIndex -> CodeUnitIndex)
-> CodeUnitIndex -> TypedByteArray CodePoint -> CodeUnitIndex
forall a b. Prim a => (a -> b -> b) -> b -> TypedByteArray a -> b
TBA.foldr (\CodePoint
cp CodeUnitIndex
s -> CodeUnitIndex
s CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
cp) CodeUnitIndex
0


-- | The suffix table tells us for each codepoint (not byte!) of the pattern how many bytes (not
-- codepoints!) we can jump ahead if the match fails at that point.
newtype SuffixTable = SuffixTable (TypedByteArray CodeUnitIndex)
  deriving stock ((forall x. SuffixTable -> Rep SuffixTable x)
-> (forall x. Rep SuffixTable x -> SuffixTable)
-> Generic SuffixTable
forall x. Rep SuffixTable x -> SuffixTable
forall x. SuffixTable -> Rep SuffixTable x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SuffixTable -> Rep SuffixTable x
from :: forall x. SuffixTable -> Rep SuffixTable x
$cto :: forall x. Rep SuffixTable x -> SuffixTable
to :: forall x. Rep SuffixTable x -> SuffixTable
Generic)
  deriving anyclass (SuffixTable -> ()
(SuffixTable -> ()) -> NFData SuffixTable
forall a. (a -> ()) -> NFData a
$crnf :: SuffixTable -> ()
rnf :: SuffixTable -> ()
NFData)

instance Show SuffixTable where
  show :: SuffixTable -> String
show (SuffixTable TypedByteArray CodeUnitIndex
table) = String
"SuffixTable (TBA.toList " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [CodeUnitIndex] -> String
forall a. Show a => a -> String
show (TypedByteArray CodeUnitIndex -> [CodeUnitIndex]
forall a. Prim a => TypedByteArray a -> [a]
TBA.toList TypedByteArray CodeUnitIndex
table) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
")"

-- | Lookup an entry in the suffix table.
suffixLookup :: SuffixTable -> Int -> CodeUnitIndex
{-# INLINE suffixLookup #-}
suffixLookup :: SuffixTable -> Int -> CodeUnitIndex
suffixLookup (SuffixTable TypedByteArray CodeUnitIndex
table) = TypedByteArray CodeUnitIndex -> Int -> CodeUnitIndex
forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray CodeUnitIndex
table

buildSuffixTable :: TypedByteArray CodePoint -> SuffixTable
buildSuffixTable :: TypedByteArray CodePoint -> SuffixTable
buildSuffixTable TypedByteArray CodePoint
pattern_ = (forall s. ST s SuffixTable) -> SuffixTable
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s SuffixTable) -> SuffixTable)
-> (forall s. ST s SuffixTable) -> SuffixTable
forall a b. (a -> b) -> a -> b
$ do
  let
    patLen :: Int
patLen = TypedByteArray CodePoint -> Int
forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_
    wholePatternSkip :: CodeUnitIndex
wholePatternSkip = TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector TypedByteArray CodePoint
pattern_

  MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table <- Int
-> ST s (MutableTypedByteArray CodeUnitIndex (PrimState (ST s)))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
TBA.newTypedByteArray Int
patLen

  let
    -- Case 1: For each position of the pattern we record the shift that would align the pattern so
    -- that it starts at the longest suffix that is at the same time a prefix, if a mismatch would
    -- happen at that position.
    --
    -- Suppose the length of the pattern is n, a mismatch occurs at position i in the pattern and j
    -- in the haystack, then we know that pattern[i+1..n] == haystack[j+1..j+n-i]. That is, we know
    -- that the part of the haystack that we already matched is a suffix of the pattern.
    -- If the pattern happens to have a prefix that is equal to or a shorter suffix of that matched
    -- suffix, we can shift the pattern to the right so that the pattern starts at the longest
    -- suffix that we have seen that conincides with a prefix of the pattern.
    --
    -- Consider the pattern `ababa`. Then we get
    --
    -- p:              0  1  2  3  4
    -- Pattern:        a  b  a  b  a
    -- lastSkipBytes:              5   not touched by init1
    -- lastSkipBytes:           4  5   "a" == "a" so if we get a mismatch here we can skip
    --                                            everything but the length of "a"
    -- lastSkipBytes:        4  4  5   "ab" /= "ba" so keep skip value
    -- lastSkipBytes:     2  4  4  5   "aba" == "aba"
    -- lastSkipBytes:  2  2  4  4  5   "abab" /= "baba"
    init1 :: CodeUnitIndex -> Int -> m ()
init1 CodeUnitIndex
lastSkipBytes Int
p
      | Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 = do
        let
          skipBytes :: CodeUnitIndex
skipBytes = case TypedByteArray CodePoint -> Int -> Maybe CodeUnitIndex
suffixIsPrefix TypedByteArray CodePoint
pattern_ (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) of
                        Maybe CodeUnitIndex
Nothing -> CodeUnitIndex
lastSkipBytes
                        -- Skip the whole pattern _except_ the bytes for the suffix(==prefix)
                        Just CodeUnitIndex
nonSkippableBytes -> CodeUnitIndex
wholePatternSkip CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
nonSkippableBytes
        MutableTypedByteArray CodeUnitIndex (PrimState m)
-> Int -> CodeUnitIndex -> m ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState m)
MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table Int
p CodeUnitIndex
skipBytes
        CodeUnitIndex -> Int -> m ()
init1 CodeUnitIndex
skipBytes (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
      | Bool
otherwise = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    -- Case 2: We also have to account for the fact that the matching suffix of the pattern might
    -- occur again somewhere within the pattern. In that case, we may not shift as far as if it was
    -- a prefix. That is why the `init2` loop is run after `init1`, potentially overwriting some
    -- entries with smaller shifts.
    init2 :: Int -> CodeUnitIndex -> ST s ()
init2 Int
p CodeUnitIndex
skipBytes
      | Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
patLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 = do
          -- If we find a suffix that ends at p, we can skip everything _after_ p.
          let skipBytes' :: CodeUnitIndex
skipBytes' = CodeUnitIndex
skipBytes CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodePoint -> CodeUnitIndex
minimumSkipForCodePoint (TypedByteArray CodePoint -> Int -> CodePoint
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ Int
p)
          case TypedByteArray CodePoint -> Int -> Maybe Int
substringIsSuffix TypedByteArray CodePoint
pattern_ Int
p of
            Maybe Int
Nothing -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
            Just Int
suffixLen -> do
              MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
-> Int -> CodeUnitIndex -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table (Int
patLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
suffixLen) CodeUnitIndex
skipBytes'
          Int -> CodeUnitIndex -> ST s ()
init2 (Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) CodeUnitIndex
skipBytes'
      | Bool
otherwise = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  CodeUnitIndex -> Int -> ST s ()
forall {m :: * -> *}.
(PrimState m ~ PrimState (ST s), PrimMonad m) =>
CodeUnitIndex -> Int -> m ()
init1 (CodeUnitIndex
wholePatternSkipCodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
-CodeUnitIndex
1) (Int
patLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  Int -> CodeUnitIndex -> ST s ()
init2 Int
0 CodeUnitIndex
wholePatternSkip
  MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
-> Int -> CodeUnitIndex -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table (Int
patLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> CodeUnitIndex
CodeUnitIndex Int
1)

  TypedByteArray CodeUnitIndex -> SuffixTable
SuffixTable (TypedByteArray CodeUnitIndex -> SuffixTable)
-> ST s (TypedByteArray CodeUnitIndex) -> ST s SuffixTable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
-> ST s (TypedByteArray CodeUnitIndex)
forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table

-- | True if the suffix of the @pattern@ starting from @pos@ is a prefix of the pattern
-- For example, @suffixIsPrefix \"aabbaa\" 4 == Just 2@.
suffixIsPrefix :: TypedByteArray CodePoint -> Int -> Maybe CodeUnitIndex
suffixIsPrefix :: TypedByteArray CodePoint -> Int -> Maybe CodeUnitIndex
suffixIsPrefix TypedByteArray CodePoint
pattern_ Int
pos = Int -> CodeUnitIndex -> Maybe CodeUnitIndex
go Int
0 (Int -> CodeUnitIndex
CodeUnitIndex Int
0)
  where
    suffixLen :: Int
suffixLen = TypedByteArray CodePoint -> Int
forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_ Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
pos
    go :: Int -> CodeUnitIndex -> Maybe CodeUnitIndex
go !Int
i !CodeUnitIndex
skipBytes
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
suffixLen =
          let prefixChar :: CodePoint
prefixChar = TypedByteArray CodePoint -> Int -> CodePoint
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ Int
i in
          if CodePoint
prefixChar CodePoint -> CodePoint -> Bool
forall a. Eq a => a -> a -> Bool
== TypedByteArray CodePoint -> Int -> CodePoint
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
            then Int -> CodeUnitIndex -> Maybe CodeUnitIndex
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (CodeUnitIndex
skipBytes CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
prefixChar)
            else Maybe CodeUnitIndex
forall a. Maybe a
Nothing
      | Bool
otherwise = CodeUnitIndex -> Maybe CodeUnitIndex
forall a. a -> Maybe a
Just CodeUnitIndex
skipBytes

-- | Length in bytes of the longest suffix of the pattern ending on @pos@.  For
-- example, @substringIsSuffix \"abaacbbaac\" 4 == Just 4@, because the
-- substring \"baac\" ends at position 4 and is at the same time the longest
-- suffix that does so, having a length of 4 characters.
--
-- For a string like "abaacaabcbaac", when we detect at pos=4 that baac==baac,
-- it means that if we get a mismatch before the "baac" suffix, we can skip the
-- "aabcbaac" characters _after_ the "baac" substring. So we can put
-- (minimumSkipForText "aabcbaac") at that point in the suffix table.
--
--   substringIsSuffix (Vector.fromList "ababa") 0 == Nothing  -- a == a, but not a proper substring
--   substringIsSuffix (Vector.fromList "ababa") 1 == Nothing  -- b /= a
--   substringIsSuffix (Vector.fromList "ababa") 2 == Nothing  -- aba == aba, but not a proper substring
--   substringIsSuffix (Vector.fromList "ababa") 3 == Nothing  -- b /= a
--   substringIsSuffix (Vector.fromList "ababa") 4 == Nothing  -- ababa == ababa, but not a proper substring
--   substringIsSuffix (Vector.fromList "baba") 0 == Nothing  -- b /= a
--   substringIsSuffix (Vector.fromList "baba") 1 == Nothing  -- ba == ba, but not a proper substring
--   substringIsSuffix (Vector.fromList "abaacaabcbaac") 4 == Just 4  -- baac == baac
--   substringIsSuffix (Vector.fromList "abaacaabcbaac") 8 == Just 1  -- c == c
--
substringIsSuffix :: TypedByteArray CodePoint -> Int -> Maybe Int
substringIsSuffix :: TypedByteArray CodePoint -> Int -> Maybe Int
substringIsSuffix TypedByteArray CodePoint
pattern_ Int
pos = Int -> Maybe Int
go Int
0
  where
    patLen :: Int
patLen = TypedByteArray CodePoint -> Int
forall a. Prim a => TypedByteArray a -> Int
TBA.length TypedByteArray CodePoint
pattern_
    go :: Int -> Maybe Int
go Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
pos = Maybe Int
forall a. Maybe a
Nothing  -- prefix==suffix, so already covered by suffixIsPrefix
         | TypedByteArray CodePoint -> Int -> CodePoint
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) CodePoint -> CodePoint -> Bool
forall a. Eq a => a -> a -> Bool
== TypedByteArray CodePoint -> Int -> CodePoint
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex TypedByteArray CodePoint
pattern_ (Int
patLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) =
             Int -> Maybe Int
go (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
         | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Maybe Int
forall a. Maybe a
Nothing  -- Nothing matched
         | Bool
otherwise = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i


-- | The bad char table tells us how many bytes we may skip ahead when encountering a certain
-- character in the input string. For example, if there's a character that is not contained in the
-- pattern at all, we can skip ahead until after that character.
data BadCharLookup = BadCharLookup
  { BadCharLookup -> TypedByteArray CodeUnitIndex
badCharLookupTable :: {-# UNPACK #-} !(TypedByteArray CodeUnitIndex)
  , BadCharLookup -> HashMap CodePoint CodeUnitIndex
badCharLookupMap :: !(HashMap.HashMap CodePoint CodeUnitIndex)
  , BadCharLookup -> CodeUnitIndex
badCharLookupDefault :: !CodeUnitIndex
  }
  deriving stock ((forall x. BadCharLookup -> Rep BadCharLookup x)
-> (forall x. Rep BadCharLookup x -> BadCharLookup)
-> Generic BadCharLookup
forall x. Rep BadCharLookup x -> BadCharLookup
forall x. BadCharLookup -> Rep BadCharLookup x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. BadCharLookup -> Rep BadCharLookup x
from :: forall x. BadCharLookup -> Rep BadCharLookup x
$cto :: forall x. Rep BadCharLookup x -> BadCharLookup
to :: forall x. Rep BadCharLookup x -> BadCharLookup
Generic, Int -> BadCharLookup -> ShowS
[BadCharLookup] -> ShowS
BadCharLookup -> String
(Int -> BadCharLookup -> ShowS)
-> (BadCharLookup -> String)
-> ([BadCharLookup] -> ShowS)
-> Show BadCharLookup
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BadCharLookup -> ShowS
showsPrec :: Int -> BadCharLookup -> ShowS
$cshow :: BadCharLookup -> String
show :: BadCharLookup -> String
$cshowList :: [BadCharLookup] -> ShowS
showList :: [BadCharLookup] -> ShowS
Show)
  deriving anyclass (BadCharLookup -> ()
(BadCharLookup -> ()) -> NFData BadCharLookup
forall a. (a -> ()) -> NFData a
$crnf :: BadCharLookup -> ()
rnf :: BadCharLookup -> ()
NFData)

-- | Number of entries in the fixed-size lookup-table of the bad char table.
badCharTableSize :: Int
{-# INLINE badCharTableSize #-}
badCharTableSize :: Int
badCharTableSize = Int
256

-- | Lookup an entry in the bad char table.
badCharLookup :: BadCharLookup -> CodePoint -> CodeUnitIndex
{-# INLINE badCharLookup #-}
badCharLookup :: BadCharLookup -> CodePoint -> CodeUnitIndex
badCharLookup (BadCharLookup TypedByteArray CodeUnitIndex
bclTable HashMap CodePoint CodeUnitIndex
bclMap CodeUnitIndex
bclDefault) CodePoint
char
  | Int
intChar Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
badCharTableSize = TypedByteArray CodeUnitIndex -> Int -> CodeUnitIndex
forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray CodeUnitIndex
bclTable Int
intChar
  | Bool
otherwise = CodeUnitIndex
-> CodePoint -> HashMap CodePoint CodeUnitIndex -> CodeUnitIndex
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
HashMap.lookupDefault CodeUnitIndex
bclDefault CodePoint
char HashMap CodePoint CodeUnitIndex
bclMap
  where
    intChar :: Int
intChar = CodePoint -> Int
forall a. Enum a => a -> Int
fromEnum CodePoint
char



buildBadCharLookup :: TypedByteArray CodePoint -> BadCharLookup
buildBadCharLookup :: TypedByteArray CodePoint -> BadCharLookup
buildBadCharLookup TypedByteArray CodePoint
pattern_ = (forall s. ST s BadCharLookup) -> BadCharLookup
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s BadCharLookup) -> BadCharLookup)
-> (forall s. ST s BadCharLookup) -> BadCharLookup
forall a b. (a -> b) -> a -> b
$ do

  let
    defaultSkip :: CodeUnitIndex
defaultSkip = TypedByteArray CodePoint -> CodeUnitIndex
minimumSkipForVector TypedByteArray CodePoint
pattern_

  -- Initialize table with the maximum skip distance, which is the length of the pattern.
  -- This applies to all characters that are not part of the pattern.
  MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table <- (Int
-> CodeUnitIndex
-> ST s (MutableTypedByteArray CodeUnitIndex (PrimState (ST s)))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> a -> m (MutableTypedByteArray a (PrimState m))
TBA.replicate Int
badCharTableSize CodeUnitIndex
defaultSkip)

  let
    -- Fill the bad character table based on the rightmost occurrence of a character in the pattern.
    -- Note that there is also a variant of Boyer-Moore that records all positions (see Wikipedia,
    -- but that requires even more storage space).
    -- Also note that we exclude the last character of the pattern when building the table.
    -- This is because
    --
    -- 1. If the last character does not occur anywhere else in the pattern and we encounter it
    --    during a mismatch, we can advance the pattern to just after that character:
    --
    --    Haystack: aaadcdabcdbb
    --    Pattern:    abcd
    --
    --    In the above example, we would match `d` and `c`, but then fail because `d` != `b`.
    --    Since `d` only occurs at the very last position of the pattern, we can shift to
    --
    --    Haystack: aaadcdabcdbb
    --    Pattern:      abcd
    --
    -- 2. If it does occur anywhere else in the pattern, we can only shift as far as it's necessary
    --    to align it with the haystack:
    --
    --    Haystack: aaadddabcdbb
    --    Pattern:    adcd
    --
    --    We match `d`, and then there is a mismatch `d` != `c`, which allows us to shift only up to:

    --    Haystack: aaadddabcdbb
    --    Pattern:     adcd



    fillTable :: HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable !HashMap CodePoint CodeUnitIndex
badCharMap !CodeUnitIndex
skipBytes = \case
      [] -> HashMap CodePoint CodeUnitIndex
-> ST s (HashMap CodePoint CodeUnitIndex)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashMap CodePoint CodeUnitIndex
badCharMap
      [CodePoint
_] -> HashMap CodePoint CodeUnitIndex
-> ST s (HashMap CodePoint CodeUnitIndex)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure HashMap CodePoint CodeUnitIndex
badCharMap  -- The last pattern character doesn't count.
      (!CodePoint
patChar : !String
patChars) ->
        let skipBytes' :: CodeUnitIndex
skipBytes' = CodeUnitIndex
skipBytes CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodePoint -> CodeUnitIndex
minimumSkipForCodePoint CodePoint
patChar in
        if CodePoint -> Int
forall a. Enum a => a -> Int
fromEnum CodePoint
patChar Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
badCharTableSize
        then do
          MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
-> Int -> CodeUnitIndex -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table (CodePoint -> Int
forall a. Enum a => a -> Int
fromEnum CodePoint
patChar) CodeUnitIndex
skipBytes'
          HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable HashMap CodePoint CodeUnitIndex
badCharMap CodeUnitIndex
skipBytes' String
patChars
        else
          let badCharMap' :: HashMap CodePoint CodeUnitIndex
badCharMap' = CodePoint
-> CodeUnitIndex
-> HashMap CodePoint CodeUnitIndex
-> HashMap CodePoint CodeUnitIndex
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HashMap.insert CodePoint
patChar CodeUnitIndex
skipBytes' HashMap CodePoint CodeUnitIndex
badCharMap
          in HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable HashMap CodePoint CodeUnitIndex
badCharMap' CodeUnitIndex
skipBytes' String
patChars

  HashMap CodePoint CodeUnitIndex
badCharMap <- HashMap CodePoint CodeUnitIndex
-> CodeUnitIndex
-> String
-> ST s (HashMap CodePoint CodeUnitIndex)
fillTable HashMap CodePoint CodeUnitIndex
forall k v. HashMap k v
HashMap.empty CodeUnitIndex
defaultSkip (TypedByteArray CodePoint -> String
forall a. Prim a => TypedByteArray a -> [a]
TBA.toList TypedByteArray CodePoint
pattern_)

  TypedByteArray CodeUnitIndex
tableFrozen <- MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
-> ST s (TypedByteArray CodeUnitIndex)
forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray CodeUnitIndex (PrimState (ST s))
table

  BadCharLookup -> ST s BadCharLookup
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BadCharLookup
    { badCharLookupTable :: TypedByteArray CodeUnitIndex
badCharLookupTable = TypedByteArray CodeUnitIndex
tableFrozen
    , badCharLookupMap :: HashMap CodePoint CodeUnitIndex
badCharLookupMap = HashMap CodePoint CodeUnitIndex
badCharMap
    , badCharLookupDefault :: CodeUnitIndex
badCharLookupDefault = CodeUnitIndex
defaultSkip
    }


-- Helper functions for easily toggling the safety of this module

-- | Read from a lookup table at the specified index.
indexTable :: Prim a => TypedByteArray a -> Int -> a
{-# INLINE indexTable #-}
indexTable :: forall a. Prim a => TypedByteArray a -> Int -> a
indexTable = TypedByteArray a -> Int -> a
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex