-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2019 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An efficient implementation of the Aho-Corasick string matching algorithm.
-- See http://web.stanford.edu/class/archive/cs/cs166/cs166.1166/lectures/02/Small02.pdf
-- for a good explanation of the algorithm.
--
-- The memory layout of the automaton, and the function that steps it, were
-- optimized to the point where string matching compiles roughly to a loop over
-- the code units in the input text, that keeps track of the current state.
-- Lookup of the next state is either just an array index (for the root state),
-- or a linear scan through a small array (for non-root states). The pointer
-- chases that are common for traversing Haskell data structures have been
-- eliminated.
--
-- The construction of the automaton has not been optimized that much, because
-- construction time is usually negligible in comparison to matching time.
-- Therefore construction is a two-step process, where first we build the
-- automaton as int maps, which are convenient for incremental construction.
-- Afterwards we pack the automaton into unboxed vectors.
module Data.Text.AhoCorasick.Automaton
    ( AcMachine (..)
    , CaseSensitivity (..)
    , CodeUnitIndex (..)
    , Match (..)
    , Next (..)
    , build
    , debugBuildDot
    , runLower
    , runText
    ) where

import Prelude hiding (length)

import Control.DeepSeq (NFData)
import Data.Bits (shiftL, shiftR, (.&.), (.|.))
import Data.Foldable (foldl')
import Data.IntMap.Strict (IntMap)
import Data.Text.Internal (Text (..))
import Data.Word (Word64)
import GHC.Generics (Generic)

import qualified Data.IntMap.Strict as IntMap
import qualified Data.List as List
import qualified Data.Vector as Vector

import Data.Text.CaseSensitivity (CaseSensitivity (..))
import Data.Text.Utf16 (CodeUnit, CodeUnitIndex (..), indexTextArray, lowerCodeUnit)
import Data.TypedByteArray (Prim, TypedByteArray)

import qualified Data.TypedByteArray as TBA

-- | A numbered state in the Aho-Corasick automaton.
type State = Int

-- | A transition is a pair of (code unit, next state). The code unit is 16 bits,
-- and the state index is 32 bits. We pack these together as a manually unlifted
-- tuple, because an unboxed Vector of tuples is a tuple of vectors, but we want
-- the elements of the tuple to be adjacent in memory. (The Word64 still needs
-- to be unpacked in the places where it is used.) The code unit is stored in
-- the least significant 32 bits, with the special value 2^16 indicating a
-- wildcard; the "failure" transition. Bit 17 through 31 (starting from zero,
-- both bounds inclusive) are always 0.
--
--  Bit 63 (most significant)                 Bit 0 (least significant)
--  |                                                                 |
--  v                                                                 v
-- |<--       goto state         -->|<-- zeros   -->| |<--   input  -->|
-- |SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS|000000000000000|W|IIIIIIIIIIIIIIII|
--                                                   |
--                                                   Wildcard bit (bit 16)
--
type Transition = Word64

data Match v = Match
  { Match v -> CodeUnitIndex
matchPos   :: {-# UNPACK #-} !CodeUnitIndex
  -- ^ The code unit index past the last code unit of the match. Note that this
  -- is not a code *point* (Haskell `Char`) index; a code point might be encoded
  -- as up to four code units.
  , Match v -> v
matchValue :: v
  -- ^ The payload associated with the matched needle.
  } deriving (Int -> Match v -> ShowS
[Match v] -> ShowS
Match v -> String
(Int -> Match v -> ShowS)
-> (Match v -> String) -> ([Match v] -> ShowS) -> Show (Match v)
forall v. Show v => Int -> Match v -> ShowS
forall v. Show v => [Match v] -> ShowS
forall v. Show v => Match v -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Match v] -> ShowS
$cshowList :: forall v. Show v => [Match v] -> ShowS
show :: Match v -> String
$cshow :: forall v. Show v => Match v -> String
showsPrec :: Int -> Match v -> ShowS
$cshowsPrec :: forall v. Show v => Int -> Match v -> ShowS
Show, Match v -> Match v -> Bool
(Match v -> Match v -> Bool)
-> (Match v -> Match v -> Bool) -> Eq (Match v)
forall v. Eq v => Match v -> Match v -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Match v -> Match v -> Bool
$c/= :: forall v. Eq v => Match v -> Match v -> Bool
== :: Match v -> Match v -> Bool
$c== :: forall v. Eq v => Match v -> Match v -> Bool
Eq)

-- | An Aho-Corasick automaton.
data AcMachine v = AcMachine
  { AcMachine v -> Vector [v]
machineValues :: !(Vector.Vector [v])
  -- ^ For every state, the values associated with its needles. If the state is
  -- not a match state, the list is empty.
  , AcMachine v -> TypedByteArray Transition
machineTransitions :: !(TypedByteArray Transition)
  -- ^ A packed vector of transitions. For every state, there is a slice of this
  -- vector that starts at the offset given by `machineOffsets`, and ends at the
  -- first wildcard transition.
  , AcMachine v -> TypedByteArray Int
machineOffsets :: !(TypedByteArray Int)
  -- ^ For every state, the index into `machineTransitions` where the transition
  -- list for that state starts.
  , AcMachine v -> TypedByteArray Transition
machineRootAsciiTransitions :: !(TypedByteArray Transition)
  -- ^ A lookup table for transitions from the root state, an optimization to
  -- avoid having to walk all transitions, at the cost of using a bit of
  -- additional memory.
  } deriving ((forall x. AcMachine v -> Rep (AcMachine v) x)
-> (forall x. Rep (AcMachine v) x -> AcMachine v)
-> Generic (AcMachine v)
forall x. Rep (AcMachine v) x -> AcMachine v
forall x. AcMachine v -> Rep (AcMachine v) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall v x. Rep (AcMachine v) x -> AcMachine v
forall v x. AcMachine v -> Rep (AcMachine v) x
$cto :: forall v x. Rep (AcMachine v) x -> AcMachine v
$cfrom :: forall v x. AcMachine v -> Rep (AcMachine v) x
Generic)

instance NFData v => NFData (AcMachine v)

-- | The wildcard value is 2^16, one more than the maximal 16-bit code unit.
wildcard :: Integral a => a
wildcard :: a
wildcard = a
0x10000

-- | Extract the code unit from a transition. The special wildcard transition
-- will return 0.
transitionCodeUnit :: Transition -> CodeUnit
transitionCodeUnit :: Transition -> CodeUnit
transitionCodeUnit Transition
t = Transition -> CodeUnit
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Transition
t Transition -> Transition -> Transition
forall a. Bits a => a -> a -> a
.&. Transition
0xffff)

