-- Alfred-Margaret: Fast Aho-Corasick string searching
-- Copyright 2022 Channable
--
-- Licensed under the 3-clause BSD license, see the LICENSE file in the
-- repository root.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | An efficient implementation of the Aho-Corasick string matching algorithm.
-- See http://web.stanford.edu/class/archive/cs/cs166/cs166.1166/lectures/02/Small02.pdf
-- for a good explanation of the algorithm.
--
-- The memory layout of the automaton, and the function that steps it, were
-- optimized to the point where string matching compiles roughly to a loop over
-- the code units in the input text, that keeps track of the current state.
-- Lookup of the next state is either just an array index (for the root state),
-- or a linear scan through a small array (for non-root states). The pointer
-- chases that are common for traversing Haskell data structures have been
-- eliminated.
--
-- The construction of the automaton has not been optimized that much, because
-- construction time is usually negligible in comparison to matching time.
-- Therefore construction is a two-step process, where first we build the
-- automaton as int maps, which are convenient for incremental construction.
-- Afterwards we pack the automaton into unboxed vectors.
--
-- This module is a rewrite of the previous version which used an older version of
-- the 'text' package which in turn used UTF-16 internally.
module Data.Text.AhoCorasick.Automaton
    ( AcMachine (..)
    , CaseSensitivity (..)
    , CodeUnitIndex (..)
    , Match (..)
    , Next (..)
    , build
    , debugBuildDot
    , runLower
    , runText
    , runWithCase
    , needleCasings
    ) where

import Control.DeepSeq (NFData)
import Data.Bits (Bits (shiftL, shiftR, (.&.), (.|.)))
import Data.Char (chr)
import Data.Foldable (foldl')
import Data.IntMap.Strict (IntMap)
import Data.Word (Word32, Word64)
import GHC.Generics (Generic)

import qualified Data.Char as Char
import qualified Data.IntMap.Strict as IntMap
import qualified Data.List as List
import qualified Data.Vector as Vector

import Data.Text.CaseSensitivity (CaseSensitivity (..))
import Data.Text.Utf8 (CodePoint, CodeUnitIndex (CodeUnitIndex), Text (..))
import Data.TypedByteArray (Prim, TypedByteArray)

import qualified Data.Text as Text
import qualified Data.Text.Utf8 as Utf8
import qualified Data.TypedByteArray as TBA

-- TYPES
-- | A numbered state in the Aho-Corasick automaton.
type State = Int

-- | A transition is a pair of (code point, next state). The code point is 21 bits,
-- and the state index is 32 bits. The code point is stored in
-- the least significant 32 bits, with the special value 2^21 indicating a
-- wildcard; the "failure" transition. Bits 22 through 31 (starting from zero,
-- both bounds inclusive) are always 0.
--
--
-- >  Bit 63 (most significant)                 Bit 0 (least significant)
-- >  |                                                                 |
-- >  v                                                                 v
-- > |<--       goto state         -->|<-- 0s -->| |<--     input     -->|
-- > |SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSS|0000000000|W|IIIIIIIIIIIIIIIIIIIII|
-- >                                              |
-- >                                        Wildcard bit (bit 21)
--
-- If you change this representation, make sure to update 'transitionCodeUnit',
-- 'wildcard', 'transitionState', 'transitionIsWildcard', 'newTransition' and
-- 'newWildcardTransition' as well. Those functions form the interface used to
-- construct and read transitions.
type Transition = Word64

type Offset = Word32

data Match v = Match
  { forall v. Match v -> CodeUnitIndex
matchPos   :: {-# UNPACK #-} !CodeUnitIndex
  -- ^ The code unit index past the last code unit of the match. Note that this
  -- is not a code *point* (Haskell `Char`) index; a code point might be encoded
  -- as up to four code units.
  , forall v. Match v -> v
matchValue :: v
  -- ^ The payload associated with the matched needle.
  }

-- | An Aho-Corasick automaton.
data AcMachine v = AcMachine
  { forall v. AcMachine v -> Vector [v]
machineValues               :: !(Vector.Vector [v])
  -- ^ For every state, the values associated with its needles. If the state is
  -- not a match state, the list is empty.
  , forall v. AcMachine v -> TypedByteArray Transition
machineTransitions          :: !(TypedByteArray Transition)
  -- ^ A packed vector of transitions. For every state, there is a slice of this
  -- vector that starts at the offset given by `machineOffsets`, and ends at the
  -- first wildcard transition.
  , forall v. AcMachine v -> TypedByteArray Offset
machineOffsets              :: !(TypedByteArray Offset)
  -- ^ For every state, the index into `machineTransitions` where the transition
  -- list for that state starts.
  , forall v. AcMachine v -> TypedByteArray Transition
machineRootAsciiTransitions :: !(TypedByteArray Transition)
  -- ^ A lookup table for transitions from the root state, an optimization to
  -- avoid having to walk all transitions, at the cost of using a bit of
  -- additional memory.
  } deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall v x. Rep (AcMachine v) x -> AcMachine v
forall v x. AcMachine v -> Rep (AcMachine v) x
$cto :: forall v x. Rep (AcMachine v) x -> AcMachine v
$cfrom :: forall v x. AcMachine v -> Rep (AcMachine v) x
Generic, forall a b. a -> AcMachine b -> AcMachine a
forall a b. (a -> b) -> AcMachine a -> AcMachine b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AcMachine b -> AcMachine a
$c<$ :: forall a b. a -> AcMachine b -> AcMachine a
fmap :: forall a b. (a -> b) -> AcMachine a -> AcMachine b
$cfmap :: forall a b. (a -> b) -> AcMachine a -> AcMachine b
Functor)

instance NFData v => NFData (AcMachine v)

-- AUTOMATON CONSTRUCTION

-- | The wildcard value is 2^21, one more than the maximal 21-bit code point.
wildcard :: Integral a => a
wildcard :: forall a. Integral a => a
wildcard = a
0x200000

-- | Extract the code unit from a transition. The special wildcard transition
-- will return 0.
transitionCodeUnit :: Transition -> CodePoint
transitionCodeUnit :: Transition -> CodePoint
transitionCodeUnit Transition
t = State -> CodePoint
Char.chr forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Transition
t forall a. Bits a => a -> a -> a
.&. Transition
0x1fffff)

