{-# options_ghc -Wno-incomplete-uni-patterns #-}

-- | This module implements a generalized version of the SMAWK algorithm
-- for computing row minima in totally ordered matrices.
--
-- I do not require rows or column numbers to be actual numbers, or even ordered,
-- instead comparing columns using occurrence order.
--
-- Unlike @Map@-based implementations, the runtime of this is actually linear.
module Data.Smawk
  ( smawk
  , smawk1
  ) where

import Control.Monad.Trans.State.Strict
import qualified Data.Foldable as Foldable
import Data.Maybe (fromJust)
import Data.Semigroup (Min(..),Arg(..))
import Data.Semigroup.Foldable (Foldable1)
import Data.List.NonEmpty (nonEmpty)
import Data.Primitive.Array (indexArray)
import GHC.Exts as Exts

-- |
-- >>> collate "abcde"
-- ("ace","bd")
collate :: [a] -> ([a],[a])
collate :: forall a. [a] -> ([a], [a])
collate = (a -> ([a], [a]) -> ([a], [a])) -> ([a], [a]) -> [a] -> ([a], [a])
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
Prelude.foldr (\a
a ~([a]
x,[a]
y) -> (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
y,[a]
x)) ([a], [a])
forall a. Monoid a => a
mempty

-- |
-- @'uncurry' 'interleave' . 'collate' = id@
--
-- >>> interleave "ace" "bd"
-- "abcde"
interleave :: [a] -> [a] -> [a]
interleave :: forall a. [a] -> [a] -> [a]
interleave (a
a:[a]
as) [a]
bs = a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
interleave [a]
bs [a]
as
interleave [] [a]
bs = [a]
bs

-- |
-- /O(|rows| + |cols|)/.
--
-- Computes __row__ minima in totally monotone matrices using the SMAWK algorithm.
--
-- Returns 'Nothing' if we have no columns.
smawk
  :: (Traversable f, Foldable g, Ord a)
  => f r -- ^ rows (in any desired ascending order)
  -> g c -- ^ columns (in any desired ascending order)
  -> (r -> c -> a) -- ^ a monotone matrix
  -> Maybe (f c) -- ^ each of the row minima
smawk :: forall (f :: * -> *) (g :: * -> *) a r c.
(Traversable f, Foldable g, Ord a) =>
f r -> g c -> (r -> c -> a) -> Maybe (f c)
smawk f r
rs g c
cs0 r -> c -> a
m = (\NonEmpty c
cs -> f r -> NonEmpty c -> (r -> c -> a) -> f c
forall (f :: * -> *) (g :: * -> *) a r c.
(Traversable f, Foldable1 g, Ord a) =>
f r -> g c -> (r -> c -> a) -> f c
smawk1 f r
rs NonEmpty c
cs r -> c -> a
m) (NonEmpty c -> f c) -> Maybe (NonEmpty c) -> Maybe (f c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [c] -> Maybe (NonEmpty c)
forall a. [a] -> Maybe (NonEmpty a)
nonEmpty (g c -> [c]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList g c
cs0)
{-# inline smawk #-}

-- |
-- /O(|rows| + |cols|)/.
--
-- Computes __row__ minima in totally monotone matrices using the SMAWK algorithm.
smawk1
  :: (Traversable f, Foldable1 g, Ord a)
  => f r -- ^ rows (in any desired ascending order)
  -> g c -- ^ columns (in any desired ascending order)
  -> (r -> c -> a) -- ^ a monotone matrix
  -> f c -- ^ each of the row minima
smawk1 :: forall (f :: * -> *) (g :: * -> *) a r c.
(Traversable f, Foldable1 g, Ord a) =>
f r -> g c -> (r -> c -> a) -> f c
smawk1 f r
rs0 g c
cs0 r -> c -> a
m = State [Int] (f c) -> [Int] -> f c
forall s a. State s a -> s -> a
evalState ((r -> StateT [Int] Identity c) -> f r -> State [Int] (f c)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse r -> StateT [Int] Identity c
forall {m :: * -> *} {p}. Monad m => p -> StateT [Int] m c
refill f r
rs0) ([Int] -> f c) -> [Int] -> f c
forall a b. (a -> b) -> a -> b
$ [r] -> [Int] -> [Int]
go (f r -> [r]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList f r
rs0) [Int
0..Array c -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Array c
rawsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] where
  raws :: Array c
raws = [Item (Array c)] -> Array c
forall l. IsList l => [Item l] -> l
Exts.fromList ([Item (Array c)] -> Array c) -> [Item (Array c)] -> Array c
forall a b. (a -> b) -> a -> b
$ g c -> [c]
forall (t :: * -> *) a. Foldable t => t a -> [a]
Foldable.toList g c
cs0
  refill :: p -> StateT [Int] m c
refill p
_ = ([Int] -> (c, [Int])) -> StateT [Int] m c
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state (([Int] -> (c, [Int])) -> StateT [Int] m c)
-> ([Int] -> (c, [Int])) -> StateT [Int] m c
forall a b. (a -> b) -> a -> b
$ \ ~(Int
x:[Int]
xs) -> (Array c -> Int -> c
forall a. Array a -> Int -> a
indexArray Array c
raws Int
x,[Int]
xs)
  go :: [r] -> [Int] -> [Int]
go [] [Int]
_ = []
  go [r]
rs [Int]
cs = [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
interleave [Int]
broken [Int]
minima where
    m' :: r -> Int -> a
m' r
r Int
c = r -> c -> a
m r
r (Array c -> Int -> c
forall a. Array a -> Int -> a
indexArray Array c
raws Int
c)
    broken :: [Int]
broken = (r -> [Int] -> Int) -> [r] -> [[Int]] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith r -> [Int] -> Int
forall {t :: * -> *}. Foldable t => r -> t Int -> Int
skim [r]
es ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [[Int]]
forall {a}. Ord a => [a] -> [a] -> [[a]]
path [Int]
cs' ([Int]
minima [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Int -> [Int]
forall a. a -> [a]
repeat Int
l)
    -- skim a bs = snd $ minimum $ fmap (\i -> (m i a, i)) bs
    skim :: r -> t Int -> Int
skim r
a t Int
bs = case Min (Arg a Int) -> Arg a Int
forall a. Min a -> a
getMin (Min (Arg a Int) -> Arg a Int) -> Min (Arg a Int) -> Arg a Int
forall a b. (a -> b) -> a -> b
$ Maybe (Min (Arg a Int)) -> Min (Arg a Int)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Min (Arg a Int)) -> Min (Arg a Int))
-> Maybe (Min (Arg a Int)) -> Min (Arg a Int)
forall a b. (a -> b) -> a -> b
$ (Int -> Maybe (Min (Arg a Int)))
-> t Int -> Maybe (Min (Arg a Int))
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (\Int
i -> Min (Arg a Int) -> Maybe (Min (Arg a Int))
forall a. a -> Maybe a
Just (Min (Arg a Int) -> Maybe (Min (Arg a Int)))
-> Min (Arg a Int) -> Maybe (Min (Arg a Int))
forall a b. (a -> b) -> a -> b
$ Arg a Int -> Min (Arg a Int)
forall a. a -> Min a
Min (Arg a Int -> Min (Arg a Int)) -> Arg a Int -> Min (Arg a Int)
forall a b. (a -> b) -> a -> b
$ a -> Int -> Arg a Int
forall a b. a -> b -> Arg a b
Arg (r -> Int -> a
m' r
a Int
i) Int
i) t Int
bs of Arg a
_ Int
i -> Int
i
    path :: [a] -> [a] -> [[a]]
path [a]
xs ~(a
y:[a]
ys) = case (a -> Bool) -> [a] -> ([a], [a])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<=a
y) [a]
xs of
      ([a]
as, [a]
bs) -> [a]
as[a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
:[a] -> [a] -> [[a]]
path (a
ya -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
bs) [a]
ys
    rs' :: Array r
rs' = [Item (Array r)] -> Array r
forall l. IsList l => [Item l] -> l
Exts.fromList [r]
[Item (Array r)]
rs -- a zipper yields same complexity, worst constants
    n :: Int
n = Array r -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Array r
rs'
    minima :: [Int]
minima = [r] -> [Int] -> [Int]
go [r]
os [Int]
cs'
    ([r]
es,[r]
os) = [r] -> ([r], [r])
forall a. [a] -> ([a], [a])
collate [r]
rs
    rcs' :: [Int]
rcs' = [Int] -> [Int] -> Int -> [Int]
reduce [Int]
cs [] Int
0
    cs' :: [Int]
cs' = [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int]
rcs'
    l :: Int
l = [Int] -> Int
forall a. [a] -> a
head [Int]
rcs'
    reduce :: [Int] -> [Int] -> Int -> [Int]
reduce [] [Int]
ys Int
_ = [Int]
ys
    reduce (Int
x:[Int]
xs) [] Int
_ = [Int] -> [Int] -> Int -> [Int]
reduce [Int]
xs [Int
x] Int
1
    reduce xxs :: [Int]
xxs@(Int
x:[Int]
xs) yys :: [Int]
yys@(Int
y:[Int]
ys) Int
t
      | r
ri <- Array r -> Int -> r
forall a. Array a -> Int -> a
indexArray Array r
rs' (Int
tInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), r -> Int -> a
m' r
ri Int
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> r -> Int -> a
m' r
ri Int
y = [Int] -> [Int] -> Int -> [Int]
reduce [Int]
xxs [Int]
ys (Int
tInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
      | Int
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n = [Int] -> [Int] -> Int -> [Int]
reduce [Int]
xs (Int
xInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
yys) (Int
tInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      | Bool
otherwise = [Int] -> [Int] -> Int -> [Int]
reduce [Int]
xs [Int]
yys Int
t
{-# inlinable smawk1 #-}