{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}

module Text.EditDistance.Linear where

import qualified Data.Array.Base as A(unsafeRead, unsafeWrite)
import qualified Data.Array.ST as A
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Unsafe as BU
import Control.Monad.ST

levenshteinDistance :: BS.ByteString -> BS.ByteString -> Int
levenshteinDistance s1 s2 = runST $ do
  v0Init <- A.newListArray (0, n) [0..]
  v1Init <- A.newArray_ (0, n)
  loop 0 v0Init v1Init
  A.unsafeRead (if even m then v0Init else v1Init) n

  where
    m = BS.length s1
    n = BS.length s2

    loop :: Int -> A.STUArray s Int Int -> A.STUArray s Int Int -> ST s ()
    loop !i !v0 !v1 | i == m = pure ()
                    | otherwise = do
      A.unsafeWrite v1 0 (i + 1)
      let !s1char = s1 `BU.unsafeIndex` i
      let go !j !prev | j == n = pure ()
                      | otherwise = do
            delCost <- v0 `A.unsafeRead` (j + 1)
            substCostBase <- v0 `A.unsafeRead` j
            let !substCost = if s1char == s2 `BU.unsafeIndex` j then 0 else 1
            let !res = min (substCost + substCostBase) $ 1 + min delCost prev
            A.unsafeWrite v1 (j + 1) res
            go (j + 1) res
      go 0 (i + 1)
      loop (i + 1) v1 v0