module Data.Array.Vector.Algorithms.TriHeap
(
sort
, sortBy
, sortByBounds
, select
, selectBy
, selectByBounds
, partialSort
, partialSortBy
, partialSortByBounds
, heapify
, pop
, popTo
, sortHeap
, Comparison
) where
import Control.Monad
import Control.Monad.ST
import Data.Array.Vector
import Data.Array.Vector.Algorithms.Common
import qualified Data.Array.Vector.Algorithms.Optimal as O
sort :: (UA e, Ord e) => MUArr e s -> ST s ()
sort = sortBy compare
sortBy :: (UA e) => Comparison e -> MUArr e s -> ST s ()
sortBy cmp a = sortByBounds cmp a 0 (lengthMU a)
sortByBounds :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> ST s ()
sortByBounds cmp a l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| otherwise = heapify cmp a l u >> sortHeap cmp a l (l+4) u >> O.sort4ByOffset cmp a l
where len = u l
select :: (UA e, Ord e) => MUArr e s -> Int -> ST s ()
select = selectBy compare
selectBy :: (UA e) => Comparison e -> MUArr e s -> Int -> ST s ()
selectBy cmp a k = selectByBounds cmp a k 0 (lengthMU a)
selectByBounds :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> Int -> ST s ()
selectByBounds cmp a k l u
| l + k <= u = heapify cmp a l (l + k) >> go l (l + k) (u 1)
| otherwise = return ()
where
go l m u
| u < m = return ()
| otherwise = do el <- readMU a l
eu <- readMU a u
case cmp eu el of
LT -> popTo cmp a l m u
_ -> return ()
go l m (u 1)
partialSort :: (UA e, Ord e) => MUArr e s -> Int -> ST s ()
partialSort = partialSortBy compare
partialSortBy :: (UA e) => Comparison e -> MUArr e s -> Int -> ST s ()
partialSortBy cmp a k = partialSortByBounds cmp a k 0 (lengthMU a)
partialSortByBounds :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> Int -> ST s ()
partialSortByBounds cmp a k l u
| len < 2 = return ()
| len == 2 = O.sort2ByOffset cmp a l
| len == 3 = O.sort3ByOffset cmp a l
| len == 4 = O.sort4ByOffset cmp a l
| u <= l + k = sortByBounds cmp a l u
| otherwise = do selectByBounds cmp a k l u
sortHeap cmp a l (l + 4) (l + k)
O.sort4ByOffset cmp a l
where
len = u l
heapify :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> ST s ()
heapify cmp a l u = loop $ (len 1) `div` 3
where
len = u l
loop k
| k < 0 = return ()
| otherwise = readMU a (l+k) >>= \e -> siftByOffset cmp a e l k len >> loop (k 1)
pop :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> ST s ()
pop cmp a l u = popTo cmp a l u u
popTo :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> Int -> ST s ()
popTo cmp a l u t = do al <- readMU a l
at <- readMU a t
writeMU a t al
siftByOffset cmp a at l 0 (u l)
sortHeap :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> Int -> ST s ()
sortHeap cmp a l m u = loop (u1) >> swap a l m
where
loop k
| m < k = pop cmp a l k >> loop (k1)
| otherwise = return ()
siftByOffset :: (UA e) => Comparison e -> MUArr e s -> e -> Int -> Int -> Int -> ST s ()
siftByOffset cmp a val off start len = sift val start len
where
sift val root len
| child < len = do (child' :*: ac) <- maximumChild cmp a off child len
case cmp val ac of
LT -> writeMU a (root + off) ac >> sift val child' len
_ -> writeMU a (root + off) val
| otherwise = writeMU a (root + off) val
where child = root * 3 + 1
maximumChild :: (UA e) => Comparison e -> MUArr e s -> Int -> Int -> Int -> ST s (Int :*: e)
maximumChild cmp a off child1 len
| child3 < len = do ac1 <- readMU a (child1 + off)
ac2 <- readMU a (child2 + off)
ac3 <- readMU a (child3 + off)
return $ case cmp ac1 ac2 of
LT -> case cmp ac2 ac3 of
LT -> child3 :*: ac3
_ -> child2 :*: ac2
_ -> case cmp ac1 ac3 of
LT -> child3 :*: ac3
_ -> child1 :*: ac1
| child2 < len = do ac1 <- readMU a (child1 + off)
ac2 <- readMU a (child2 + off)
return $ case cmp ac1 ac2 of
LT -> child2 :*: ac2
_ -> child1 :*: ac1
| otherwise = do ac1 <- readMU a (child1 + off) ; return (child1 :*: ac1)
where
child2 = child1 + 1
child3 = child1 + 2