-- | Extract the goto state from a transition.
transitionState :: Transition -> State
transitionState :: Transition -> State
transitionState Transition
t = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Transition
t forall a. Bits a => a -> State -> a
`shiftR` State
32)

-- | Test if the transition is not for a specific code unit, but the wildcard
-- transition to take if nothing else matches.
transitionIsWildcard :: Transition -> Bool
transitionIsWildcard :: Transition -> Bool
transitionIsWildcard Transition
t = (Transition
t forall a. Bits a => a -> a -> a
.&. forall a. Integral a => a
wildcard) forall a. Eq a => a -> a -> Bool
== forall a. Integral a => a
wildcard

newTransition :: CodePoint -> State -> Transition
newTransition :: CodePoint -> State -> Transition
newTransition CodePoint
input State
state =
  let
    input64 :: Transition
input64 = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ CodePoint -> State
Char.ord CodePoint
input :: Word64
    state64 :: Transition
state64 = forall a b. (Integral a, Num b) => a -> b
fromIntegral State
state :: Word64
  in
    (Transition
state64 forall a. Bits a => a -> State -> a
`shiftL` State
32) forall a. Bits a => a -> a -> a
.|. Transition
input64

newWildcardTransition :: State -> Transition
newWildcardTransition :: State -> Transition
newWildcardTransition State
state =
  let
    state64 :: Transition
state64 = forall a b. (Integral a, Num b) => a -> b
fromIntegral State
state :: Word64
  in
    (Transition
state64 forall a. Bits a => a -> State -> a
`shiftL` State
32) forall a. Bits a => a -> a -> a
.|. forall a. Integral a => a
wildcard

-- | Pack transitions for each state into one contiguous array. In order to find
-- the transitions for a specific state, we also produce a vector of start
-- indices. All transition lists are terminated by a wildcard transition, so
-- there is no need to record the length.
packTransitions :: [[Transition]] -> (TypedByteArray Transition, TypedByteArray Offset)
packTransitions :: [[Transition]]
-> (TypedByteArray Transition, TypedByteArray Offset)
packTransitions [[Transition]]
transitions =
  let
    packed :: TypedByteArray Transition
packed = forall a. Prim a => [a] -> TypedByteArray a
TBA.fromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Transition]]
transitions
    offsets :: TypedByteArray Offset
offsets = forall a. Prim a => [a] -> TypedByteArray a
TBA.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(+) State
0 forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t a -> State
List.length [[Transition]]
transitions
  in
    (TypedByteArray Transition
packed, TypedByteArray Offset
offsets)

-- | Construct an Aho-Corasick automaton for the given needles.
-- The automaton uses Unicode code points to match the input.
build :: [(Text, v)] -> AcMachine v
build :: forall v. [(Text, v)] -> AcMachine v
build [(Text, v)]
needlesWithValues =
  let
    -- Construct the Aho-Corasick automaton using IntMaps, which are a suitable
    -- representation when building the automaton. We use int maps rather than
    -- hash maps to ensure that the iteration order is the same as that of a
    -- vector.
    (State
numStates, TransitionMap
transitionMap, ValuesMap v
initialValueMap) = forall v. [(Text, v)] -> (State, TransitionMap, ValuesMap v)
buildTransitionMap [(Text, v)]
needlesWithValues
    fallbackMap :: FallbackMap
fallbackMap = TransitionMap -> FallbackMap
buildFallbackMap TransitionMap
transitionMap
    valueMap :: ValuesMap v
valueMap = forall v.
TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap TransitionMap
transitionMap FallbackMap
fallbackMap ValuesMap v
initialValueMap

    -- Convert the map of transitions, and the map of fallback states, into a
    -- list of transition lists, where every transition list is terminated by
    -- a wildcard transition to the fallback state.
    prependTransition :: [Transition] -> State -> State -> [Transition]
prependTransition [Transition]
ts State
input State
state = CodePoint -> State -> Transition
newTransition (State -> CodePoint
Char.chr State
input) State
state forall a. a -> [a] -> [a]
: [Transition]
ts
    makeTransitions :: State -> FallbackMap -> [Transition]
