-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2019 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An efficient implementation of the Boyer-Moore string search algorithm.
-- http://www-igm.univ-mlv.fr/~lecroq/string/node14.html#SECTION00140
-- https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string-search_algorithm
--
-- This module contains a almost 1:1 translation from the C example code in the
-- wikipedia article.
--
-- The algorithm here can be potentially improved by including the Galil rule
-- (https://en.wikipedia.org/wiki/Boyer%E2%80%93Moore_string-search_algorithm#The_Galil_rule)
module Data.Text.BoyerMoore.Automaton
    ( Automaton
    , CaseSensitivity (..)
    , CodeUnitIndex (..)
    , Next (..)
    , buildAutomaton
    , patternLength
    , patternText
    , runText
    ) where

import Prelude hiding (length)

import Control.DeepSeq (NFData)
import Control.Monad (when)
import Control.Monad.ST (runST)
import Data.Hashable (Hashable (..), Hashed, hashed, unhashed)
import GHC.Generics (Generic)

#if defined(HAS_AESON)
import qualified Data.Aeson as AE
#endif

import Data.Text.CaseSensitivity (CaseSensitivity (..))
import Data.Text.Utf8 (CodeUnit, CodeUnitIndex (..), Text)
import Data.TypedByteArray (Prim, TypedByteArray)

import qualified Data.Text.Utf8 as Utf8
import qualified Data.TypedByteArray as TBA

data Next a
  = Done !a
  | Step !a

-- | A Boyer-Moore automaton is based on lookup-tables that allow skipping through the haystack.
-- This allows for sub-linear matching in some cases, as we do not have to look at every input
-- character.
--
-- NOTE: Unlike the AcMachine, a Boyer-Moore automaton only returns non-overlapping matches.
-- This means that a Boyer-Moore automaton is not a 100% drop-in replacement for Aho-Corasick.
--
-- Returning overlapping matches would degrade the performance to /O(nm)/ in pathological cases like
-- finding @aaaa@ in @aaaaa....aaaaaa@ as for each match it would scan back the whole /m/ characters
-- of the pattern.
data Automaton = Automaton
  { Automaton -> Hashed Text
automatonPattern :: Hashed Text
  , Automaton -> SuffixTable
automatonSuffixTable :: SuffixTable
  , Automaton -> BadCharTable
automatonBadCharTable :: BadCharTable
  }
  deriving stock ((forall x. Automaton -> Rep Automaton x)
-> (forall x. Rep Automaton x -> Automaton) -> Generic Automaton
forall x. Rep Automaton x -> Automaton
forall x. Automaton -> Rep Automaton x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Automaton x -> Automaton
$cfrom :: forall x. Automaton -> Rep Automaton x
Generic, Int -> Automaton -> ShowS
[Automaton] -> ShowS
Automaton -> String
(Int -> Automaton -> ShowS)
-> (Automaton -> String)
-> ([Automaton] -> ShowS)
-> Show Automaton
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Automaton] -> ShowS
$cshowList :: [Automaton] -> ShowS
show :: Automaton -> String
$cshow :: Automaton -> String
showsPrec :: Int -> Automaton -> ShowS
$cshowsPrec :: Int -> Automaton -> ShowS
Show)
  deriving anyclass (Automaton -> ()
(Automaton -> ()) -> NFData Automaton
forall a. (a -> ()) -> NFData a
rnf :: Automaton -> ()
$crnf :: Automaton -> ()
NFData)

instance Hashable Automaton where
  hashWithSalt :: Int -> Automaton -> Int
hashWithSalt Int
salt (Automaton Hashed Text
pattern SuffixTable
_ BadCharTable
_) = Int -> Hashed Text -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt Hashed Text
pattern

instance Eq Automaton where
  (Automaton Hashed Text
pat1 SuffixTable
_ BadCharTable
_) == :: Automaton -> Automaton -> Bool
== (Automaton Hashed Text
pat2 SuffixTable
_ BadCharTable
_) = Hashed Text
pat1 Hashed Text -> Hashed Text -> Bool
forall a. Eq a => a -> a -> Bool
== Hashed Text
pat2

