{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE CPP #-}
module Std.Data.Vector.Sort (
  
    mergeSort
  , mergeSortBy
  , mergeTileSize
  , insertSort
  , insertSortBy
  , Down(..)
  , radixSort
  , Radix(..)
  , RadixDown(..)
  ) where
import           Control.Monad.ST
import           Data.Bits
import           Data.Int
import           Data.Ord               (Down (..))
import           Data.Primitive         (sizeOf)
import           Data.Primitive.Types   (Prim (..))
import           Data.Word
import           Prelude                hiding (splitAt)
import           Std.Data.Array
import           Std.Data.Vector.Base
import           Std.Data.Vector.Extra
import           Std.Data.PrimArray.Cast
mergeSort :: forall v a. (Vec v a, Ord a) => v a -> v a
{-# INLINABLE mergeSort #-}
mergeSort = mergeSortBy compare
mergeSortBy :: forall v a. Vec v a => (a -> a -> Ordering) -> v a -> v a
{-# INLINE mergeSortBy #-}
mergeSortBy cmp v@(Vec _ _ l)
    | l <= mergeTileSize = insertSortBy cmp v
    | otherwise = runST (do
        
        w1 <- newArr l
        w2 <- newArr l
        firstPass v 0 w1
        w <- mergePass w1 w2 mergeTileSize
        return $! fromArr w 0 l)
  where
    firstPass !v !i !marr
        | i >= l     = return ()
        | otherwise = do
            let (v',rest) = splitAt mergeTileSize v
            insertSortToMArr cmp v' i marr
            firstPass rest (i+mergeTileSize) marr
    mergePass !w1 !w2 !blockSiz
        | blockSiz >= l = unsafeFreezeArr w1
        | otherwise     = do
            mergeLoop w1 w2 blockSiz 0
            mergePass w2 w1 (blockSiz*2) 
    mergeLoop !src !target !blockSiz !i
        | i >= l-blockSiz =                 
            if i >= l
            then return ()
            else copyMutableArr target i src i (l-i)
        | otherwise = do
            let !mergeEnd = min (i+blockSiz+blockSiz) l
            mergeBlock src target (i+blockSiz) mergeEnd i (i+blockSiz) i
            mergeLoop src target blockSiz mergeEnd
    mergeBlock !src !target !leftEnd !rightEnd !i !j !k = do
        l <- readArr src i
        r <- readArr src j
        case r `cmp` l of
            LT -> do
                writeArr target k r
                let !j' = j + 1
                    !k' = k + 1
                if j' >= rightEnd
                then copyMutableArr target k' src i (leftEnd - i)
                else mergeBlock src target leftEnd rightEnd i j' k'
            _ -> do
                writeArr target k l
                let !i' = i + 1
                    !k' = k + 1
                if i' >= leftEnd
                then copyMutableArr target k' src j (rightEnd - j)
                else mergeBlock src target leftEnd rightEnd i' j k'
mergeTileSize :: Int
{-# INLINE mergeTileSize #-}
mergeTileSize = 16
insertSort :: (Vec v a, Ord a) => v a -> v a
{-# INLINE insertSort #-}
insertSort = insertSortBy compare
insertSortBy :: Vec v a => (a -> a -> Ordering) -> v a -> v a
{-# INLINE insertSortBy #-}
insertSortBy _ v@(Vec _ _ 0) = empty
insertSortBy _ v@(Vec arr s 1) = case indexArr' arr s of (# x #) -> singleton x
insertSortBy cmp v@(Vec arr s l) = create l (insertSortToMArr cmp v 0)
insertSortToMArr  :: Vec v a
                  => (a -> a -> Ordering)
                  -> v a            
                  -> Int            
                  -> MArray v s a   
                  -> ST s ()
{-# INLINE insertSortToMArr #-}
insertSortToMArr cmp (Vec arr s l) moff marr = go s
  where
    !end = s + l
    !doff = moff-s
    go !i | i >= end  = return ()
          | otherwise = case indexArr' arr i of
               (# x #) -> do insert x (i+doff)
                             go (i+1)
    insert !temp !i
        | i <= moff = do
            writeArr marr moff temp
        | otherwise = do
            x <- readArr marr (i-1)
            case temp `cmp` x of
                LT -> do
                    writeArr marr i x
                    insert temp (i-1)
                _ -> writeArr marr i temp
class Radix a where
    
    bucketSize :: a -> Int
    
    
    passes :: a -> Int
    
    radixLSB  :: a -> Int
    
    radix  :: Int -> a -> Int
    
    radixMSB  :: a -> Int
instance Radix Int8 where
    {-# INLINE bucketSize #-};
    bucketSize _ = 256
    {-# INLINE passes #-}
    passes _ = 1
    {-# INLINE radixLSB #-}
    radixLSB a =  255 .&. fromIntegral a `xor` 128
    {-# INLINE radix #-}
    radix _ a =  255 .&. fromIntegral a `xor` 128
    {-# INLINE radixMSB #-}
    radixMSB a =  255 .&. fromIntegral a `xor` 128
#define MULTI_BYTES_INT_RADIX(T) \
    {-# INLINE bucketSize #-}; \
    bucketSize _ = 256; \
    {-# INLINE passes #-}; \
    passes _ = sizeOf (undefined :: T); \
    {-# INLINE radixLSB #-}; \
    radixLSB a = fromIntegral (255 .&. a); \
    {-# INLINE radix #-}; \
    radix i a = fromIntegral (a `unsafeShiftR` (i `unsafeShiftL` 3)) .&. 255; \
    {-# INLINE radixMSB #-}; \
    radixMSB a = fromIntegral ((a `xor` minBound) `unsafeShiftR` ((passes a-1) `unsafeShiftL` 3)) .&. 255
instance Radix Int where MULTI_BYTES_INT_RADIX(Int)
instance Radix Int16 where MULTI_BYTES_INT_RADIX(Int16)
instance Radix Int32 where MULTI_BYTES_INT_RADIX(Int32)
instance Radix Int64 where MULTI_BYTES_INT_RADIX(Int64)
instance Radix Word8 where
    {-# INLINE bucketSize #-};
    bucketSize _ = 256
    {-# INLINE passes #-}
    passes _ = 1
    {-# INLINE radixLSB #-}
    radixLSB = fromIntegral
    {-# INLINE radix #-}
    radix _  = fromIntegral
    {-# INLINE radixMSB #-}
    radixMSB = fromIntegral
#define MULTI_BYTES_WORD_RADIX(T) \
    {-# INLINE bucketSize #-}; \
    bucketSize _ = 256; \
    {-# INLINE passes #-}; \
    passes _ = sizeOf (undefined :: T); \
    {-# INLINE radixLSB #-}; \
    radixLSB a = fromIntegral (255 .&. a); \
    {-# INLINE radix #-}; \
    radix i a = fromIntegral (a `unsafeShiftR` (i `unsafeShiftL` 3)) .&. 255; \
    {-# INLINE radixMSB #-}; \
    radixMSB a = fromIntegral (a `unsafeShiftR` ((passes a-1) `unsafeShiftL` 3)) .&. 255
instance Radix Word where MULTI_BYTES_INT_RADIX(Word)
instance Radix Word16 where MULTI_BYTES_INT_RADIX(Word16)
instance Radix Word32 where MULTI_BYTES_INT_RADIX(Word32)
instance Radix Word64 where MULTI_BYTES_INT_RADIX(Word64)
newtype RadixDown a = RadixDown a deriving (Show, Eq, Prim)
instance Radix a => Radix (RadixDown a) where
    {-# INLINE bucketSize #-}
    bucketSize (RadixDown a) = bucketSize a
    {-# INLINE passes #-}
    passes (RadixDown a)  = passes a
    {-# INLINE radixLSB #-}
    radixLSB (RadixDown a) = bucketSize a - radixLSB a -1
    {-# INLINE radix #-}
    radix i (RadixDown a) = bucketSize a - radix i a -1
    {-# INLINE radixMSB #-}
    radixMSB (RadixDown a) = bucketSize a - radixMSB a -1
radixSort :: forall v a. (Vec v a, Radix a) => v a -> v a
{-# INLINABLE radixSort #-}
radixSort v@(Vec _ _ 0) = empty
radixSort v@(Vec arr s 1) = case indexArr' arr s of (# x #) -> singleton x
radixSort (Vec arr s l) = runST (do
        bucket <- newArrWith buktSiz 0 :: ST s (MutablePrimArray s Int)
        w1 <- newArr l
        firstCountPass arr bucket s
        accumBucket bucket buktSiz 0 0
        firstMovePass arr s bucket w1
        w <- if passSiz == 1
            then unsafeFreezeArr w1
            else do
                w2 <- newArr l
                radixLoop w1 w2 bucket buktSiz 1
        return $! fromArr w 0 l)
  where
    passSiz = passes (undefined :: a)
    buktSiz = bucketSize (undefined :: a)
    !end = s + l
    {-# INLINABLE firstCountPass #-}
    firstCountPass !arr !bucket !i
        | i >= end  = return ()
        | otherwise = case indexArr' arr i of
            (# x #) -> do
                let !r = radixLSB x
                c <- readArr bucket r
                writeArr bucket r (c+1)
                firstCountPass arr bucket (i+1)
    {-# INLINABLE accumBucket #-}
    accumBucket !bucket !buktSiz !i !acc
        | i >= buktSiz = return ()
        | otherwise = do
            c <- readArr bucket i
            writeArr bucket i acc
            accumBucket bucket buktSiz (i+1) (acc+c)
    {-# INLINABLE firstMovePass #-}
    firstMovePass !arr !i !bucket !w
        | i >= end  = return ()
        | otherwise = case indexArr' arr i of
            (# x #) -> do
                let !r = radixLSB x
                c <- readArr bucket r
                writeArr bucket r (c+1)
                writeArr w c x
                firstMovePass arr (i+1) bucket w
    {-# INLINABLE radixLoop #-}
    radixLoop !w1 !w2 !bucket !buktSiz !pass
        | pass >= passSiz-1 = do
            setArr bucket 0 buktSiz 0   
            lastCountPass w1 bucket 0
            accumBucket bucket buktSiz 0 0
            lastMovePass w1 bucket w2 0
            unsafeFreezeArr w2
        | otherwise = do
            setArr bucket 0 buktSiz 0   
            countPass w1 bucket pass 0
            accumBucket bucket buktSiz 0 0
            movePass w1 bucket pass w2 0
            radixLoop w2 w1 bucket buktSiz (pass+1)
    {-# INLINABLE countPass #-}
    countPass !marr !bucket !pass !i
        | i >= l  = return ()
        | otherwise = do
                x <- readArr marr i
                let !r = radix pass x
                c <- readArr bucket r
                writeArr bucket r (c+1)
                countPass marr bucket pass (i+1)
    {-# INLINABLE movePass #-}
    movePass !src !bucket !pass !target !i
        | i >= l  = return ()
        | otherwise = do
                x <- readArr src i
                let !r = radix pass x
                c <- readArr bucket r
                writeArr bucket r (c+1)
                writeArr target c x
                movePass src bucket pass target (i+1)
    {-# INLINABLE lastCountPass #-}
    lastCountPass !marr !bucket !i
        | i >= l  = return ()
        | otherwise = do
                x <- readArr marr i
                let !r = radixMSB x
                c <- readArr bucket r
                writeArr bucket r (c+1)
                lastCountPass marr bucket (i+1)
    {-# INLINABLE lastMovePass #-}
    lastMovePass !src !bucket !target !i
        | i >= l  = return ()
        | otherwise = do
                x <- readArr src i
                let !r = radixMSB x
                c <- readArr bucket r
                writeArr bucket r (c+1)
                writeArr target c x
                lastMovePass src bucket target (i+1)