makeTransitions State
fallback FallbackMap
ts = forall a b. (a -> State -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' [Transition] -> State -> State -> [Transition]
prependTransition [State -> Transition
newWildcardTransition State
fallback] FallbackMap
ts
    transitionsList :: [[Transition]]
transitionsList = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith State -> FallbackMap -> [Transition]
makeTransitions (forall a. IntMap a -> [a]
IntMap.elems FallbackMap
fallbackMap) (forall a. IntMap a -> [a]
IntMap.elems TransitionMap
transitionMap)

    -- Pack the transition lists into one contiguous array, and build the lookup
    -- table for the transitions from the root state.
    (TypedByteArray Transition
transitions, TypedByteArray Offset
offsets) = [[Transition]]
-> (TypedByteArray Transition, TypedByteArray Offset)
packTransitions [[Transition]]
transitionsList
    rootTransitions :: TypedByteArray Transition
rootTransitions = FallbackMap -> TypedByteArray Transition
buildAsciiTransitionLookupTable forall a b. (a -> b) -> a -> b
$ TransitionMap
transitionMap forall a. IntMap a -> State -> a
IntMap.! State
0
    values :: Vector [v]
values = forall a. State -> (State -> a) -> Vector a
Vector.generate State
numStates (ValuesMap v
valueMap forall a. IntMap a -> State -> a
IntMap.!)
  in
    forall v.
Vector [v]
-> TypedByteArray Transition
-> TypedByteArray Offset
-> TypedByteArray Transition
-> AcMachine v
AcMachine Vector [v]
values TypedByteArray Transition
transitions TypedByteArray Offset
offsets TypedByteArray Transition
rootTransitions

-- | Build the automaton, and format it as Graphviz Dot, for visual debugging.
debugBuildDot :: [Text] -> String
debugBuildDot :: [Text] -> String
debugBuildDot [Text]
needles =
  let
    (State
_numStates, TransitionMap
transitionMap, ValuesMap State
initialValueMap) =
      forall v. [(Text, v)] -> (State, TransitionMap, ValuesMap v)
buildTransitionMap forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Text]
needles ([State
0..] :: [Int])
    fallbackMap :: FallbackMap
fallbackMap = TransitionMap -> FallbackMap
buildFallbackMap TransitionMap
transitionMap
    valueMap :: ValuesMap State
valueMap = forall v.
TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap TransitionMap
transitionMap FallbackMap
fallbackMap ValuesMap State
initialValueMap

    dotEdge :: String -> a -> a -> String
dotEdge String
extra a
state a
nextState =
      String
"  " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
state forall a. [a] -> [a] -> [a]
++ String
" -> " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show a
nextState forall a. [a] -> [a] -> [a]
++ String
" [" forall a. [a] -> [a] -> [a]
++ String
extra forall a. [a] -> [a] -> [a]
++ String
"];"

    dotFallbackEdge :: [String] -> State -> State -> [String]
    dotFallbackEdge :: [String] -> State -> State -> [String]
dotFallbackEdge [String]
edges State
state State
nextState =
      forall {a} {a}. (Show a, Show a) => String -> a -> a -> String
dotEdge String
"style = dashed" State
state State
nextState forall a. a -> [a] -> [a]
: [String]
edges

    dotTransitionEdge :: State -> [String] -> Int -> State -> [String]
    dotTransitionEdge :: State -> [String] -> State -> State -> [String]
dotTransitionEdge State
state [String]
edges State
input State
nextState =
      forall {a} {a}. (Show a, Show a) => String -> a -> a -> String
dotEdge (String
"label = \"" forall a. [a] -> [a] -> [a]
++ State -> String
showInput State
input forall a. [a] -> [a] -> [a]
++ String
"\"") State
state State
nextState forall a. a -> [a] -> [a]
: [String]
edges

    showInput :: State -> String
showInput State
input = [State -> CodePoint
chr State
input]

    prependTransitionEdges :: [String] -> State -> [String]
prependTransitionEdges [String]
edges State
state =
      forall a b. (a -> State -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' (State -> [String] -> State -> State -> [String]
dotTransitionEdge State
state) [String]
edges (TransitionMap
transitionMap forall a. IntMap a -> State -> a
IntMap.! State
state)

    dotMatchState :: [String] -> State -> [Int] -> [String]
    dotMatchState :: [String] -> State -> [State] -> [String]
dotMatchState [String]
edges State
_ [] = [String]
edges
    dotMatchState [String]
edges State
state [State]
_ = (String
"  " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show State
state forall a. [a] -> [a] -> [a]
++ String
" [shape = doublecircle];") forall a. a -> [a] -> [a]
: [String]
edges

    dot0 :: [String]
dot0 = forall a. (a -> State -> a) -> a -> TransitionMap -> a
foldBreadthFirst [String] -> State -> [String]
prependTransitionEdges [] TransitionMap
transitionMap
    dot1 :: [String]
