{-# LANGUAGE ScopedTypeVariables, BangPatterns, TypeOperators #-}

-- ---------------------------------------------------------------------------
-- |
-- Copyright   : (c) 2008-2011 Dan Doel
-- Maintainer  : Dan Doel <dan.doel@gmail.com>
-- Stability   : Experimental
-- Portability : Non-portable (scoped type variables, bang patterns)
--
-- This module provides a radix sort for a subclass of unboxed arrays. The
-- radix class gives information on
--   * the number of passes needed for the data type
--
--   * the size of the auxiliary arrays
--
--   * how to compute the pass-k radix of a value
--
-- Radix sort is not a comparison sort, so it is able to achieve O(n) run
-- time, though it also uses O(n) auxiliary space. In addition, there is a
-- constant space overhead of 2*size*sizeOf(Int) for the sort, so it is not
-- advisable to use this sort for large numbers of very small arrays.
--
-- A standard example (upon which one could base their own Radix instance)
-- is Word32:
--
--   * We choose to sort on r = 8 bits at a time
--
--   * A Word32 has b = 32 bits total
--
--   Thus, b/r = 4 passes are required, 2^r = 256 elements are needed in an
--   auxiliary array, and the radix function is:
--
--    > radix k e = (e `shiftR` (k*8)) .&. 256

import qualified Data.Vector.Primitive.Mutable as PV
import Data.Vector.Generic.Mutable

import Data.Vector.Algorithms.Common

import Data.Bits
import Data.Int
import Data.Word

import Foreign.Storable

-- | The number of passes necessary to sort an array of es
passes :: e -> Int
-- | The size of an auxiliary array
size   :: e -> Int
-- | The radix function parameterized by the current pass
radix  :: Int -> e -> Int

passes _ = sizeOf (undefined :: Int)
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = e .&. 255
| i == passes e - 1 = radix' (e `xor` minBound)
where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255

passes _ = 1
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix _ e = 255 .&. fromIntegral e `xor` 128

passes _ = 2
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255)

passes _ = 4
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255)

passes _ = 8
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255)

passes _ = sizeOf (undefined :: Word)
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255)

passes _ = 1
{-# INLINE passes #-}
size _ = 256
{-# INLINE size #-}

passes _ = 2
{-# INLINE passes #-}
size   _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)

passes _ = 4
{-# INLINE passes #-}
size   _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)

passes _ = 8
{-# INLINE passes #-}
size   _ = 256
{-# INLINE size #-}
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255)

passes ~(i, j) = passes i + passes j
{-# INLINE passes #-}
size   ~(i, j) = size i `max` size j
{-# INLINE size #-}
radix k ~(i, j) | k < passes j = radix k j
| otherwise    = radix (k - passes j) i

-- | Sorts an array based on the Radix instance.
sort :: forall e m v. (PrimMonad m, MVector v e, Radix e)
=> v (PrimState m) e -> m ()
sort arr = sortBy (passes e) (size e) radix arr
where
e :: e
e = undefined
{-# INLINE sort #-}

-- requires the number of passes to fully sort the array,
-- the size of of auxiliary arrays necessary (should be
-- one greater than the maximum value returned by the radix
-- function), and a radix function, which takes the pass
-- and an element, and returns the relevant radix.
sortBy :: (PrimMonad m, MVector v e)
=> Int               -- ^ the number of passes
-> Int               -- ^ the size of auxiliary arrays
-> (Int -> e -> Int) -- ^ the radix function
-> v (PrimState m) e -- ^ the array to be sorted
-> m ()
sortBy passes size rdx arr = do
tmp    <- new (length arr)
count  <- new size
radixLoop passes rdx arr tmp count
{-# INLINE sortBy #-}

=> Int                          -- passes
-> (Int -> e -> Int)            -- radix function
-> v (PrimState m) e            -- array to sort
-> v (PrimState m) e            -- temporary array
-> PV.MVector (PrimState m) Int -- radix count array
-> m ()
radixLoop passes rdx src dst count = go False 0
where
len = length src
go swap k
| k < passes = if swap
then body rdx dst src count k >> go (not swap) (k+1)
else body rdx src dst count k >> go (not swap) (k+1)
| otherwise  = when swap (unsafeCopy src dst)

body :: (PrimMonad m, MVector v e)
=> (Int -> e -> Int)            -- radix function
-> v (PrimState m) e            -- source array
-> v (PrimState m) e            -- destination array
-> PV.MVector (PrimState m) Int -- radix count
-> Int                          -- current pass
-> m ()
body rdx src dst count k = do
countLoop (rdx k) src count
accumulate count
moveLoop k rdx src dst count
{-# INLINE body #-}

=> PV.MVector (PrimState m) Int -> m ()
accumulate count = go 0 0
where
len = length count
go i acc
| i < len   = do ci <- unsafeRead count i
unsafeWrite count i acc
go (i+1) (acc + ci)
| otherwise = return ()
{-# INLINE accumulate #-}

moveLoop :: (PrimMonad m, MVector v e)
=> Int -> (Int -> e -> Int) -> v (PrimState m) e
-> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
moveLoop k rdx src dst prefix = go 0
where
len = length src
go i
| i < len    = do srci <- unsafeRead src i
pf   <- inc prefix (rdx k srci)
unsafeWrite dst pf srci
go (i+1)
| otherwise  = return ()
{-# INLINE moveLoop #-}