-- | Extract the goto state from a transition.
transitionState :: Transition -> State
transitionState :: Transition -> Int
transitionState Transition
t = Transition -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Transition
t Transition -> Int -> Transition
forall a. Bits a => a -> Int -> a
`shiftR` Int
32)

-- | Test if the transition is not for a specific code unit, but the wildcard
-- transition to take if nothing else matches.
transitionIsWildcard :: Transition -> Bool
transitionIsWildcard :: Transition -> Bool
transitionIsWildcard Transition
t = (Transition
t Transition -> Transition -> Transition
forall a. Bits a => a -> a -> a
.&. Transition
forall a. Integral a => a
wildcard) Transition -> Transition -> Bool
forall a. Eq a => a -> a -> Bool
== Transition
forall a. Integral a => a
wildcard

newTransition :: CodeUnit -> State -> Transition
newTransition :: CodeUnit -> Int -> Transition
newTransition CodeUnit
input Int
state =
  let
    input64 :: Transition
input64 = CodeUnit -> Transition
forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
input :: Word64
    state64 :: Transition
state64 = Int -> Transition
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
state :: Word64
  in
    (Transition
state64 Transition -> Int -> Transition
forall a. Bits a => a -> Int -> a
`shiftL` Int
32) Transition -> Transition -> Transition
forall a. Bits a => a -> a -> a
.|. Transition
input64

newWildcardTransition :: State -> Transition
newWildcardTransition :: Int -> Transition
newWildcardTransition Int
state =
  let
    state64 :: Transition
state64 = Int -> Transition
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
state :: Word64
  in
    (Transition
state64 Transition -> Int -> Transition
forall a. Bits a => a -> Int -> a
`shiftL` Int
32) Transition -> Transition -> Transition
forall a. Bits a => a -> a -> a
.|. Transition
forall a. Integral a => a
wildcard

-- | Pack transitions for each state into one contiguous array. In order to find
-- the transitions for a specific state, we also produce a vector of start
-- indices. All transition lists are terminated by a wildcard transition, so
-- there is no need to record the length.
packTransitions :: [[Transition]] -> (TypedByteArray Transition, TypedByteArray Int)
packTransitions :: [[Transition]] -> (TypedByteArray Transition, TypedByteArray Int)
packTransitions [[Transition]]
transitions =
  let
    packed :: TypedByteArray Transition
packed = [Transition] -> TypedByteArray Transition
forall a. Prim a => [a] -> TypedByteArray a
TBA.fromList ([Transition] -> TypedByteArray Transition)
-> [Transition] -> TypedByteArray Transition
forall a b. (a -> b) -> a -> b
$ [[Transition]] -> [Transition]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Transition]]
transitions
    offsets :: TypedByteArray Int
offsets = [Int] -> TypedByteArray Int
forall a. Prim a => [a] -> TypedByteArray a
TBA.fromList ([Int] -> TypedByteArray Int) -> [Int] -> TypedByteArray Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ ([Transition] -> Int) -> [[Transition]] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Transition] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
List.length [[Transition]]
transitions
  in
    (TypedByteArray Transition
packed, TypedByteArray Int
offsets)

-- | Construct an Aho-Corasick automaton for the given needles.
-- Takes a list of code units rather than `Text`, to allow mapping the code
-- units before construction, for example to lowercase individual code points,
-- rather than doing proper case folding (which might change the number of code
-- units).
build :: [([CodeUnit], v)] -> AcMachine v
build :: [([CodeUnit], v)] -> AcMachine v
build [([CodeUnit], v)]
needlesWithValues =
  let
    -- Construct the Aho-Corasick automaton using IntMaps, which are a suitable
    -- representation when building the automaton. We use int maps rather than
    -- hash maps to ensure that the iteration order is the same as that of a
    -- vector.
    (Int
numStates, TransitionMap
transitionMap, ValuesMap v
initialValueMap) = [([CodeUnit], v)] -> (Int, TransitionMap, ValuesMap v)
forall v. [([CodeUnit], v)] -> (Int, TransitionMap, ValuesMap v)
buildTransitionMap [([CodeUnit], v)]
needlesWithValues
    fallbackMap :: FallbackMap