dot1 = forall a b. (a -> State -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' [String] -> State -> State -> [String]
dotFallbackEdge [String]
dot0 FallbackMap
fallbackMap
    dot2 :: [String]
dot2 = forall a b. (a -> State -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' [String] -> State -> [State] -> [String]
dotMatchState [String]
dot1 ValuesMap State
valueMap
  in
    -- Set rankdir = "LR" to prefer a left-to-right graph, rather than top to
    -- bottom. I have dual widescreen monitors and I don't use them in portrait
    -- mode. Reverse the instructions because order affects node lay-out, and by
    -- prepending we built up a reversed list.
    [String] -> String
unlines forall a b. (a -> b) -> a -> b
$ [String
"digraph {", String
"  rankdir = \"LR\";"] forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [a]
reverse [String]
dot2 forall a. [a] -> [a] -> [a]
++ [String
"}"]

-- Different int maps that are used during constuction of the automaton. The
-- transition map represents the trie of states, the fallback map contains the
-- fallback (or "failure" or "suffix") edge for every state.
type TransitionMap = IntMap (IntMap State)
type FallbackMap = IntMap State
type ValuesMap v = IntMap [v]

-- | Build the trie of the Aho-Corasick state machine for all input needles.
buildTransitionMap :: forall v. [(Text, v)] -> (Int, TransitionMap, ValuesMap v)
buildTransitionMap :: forall v. [(Text, v)] -> (State, TransitionMap, ValuesMap v)
buildTransitionMap =
  let
    -- | Inserts a single needle into the given transition and values map.
    insertNeedle :: (Int, TransitionMap, ValuesMap v) -> (Text, v) -> (Int, TransitionMap, ValuesMap v)
    insertNeedle :: (State, TransitionMap, ValuesMap v)
-> (Text, v) -> (State, TransitionMap, ValuesMap v)
insertNeedle !(State, TransitionMap, ValuesMap v)
acc (!Text
needle, !v
value) = State
-> CodeUnitIndex
-> (State, TransitionMap, ValuesMap v)
-> (State, TransitionMap, ValuesMap v)
go State
stateInitial CodeUnitIndex
0 (State, TransitionMap, ValuesMap v)
acc
      where
        !needleLen :: CodeUnitIndex
needleLen = Text -> CodeUnitIndex
Utf8.lengthUtf8 Text
needle

        go :: State
-> CodeUnitIndex
-> (State, TransitionMap, ValuesMap v)
-> (State, TransitionMap, ValuesMap v)
go !State
state !CodeUnitIndex
index (!State
numStates, !TransitionMap
transitions, !ValuesMap v
values)
          -- End of the current needle, insert the associated payload value.
          -- If a needle occurs multiple times, then at this point we will merge
          -- their payload values, so the needle is reported twice, possibly with
          -- different payload values.
          | CodeUnitIndex
index forall a. Ord a => a -> a -> Bool
>= CodeUnitIndex
needleLen = (State
numStates, TransitionMap
transitions, forall a. (a -> a -> a) -> State -> a -> IntMap a -> IntMap a
IntMap.insertWith forall a. [a] -> [a] -> [a]
(++) State
state [v
value] ValuesMap v
values)
        go !State
state !CodeUnitIndex
index (!State
numStates, !TransitionMap
transitions, !ValuesMap v
values) =
          let
            !transitionsFromState :: FallbackMap
transitionsFromState = TransitionMap
transitions forall a. IntMap a -> State -> a
IntMap.! State
state
            (!CodeUnitIndex
codeUnits, !CodePoint
input) = Text -> CodeUnitIndex -> (CodeUnitIndex, CodePoint)
Utf8.unsafeIndexCodePoint Text
needle CodeUnitIndex
index
          in
            case forall a. State -> IntMap a -> Maybe a
IntMap.lookup (CodePoint -> State
Char.ord CodePoint
input) FallbackMap
transitionsFromState of
              -- Transition already exists, follow it and continue from there.
              Just !State
nextState ->
                State
-> CodeUnitIndex
-> (State, TransitionMap, ValuesMap v)
-> (State, TransitionMap, ValuesMap v)
go State
nextState (CodeUnitIndex
index forall a. Num a => a -> a -> a
+ CodeUnitIndex
codeUnits) (State
numStates, TransitionMap
transitions, ValuesMap v
values)
              -- Transition for input does not exist at state:
              -- Allocate a new state, and insert a transition to it.
              -- Also insert an empty transition map for it.
              Maybe State
Nothing ->
                let
                  !nextState :: State
nextState = State
numStates
                  !transitionsFromState' :: FallbackMap
transitionsFromState' = forall a. State -> a -> IntMap a -> IntMap a
IntMap.insert (CodePoint -> State
Char.ord CodePoint
input) State
nextState FallbackMap
transitionsFromState
                  !transitions' :: TransitionMap
transitions'
                    = forall a. State -> a -> IntMap a -> IntMap a
IntMap.insert State
state FallbackMap
transitionsFromState'
                    forall a b. (a -> b) -> a -> b
$ forall a. State -> a -> IntMap a -> IntMap a
IntMap.insert State
nextState forall a. IntMap a
IntMap.empty TransitionMap
transitions
                in
                  State
-> CodeUnitIndex
-> (State, TransitionMap, ValuesMap v)
-> (State, TransitionMap, ValuesMap v)
go State
nextState (CodeUnitIndex
index forall a. Num a => a -> a -> a
+ CodeUnitIndex
codeUnits) (State
numStates forall a. Num a => a -> a -> a
+ State
1, TransitionMap
transitions', ValuesMap v
values)

    -- Initially, the root state (state 0) exists, and it has no transitions
    -- to anywhere.
    stateInitial :: State
stateInitial = State
0
    initialTransitions :: IntMap (IntMap a)
initialTransitions = forall a. State -> a -> IntMap a
IntMap.singleton State
stateInitial forall a. IntMap a
IntMap.empty
    initialValues :: IntMap a
initialValues = forall a. IntMap a
IntMap.empty
  in
    forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (State, TransitionMap, ValuesMap v)
-> (Text, v) -> (State, TransitionMap, ValuesMap v)
insertNeedle (State
1, forall {a}. IntMap (IntMap a)
initialTransitions, forall a. IntMap a
initialValues)

-- Size of the ascii transition lookup table.
asciiCount :: Integral a => a
asciiCount :: forall a. Integral a => a
asciiCount = a
128

-- | Build a lookup table for the first 128 code points, that can be used for
-- O(1) lookup of a transition, rather than doing a linear scan over all
-- transitions. The fallback goes back to the initial state, state 0.
{-# NOINLINE buildAsciiTransitionLookupTable  #-}
buildAsciiTransitionLookupTable :: IntMap State -> TypedByteArray Transition
buildAsciiTransitionLookupTable :: FallbackMap -> TypedByteArray Transition
buildAsciiTransitionLookupTable FallbackMap
transitions = forall a. Prim a => State -> (State -> a) -> TypedByteArray a
TBA.generate forall a. Integral a => a
asciiCount forall a b. (a -> b) -> a -> b
$ \State
i ->
  case forall a. State -> IntMap a -> Maybe a
IntMap.lookup State
i FallbackMap
transitions of
    Just State
state -> CodePoint -> State -> Transition
newTransition (State -> CodePoint
Char.chr State
i) State
state
    Maybe State
Nothing    -> State -> Transition
newWildcardTransition State
0

-- | Traverse the state trie in breadth-first order.
foldBreadthFirst :: (a -> State -> a) -> a -> TransitionMap -> a
foldBreadthFirst :: forall a. (a -> State -> a) -> a -> TransitionMap -> a
foldBreadthFirst a -> State -> a
f a
seed TransitionMap
transitions = [State] -> [State] -> a -> a
go [State
0] [] a
seed
  where
    -- For the traversal, we keep a queue of states to vitit. Every iteration we
    -- take one off the front, and all states reachable from there get added to
    -- the back. Rather than using a list for this, we use the functional
    -- amortized queue to avoid O(n²) append. This makes a measurable difference
    -- when the backlog can grow large. In one of our benchmark inputs for
    -- example, we have roughly 160 needles that are 10 characters each (but
    -- with some shared prefixes), and the backlog size grows to 148 during
    -- construction. Construction time goes down from ~0.80 ms to ~0.35 ms by
    -- using the amortized queue.
    -- See also section 3.1.1 of Purely Functional Data Structures by Okasaki
    -- https://www.cs.cmu.edu/~rwh/theses/okasaki.pdf.
    go :: [State] -> [State] -> a -> a
go [] [] !a
acc = a
acc
    go [] [State]
revBacklog !a
acc = [State] -> [State] -> a -> a
go (forall a. [a] -> [a]
reverse [State]
revBacklog) [] a
acc
    go (State
state : [State]
backlog) [State]
revBacklog !a
acc =
      let
        -- Note that the backlog never contains duplicates, because we traverse
        -- a trie that only branches out. For every state, there is only one
        -- path from the root that leads to it.
        extra :: [State]
extra = forall a. IntMap a -> [a]
IntMap.elems forall a b. (a -> b) -> a -> b
$ TransitionMap
transitions forall a. IntMap a -> State -> a
IntMap.! State
state
      in
        [State] -> [State] -> a -> a
go [State]
backlog ([State]
extra forall a. [a] -> [a] -> [a]
++ [State]
revBacklog) (a -> State -> a
f a
acc State
state)

-- | Determine the fallback transition for every state, by traversing the
-- transition trie breadth-first.
buildFallbackMap :: TransitionMap -> FallbackMap
buildFallbackMap :: TransitionMap -> FallbackMap
buildFallbackMap TransitionMap
transitions =
  let
    -- Suppose that in state `state`, there is a transition for input `input`
    -- to state `nextState`, and we already know the fallback for `state`. Then
    -- this function returns the fallback state for `nextState`.
    getFallback :: FallbackMap -> State -> Int -> State
    -- All the states after the root state (state 0) fall back to the root state.
    getFallback :: FallbackMap -> State -> State -> State
getFallback FallbackMap
_ State
0 State
_ = State
0
    getFallback FallbackMap
fallbacks !State
state !State
input =
      let
        fallback :: State
fallback = FallbackMap
fallbacks forall a. IntMap a -> State -> a
IntMap.! State
state
        transitionsFromFallback :: FallbackMap
transitionsFromFallback = TransitionMap
transitions forall a. IntMap a -> State -> a
IntMap.! State
fallback
      in
        case forall a. State -> IntMap a -> Maybe a
IntMap.lookup State
input FallbackMap
transitionsFromFallback of
          Just State
st -> State
st
          Maybe State
Nothing -> FallbackMap -> State -> State -> State
getFallback FallbackMap
fallbacks State
fallback State
input

    insertFallback :: State -> FallbackMap -> Int -> State -> FallbackMap
    insertFallback :: State -> FallbackMap -> State -> State -> FallbackMap
insertFallback !State
state FallbackMap
fallbacks !State
input !State
nextState =
      forall a. State -> a -> IntMap a -> IntMap a
IntMap.insert State
nextState (FallbackMap -> State -> State -> State
getFallback FallbackMap
fallbacks State
state State
input) FallbackMap
fallbacks

    insertFallbacks :: FallbackMap -> State -> FallbackMap
    insertFallbacks :: FallbackMap -> State -> FallbackMap
insertFallbacks FallbackMap
fallbacks !State
state =
      forall a b. (a -> State -> b -> a) -> a -> IntMap b -> a
IntMap.foldlWithKey' (State -> FallbackMap -> State -> State -> FallbackMap
insertFallback State
state) FallbackMap
fallbacks (TransitionMap
transitions forall a. IntMap a -> State -> a
IntMap.! State
state)
  in
    forall a. (a -> State -> a) -> a -> TransitionMap -> a
foldBreadthFirst FallbackMap -> State -> FallbackMap
insertFallbacks (forall a. State -> a -> IntMap a
IntMap.singleton State
0 State
0) TransitionMap
transitions

-- | Determine which matches to report at every state, by traversing the
-- transition trie breadth-first, and appending all the matches from a fallback
-- state to the matches for the current state.
buildValueMap :: forall v. TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap :: forall v.
TransitionMap -> FallbackMap -> ValuesMap v -> ValuesMap v
buildValueMap TransitionMap
transitions FallbackMap
fallbacks ValuesMap v
valuesInitial =
  let
    insertValues :: ValuesMap v -> State -> ValuesMap v
    insertValues :: ValuesMap v -> State -> ValuesMap v
insertValues ValuesMap v
values !State
state =
      let
        fallbackValues :: [v]
fallbackValues = ValuesMap v
values forall a. IntMap a -> State -> a
IntMap.! (FallbackMap
fallbacks forall a. IntMap a -> State -> a
IntMap.! State
state)
        valuesForState :: [v]
valuesForState = case forall a. State -> IntMap a -> Maybe a
IntMap.lookup State
state ValuesMap v
valuesInitial of
          Just [v]
vs -> [v]
vs forall a. [a] -> [a] -> [a]
++ [v]
fallbackValues
          Maybe [v]
Nothing -> [v]
fallbackValues
      in
        forall a. State -> a -> IntMap a -> IntMap a
IntMap.insert State
state [v]
valuesForState ValuesMap v
values
  in
    forall a. (a -> State -> a) -> a -> TransitionMap -> a
foldBreadthFirst ValuesMap v -> State -> ValuesMap v
insertValues (forall a. State -> a -> IntMap a
IntMap.singleton State
0 []) TransitionMap
transitions

-- Define aliases for array indexing so we can turn bounds checks on and off
-- in one place. We ran this code with `Vector.!` (bounds-checked indexing) in
-- production for two months without failing the bounds check, so we have turned
-- the check off for performance now.
{-# INLINE at #-}
at :: forall a. Vector.Vector a -> Int -> a
at :: forall a. Vector a -> State -> a
at = forall a. Vector a -> State -> a
Vector.unsafeIndex

{-# INLINE uAt #-}
uAt :: Prim a => TypedByteArray a -> Int -> a
uAt :: forall a. Prim a => TypedByteArray a -> State -> a
uAt = forall a. Prim a => TypedByteArray a -> State -> a
TBA.unsafeIndex

-- RUNNING THE MACHINE

-- | Result of handling a match: stepping the automaton can exit early by
-- returning a `Done`, or it can continue with a new accumulator with `Step`.
data Next a = Done !a | Step !a

-- | Run the automaton, possibly lowercasing the input text on the fly if case
-- insensitivity is desired. See also `runLower`.
--
-- The code of this function itself is organized as a state machine as well.
-- Each state in the diagram below corresponds to a function defined in
-- `runWithCase`. These functions are written in a way such that GHC identifies them
-- as [join points](https://www.microsoft.com/en-us/research/publication/compiling-without-continuations/).
-- This means that they can be compiled to jumps instead of function calls, which helps performance a lot.
--
-- @
--   ┌─────────────────────────────┐
--   │                             │
-- ┌─▼──────────┐   ┌──────────────┴─┐   ┌──────────────┐
-- │consumeInput├───►lookupTransition├───►collectMatches│
-- └─▲──────────┘   └─▲────────────┬─┘   └────────────┬─┘
--   │                │            │                  │
--   │                └────────────┘                  │
--   │                                                │
--   └────────────────────────────────────────────────┘
-- @
--
-- * @consumeInput@ decodes a code point of up to four code units and possibly lowercases it.
--   It passes this code point to @followCodePoint@, which in turn calls @lookupTransition@.
-- * @lookupTransition@ checks whether the given code point matches any transitions at the given state.
--   If so, it follows the transition and calls @collectMatches@. Otherwise, it follows the fallback transition
--   and calls @followCodePoint@ or @consumeInput@.
-- * @collectMatches@ checks whether the current state is accepting and updates the accumulator accordingly.
--   Afterwards it loops back to @consumeInput@.
--
-- NOTE: @followCodePoint@ is actually inlined into @consumeInput@ by GHC.
-- It is included in the diagram for illustrative reasons only.
--
-- All of these functions have the arguments @offset@, @state@ and @acc@ which encode the current input
-- position and the accumulator, which contains the matches. If you change any of the functions above,
-- make sure to check the Core dumps afterwards that @offset@ and @state@ were turned
-- into unboxed @Int#@ by GHC. If any of them aren't, the program will constantly allocate and deallocate heap space for them.
-- You can nudge GHC in the right direction by using bang patterns on these arguments.
--
-- WARNING: Run benchmarks when modifying this function; its performance is
-- fragile. It took many days to discover the current formulation which compiles
-- to fast code; removing the wrong bang pattern could cause a 10% performance
-- regression.
{-# INLINE runWithCase #-}
runWithCase :: forall a v. CaseSensitivity -> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runWithCase :: forall a v.
CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runWithCase !CaseSensitivity
caseSensitivity !a
seed !a -> Match v -> Next a
f !AcMachine v
machine !Text
text =
  CodeUnitIndex -> a -> State -> a
consumeInput CodeUnitIndex
initialOffset a
seed State
initialState
  where
    initialState :: State
initialState = State
0

    Text !Array
u8data !State
off !State
len = Text
text
    AcMachine !Vector [v]
values !TypedByteArray Transition
transitions !TypedByteArray Offset
offsets !TypedByteArray Transition
rootAsciiTransitions = AcMachine v
machine

    !initialOffset :: CodeUnitIndex
initialOffset = State -> CodeUnitIndex
CodeUnitIndex State
off
    !limit :: CodeUnitIndex
limit = State -> CodeUnitIndex
CodeUnitIndex forall a b. (a -> b) -> a -> b
$ State
off forall a. Num a => a -> a -> a
+ State
len

    -- NOTE: All of the arguments are strict here, because we want to compile
    -- them down to unpacked variables on the stack, or even registers.

    -- When we follow an edge, we look in the transition table and do a
    -- linear scan over all transitions until we find the right one, or
    -- until we hit the wildcard transition at the end. For 0 or 1 or 2
    -- transitions that is fine, but the initial state often has more
    -- transitions, so we have a dedicated lookup table for it, that takes
    -- up a bit more space, but provides O(1) lookup of the next state. We
    -- only do this for the first 128 code units (all of ascii).

    -- | Consume a code unit sequence that constitutes a full code point.
    -- If the code unit at @offset@ is ASCII, we can lower it using 'Utf8.toLowerAscii'.
    {-# NOINLINE consumeInput #-}
    consumeInput :: CodeUnitIndex -> a -> State -> a
    consumeInput :: CodeUnitIndex -> a -> State -> a
consumeInput !CodeUnitIndex
offset !a
acc !State
_state
      | CodeUnitIndex
offset forall a. Ord a => a -> a -> Bool
>= CodeUnitIndex
limit = a
acc
    consumeInput !CodeUnitIndex
offset !a
acc !State
state =
      CodeUnitIndex -> a -> CodePoint -> State -> a
followCodePoint (CodeUnitIndex
offset forall a. Num a => a -> a -> a
+ CodeUnitIndex
codeUnits) a
acc CodePoint
possiblyLoweredCp State
state

      where
        (!CodeUnitIndex
codeUnits, !CodePoint
cp) = Array -> CodeUnitIndex -> (CodeUnitIndex, CodePoint)
Utf8.unsafeIndexCodePoint' Array
u8data CodeUnitIndex
offset

        !possiblyLoweredCp :: CodePoint
possiblyLoweredCp = case CaseSensitivity
caseSensitivity of
          CaseSensitivity
CaseSensitive -> CodePoint
cp
          CaseSensitivity
IgnoreCase -> CodePoint -> CodePoint
Utf8.lowerCodePoint CodePoint
cp

    {-# INLINE followCodePoint #-}
    followCodePoint :: CodeUnitIndex -> a -> CodePoint -> State -> a
    followCodePoint :: CodeUnitIndex -> a -> CodePoint -> State -> a
followCodePoint !CodeUnitIndex
offset !a
acc !CodePoint
cp !State
state
      | State
state forall a. Eq a => a -> a -> Bool
== State
initialState Bool -> Bool -> Bool
&& CodePoint -> State
Char.ord CodePoint
cp forall a. Ord a => a -> a -> Bool
< forall a. Integral a => a
asciiCount = CodeUnitIndex -> a -> CodePoint -> a
lookupRootAsciiTransition CodeUnitIndex
offset a
acc CodePoint
cp
      | Bool
otherwise = CodeUnitIndex -> a -> CodePoint -> State -> Offset -> a
lookupTransition CodeUnitIndex
offset a
acc CodePoint
cp State
state forall a b. (a -> b) -> a -> b
$ TypedByteArray Offset
offsets forall a. Prim a => TypedByteArray a -> State -> a
`uAt` State
state

    -- NOTE: This function can't be inlined since it is self-recursive.
    {-# NOINLINE lookupTransition #-}
    lookupTransition :: CodeUnitIndex -> a -> CodePoint -> State -> Offset -> a
    lookupTransition :: CodeUnitIndex -> a -> CodePoint -> State -> Offset -> a
lookupTransition !CodeUnitIndex
offset !a
acc !CodePoint
cp !State
state !Offset
i
      -- There is no transition for the given input. Follow the fallback edge,
      -- and try again from that state, etc. If we are in the base state
      -- already, then nothing matched, so move on to the next input.
      | Transition -> Bool
transitionIsWildcard Transition
t =
        if State
state forall a. Eq a => a -> a -> Bool
== State
initialState
          then CodeUnitIndex -> a -> State -> a
consumeInput CodeUnitIndex
offset a
acc State
state
          else CodeUnitIndex -> a -> CodePoint -> State -> a
followCodePoint CodeUnitIndex
offset a
acc CodePoint
cp (Transition -> State
transitionState Transition
t)
      -- We found the transition, switch to that new state, possibly matching the rest of cus.
      -- NOTE: This comes after wildcard checking, because the code unit of
      -- the wildcard transition is 0, which is a valid input.
      | Transition -> CodePoint
transitionCodeUnit Transition
t forall a. Eq a => a -> a -> Bool
== CodePoint
cp =
        CodeUnitIndex -> a -> State -> a
collectMatches CodeUnitIndex
offset a
acc (Transition -> State
transitionState Transition
t)
      -- The transition we inspected is not for the current input, and it is not
      -- a wildcard either; look at the next transition then.
      | Bool
otherwise =
        CodeUnitIndex -> a -> CodePoint -> State -> Offset -> a
lookupTransition CodeUnitIndex
offset a
acc CodePoint
cp State
state forall a b. (a -> b) -> a -> b
$ Offset
i forall a. Num a => a -> a -> a
+ Offset
1

      where
        !t :: Transition
t = TypedByteArray Transition
transitions forall a. Prim a => TypedByteArray a -> State -> a
`uAt` forall a b. (Integral a, Num b) => a -> b
fromIntegral Offset
i

    -- NOTE: there is no `state` argument here, because this case applies only
    -- to the root state `stateInitial`.
    {-# INLINE lookupRootAsciiTransition #-}
    lookupRootAsciiTransition :: CodeUnitIndex -> a -> CodePoint -> a
lookupRootAsciiTransition !CodeUnitIndex
offset !a
acc !CodePoint
cp
      -- Given code unit does not match at root ==> Repeat at offset from initial state
      | Transition -> Bool
transitionIsWildcard Transition
t = CodeUnitIndex -> a -> State -> a
consumeInput CodeUnitIndex
offset a
acc State
initialState
      -- Transition matched!
      | Bool
otherwise = CodeUnitIndex -> a -> State -> a
collectMatches CodeUnitIndex
offset a
acc forall a b. (a -> b) -> a -> b
$ Transition -> State
transitionState Transition
t
      where !t :: Transition
t = TypedByteArray Transition
rootAsciiTransitions forall a. Prim a => TypedByteArray a -> State -> a
`uAt` CodePoint -> State
Char.ord CodePoint
cp

    {-# NOINLINE collectMatches #-}
    collectMatches :: CodeUnitIndex -> a -> State -> a
collectMatches !CodeUnitIndex
offset !a
acc !State
state =
      let
        matchedValues :: [v]
matchedValues = Vector [v]
values forall a. Vector a -> State -> a
`at` State
state
        -- Fold over the matched values. If at any point the user-supplied fold
        -- function returns `Done`, then we early out. Otherwise continue.
        handleMatch :: a -> [v] -> a
handleMatch !a
acc' [v]
vs = case [v]
vs of
          []     -> CodeUnitIndex -> a -> State -> a
consumeInput CodeUnitIndex
offset a
acc' State
state
          v
v:[v]
more -> case a -> Match v -> Next a
f a
acc' (forall v. CodeUnitIndex -> v -> Match v
Match (CodeUnitIndex
offset forall a. Num a => a -> a -> a
- CodeUnitIndex
initialOffset) v
v) of
            Step a
newAcc -> a -> [v] -> a
handleMatch a
newAcc [v]
more
            Done a
finalAcc -> a
finalAcc
      in
        a -> [v] -> a
handleMatch a
acc [v]
matchedValues

-- NOTE: To get full advantage of inlining this function, you probably want to
-- compile the compiling module with -fllvm and the same optimization flags as
-- this module.
{-# INLINE runText #-}
runText :: forall a v. a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runText :: forall a v.
a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runText = forall a v.
CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runWithCase CaseSensitivity
CaseSensitive

-- Finds all matches in the lowercased text. This function lowercases the input text
-- on the fly to avoid allocating a second lowercased text array.  It is still the
-- responsibility of  the caller to lowercase the needles. Needles that contain
-- uppercase code  points will not match.
--
-- NOTE: To get full advantage of inlining this function, you probably want to
-- compile the compiling module with -fllvm and the same optimization flags as
-- this module.
{-# INLINE runLower #-}
runLower :: forall a v. a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runLower :: forall a v.
a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runLower = forall a v.
CaseSensitivity
-> a -> (a -> Match v -> Next a) -> AcMachine v -> Text -> a
runWithCase CaseSensitivity
IgnoreCase


-- | Given a lower case text, this gives all the texts that would lowercase to this one
--
--     needleCasings "abc" == ["abc","abC","aBc","aBC","Abc","AbC","ABc","ABC"]
--     needleCasings "ABC" == []
--     needleCasings "ω1" == ["Ω1","ω1","Ω1"]
--
needleCasings :: Text -> [Text]
needleCasings :: Text -> [Text]
needleCasings = forall a b. (a -> b) -> [a] -> [b]
map String -> Text
Text.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
loop forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack
  where
    loop :: String -> [String]
loop String
"" = [String
""]
    loop (CodePoint
c:String
cs) = (:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CodePoint -> String
Utf8.unlowerCodePoint CodePoint
c forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> [String]
loop String
cs