module Data.Array.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where
import Control.Monad
import Control.Monad.ST
import Data.Array.Vector
import Data.Array.Vector.Algorithms.Common
import Data.Bits
import Data.Int
import Data.Word
import Foreign.Storable
class UA e => Radix e where
passes :: e -> Int
size :: e -> Int
radix :: Int -> e -> Int
instance Radix Int where
passes _ = sizeOf (undefined :: Int)
size _ = 256
radix 0 e = e .&. 255
radix i e
| i == passes e 1 = radix' (e `xor` minBound)
| otherwise = radix' e
where radix' e = (e `shiftR` (i `shiftL` 3)) .&. 255
instance Radix Int8 where
passes _ = 1
size _ = 256
radix _ e = 255 .&. fromIntegral e `xor` 128
instance Radix Int16 where
passes _ = 2
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral (((e `xor` minBound) `shiftR` 8) .&. 255)
instance Radix Int32 where
passes _ = 4
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral (((e `xor` minBound) `shiftR` 24) .&. 255)
instance Radix Int64 where
passes _ = 8
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral (((e `xor` minBound) `shiftR` 56) .&. 255)
instance Radix Word where
passes _ = sizeOf (undefined :: Word)
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix i e = fromIntegral ((e `shiftR` (i `shiftL` 3)) .&. 255)
instance Radix Word8 where
passes _ = 1
size _ = 256
radix _ = fromIntegral
instance Radix Word16 where
passes _ = 2
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
instance Radix Word32 where
passes _ = 4
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
instance Radix Word64 where
passes _ = 8
size _ = 256
radix 0 e = fromIntegral (e .&. 255)
radix 1 e = fromIntegral ((e `shiftR` 8) .&. 255)
radix 2 e = fromIntegral ((e `shiftR` 16) .&. 255)
radix 3 e = fromIntegral ((e `shiftR` 24) .&. 255)
radix 4 e = fromIntegral ((e `shiftR` 32) .&. 255)
radix 5 e = fromIntegral ((e `shiftR` 40) .&. 255)
radix 6 e = fromIntegral ((e `shiftR` 48) .&. 255)
radix 7 e = fromIntegral ((e `shiftR` 56) .&. 255)
instance (Radix i, Radix j) => Radix (i :*: j) where
passes ~(i :*: j) = passes i + passes j
size ~(i :*: j) = size i `max` size j
radix k ~(i :*: j) | k < passes j = radix k j
| otherwise = radix (k passes j) i
sort :: forall e s. Radix e => MUArr e s -> ST s ()
sort arr = sortBy (passes e) (size e) radix arr
where
e :: e
e = undefined
sortBy :: (UA e) => Int
-> Int
-> (Int -> e -> Int)
-> MUArr e s
-> ST s ()
sortBy passes size rdx arr = do
tmp <- newMU (lengthMU arr)
count <- newMU (size)
prefix <- newMU (size)
radixLoop passes rdx arr tmp count prefix
radixLoop :: (UA e) => Int
-> (Int -> e -> Int)
-> MUArr e s
-> MUArr e s
-> MUArr Int s
-> MUArr Int s
-> ST s ()
radixLoop passes rdx src dst count prefix = go False 0
where
len = lengthMU src
go swap k
| k < passes = if swap
then body rdx dst src count prefix k >> go (not swap) (k+1)
else body rdx src dst count prefix k >> go (not swap) (k+1)
| otherwise = when swap (mcopyMU dst src 0 0 len)
body :: (UA e) => (Int -> e -> Int)
-> MUArr e s
-> MUArr e s
-> MUArr Int s
-> MUArr Int s
-> Int
-> ST s ()
body rdx src dst count prefix k = do
zero count
countLoop k rdx src count
writeMU prefix 0 0
prefixLoop count prefix
moveLoop k rdx src dst prefix
zero :: MUArr Int s -> ST s ()
zero a = go 0
where
len = lengthMU a
go i
| i < len = writeMU a i 0 >> go (i+1)
| otherwise = return ()
countLoop :: (UA e) => Int -> (Int -> e -> Int) -> MUArr e s -> MUArr Int s -> ST s ()
countLoop k rdx src count = go 0
where
len = lengthMU src
go i
| i < len = readMU src i >>= inc count . rdx k >> go (i+1)
| otherwise = return ()
prefixLoop :: MUArr Int s -> MUArr Int s -> ST s ()
prefixLoop count prefix = go 1 0
where
len = lengthMU count
go i pi
| i < len = do ci <- readMU count (i1)
let pi' = pi + ci
writeMU prefix i pi'
go (i+1) pi'
| otherwise = return ()
moveLoop :: (UA e) => Int -> (Int -> e -> Int) -> MUArr e s -> MUArr e s -> MUArr Int s -> ST s ()
moveLoop k rdx src dst prefix = go 0
where
len = lengthMU src
go i
| i < len = do srci <- readMU src i
pf <- inc prefix (rdx k srci)
writeMU dst pf srci
go (i+1)
| otherwise = return ()
inc :: MUArr Int s -> Int -> ST s Int
inc arr i = readMU arr i >>= \e -> writeMU arr i (e+1) >> return e