{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_HADDOCK hide, prune #-}
-- |
-- Module         : Data.ByteString.Search.Internal.Utils
-- Copyright      : Daniel Fischer
-- Licence        : BSD3
-- Maintainer     : Daniel Fischer <daniel.is.fischer@googlemail.com>
-- Stability      : Provisional
-- Portabiltity   : non-portable
--
-- Author         : Daniel Fischer
--
-- Utilities for several searching algorithms.

module Data.ByteString.Search.Internal.Utils ( kmpBorders
                                             , automaton
                                             , ldrop
                                             , ltake
                                             , lsplit
                                             , release
                                             , keep
                                             , strictify
                                             ) where

import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeIndex)

import Data.Array.Base (unsafeRead, unsafeWrite, unsafeAt)
import Data.Array.ST
import Data.Array.Unboxed
import Control.Monad (when)

import Data.Bits
import Data.Word (Word8)

------------------------------------------------------------------------------
--                              Preprocessing                               --
------------------------------------------------------------------------------

{-# INLINE automaton #-}
automaton :: S.ByteString -> UArray Int Int
automaton !pat = runSTUArray (do
    let !patLen = S.length pat
        {-# INLINE patAt #-}
        patAt !i = fromIntegral (unsafeIndex pat i)
        !bord = kmpBorders pat
    aut <- newArray (0, (patLen + 1)*256 - 1) 0
    unsafeWrite aut (patAt 0) 1
    let loop !state = do
            let !base = state `shiftL` 8
                inner j
                    | j < 0     = if state == patLen
                                    then return aut
                                    else loop (state+1)
                    | otherwise = do
                        let !i = base + patAt j
                        s <- unsafeRead aut i
                        when (s == 0) (unsafeWrite aut i (j+1))
                        inner (unsafeAt bord j)
            if state == patLen
                then inner (unsafeAt bord state)
                else inner state
    loop 1)

-- kmpBorders calculates the width of the widest borders of the prefixes
-- of the pattern which are not extensible to borders of the next
-- longer prefix. Most entries will be 0.
{-# INLINE kmpBorders #-}
kmpBorders :: S.ByteString -> UArray Int Int
kmpBorders pat = runSTUArray (do
    let !patLen = S.length pat
        {-# INLINE patAt #-}
        patAt :: Int -> Word8
        patAt i = unsafeIndex pat i
    ar <- newArray_ (0, patLen)
    unsafeWrite ar 0 (-1)
    let dec w j
            | j < 0 || w == patAt j = return $! j+1
            | otherwise = unsafeRead ar j >>= dec w
        bordLoop !i !j
            | patLen < i    = return ar
            | otherwise     = do
                let !w = patAt (i-1)
                j' <- dec w j
                if i < patLen && patAt j' == patAt i
                    then unsafeRead ar j' >>= unsafeWrite ar i
                    else unsafeWrite ar i j'
                bordLoop (i+1) j'
    bordLoop 1 (-1))

------------------------------------------------------------------------------
--                             Helper Functions                             --
------------------------------------------------------------------------------

{-# INLINE strictify #-}
strictify :: L.ByteString -> S.ByteString
strictify = S.concat . L.toChunks

-- drop k bytes from a list of strict ByteStrings
{-# INLINE ldrop #-}
ldrop :: Int -> [S.ByteString] -> [S.ByteString]
ldrop _ [] = []
ldrop k (!h : t)
  | k < l     = S.drop k h : t
  | otherwise = ldrop (k - l) t
    where
      !l = S.length h

-- take k bytes from a list of strict ByteStrings
{-# INLINE ltake #-}
ltake :: Int -> [S.ByteString] -> [S.ByteString]
ltake _ [] = []
ltake !k (!h : t)
  | l < k     = h : ltake (k - l) t
  | otherwise = [S.take k h]
    where
      !l = S.length h

-- split a list of strict ByteStrings at byte k
{-# INLINE lsplit #-}
lsplit :: Int -> [S.ByteString] -> ([S.ByteString], [S.ByteString])
lsplit _ [] = ([],[])
lsplit !k (!h : t)
  = case compare k l of
      LT -> ([S.take k h], S.drop k h : t)
      EQ -> ([h], t)
      GT -> let (u, v) = lsplit (k - l) t in (h : u, v)
  where
    !l = S.length h


-- release is used to keep the zipper in lazySearcher from remembering
-- the leading part of the searched string.  The deep parameter is the
-- number of characters that the past needs to hold.  This ensures
-- lazy streaming consumption of the searched string.
{-# INLINE release #-}
release :: Int ->  [S.ByteString] -> [S.ByteString]
release !deep _
    | deep <= 0 = []
release !deep (!x:xs) = let !rest = release (deep-S.length x) xs in x : rest
release _ [] = error "stringsearch.release could not find enough past!"

-- keep is like release, only we mustn't forget the part of the past
-- we don't need anymore for matching but have to keep it for
-- breaking, splitting and replacing.
-- The names would be more appropriate the other way round, but that's
-- a historical accident, so what?
{-# INLINE keep #-}
keep :: Int -> [S.ByteString] -> ([S.ByteString],[S.ByteString])
keep !deep xs
    | deep < 1    = ([],xs)
keep deep (!x:xs) = let (!p,d) = keep (deep - S.length x) xs in (x:p,d)
keep _ [] = error "Forgot too much"