module Data.AhoCorasick
( ACMachine
, State(..)
, Match(..)
, construct
, constructWithValues
, root
, run
, step
, renderGraph
) where
import Control.Applicative
import Control.Monad
import Data.Hashable (Hashable)
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as Map
import Data.List
import Data.Maybe
import Data.Vector (Vector)
import qualified Data.Vector as V
import GHC.Generics (Generic)
data ACMachine a v = ACMachine (Goto a) Failure (Output v)
type Goto a = Vector (HashMap a State)
type Failure = Vector State
type Output v = Vector [(Int, v)]
type GotoMap a = HashMap State (HashMap a State)
type FailureMap = HashMap State State
type OutputMap v = HashMap State [(Int, v)]
newtype State = State Int
deriving (Eq, Generic)
instance Hashable State
data Match v = Match
{ matchPos :: Int
, matchValue :: v
} deriving (Show)
construct :: (Eq a, Hashable a) => [[a]] -> ACMachine a [a]
construct ps = ACMachine (toGotoArray n gotoMap) (toFailureArray n failureMap) (toOutputArray n outputMap)
where
(m, gotoMap) = buildGoto ps
n = m + 1
failureMap = buildFailure gotoMap
outputMap = buildOutput pvs gotoMap failureMap
pvs = zip ps ps
constructWithValues :: (Eq a, Hashable a) => [([a], v)] -> ACMachine a v
constructWithValues pvs = ACMachine (toGotoArray n gotoMap) (toFailureArray n failureMap) (toOutputArray n outputMap)
where
(m, gotoMap) = buildGoto ps
n = m + 1
failureMap = buildFailure gotoMap
outputMap = buildOutput pvs gotoMap failureMap
ps = map fst pvs
toGotoArray :: Int -> GotoMap a -> Goto a
toGotoArray n m = V.generate n (fromMaybe Map.empty . flip Map.lookup m . State)
toFailureArray :: Int -> FailureMap -> Failure
toFailureArray n m = V.generate n (fromMaybe (error "failure: ") . flip Map.lookup m . State)
toOutputArray :: Int -> OutputMap v -> Output v
toOutputArray n m = V.generate n (fromMaybe [] . flip Map.lookup m . State)
root :: State
root = State 0
run :: (Eq a, Hashable a) => ACMachine a v -> [a] -> [Match v]
run acm = go root . zip [1..]
where
go _ [] = []
go s ((i, x):ixs) = map toMatch vs ++ go s' ixs
where
toMatch (l, v) = Match { matchPos = i l + 1, matchValue = v }
(s', vs) = step acm x s
step :: (Eq a, Hashable a) => ACMachine a v -> a -> State -> (State, [(Int, v)])
step (ACMachine g f o) x s = (s', output s')
where
s' = head $ mapMaybe (flip goto x) $ iterate failure s
goto (State 0) x' = (Map.lookup x' $ g V.! 0) <|> Just root
goto (State i) x' = Map.lookup x' $ g V.! i
failure (State i) = f V.! i
output (State i) = o V.! i
buildOutput :: (Eq a, Hashable a) => [([a], v)] -> GotoMap a -> FailureMap -> OutputMap v
buildOutput pvs gotoMap failureMap = foldl' build o0 $ tail $ toBFList gotoMap
where
build o s = foldl' (\a (_, s') -> Map.insertWith (flip (++)) s' (lookupDefault [] (failure s') a) a) o ts
where
ts = Map.toList $ lookupDefault Map.empty s gotoMap
failure = fromMaybe (error "failure: ") . flip Map.lookup failureMap
o0 = Map.fromList $ fromJust $ mapM toKV pvs
toKV (p, v) = do
s <- finalState gotoMap root p
return (s, [(length p, v)])
finalState :: (Eq a, Hashable a) => GotoMap a -> State -> [a] -> Maybe State
finalState m = foldM (\s x -> Map.lookup s m >>= Map.lookup x)
buildGoto :: (Eq a, Hashable a) => [[a]] -> (Int, GotoMap a)
buildGoto = foldl' (flip extend) (0, Map.empty)
extend :: (Eq a, Hashable a) => [a] -> (Int, GotoMap a) -> (Int, GotoMap a)
extend = go root
where
go _ [] nm = nm
go s (x:xs) nm@(n, m) = case Map.lookup x sm of
Nothing -> go s' xs (n', m')
where
s' = State n'
n' = n + 1
sm' = Map.insert x s' sm
m' = Map.insert s sm' m
Just s' -> go s' xs nm
where
sm = fromMaybe Map.empty $ Map.lookup s m
buildFailure :: (Eq a, Hashable a) => GotoMap a -> FailureMap
buildFailure m = foldl' build Map.empty $ toBFList m
where
build f s = foldl' (\a (x, s') -> Map.insert s' (failureState f s x) a) f ts
where
ts = Map.toList $ lookupDefault Map.empty s m
failureState _ (State 0) _ = root
failureState f s x = head $ mapMaybe (flip goto x) $ iterate failure (failure s)
where
failure = fromMaybe (error "failure: ") . flip Map.lookup f
goto (State 0) x = (Map.lookup root m >>= Map.lookup x) <|> Just root
goto s x = Map.lookup s m >>= Map.lookup x
toBFList :: GotoMap a -> [State]
toBFList m = ss0
where
ss0 = root : go 1 ss0
go 0 _ = []
go n (s:ss) = case Map.lookup s m of
Nothing -> go (n 1) ss
Just sm -> children ++ go (n 1 + Map.size sm) ss
where
children = Map.elems sm
go _ _ = error "toBFList: invalid state"
toBFList' :: Goto a -> [State]
toBFList' v = ss0
where
ss0 = root : go 1 ss0
go 0 _ = []
go n ((State i):ss) = children ++ go (n 1 + Map.size sm) ss
where
sm = v V.! i
children = Map.elems sm
go _ _ = error "toBFList: invalid state"
lookupDefault :: (Eq k, Hashable k) => v -> k -> HashMap k v -> v
lookupDefault def k m = fromMaybe def $ Map.lookup k m
renderGraph :: ACMachine Char [Char] -> String
renderGraph (ACMachine g f o) =
graph "digraph" $ statements [
attr "graph" [("rankdir", "LR")]
, statements $ map state (toBFList' g)
, statements $ map stateWithOutput $ filter (not . null . snd) $ zip (map State [0..]) (V.toList o)
, statements $ map (\s@(State i) -> statements $ map (uncurry $ transEdge s) $ Map.toList $ g V.! i) (toBFList' g)
, statements $ map (\s@(State i) -> failEdge s $ f V.! i) (tail $ toBFList' g)
]
where
statements = intercalate " "
graph typ body = typ ++ " { " ++ body ++ " }"
attr typ attrList = typ ++ " " ++ "[" ++ intercalate "," (map kvStr attrList) ++ "];"
node nid attrList = nid ++ " " ++ "[" ++ intercalate "," (map kvStr attrList) ++ "];"
kvStr (k, v) = k ++ "=" ++ v
state s@(State 0) = node (stateID s) [("shape", "doublecircle")]
state s = node (stateID s) [("shape", "circle")]
stateWithOutput (s, xs) = node (stateID s) [("label", "<" ++ tableHTML (stateID s) ("{" ++ intercalate "," (map snd xs) ++ "}") ++ ">"), ("shape", "none")]
tableHTML row1 row2 = "<table cellborder=\"0\"><tr><td>" ++ row1 ++ "</td></tr><tr><td>" ++ row2 ++ "</td></tr></table>"
stateID (State 0) = "Root"
stateID (State n) = 'S' : show n
transEdge s x s' = stateID s ++ " -> " ++ stateID s' ++ " [label=\"" ++ [x] ++ "\"];"
failEdge s s' = stateID s ++ " -> " ++ stateID s' ++ " [style=dashed, constraint=false];"