#if defined(HAS_AESON)
instance AE.FromJSON Automaton where
  parseJSON :: Value -> Parser Automaton
parseJSON Value
v = Text -> Automaton
buildAutomaton (Text -> Automaton) -> Parser Text -> Parser Automaton
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Value -> Parser Text
forall a. FromJSON a => Value -> Parser a
AE.parseJSON Value
v

instance AE.ToJSON Automaton where
  toJSON :: Automaton -> Value
toJSON = Text -> Value
forall a. ToJSON a => a -> Value
AE.toJSON (Text -> Value) -> (Automaton -> Text) -> Automaton -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Hashed Text -> Text
forall a. Hashed a -> a
unhashed (Hashed Text -> Text)
-> (Automaton -> Hashed Text) -> Automaton -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Hashed Text
automatonPattern
#endif

buildAutomaton :: Text -> Automaton
buildAutomaton :: Text -> Automaton
buildAutomaton Text
pattern = Hashed Text -> SuffixTable -> BadCharTable -> Automaton
Automaton (Text -> Hashed Text
forall a. Hashable a => a -> Hashed a
hashed Text
pattern) (Text -> SuffixTable
buildSuffixTable Text
pattern) (Text -> BadCharTable
buildBadCharTable Text
pattern)

-- | Finds all matches in the text, calling the match callback with the *first*
-- matched character of each match of the pattern.
--
-- NOTE: This is unlike Aho-Corasick, which reports the index of the character
-- right after a match.
--
-- NOTE: In the UTF-16 version of this module, there is a function 'Data.Text.BoyerMoore.Automaton.runLower'
-- which does lower-case matching. This function does not exist for the UTF-8 version since it is very
-- tricky to skip code points going backwards without preprocessing the whole input first.
--
-- NOTE: To get full advantage of inlining this function, you probably want to
-- compile the compiling module with -fllvm and the same optimization flags as
-- this module.
runText  :: forall a
  . a
  -> (a -> CodeUnitIndex -> Next a)
  -> Automaton
  -> Text
  -> a
{-# INLINE runText #-}
runText :: a -> (a -> CodeUnitIndex -> Next a) -> Automaton -> Text -> a
runText a
seed a -> CodeUnitIndex -> Next a
f Automaton
automaton Text
text
  | CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Eq a => a -> a -> Bool
== CodeUnitIndex
0 = a
seed
  | Bool
otherwise = a -> CodeUnitIndex -> a
go a
seed (CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
  where
    Automaton Hashed Text
patternHashed SuffixTable
suffixTable BadCharTable
badCharTable = Automaton
automaton
    -- Use needle as identifier since pattern is potentially a keyword
    needle :: Text
needle = Hashed Text -> Text
forall a. Hashed a -> a
unhashed Hashed Text
patternHashed
    patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
needle
    stringLen :: CodeUnitIndex
stringLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
text

    codeUnitAt :: CodeUnitIndex -> CodeUnit
codeUnitAt = Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
text

    {-# INLINE go #-}
    go :: a -> CodeUnitIndex -> a
go a
result CodeUnitIndex
haystackIndex
      | CodeUnitIndex
haystackIndex CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
stringLen = a -> CodeUnitIndex -> CodeUnitIndex -> a
matchLoop a
result CodeUnitIndex
haystackIndex (CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
      | Bool
otherwise = a
result

    -- Compare the needle back-to-front with the haystack
    matchLoop :: a -> CodeUnitIndex -> CodeUnitIndex -> a
matchLoop a
result CodeUnitIndex
haystackIndex CodeUnitIndex
needleIndex
      | CodeUnitIndex
needleIndex CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
>= CodeUnitIndex
0 Bool -> Bool -> Bool
&& CodeUnitIndex -> CodeUnit
codeUnitAt CodeUnitIndex
haystackIndex CodeUnit -> CodeUnit -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
needle CodeUnitIndex
needleIndex =
        -- Characters match, try the pair before
        a -> CodeUnitIndex -> CodeUnitIndex -> a
matchLoop a
result (CodeUnitIndex
haystackIndex CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1) (CodeUnitIndex
needleIndex CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
      -- We found a match (all needle characters matched)
      | CodeUnitIndex
needleIndex CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
0 =
        case a -> CodeUnitIndex -> Next a
f a
result (CodeUnitIndex
haystackIndex CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1) of
          Done a
final -> a
final
          -- `haystackIndex` now points to the character just before the match starts
          -- Adding `patLen` once points to the last character of the match,
          -- Adding `patLen` once more points to the earliest character where
          -- we can find a non-overlapping match.
          Step a
intermediate -> a -> CodeUnitIndex -> a
go a
intermediate (CodeUnitIndex
haystackIndex CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
2 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
* CodeUnitIndex
patLen)
      -- We know it's not a match, the characters differ at the current position
      | Bool
otherwise =
        let
          -- The bad character table tells us how far we can advance to the right so that the
          -- character at the current position in the input string, where matching failed,
          -- is lined up with it's rightmost occurrence in the needle.
          -- Note: we could end up left of were we started, essentially never making progress,
          -- if we were to use this rule alone.
          badCharSkip :: CodeUnitIndex
badCharSkip = BadCharTable -> CodeUnit -> CodeUnitIndex
badCharLookup BadCharTable
badCharTable (CodeUnitIndex -> CodeUnit
codeUnitAt CodeUnitIndex
haystackIndex)
          suffixSkip :: CodeUnitIndex
suffixSkip = SuffixTable -> CodeUnitIndex -> CodeUnitIndex
suffixLookup SuffixTable
suffixTable CodeUnitIndex
needleIndex
          skip :: CodeUnitIndex
skip = CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Ord a => a -> a -> a
max CodeUnitIndex
badCharSkip CodeUnitIndex
suffixSkip
        in
          a -> CodeUnitIndex -> a
go a
result (CodeUnitIndex
haystackIndex CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
skip)

-- | Length of the matched pattern measured in UTF-8 code units (bytes).
patternLength :: Automaton -> CodeUnitIndex
patternLength :: Automaton -> CodeUnitIndex
patternLength = Text -> CodeUnitIndex
Utf8.lengthUtf8 (Text -> CodeUnitIndex)
-> (Automaton -> Text) -> Automaton -> CodeUnitIndex
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Automaton -> Text
patternText

-- | Return the pattern that was used to construct the automaton.
patternText :: Automaton -> Text
patternText :: Automaton -> Text
patternText (Automaton Hashed Text
pattern SuffixTable
_ BadCharTable
_) = Hashed Text -> Text
forall a. Hashed a -> a
unhashed Hashed Text
pattern

-- | The suffix table tells us for each character of the pattern how many characters we can
-- jump ahead if the match fails at that point.
newtype SuffixTable = SuffixTable (TypedByteArray Int)
  deriving stock ((forall x. SuffixTable -> Rep SuffixTable x)
-> (forall x. Rep SuffixTable x -> SuffixTable)
-> Generic SuffixTable
forall x. Rep SuffixTable x -> SuffixTable
forall x. SuffixTable -> Rep SuffixTable x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep SuffixTable x -> SuffixTable
$cfrom :: forall x. SuffixTable -> Rep SuffixTable x
Generic, Int -> SuffixTable -> ShowS
[SuffixTable] -> ShowS
SuffixTable -> String
(Int -> SuffixTable -> ShowS)
-> (SuffixTable -> String)
-> ([SuffixTable] -> ShowS)
-> Show SuffixTable
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SuffixTable] -> ShowS
$cshowList :: [SuffixTable] -> ShowS
show :: SuffixTable -> String
$cshow :: SuffixTable -> String
showsPrec :: Int -> SuffixTable -> ShowS
$cshowsPrec :: Int -> SuffixTable -> ShowS
Show)
  deriving anyclass (SuffixTable -> ()
(SuffixTable -> ()) -> NFData SuffixTable
forall a. (a -> ()) -> NFData a
rnf :: SuffixTable -> ()
$crnf :: SuffixTable -> ()
NFData)

-- | Lookup an entry in the suffix table.
suffixLookup :: SuffixTable -> CodeUnitIndex -> CodeUnitIndex
{-# INLINE suffixLookup #-}
suffixLookup :: SuffixTable -> CodeUnitIndex -> CodeUnitIndex
suffixLookup (SuffixTable TypedByteArray Int
table) = Int -> CodeUnitIndex
CodeUnitIndex (Int -> CodeUnitIndex)
-> (CodeUnitIndex -> Int) -> CodeUnitIndex -> CodeUnitIndex
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedByteArray Int -> Int -> Int
forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray Int
table (Int -> Int) -> (CodeUnitIndex -> Int) -> CodeUnitIndex -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodeUnitIndex -> Int
codeUnitIndex

buildSuffixTable :: Text -> SuffixTable
buildSuffixTable :: Text -> SuffixTable
buildSuffixTable Text
pattern = (forall s. ST s SuffixTable) -> SuffixTable
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s SuffixTable) -> SuffixTable)
-> (forall s. ST s SuffixTable) -> SuffixTable
forall a b. (a -> b) -> a -> b
$ do
  let patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
pattern

  MutableTypedByteArray Int s
table <- Int -> ST s (MutableTypedByteArray Int (PrimState (ST s)))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> m (MutableTypedByteArray a (PrimState m))
TBA.newTypedByteArray (Int -> ST s (MutableTypedByteArray Int (PrimState (ST s))))
-> Int -> ST s (MutableTypedByteArray Int (PrimState (ST s)))
forall a b. (a -> b) -> a -> b
$ CodeUnitIndex -> Int
codeUnitIndex CodeUnitIndex
patLen

  let
    -- Case 1: For each position of the pattern we record the shift that would align the pattern so
    -- that it starts at the longest suffix that is at the same time a prefix, if a mismatch would
    -- happen at that position.
    --
    -- Suppose the length of the pattern is n, a mismatch occurs at position i in the pattern and j
    -- in the haystack, then we know that pattern[i+1..n] == haystack[j+1..j+n-i]. That is, we know
    -- that the part of the haystack that we already matched is a suffix of the pattern.
    -- If the pattern happens to have a prefix that is equal to or a shorter suffix of that matched
    -- suffix, we can shift the pattern to the right so that the pattern starts at the longest
    -- suffix that we have seen that conincides with a prefix of the pattern.
    --
    -- Consider the pattern `ababa`. Then we get
    --
    -- p:                0  1  2  3  4
    -- Pattern:          a  b  a  b  a
    -- lastPrefixIndex:  2  2  4  4  5
    -- table:            6  5  6  5  5
    init1 :: CodeUnitIndex -> CodeUnitIndex -> ST s ()
init1 CodeUnitIndex
lastPrefixIndex CodeUnitIndex
p
      | CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
>= CodeUnitIndex
0 = do
        let
          prefixIndex :: CodeUnitIndex
prefixIndex
            | Text -> CodeUnitIndex -> Bool
isPrefix Text
pattern (CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1) = CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1
            | Bool
otherwise = CodeUnitIndex
lastPrefixIndex
        MutableTypedByteArray Int (PrimState (ST s))
-> Int -> Int -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray Int s
MutableTypedByteArray Int (PrimState (ST s))
table (CodeUnitIndex -> Int
codeUnitIndex CodeUnitIndex
p) (CodeUnitIndex -> Int
codeUnitIndex (CodeUnitIndex -> Int) -> CodeUnitIndex -> Int
forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
prefixIndex CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
p)
        CodeUnitIndex -> CodeUnitIndex -> ST s ()
init1 CodeUnitIndex
prefixIndex (CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
      | Bool
otherwise = () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    -- Case 2: We also have to account for the fact that the matching suffix of the pattern might
    -- occur again somewhere within the pattern. In that case, we may not shift as far as if it was
    -- a prefix. That is why the `init2` loop is run after `init1`, potentially overwriting some
    -- entries with smaller shifts.
    init2 :: CodeUnitIndex -> ST s ()
init2 CodeUnitIndex
p
      | CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 = do
        let
          suffixLen :: CodeUnitIndex
suffixLen = Text -> CodeUnitIndex -> CodeUnitIndex
suffixLength Text
pattern CodeUnitIndex
p
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
suffixLen) CodeUnit -> CodeUnit -> Bool
forall a. Eq a => a -> a -> Bool
/= Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
suffixLen)) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
          MutableTypedByteArray Int (PrimState (ST s))
-> Int -> Int -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray Int s
MutableTypedByteArray Int (PrimState (ST s))
table (CodeUnitIndex -> Int
codeUnitIndex (CodeUnitIndex -> Int) -> CodeUnitIndex -> Int
forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
suffixLen) (CodeUnitIndex -> Int
codeUnitIndex (CodeUnitIndex -> Int) -> CodeUnitIndex -> Int
forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
suffixLen)
        CodeUnitIndex -> ST s ()
init2 (CodeUnitIndex
p CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
      | Bool
otherwise = () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  CodeUnitIndex -> CodeUnitIndex -> ST s ()
init1 (CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1) (CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1)
  CodeUnitIndex -> ST s ()
init2 CodeUnitIndex
0

  TypedByteArray Int -> SuffixTable
SuffixTable (TypedByteArray Int -> SuffixTable)
-> ST s (TypedByteArray Int) -> ST s SuffixTable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableTypedByteArray Int (PrimState (ST s))
-> ST s (TypedByteArray Int)
forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray Int s
MutableTypedByteArray Int (PrimState (ST s))
table


-- | The bad char table tells us how far we may skip ahead when encountering a certain character
-- in the input string. For example, if there's a character that is not contained in the pattern at
-- all, we can skip ahead until after that character.
data BadCharTable = BadCharTable
  { BadCharTable -> TypedByteArray Int
badCharTableEntries :: {-# UNPACK #-} !(TypedByteArray Int)
    -- ^ The element type should be CodeUnitIndex, but there's no unboxed vector for that type, and
    -- defining it would be a lot of boilerplate.
  , BadCharTable -> CodeUnitIndex
badCharTablePatternLen :: CodeUnitIndex
  }
  deriving stock ((forall x. BadCharTable -> Rep BadCharTable x)
-> (forall x. Rep BadCharTable x -> BadCharTable)
-> Generic BadCharTable
forall x. Rep BadCharTable x -> BadCharTable
forall x. BadCharTable -> Rep BadCharTable x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep BadCharTable x -> BadCharTable
$cfrom :: forall x. BadCharTable -> Rep BadCharTable x
Generic, Int -> BadCharTable -> ShowS
[BadCharTable] -> ShowS
BadCharTable -> String
(Int -> BadCharTable -> ShowS)
-> (BadCharTable -> String)
-> ([BadCharTable] -> ShowS)
-> Show BadCharTable
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [BadCharTable] -> ShowS
$cshowList :: [BadCharTable] -> ShowS
show :: BadCharTable -> String
$cshow :: BadCharTable -> String
showsPrec :: Int -> BadCharTable -> ShowS
$cshowsPrec :: Int -> BadCharTable -> ShowS
Show)
  deriving anyclass (BadCharTable -> ()
(BadCharTable -> ()) -> NFData BadCharTable
forall a. (a -> ()) -> NFData a
rnf :: BadCharTable -> ()
$crnf :: BadCharTable -> ()
NFData)

-- | Number of entries in the fixed-size lookup-table of the bad char table.
badcharTableSize :: Int
{-# INLINE badcharTableSize #-}
badcharTableSize :: Int
badcharTableSize = Int
256

-- | Lookup an entry in the bad char table.
badCharLookup :: BadCharTable -> CodeUnit -> CodeUnitIndex
{-# INLINE badCharLookup #-}
badCharLookup :: BadCharTable -> CodeUnit -> CodeUnitIndex
badCharLookup (BadCharTable TypedByteArray Int
asciiTable CodeUnitIndex
_patLen) CodeUnit
char = Int -> CodeUnitIndex
CodeUnitIndex (Int -> CodeUnitIndex) -> Int -> CodeUnitIndex
forall a b. (a -> b) -> a -> b
$ TypedByteArray Int -> Int -> Int
forall a. Prim a => TypedByteArray a -> Int -> a
indexTable TypedByteArray Int
asciiTable Int
intChar
  where
    intChar :: Int
intChar = CodeUnit -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
char

-- | True if the suffix of the @pattern@ starting from @pos@ is a prefix of the pattern
-- For example, @isPrefix \"aabbaa\" 4 == True@.
isPrefix :: Text -> CodeUnitIndex -> Bool
isPrefix :: Text -> CodeUnitIndex -> Bool
isPrefix Text
needle CodeUnitIndex
pos = CodeUnitIndex -> Bool
go CodeUnitIndex
0
  where
    suffixLen :: CodeUnitIndex
suffixLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
needle CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
pos
    go :: CodeUnitIndex -> Bool
go CodeUnitIndex
i
      | CodeUnitIndex
i CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
suffixLen =
        -- FIXME: Check whether implementing the linter warning kills tco
        if Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
needle CodeUnitIndex
i CodeUnit -> CodeUnit -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
needle (CodeUnitIndex
pos CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
i)
          then CodeUnitIndex -> Bool
go (CodeUnitIndex
i CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
          else Bool
False
      | Bool
otherwise = Bool
True

-- | Length of the longest suffix of the pattern ending on @pos@.
-- For example, @suffixLength \"abaacbbaac\" 4 == 4@, because the substring \"baac\" ends at position
-- 4 and is at the same time the longest suffix that does so, having length 4.
suffixLength :: Text -> CodeUnitIndex -> CodeUnitIndex
suffixLength :: Text -> CodeUnitIndex -> CodeUnitIndex
suffixLength Text
pattern CodeUnitIndex
pos = CodeUnitIndex -> CodeUnitIndex
go CodeUnitIndex
0
  where
    patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
pattern
    go :: CodeUnitIndex -> CodeUnitIndex
go CodeUnitIndex
i
      | Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
pos CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
i) CodeUnit -> CodeUnit -> Bool
forall a. Eq a => a -> a -> Bool
== Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern (CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
i) Bool -> Bool -> Bool
&& CodeUnitIndex
i CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
pos = CodeUnitIndex -> CodeUnitIndex
go (CodeUnitIndex
i CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
      | Bool
otherwise = CodeUnitIndex
i

buildBadCharTable :: Text -> BadCharTable
buildBadCharTable :: Text -> BadCharTable
buildBadCharTable Text
pattern = (forall s. ST s BadCharTable) -> BadCharTable
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s BadCharTable) -> BadCharTable)
-> (forall s. ST s BadCharTable) -> BadCharTable
forall a b. (a -> b) -> a -> b
$ do
  let patLen :: CodeUnitIndex
patLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
pattern

  -- Initialize table with the maximum skip distance, which is the length of the pattern.
  -- This applies to all characters that are not part of the pattern.
  MutableTypedByteArray Int s
asciiTable <- Int -> Int -> ST s (MutableTypedByteArray Int (PrimState (ST s)))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
Int -> a -> m (MutableTypedByteArray a (PrimState m))
TBA.replicate Int
badcharTableSize (Int -> ST s (MutableTypedByteArray Int (PrimState (ST s))))
-> Int -> ST s (MutableTypedByteArray Int (PrimState (ST s)))
forall a b. (a -> b) -> a -> b
$ CodeUnitIndex -> Int
codeUnitIndex CodeUnitIndex
patLen

  let
    -- Fill the bad character table based on the rightmost occurrence of a character in the pattern.
    -- Note that there is also a variant of Boyer-Moore that records all positions (see Wikipedia,
    -- but that requires even more storage space).
    -- Also note that we exclude the last character of the pattern when building the table.
    -- This is because
    --
    -- 1. If the last character does not occur anywhere else in the pattern and we encounter it
    --    during a mismatch, we can advance the pattern to just after that character:
    --
    --    Haystack: aaadcdabcdbb
    --    Pattern:    abcd
    --
    --    In the above example, we would match `d` and `c`, but then fail because `d` != `b`.
    --    Since `d` only occurs at the very last position of the pattern, we can shift to
    --
    --    Haystack: aaadcdabcdbb
    --    Pattern:      abcd
    --
    -- 2. If it does occur anywhere else in the pattern, we can only shift as far as it's necessary
    --    to align it with the haystack:
    --
    --    Haystack: aaadddabcdbb
    --    Pattern:    adcd
    --
    --    We match `d`, and then there is a mismatch `d` != `c`, which allows us to shift only up to:

    --    Haystack: aaadddabcdbb
    --    Pattern:     adcd
    fillTable :: CodeUnitIndex -> ST s ()
fillTable !CodeUnitIndex
i
      -- for(i = 0; i < patLen - 1; i++) {
      | CodeUnitIndex
i CodeUnitIndex -> CodeUnitIndex -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 = do
        let patChar :: CodeUnit
patChar = Text -> CodeUnitIndex -> CodeUnit
Utf8.unsafeIndexCodeUnit Text
pattern CodeUnitIndex
i
        MutableTypedByteArray Int (PrimState (ST s))
-> Int -> Int -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableTypedByteArray a (PrimState m) -> Int -> a -> m ()
TBA.writeTypedByteArray MutableTypedByteArray Int s
MutableTypedByteArray Int (PrimState (ST s))
asciiTable (CodeUnit -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
patChar) (CodeUnitIndex -> Int
codeUnitIndex (CodeUnitIndex -> Int) -> CodeUnitIndex -> Int
forall a b. (a -> b) -> a -> b
$ CodeUnitIndex
patLen CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
1 CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
- CodeUnitIndex
i)
        CodeUnitIndex -> ST s ()
fillTable (CodeUnitIndex
i CodeUnitIndex -> CodeUnitIndex -> CodeUnitIndex
forall a. Num a => a -> a -> a
+ CodeUnitIndex
1)
      | Bool
otherwise = () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  CodeUnitIndex -> ST s ()
fillTable CodeUnitIndex
0

  TypedByteArray Int
asciiTableFrozen <- MutableTypedByteArray Int (PrimState (ST s))
-> ST s (TypedByteArray Int)
forall (m :: * -> *) a.
PrimMonad m =>
MutableTypedByteArray a (PrimState m) -> m (TypedByteArray a)
TBA.unsafeFreezeTypedByteArray MutableTypedByteArray Int s
MutableTypedByteArray Int (PrimState (ST s))
asciiTable

  BadCharTable -> ST s BadCharTable
forall (f :: * -> *) a. Applicative f => a -> f a
pure BadCharTable :: TypedByteArray Int -> CodeUnitIndex -> BadCharTable
BadCharTable
    { badCharTableEntries :: TypedByteArray Int
badCharTableEntries = TypedByteArray Int
asciiTableFrozen
    , badCharTablePatternLen :: CodeUnitIndex
badCharTablePatternLen = CodeUnitIndex
patLen
    }


-- Helper functions for easily toggling the safety of this module

-- | Read from a lookup table at the specified index.
indexTable :: Prim a => TypedByteArray a -> Int -> a
{-# INLINE indexTable #-}
indexTable :: TypedByteArray a -> Int -> a
indexTable = TypedByteArray a -> Int -> a
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex