{-# LANGUAGE BangPatterns #-}

module Data.Seqn.Internal.KMP
  ( Table
  , State
  , build
  , step
  ) where

import Data.Primitive.Array (Array, indexArray, sizeofArray)
import Data.Primitive.PrimArray
  ( PrimArray
  , indexPrimArray
  , newPrimArray
  , readPrimArray
  , runPrimArray
  , writePrimArray
  )

-- Knuth–Morris–Pratt algorithm
-- See https://en.wikipedia.org/wiki/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm
--
-- In Table xa pa,
-- * xa is the pattern.
-- * pa is the prefix function. pa!i is the length of longest proper prefix of
--   xa that ends at index i of xa.

data Table a = Table
  {-# UNPACK #-} !(Array a)
  {-# UNPACK #-} !(PrimArray Int)

newtype State a = State Int

-- Precondition: 0 < length xa
build :: Eq a => Array a -> (Table a, State a)
build :: forall a. Eq a => Array a -> (Table a, State a)
build Array a
xa
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Char] -> (Table a, State a)
forall a. HasCallStack => [Char] -> a
error [Char]
"non-positive length"
  | Bool
otherwise = (Array a -> PrimArray Int -> Table a
forall a. Array a -> PrimArray Int -> Table a
Table Array a
xa PrimArray Int
pa, Int -> State a
forall a. Int -> State a
State Int
0)
  where
    n :: Int
n = Array a -> Int
forall a. Array a -> Int
sizeofArray Array a
xa
    !pa :: PrimArray Int
pa = (forall s. ST s (MutablePrimArray s Int)) -> PrimArray Int
forall a. (forall s. ST s (MutablePrimArray s a)) -> PrimArray a
runPrimArray ((forall s. ST s (MutablePrimArray s Int)) -> PrimArray Int)
-> (forall s. ST s (MutablePrimArray s Int)) -> PrimArray Int
forall a b. (a -> b) -> a -> b
$ do
      MutablePrimArray (PrimState (ST s)) Int
pma <- Int -> ST s (MutablePrimArray (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
n
      MutablePrimArray (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState (ST s)) Int
pma Int
0 Int
0
      Int -> Int -> (Int -> ST s ()) -> ST s ()
forall (f :: * -> *) a.
Applicative f =>
Int -> Int -> (Int -> f a) -> f ()
for_ Int
1 (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        let go :: Int -> ST s Int
go Int
j | Array a -> Int -> a
forall a. Array a -> Int -> a
indexArray Array a
xa Int
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== Array a -> Int -> a
forall a. Array a -> Int -> a
indexArray Array a
xa Int
j = Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            go Int
0 = Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0
            go Int
j = MutablePrimArray (PrimState (ST s)) Int -> Int -> ST s Int
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray (PrimState (ST s)) Int
pma (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ST s Int -> (Int -> ST s Int) -> ST s Int
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> ST s Int
go
        MutablePrimArray (PrimState (ST s)) Int -> Int -> ST s Int
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray MutablePrimArray (PrimState (ST s)) Int
pma (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ST s Int -> (Int -> ST s Int) -> ST s Int
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> ST s Int
go ST s Int -> (Int -> ST s ()) -> ST s ()
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutablePrimArray (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray (PrimState (ST s)) Int
pma Int
i
      MutablePrimArray s Int -> ST s (MutablePrimArray s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MutablePrimArray s Int
MutablePrimArray (PrimState (ST s)) Int
pma
{-# INLINABLE build #-}

step :: Eq a => Table a -> State a -> a -> (Bool, State a)
step :: forall a. Eq a => Table a -> State a -> a -> (Bool, State a)
step (Table Array a
xa PrimArray Int
pa) (State Int
i) a
x = Int -> (Bool, State a)
forall {a}. Int -> (Bool, State a)
go Int
i
  where
    go :: Int -> (Bool, State a)
go Int
j | Array a -> Int -> a
forall a. Array a -> Int -> a
indexArray Array a
xa Int
j a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x =
      if Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Array a -> Int
forall a. Array a -> Int
sizeofArray Array a
xa
      then (,) Bool
True (State a -> (Bool, State a)) -> State a -> (Bool, State a)
forall a b. (a -> b) -> a -> b
$! Int -> State a
forall a. Int -> State a
State (PrimArray Int -> Int -> Int
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray Int
pa Int
j)
      else (Bool
False, Int -> State a
forall a. Int -> State a
State (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
    go Int
0 = (Bool
False, Int -> State a
forall a. Int -> State a
State Int
0)
    go Int
j = Int -> (Bool, State a)
go (PrimArray Int -> Int -> Int
forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray Int
pa (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
{-# INLINABLE step #-}

for_ :: Applicative f => Int -> Int -> (Int -> f a) -> f ()
for_ :: forall (f :: * -> *) a.
Applicative f =>
Int -> Int -> (Int -> f a) -> f ()
for_ !Int
i1 !Int
i2 Int -> f a
f = Int -> f ()
go Int
i1
  where
    go :: Int -> f ()
go Int
i = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i2 then () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () else Int -> f a
f Int
i f a -> f () -> f ()
forall a b. f a -> f b -> f b
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Int -> f ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
{-# INLINE for_ #-}