module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import qualified Data.Vector.Primitive.Mutable as PV
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common (Comparison)
import Data.Bits
import Data.Int
import Data.Word
import Foreign.Storable
class Radix e where
passes :: e -> Int
size :: e -> Int
radix :: Int -> e -> Int
instance Radix Int where
passes _ = sizeOf (undefined :: Int)
size _ = 256
radix 0 e = e .&. 255
radix i e
| i == passes e 1 = radix' (e `xor` minBound)
| otherwise = radix' e
where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255
instance Radix Int8 where
passes _ = 1
size _ = 256
radix _ e = 255 .&. fromIntegral e `xor` 128
instance Radix Int16 where
passes _ = 2
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255)
instance Radix Int32 where
passes _ = 4
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255)
instance Radix Int64 where
passes _ = 8
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255)
instance Radix Word where
passes _ = sizeOf (undefined :: Word)
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255)
instance Radix Word8 where
passes _ = 1
size _ = 256
radix _ = fromIntegral
instance Radix Word16 where
passes _ = 2
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
instance Radix Word32 where
passes _ = 4
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
instance Radix Word64 where
passes _ = 8
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255)
instance (Radix i, Radix j) => Radix (i, j) where
passes ~(i, j) = passes i + passes j
size ~(i, j) = size i `max` size j
radix k ~(i, j) | k < passes j = radix k j
| otherwise = radix (k passes j) i
sort :: forall e m v. (PrimMonad m, MVector v e, Radix e)
=> v (PrimState m) e -> m ()
sort arr = sortBy (passes e) (size e) radix arr
where
e :: e
e = undefined
sortBy :: (PrimMonad m, MVector v e)
=> Int
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy passes size rdx arr = do
tmp <- new (length arr)
count <- new size
prefix <- new size
radixLoop passes rdx arr tmp count prefix
radixLoop :: (PrimMonad m, MVector v e)
=> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> m ()
radixLoop passes rdx src dst count prefix = go False 0
where
len = length src
go swap k
| k < passes = if swap
then body rdx dst src count prefix k >> go (not swap) (k+1)
else body rdx src dst count prefix k >> go (not swap) (k+1)
| otherwise = when swap (unsafeCopy src dst)
body :: (PrimMonad m, MVector v e)
=> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> PV.MVector (PrimState m) Int
-> Int
-> m ()
body rdx src dst count prefix k = do
set count 0
countLoop k rdx src count
unsafeWrite prefix 0 0
prefixLoop count prefix
moveLoop k rdx src dst prefix
countLoop :: (PrimMonad m, MVector v e)
=> Int -> (Int -> e -> Int)
-> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
countLoop k rdx src count = go 0
where
len = length src
go i
| i < len = unsafeRead src i >>= inc count . rdx k >> go (i+1)
| otherwise = return ()
prefixLoop :: (PrimMonad m)
=> PV.MVector (PrimState m) Int -> PV.MVector (PrimState m) Int
-> m ()
prefixLoop count prefix = go 1 0
where
len = length count
go i pi
| i < len = do ci <- unsafeRead count (i1)
let pi' = pi + ci
unsafeWrite prefix i pi'
go (i+1) pi'
| otherwise = return ()
moveLoop :: (PrimMonad m, MVector v e)
=> Int -> (Int -> e -> Int) -> v (PrimState m) e
-> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
moveLoop k rdx src dst prefix = go 0
where
len = length src
go i
| i < len = do srci <- unsafeRead src i
pf <- inc prefix (rdx k srci)
unsafeWrite dst pf srci
go (i+1)
| otherwise = return ()
inc :: (PrimMonad m) => PV.MVector (PrimState m) Int -> Int -> m Int
inc arr i = unsafeRead arr i >>= \e -> unsafeWrite arr i (e+1) >> return e