fallbackMap = TransitionMap -> FallbackMap
buildFallbackMap TransitionMap
transitionMap
    valueMap :: ValuesMap v
valueMap = TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
forall v.
TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap TransitionMap
transitionMap FallbackMap
fallbackMap ValuesMap v
initialValueMap

    -- Convert the map of transitions, and the map of fallback states, into a
    -- list of transition lists, where every transition list is terminated by
    -- a wildcard transition to the fallback state.
    prependTransition :: [Transition] -> a -> Int -> [Transition]
prependTransition [Transition]
ts a
input Int
state = CodeUnit -> Int -> Transition
newTransition (a -> CodeUnit
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
input) Int
state Transition -> [Transition] -> [Transition]
forall a. a -> [a] -> [a]
: [Transition]
ts
    makeTransitions :: Int -> FallbackMap -> [Transition]
makeTransitions Int
fallback FallbackMap
ts = ([Transition] -> Int -> Int -> [Transition])
-> [Transition] -> FallbackMap -> [Transition]
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' [Transition] -> Int -> Int -> [Transition]
forall a. Integral a => [Transition] -> a -> Int -> [Transition]
prependTransition [Int -> Transition
newWildcardTransition Int
fallback] FallbackMap
ts
    transitionsList :: [[Transition]]
transitionsList = (Int -> FallbackMap -> [Transition])
-> [Int] -> [FallbackMap] -> [[Transition]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> FallbackMap -> [Transition]
makeTransitions (FallbackMap -> [Int]
forall a. IntMap a -> [a]
IntMap.elems FallbackMap
fallbackMap) (TransitionMap -> [FallbackMap]
forall a. IntMap a -> [a]
IntMap.elems TransitionMap
transitionMap)

    -- Pack the transition lists into one contiguous array, and build the lookup
    -- table for the transitions from the root state.
    (TypedByteArray Transition
transitions, TypedByteArray Int
offsets) = [[Transition]] -> (TypedByteArray Transition, TypedByteArray Int)
packTransitions [[Transition]]
transitionsList
    rootTransitions :: TypedByteArray Transition
rootTransitions = FallbackMap -> TypedByteArray Transition
buildAsciiTransitionLookupTable (FallbackMap -> TypedByteArray Transition)
-> FallbackMap -> TypedByteArray Transition
forall a b. (a -> b) -> a -> b
$ TransitionMap
transitionMap TransitionMap -> Int -> FallbackMap
forall a. IntMap a -> Int -> a
IntMap.! Int
0
    values :: Vector [v]
values = Int -> (Int -> [v]) -> Vector [v]
forall a. Int -> (Int -> a) -> Vector a
Vector.generate Int
numStates (ValuesMap v
valueMap ValuesMap v -> Int -> [v]
forall a. IntMap a -> Int -> a
IntMap.!)
  in
    Vector [v]
-> TypedByteArray Transition
-> TypedByteArray Int
-> TypedByteArray Transition
-> AcMachine v
forall v.
Vector [v]
-> TypedByteArray Transition
-> TypedByteArray Int
-> TypedByteArray Transition
-> AcMachine v
AcMachine Vector [v]
values TypedByteArray Transition
transitions TypedByteArray Int
offsets TypedByteArray Transition
rootTransitions

-- | Build the automaton, and format it as Graphviz Dot, for visual debugging.
debugBuildDot :: [[CodeUnit]] -> String
debugBuildDot :: [[CodeUnit]] -> String
debugBuildDot [[CodeUnit]]
needles =
  let
    (Int
_numStates, TransitionMap
transitionMap, ValuesMap Int
initialValueMap) =
      [([CodeUnit], Int)] -> (Int, TransitionMap, ValuesMap Int)
forall v. [([CodeUnit], v)] -> (Int, TransitionMap, ValuesMap v)
buildTransitionMap ([([CodeUnit], Int)] -> (Int, TransitionMap, ValuesMap Int))
-> [([CodeUnit], Int)] -> (Int, TransitionMap, ValuesMap Int)
forall a b. (a -> b) -> a -> b
$ [[CodeUnit]] -> [Int] -> [([CodeUnit], Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[CodeUnit]]
needles ([Int
0..] :: [Int])
    fallbackMap :: FallbackMap
fallbackMap = TransitionMap -> FallbackMap
buildFallbackMap TransitionMap
transitionMap
    valueMap :: ValuesMap Int
valueMap = TransitionMap -> FallbackMap -> ValuesMap Int -> ValuesMap Int
forall v.
TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap TransitionMap
transitionMap FallbackMap
fallbackMap ValuesMap Int
initialValueMap

    dotEdge :: String -> a -> a -> String
dotEdge String
extra a
state a
nextState =
      String
"  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (a -> String
forall a. Show a => a -> String
show a
state) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" -> " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (a -> String
forall a. Show a => a -> String
show a
nextState) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
extra String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"];"

    dotFallbackEdge :: [String] -> State -> State -> [String]
    dotFallbackEdge :: [String] -> Int -> Int -> [String]
dotFallbackEdge [String]
edges Int
state Int
nextState =
      (String -> Int -> Int -> String
forall a a. (Show a, Show a) => String -> a -> a -> String
dotEdge String
"style = dashed" Int
state Int
nextState) String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
edges

    dotTransitionEdge :: State -> [String] -> Int -> State -> [String]
    dotTransitionEdge :: Int -> [String] -> Int -> Int -> [String]
dotTransitionEdge Int
state [String]
edges Int
input Int
nextState =
      (String -> Int -> Int -> String
forall a a. (Show a, Show a) => String -> a -> a -> String
dotEdge (String
"label = \"" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
input String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\"") Int
state Int
nextState) String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
edges

    prependTransitionEdges :: [String] -> Int -> [String]
prependTransitionEdges [String]
edges Int
state =
      ([String] -> Int -> Int -> [String])
-> [String] -> FallbackMap -> [String]
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' (Int -> [String] -> Int -> Int -> [String]
dotTransitionEdge Int
state) [String]
edges (TransitionMap
transitionMap TransitionMap -> Int -> FallbackMap
forall a. IntMap a -> Int -> a
IntMap.! Int
state)

    dotMatchState :: [String] -> State -> [Int] -> [String]
    dotMatchState :: [String] -> Int -> [Int] -> [String]
