module Data.Vector.Algorithms.Intro
       ( 
         sort
       , sortBy
       , sortByBounds 
         
       , select
       , selectBy
       , selectByBounds
         
       , partialSort
       , partialSortBy
       , partialSortByBounds
       , Comparison
       ) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import Data.Bits
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison)
import qualified Data.Vector.Algorithms.Insertion as I
import qualified Data.Vector.Algorithms.Optimal   as O
import qualified Data.Vector.Algorithms.TriHeap   as H
sort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> m ()
sort = sortBy compare
sortBy :: (PrimMonad m, MVector v e) => Comparison e -> v (PrimState m) e -> m ()
sortBy cmp a = sortByBounds cmp a 0 (length a)
sortByBounds :: (PrimMonad m, MVector v e)
             => Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
sortByBounds cmp a l u
  | len < 2   = return ()
  | len == 2  = O.sort2ByOffset cmp a l
  | len == 3  = O.sort3ByOffset cmp a l
  | len == 4  = O.sort4ByOffset cmp a l
  | otherwise = introsort cmp a (ilg len) l u
 where len = u  l
introsort :: (PrimMonad m, MVector v e)
          => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
introsort cmp a i l u = sort i l u >> I.sortByBounds cmp a l u
 where
 sort 0 l u = H.sortByBounds cmp a l u
 sort d l u
   | len < threshold = return ()
   | otherwise = do O.sort3ByIndex cmp a c l (u1) 
                    p <- unsafeRead a l
                    mid <- partitionBy cmp a p (l+1) u
                    unsafeSwap a l (mid  1)
                    sort (d1) mid u
                    sort (d1) l   (mid  1)
  where
  len = u  l
  c   = (u + l) `div` 2
select :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m ()
select = selectBy compare
selectBy :: (PrimMonad m, MVector v e)
         => Comparison e -> v (PrimState m) e -> Int -> m ()
selectBy cmp a k = selectByBounds cmp a k 0 (length a)
selectByBounds :: (PrimMonad m, MVector v e)
               => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
selectByBounds cmp a k l u = go (ilg len) l (l + k) u
 where
 len = u  l
 go 0 l m u = H.selectByBounds cmp a (m  l) l u
 go n l m u = do O.sort3ByIndex cmp a c l (u1)
                 p <- unsafeRead a l
                 mid <- partitionBy cmp a p (l+1) u
                 unsafeSwap a l (mid  1)
                 if m > mid
                   then go (n1) mid m u
                   else if m < mid  1
                        then go (n1) l m (mid  1)
                        else return ()
  where c = (u + l) `div` 2
partialSort :: (PrimMonad m, MVector v e, Ord e) => v (PrimState m) e -> Int -> m ()
partialSort = partialSortBy compare
partialSortBy :: (PrimMonad m, MVector v e)
              => Comparison e -> v (PrimState m) e -> Int -> m ()
partialSortBy cmp a k = partialSortByBounds cmp a k 0 (length a)
partialSortByBounds :: (PrimMonad m, MVector v e)
                    => Comparison e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
partialSortByBounds cmp a k l u = go (ilg len) l (l + k) u
 where
 len = u  l
 go 0 l m n = H.partialSortByBounds cmp a (m  l) l u
 go n l m u
   | l == m    = return ()
   | otherwise = do O.sort3ByIndex cmp a c l (u1)
                    p <- unsafeRead a l
                    mid <- partitionBy cmp a p (l+1) u
                    unsafeSwap a l (mid  1)
                    case compare m mid of
                      GT -> do introsort cmp a (n1) l (mid  1)
                               go (n1) mid m u
                      EQ -> introsort cmp a (n1) l m
                      LT -> go n l m (mid  1)
  where c = (u + l) `div` 2
partitionBy :: forall m v e. (PrimMonad m, MVector v e)
            => Comparison e -> v (PrimState m) e -> e -> Int -> Int -> m Int
partitionBy cmp a = partUp
 where
 
 
 partUp :: e -> Int -> Int -> m Int
 partUp p l u
   | l < u = do e <- unsafeRead a l
                case cmp e p of
                  LT -> partUp p (l+1) u
                  _  -> partDown p l (u1)
   | otherwise = return l
 partDown :: e -> Int -> Int -> m Int
 partDown p l u
   | l < u = do e <- unsafeRead a u
                case cmp p e of
                  LT -> partDown p l (u1)
                  _  -> unsafeSwap a l u >> partUp p (l+1) u
   | otherwise = return l
ilg :: Int -> Int
ilg m = 2 * loop m 0
 where
 loop 0 !k = k  1
 loop n !k = loop (n `shiftR` 1) (k+1)
threshold :: Int
threshold = 18