dotMatchState [String]
edges Int
_ [] = [String]
edges
    dotMatchState [String]
edges Int
state [Int]
_ = (String
"  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
state String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" [shape = doublecircle];") String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
edges

    dot0 :: [String]
dot0 = ([String] -> Int -> [String])
-> [String] -> TransitionMap -> [String]
forall a. (a -> Int -> a) -> a -> TransitionMap -> a
foldBreadthFirst [String] -> Int -> [String]
prependTransitionEdges [] TransitionMap
transitionMap
    dot1 :: [String]
dot1 = ([String] -> Int -> Int -> [String])
-> [String] -> FallbackMap -> [String]
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' [String] -> Int -> Int -> [String]
dotFallbackEdge [String]
dot0 FallbackMap
fallbackMap
    dot2 :: [String]
dot2 = ([String] -> Int -> [Int] -> [String])
-> [String] -> ValuesMap Int -> [String]
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' [String] -> Int -> [Int] -> [String]
dotMatchState [String]
dot1 ValuesMap Int
valueMap
  in
    -- Set rankdir = "LR" to prefer a left-to-right graph, rather than top to
    -- bottom. I have dual widescreen monitors and I don't use them in portrait
    -- mode. Reverse the instructions because order affects node lay-out, and by
    -- prepending we built up a reversed list.
    [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ [String
"digraph {", String
"  rankdir = \"LR\";"] [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ ([String] -> [String]
forall a. [a] -> [a]
reverse [String]
dot2) [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String
"}"]

-- Different int maps that are used during constuction of the automaton. The
-- transition map represents the trie of states, the fallback map contains the
-- fallback (or "failure" or "suffix") edge for every state.
type TransitionMap = IntMap (IntMap State)
type FallbackMap = IntMap State
type ValuesMap v = IntMap [v]

-- | Build the trie of the Aho-Corasick state machine for all input needles.
buildTransitionMap :: forall v. [([CodeUnit], v)] -> (Int, TransitionMap, ValuesMap v)
buildTransitionMap :: [([CodeUnit], v)] -> (Int, TransitionMap, ValuesMap v)
buildTransitionMap =
  let
    go :: State
      -> (Int, TransitionMap, ValuesMap v)
      -> ([CodeUnit], v)
      -> (Int, TransitionMap, ValuesMap v)

    -- End of the current needle, insert the associated payload value.
    -- If a needle occurs multiple times, then at this point we will merge
    -- their payload values, so the needle is reported twice, possibly with
    -- different payload values.
    go :: Int
-> (Int, TransitionMap, ValuesMap v)
-> ([CodeUnit], v)
-> (Int, TransitionMap, ValuesMap v)
go !Int
state (!Int
numStates, TransitionMap
transitions, ValuesMap v
values) ([], v
v) =
      (Int
numStates, TransitionMap
transitions, ([v] -> [v] -> [v]) -> Int -> [v] -> ValuesMap v -> ValuesMap v
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IntMap.insertWith [v] -> [v] -> [v]
forall a. [a] -> [a] -> [a]
(++) Int
state [v
v] ValuesMap v
values)

    -- Follow the edge for the given input from the current state, creating it
    -- if it does not exist.
    go !Int
state (!Int
numStates, TransitionMap
transitions, ValuesMap v
values) (!CodeUnit
input : [CodeUnit]
needleTail, v
vs) =
      let
        transitionsFromState :: FallbackMap
transitionsFromState = TransitionMap
transitions TransitionMap -> Int -> FallbackMap
forall a. IntMap a -> Int -> a
IntMap.! Int
state
      in
        case Int -> FallbackMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (CodeUnit -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
input) FallbackMap
transitionsFromState of
          Just Int
nextState ->
            Int
-> (Int, TransitionMap, ValuesMap v)
-> ([CodeUnit], v)
-> (Int, TransitionMap, ValuesMap v)
go Int
nextState (Int
numStates, TransitionMap
transitions, ValuesMap v
values) ([CodeUnit]
needleTail, v
vs)
          Maybe Int
Nothing ->
            let
              -- Allocate a new state, and insert a transition to it.
              -- Also insert an empty transition map for it.
              nextState :: Int
nextState = Int
numStates
              transitionsFromState' :: FallbackMap
transitionsFromState' = Int -> Int -> FallbackMap -> FallbackMap
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert (CodeUnit -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
input) Int
nextState FallbackMap
transitionsFromState
              transitions' :: TransitionMap
transitions'
                = Int -> FallbackMap -> TransitionMap -> TransitionMap
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
state FallbackMap
transitionsFromState'
                (TransitionMap -> TransitionMap) -> TransitionMap -> TransitionMap
forall a b. (a -> b) -> a -> b
$ Int -> FallbackMap -> TransitionMap -> TransitionMap
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
nextState FallbackMap
forall a. IntMap a
IntMap.empty
                (TransitionMap -> TransitionMap) -> TransitionMap -> TransitionMap
forall a b. (a -> b) -> a -> b
$ TransitionMap
transitions
            in
              Int
-> (Int, TransitionMap, ValuesMap v)
-> ([CodeUnit], v)
-> (Int, TransitionMap, ValuesMap v)
go Int
nextState (Int
numStates Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, TransitionMap
transitions', ValuesMap v
values) ([CodeUnit]
needleTail, v
vs)

    -- Initially, the root state (state 0) exists, and it has no transitions
    -- to anywhere.
    stateInitial :: Int
stateInitial = Int
0
    initialTransitions :: IntMap (IntMap a)
initialTransitions = Int -> IntMap a -> IntMap (IntMap a)
forall a. Int -> a -> IntMap a
IntMap.singleton Int
stateInitial IntMap a
forall a. IntMap a
IntMap.empty
    initialValues :: IntMap a
initialValues = IntMap a
forall a. IntMap a
IntMap.empty
    insertNeedle :: (Int, TransitionMap, ValuesMap v)
-> ([CodeUnit], v) -> (Int, TransitionMap, ValuesMap v)
insertNeedle = Int
-> (Int, TransitionMap, ValuesMap v)
-> ([CodeUnit], v)
-> (Int, TransitionMap, ValuesMap v)
go Int
stateInitial
  in
    ((Int, TransitionMap, ValuesMap v)
 -> ([CodeUnit], v) -> (Int, TransitionMap, ValuesMap v))
-> (Int, TransitionMap, ValuesMap v)
-> [([CodeUnit], v)]
-> (Int, TransitionMap, ValuesMap v)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Int, TransitionMap, ValuesMap v)
-> ([CodeUnit], v) -> (Int, TransitionMap, ValuesMap v)
insertNeedle (Int
1, TransitionMap
forall a. IntMap (IntMap a)
initialTransitions, ValuesMap v
forall a. IntMap a
initialValues)

-- Size of the ascii transition lookup table.
asciiCount :: Integral a => a
asciiCount :: a
asciiCount = a
128

-- | Build a lookup table for the first 128 code units, that can be used for
-- O(1) lookup of a transition, rather than doing a linear scan over all
-- transitions. The fallback goes back to the initial state, state 0.
buildAsciiTransitionLookupTable :: IntMap State -> TypedByteArray Transition
buildAsciiTransitionLookupTable :: FallbackMap -> TypedByteArray Transition
buildAsciiTransitionLookupTable FallbackMap
transitions = Int -> (Int -> Transition) -> TypedByteArray Transition
forall a. Prim a => Int -> (Int -> a) -> TypedByteArray a
TBA.generate Int
forall a. Integral a => a
asciiCount ((Int -> Transition) -> TypedByteArray Transition)
-> (Int -> Transition) -> TypedByteArray Transition
forall a b. (a -> b) -> a -> b
$ \Int
i ->
  case Int -> FallbackMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
i FallbackMap
transitions of
    Just Int
state -> CodeUnit -> Int -> Transition
newTransition (Int -> CodeUnit
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i) Int
state
    Maybe Int
Nothing -> Int -> Transition
newWildcardTransition Int
0

-- | Traverse the state trie in breadth-first order.
foldBreadthFirst :: (a -> State -> a) -> a -> TransitionMap -> a
foldBreadthFirst :: (a -> Int -> a) -> a -> TransitionMap -> a
foldBreadthFirst a -> Int -> a
f a
seed TransitionMap
transitions = [Int] -> [Int] -> a -> a
go [Int
0] [] a
seed
  where
    -- For the traversal, we keep a queue of states to vitit. Every iteration we
    -- take one off the front, and all states reachable from there get added to
    -- the back. Rather than using a list for this, we use the functional
    -- amortized queue to avoid O(n²) append. This makes a measurable difference
    -- when the backlog can grow large. In one of our benchmark inputs for
    -- example, we have roughly 160 needles that are 10 characters each (but
    -- with some shared prefixes), and the backlog size grows to 148 during
    -- construction. Construction time goes down from ~0.80 ms to ~0.35 ms by
    -- using the amortized queue.
    -- See also section 3.1.1 of Purely Functional Data Structures by Okasaki
    -- https://www.cs.cmu.edu/~rwh/theses/okasaki.pdf.
    go :: [Int] -> [Int] -> a -> a
go [] [] !a
acc = a
acc
    go [] [Int]
revBacklog !a
acc = [Int] -> [Int] -> a -> a
go ([Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
revBacklog) [] a
acc
    go (Int
state : [Int]
backlog) [Int]
revBacklog !a
acc =
      let
        -- Note that the backlog never contains duplicates, because we traverse
        -- a trie that only branches out. For every state, there is only one
        -- path from the root that leads to it.
        extra :: [Int]
extra = FallbackMap -> [Int]
forall a. IntMap a -> [a]
IntMap.elems (FallbackMap -> [Int]) -> FallbackMap -> [Int]
forall a b. (a -> b) -> a -> b
$ TransitionMap
transitions TransitionMap -> Int -> FallbackMap
forall a. IntMap a -> Int -> a
IntMap.! Int
state
      in
        [Int] -> [Int] -> a -> a
go [Int]
backlog ([Int]
extra [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int]
revBacklog) (a -> Int -> a
f a
acc Int
state)

-- | Determine the fallback transition for every state, by traversing the
-- transition trie breadth-first.
buildFallbackMap :: TransitionMap -> FallbackMap
buildFallbackMap :: TransitionMap -> FallbackMap
buildFallbackMap TransitionMap
transitions =
  let
    -- Suppose that in state `state`, there is a transition for input `input`
    -- to state `nextState`, and we already know the fallback for `state`. Then
    -- this function returns the fallback state for `nextState`.
    getFallback :: FallbackMap -> State -> Int -> State
    -- All the states after the root state (state 0) fall back to the root state.
    getFallback :: FallbackMap -> Int -> Int -> Int
getFallback FallbackMap
_ Int
0 Int
_ = Int
0
    getFallback FallbackMap
fallbacks !Int
state !Int
input =
      let
        fallback :: Int
fallback = FallbackMap
fallbacks FallbackMap -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
state
        transitionsFromFallback :: FallbackMap
transitionsFromFallback = TransitionMap
transitions TransitionMap -> Int -> FallbackMap
forall a. IntMap a -> Int -> a
IntMap.! Int
fallback
      in
        case Int -> FallbackMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
input FallbackMap
transitionsFromFallback of
          Just Int
st -> Int
st
          Maybe Int
Nothing -> FallbackMap -> Int -> Int -> Int
getFallback FallbackMap
fallbacks Int
fallback Int
input

    insertFallback :: State -> FallbackMap -> Int -> State -> FallbackMap
    insertFallback :: Int -> FallbackMap -> Int -> Int -> FallbackMap
insertFallback !Int
state FallbackMap
fallbacks !Int
input !Int
nextState =
      Int -> Int -> FallbackMap -> FallbackMap
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
nextState (FallbackMap -> Int -> Int -> Int
getFallback FallbackMap
fallbacks Int
state Int
input) FallbackMap
fallbacks

    insertFallbacks :: FallbackMap -> State -> FallbackMap
    insertFallbacks :: FallbackMap -> Int -> FallbackMap
insertFallbacks FallbackMap
fallbacks !Int
state =
      (FallbackMap -> Int -> Int -> FallbackMap)
-> FallbackMap -> FallbackMap -> FallbackMap
forall a b. (a -> Int -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' (Int -> FallbackMap -> Int -> Int -> FallbackMap
insertFallback Int
state) FallbackMap
fallbacks (TransitionMap
transitions TransitionMap -> Int -> FallbackMap
forall a. IntMap a -> Int -> a
IntMap.! Int
state)
  in
    (FallbackMap -> Int -> FallbackMap)
-> FallbackMap -> TransitionMap -> FallbackMap
forall a. (a -> Int -> a) -> a -> TransitionMap -> a
foldBreadthFirst FallbackMap -> Int -> FallbackMap
insertFallbacks (Int -> Int -> FallbackMap
forall a. Int -> a -> IntMap a
IntMap.singleton Int
0 Int
0) TransitionMap
transitions

-- | Determine which matches to report at every state, by traversing the
-- transition trie breadth-first, and appending all the matches from a fallback
-- state to the matches for the current state.
buildValueMap :: forall v. TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap :: TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap TransitionMap
transitions FallbackMap
fallbacks ValuesMap v
valuesInitial =
  let
    insertValues :: ValuesMap v -> State -> ValuesMap v
    insertValues :: ValuesMap v -> Int -> ValuesMap v
insertValues ValuesMap v
values !Int
state =
      let
        fallbackValues :: [v]
fallbackValues = ValuesMap v
values ValuesMap v -> Int -> [v]
forall a. IntMap a -> Int -> a
IntMap.! (FallbackMap
fallbacks FallbackMap -> Int -> Int
forall a. IntMap a -> Int -> a
IntMap.! Int
state)
        valuesForState :: [v]
valuesForState = case Int -> ValuesMap v -> Maybe [v]
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
state ValuesMap v
valuesInitial of
          Just [v]
vs -> [v]
vs [v] -> [v] -> [v]
forall a. [a] -> [a] -> [a]
++ [v]
fallbackValues
          Maybe [v]
Nothing -> [v]
fallbackValues
      in
        Int -> [v] -> ValuesMap v -> ValuesMap v
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
state [v]
valuesForState ValuesMap v
values
  in
    (ValuesMap v -> Int -> ValuesMap v)
-> ValuesMap v -> TransitionMap -> ValuesMap v
forall a. (a -> Int -> a) -> a -> TransitionMap -> a
foldBreadthFirst ValuesMap v -> Int -> ValuesMap v
insertValues (Int -> [v] -> ValuesMap v
forall a. Int -> a -> IntMap a
IntMap.singleton Int
0 []) TransitionMap
transitions

-- Define aliases for array indexing so we can turn bounds checks on and off
-- in one place. We ran this code with `Vector.!` (bounds-checked indexing) in
-- production for two months without failing the bounds check, so we have turned
-- the check off for performance now.
at :: forall a. Vector.Vector a -> Int -> a
at :: Vector a -> Int -> a
at = Vector a -> Int -> a
forall a. Vector a -> Int -> a
Vector.unsafeIndex

{-# INLINE uAt #-}
uAt :: Prim a => TypedByteArray a -> Int -> a
uAt :: TypedByteArray a -> Int -> a
uAt = TypedByteArray a -> Int -> a
forall a. Prim a => TypedByteArray a -> Int -> a
TBA.unsafeIndex

-- | Result of handling a match: stepping the automaton can exit early by
-- returning a `Done`, or it can continue with a new accumulator with `Step`.
data Next a
  = Done !a
  | Step !a

-- | Run the automaton, possibly lowercasing the input text on the fly if case
-- insensitivity is desired. See also `lowerCodeUnit` and `runLower`.
-- WARNING: Run benchmarks when modifying this function; its performance is
-- fragile. It took many days to discover the current formulation which compiles
-- to fast code; removing the wrong bang pattern could cause a 10% performance
-- regression.
{-# INLINE runWithCase #-}
runWithCase
  :: forall a v
  .  CaseSensitivity
  -> a
  -> (a -> Match v -> Next a)
  -> AcMachine v
  -> Text
  -> a
runWithCase :: CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runWithCase CaseSensitivity
caseSensitivity a
seed a -> Match v -> Next a
f AcMachine v
machine Text
text =
  let
    Text Array
u16data !Int
initialOffset !Int
initialRemaining = Text
text
    !values :: Vector [v]
values = AcMachine v -> Vector [v]
forall v. AcMachine v -> Vector [v]
machineValues AcMachine v
machine
    !transitions :: TypedByteArray Transition
transitions = AcMachine v -> TypedByteArray Transition
forall v. AcMachine v -> TypedByteArray Transition
machineTransitions AcMachine v
machine
    !offsets :: TypedByteArray Int
offsets = AcMachine v -> TypedByteArray Int
forall v. AcMachine v -> TypedByteArray Int
machineOffsets AcMachine v
machine
    !rootAsciiTransitions :: TypedByteArray Transition
rootAsciiTransitions = AcMachine v -> TypedByteArray Transition
forall v. AcMachine v -> TypedByteArray Transition
machineRootAsciiTransitions AcMachine v
machine
    !stateInitial :: Int
stateInitial = Int
0

    -- NOTE: All of the arguments are strict here, because we want to compile
    -- them down to unpacked variables on the stack, or even registers.
    -- The INLINE / NOINLINE annotations here were added to fix a regression we
    -- observed when going from GHC 8.2 to GHC 8.6, and this particular
    -- combination of INLINE and NOINLINE is the fastest one. Removing increases
    -- the benchmark running time by about 9%.

    {-# NOINLINE consumeInput #-}
    consumeInput :: Int -> Int -> a -> State -> a
    consumeInput :: Int -> Int -> a -> Int -> a
consumeInput !Int
offset !Int
remaining !a
acc !Int
state =
      let
        inputCodeUnit :: CodeUnit
inputCodeUnit = CodeUnit -> CodeUnit
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CodeUnit -> CodeUnit) -> CodeUnit -> CodeUnit
forall a b. (a -> b) -> a -> b
$ Array -> Int -> CodeUnit
indexTextArray Array
u16data Int
offset
        -- NOTE: Although doing this match here entangles the automaton a bit
        -- with case sensitivity, doing so is faster than passing in a function
        -- that transforms each code unit.
        casedCodeUnit :: CodeUnit
casedCodeUnit = case CaseSensitivity
caseSensitivity of
          CaseSensitivity
IgnoreCase -> CodeUnit -> CodeUnit
lowerCodeUnit CodeUnit
inputCodeUnit
          CaseSensitivity
CaseSensitive -> CodeUnit
inputCodeUnit
      in
        case Int
remaining of
          Int
0 -> a
acc
          Int
_ -> Int -> Int -> a -> Int -> CodeUnit -> a
followEdge (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
remaining Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a
acc Int
state CodeUnit
casedCodeUnit

    {-# INLINE followEdge #-}
    followEdge :: Int -> Int -> a -> State -> CodeUnit -> a
    followEdge :: Int -> Int -> a -> Int -> CodeUnit -> a
followEdge !Int
offset !Int
remaining !a
acc !Int
state !CodeUnit
input =
      let
        !tssOffset :: Int
tssOffset = TypedByteArray Int
offsets TypedByteArray Int -> Int -> Int
forall a. Prim a => TypedByteArray a -> Int -> a
`uAt` Int
state
      in
        -- When we follow an edge, we look in the transition table and do a
        -- linear scan over all transitions until we find the right one, or
        -- until we hit the wildcard transition at the end. For 0 or 1 or 2
        -- transitions that is fine, but the initial state often has more
        -- transitions, so we have a dedicated lookup table for it, that takes
        -- up a bit more space, but provides O(1) lookup of the next state. We
        -- only do this for the first 128 code units (all of ascii).
        if Int
state Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
stateInitial Bool -> Bool -> Bool
&& CodeUnit
input CodeUnit -> CodeUnit -> Bool
forall a. Ord a => a -> a -> Bool
< CodeUnit
forall a. Integral a => a
asciiCount
          then Int -> Int -> a -> CodeUnit -> a
lookupRootAsciiTransition Int
offset Int
remaining a
acc CodeUnit
input
          else Int -> Int -> a -> Int -> CodeUnit -> Int -> a
lookupTransition Int
offset Int
remaining a
acc Int
state CodeUnit
input Int
tssOffset

    {-# NOINLINE collectMatches #-}
    collectMatches :: Int -> Int -> a -> State -> a
    collectMatches :: Int -> Int -> a -> Int -> a
collectMatches !Int
offset !Int
remaining !a
acc !Int
state =
      let
        matchedValues :: [v]
matchedValues = Vector [v]
values Vector [v] -> Int -> [v]
forall a. Vector a -> Int -> a
`at` Int
state
        -- Fold over the matched values. If at any point the user-supplied fold
        -- function returns `Done`, then we early out. Otherwise continue.
        handleMatch :: a -> [v] -> a
handleMatch !a
acc' [v]
vs = case [v]
vs of
          []     -> Int -> Int -> a -> Int -> a
consumeInput Int
offset Int
remaining a
acc' Int
state
          v
v:[v]
more -> case a -> Match v -> Next a
f a
acc' (CodeUnitIndex -> v -> Match v
forall v. CodeUnitIndex -> v -> Match v
Match (Int -> CodeUnitIndex
CodeUnitIndex (Int -> CodeUnitIndex) -> Int -> CodeUnitIndex
forall a b. (a -> b) -> a -> b
$ Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
initialOffset) v
v) of
            Step a
newAcc -> a -> [v] -> a
handleMatch a
newAcc [v]
more
            Done a
finalAcc -> a
finalAcc
      in
        a -> [v] -> a
handleMatch a
acc [v]
matchedValues

    -- NOTE: there is no `state` argument here, because this case applies only
    -- to the root state `stateInitial`.
    {-# INLINE lookupRootAsciiTransition #-}
    lookupRootAsciiTransition :: Int -> Int -> a -> CodeUnit -> a
    lookupRootAsciiTransition :: Int -> Int -> a -> CodeUnit -> a
lookupRootAsciiTransition !Int
offset !Int
remaining !a
acc !CodeUnit
input =
      case TypedByteArray Transition
rootAsciiTransitions TypedByteArray Transition -> Int -> Transition
forall a. Prim a => TypedByteArray a -> Int -> a
`uAt` CodeUnit -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CodeUnit
input of
        Transition
t | Transition -> Bool
transitionIsWildcard Transition
t -> Int -> Int -> a -> Int -> a
consumeInput Int
offset Int
remaining a
acc Int
stateInitial
          | Bool
otherwise -> Int -> Int -> a -> Int -> a
collectMatches Int
offset Int
remaining a
acc (Transition -> Int
transitionState Transition
t)

    {-# INLINE lookupTransition #-}
    lookupTransition :: Int -> Int -> a -> State -> CodeUnit -> Int -> a
    lookupTransition :: Int -> Int -> a -> Int -> CodeUnit -> Int -> a
lookupTransition !Int
offset !Int
remaining !a
acc !Int
state !CodeUnit
input !Int
i =
      case TypedByteArray Transition
transitions TypedByteArray Transition -> Int -> Transition
forall a. Prim a => TypedByteArray a -> Int -> a
`uAt` Int
i of
        -- There is no transition for the given input. Follow the fallback edge,
        -- and try again from that state, etc. If we are in the base state
        -- already, then nothing matched, so move on to the next input.
        Transition
t | Transition -> Bool
transitionIsWildcard Transition
t ->
              if Int
state Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
stateInitial
                then Int -> Int -> a -> Int -> a
consumeInput Int
offset Int
remaining a
acc Int
state
                else Int -> Int -> a -> Int -> CodeUnit -> a
followEdge Int
offset Int
remaining a
acc (Transition -> Int
transitionState Transition
t) CodeUnit
input

        -- We found the transition, switch to that new state, collecting matches.
        -- NOTE: This comes after wildcard checking, because the code unit of
        -- the wildcard transition is 0, which is a valid input.
        Transition
t | Transition -> CodeUnit
transitionCodeUnit Transition
t CodeUnit -> CodeUnit -> Bool
forall a. Eq a => a -> a -> Bool
== CodeUnit
input ->
              Int -> Int -> a -> Int -> a
collectMatches Int
offset Int
remaining a
acc (Transition -> Int
transitionState Transition
t)

        -- The transition we inspected is not for the current input, and it is not
        -- a wildcard either; look at the next transition then.
        Transition
_ -> Int -> Int -> a -> Int -> CodeUnit -> Int -> a
lookupTransition Int
offset Int
remaining a
acc Int
state CodeUnit
input (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  in
    Int -> Int -> a -> Int -> a
consumeInput Int
initialOffset Int
initialRemaining a
seed Int
stateInitial

-- NOTE: To get full advantage of inlining this function, you probably want to
-- compile the compiling module with -fllvm and the same optimization flags as
-- this module.
{-# INLINE runText #-}
runText :: forall a v. a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runText :: a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runText = CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
forall a v.
CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runWithCase CaseSensitivity
CaseSensitive

-- Finds all matches in the lowercased text. This function lowercases the text
-- on the fly to avoid allocating a second lowercased text array. Lowercasing is
-- applied to individual code units, so the indexes into the lowercased text can
-- be used to index into the original text. It is still the responsibility of
-- the caller to lowercase the needles. Needles that contain uppercase code
-- points will not match.
--
-- NOTE: To get full advantage of inlining this function, you probably want to
-- compile the compiling module with -fllvm and the same optimization flags as
-- this module.
{-# INLINE runLower #-}
runLower :: forall a v. a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runLower :: a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runLower = CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
forall a v.
CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runWithCase CaseSensitivity
IgnoreCase