{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# lANGUAGE ScopedTypeVariables #-}

-- ---------------------------------------------------------------------------
-- |
-- Module      : Data.Vector.Algorithms.AmericanFlag
-- Copyright   : (c) 2011 Dan Doel
-- Maintainer  : Dan Doel <dan.doel@gmail.com>
-- Stability   : Experimental
-- Portability : Non-portable (FlexibleContexts, ScopedTypeVariables)
--
-- This module implements American flag sort: an in-place, unstable, bucket
-- sort. Also in contrast to radix sort, the values are inspected in a big
-- endian order, and buckets are sorted via recursive splitting. This,
-- however, makes it sensible for sorting strings in lexicographic order
-- (provided indexing is fast).
--
-- The algorithm works as follows: at each stage, the array is looped over,
-- counting the number of elements for each bucket. Then, starting at the
-- beginning of the array, elements are permuted in place to reside in the
-- proper bucket, following chains until they reach back to the current
-- base index. Finally, each bucket is sorted recursively. This lends itself
-- well to the aforementioned variable-length strings, and so the algorithm
-- takes a stopping predicate, which is given a representative of the stripe,
-- rather than running for a set number of iterations.

module Data.Vector.Algorithms.AmericanFlag ( sort
                                           , sortBy
                                           , terminate
                                           , Lexicographic(..)
                                           ) where

import Prelude hiding (read, length)

import Control.Monad
import Control.Monad.Primitive

import Data.Proxy

import Data.Word
import Data.Int
import Data.Bits

import qualified Data.ByteString as B

import Data.Vector.Generic.Mutable
import qualified Data.Vector.Primitive.Mutable as PV

import qualified Data.Vector.Unboxed.Mutable as U

import Data.Vector.Algorithms.Common

import qualified Data.Vector.Algorithms.Insertion as I

import Foreign.Storable

-- | The methods of this class specify the information necessary to sort
-- arrays using the default ordering. The name 'Lexicographic' is meant
-- to convey that index should return results in a similar way to indexing
-- into a string.
class Lexicographic e where
  -- | Computes the length of a representative of a stripe. It should take 'n'
  -- passes to sort values of extent 'n'. The extent may not be uniform across
  -- all values of the type.
  extent    :: e -> Int

  -- | The size of the bucket array necessary for sorting es
  size      :: Proxy e -> Int
  -- | Determines which bucket a given element should inhabit for a
  -- particular iteration.
  index     :: Int -> e -> Int

instance Lexicographic Word8 where
  extent :: Word8 -> Int
extent Word8
_ = Int
1
  {-# INLINE extent #-}
  size :: Proxy Word8 -> Int
size Proxy Word8
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Word8 -> Int
index Int
_ Word8
n = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
n
  {-# INLINE index #-}

instance Lexicographic Word16 where
  extent :: Word16 -> Int
extent Word16
_ = Int
2
  {-# INLINE extent #-}
  size :: Proxy Word16 -> Int
size Proxy Word16
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Word16 -> Int
index Int
0 Word16
n = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> Word16 -> Int
forall a b. (a -> b) -> a -> b
$ (Word16
n Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftR`  Int
8) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
255
  index Int
1 Word16
n = Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> Word16 -> Int
forall a b. (a -> b) -> a -> b
$ Word16
n Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
255
  index Int
_ Word16
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic Word32 where
  extent :: Word32 -> Int
extent Word32
_ = Int
4
  {-# INLINE extent #-}
  size :: Proxy Word32 -> Int
size Proxy Word32
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Word32 -> Int
index Int
0 Word32
n = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ (Word32
n Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255
  index Int
1 Word32
n = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ (Word32
n Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255
  index Int
2 Word32
n = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ (Word32
n Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR`  Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255
  index Int
3 Word32
n = Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Word32 -> Int
forall a b. (a -> b) -> a -> b
$ Word32
n Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
255
  index Int
_ Word32
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic Word64 where
  extent :: Word64 -> Int
extent Word64
_ = Int
8
  {-# INLINE extent #-}
  size :: Proxy Word64 -> Int
size Proxy Word64
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Word64 -> Int
index Int
0 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
56) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
1 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
48) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
2 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
40) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
3 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
4 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
5 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
6 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ (Word64
n Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR`  Int
8) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
7 Word64
n = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ Word64
n Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
255
  index Int
_ Word64
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic Word where
  extent :: Word -> Int
extent Word
_ = Word -> Int
forall a. Storable a => a -> Int
sizeOf (Word
0 :: Word)
  {-# INLINE extent #-}
  size :: Proxy Word -> Int
size Proxy Word
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Word -> Int
index Int
0 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
56) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
1 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
48) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
2 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
40) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
3 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
32) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
4 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
5 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
6 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`shiftR`  Int
8) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
7 Word
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ Word
n Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
255
  index Int
_ Word
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic Int8 where
  extent :: Int8 -> Int
extent Int8
_ = Int
1
  {-# INLINE extent #-}
  size :: Proxy Int8 -> Int
size Proxy Int8
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Int8 -> Int
index Int
_ Int8
n = Int
255 Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
n Int -> Int -> Int
forall a. Bits a => a -> a -> a
`xor` Int
128
  {-# INLINE index #-}

instance Lexicographic Int16 where
  extent :: Int16 -> Int
extent Int16
_ = Int
2
  {-# INLINE extent #-}
  size :: Proxy Int16 -> Int
size Proxy Int16
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Int16 -> Int
index Int
0 Int16
n = Int16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int16 -> Int) -> Int16 -> Int
forall a b. (a -> b) -> a -> b
$ ((Int16
n Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
`xor` Int16
forall a. Bounded a => a
minBound) Int16 -> Int -> Int16
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.&. Int16
255
  index Int
1 Int16
n = Int16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int16 -> Int) -> Int16 -> Int
forall a b. (a -> b) -> a -> b
$ Int16
n Int16 -> Int16 -> Int16
forall a. Bits a => a -> a -> a
.&. Int16
255
  index Int
_ Int16
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic Int32 where
  extent :: Int32 -> Int
extent Int32
_ = Int
4
  {-# INLINE extent #-}
  size :: Proxy Int32 -> Int
size Proxy Int32
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Int32 -> Int
index Int
0 Int32
n = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ ((Int32
n Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
`xor` Int32
forall a. Bounded a => a
minBound) Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255
  index Int
1 Int32
n = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ (Int32
n Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255
  index Int
2 Int32
n = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ (Int32
n Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR`  Int
8) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255
  index Int
3 Int32
n = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Int32 -> Int
forall a b. (a -> b) -> a -> b
$ Int32
n Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
255
  index Int
_ Int32
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic Int64 where
  extent :: Int64 -> Int
extent Int64
_ = Int
8
  {-# INLINE extent #-}
  size :: Proxy Int64 -> Int
size Proxy Int64
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Int64 -> Int
index Int
0 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ ((Int64
n Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
`xor` Int64
forall a. Bounded a => a
minBound) Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
56) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
1 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
48) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
2 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
40) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
3 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
32) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
4 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
5 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
6 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ (Int64
n Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR`  Int
8) Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
7 Int64
n = Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64 -> Int) -> Int64 -> Int
forall a b. (a -> b) -> a -> b
$ Int64
n Int64 -> Int64 -> Int64
forall a. Bits a => a -> a -> a
.&. Int64
255
  index Int
_ Int64
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic Int where
  extent :: Int -> Int
extent Int
_ = Int -> Int
forall a. Storable a => a -> Int
sizeOf (Int
0 :: Int)
  {-# INLINE extent #-}
  size :: Proxy Int -> Int
size Proxy Int
_ = Int
256
  {-# INLINE size #-}
  index :: Int -> Int -> Int
index Int
0 Int
n = ((Int
n Int -> Int -> Int
forall a. Bits a => a -> a -> a
`xor` Int
forall a. Bounded a => a
minBound) Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
56) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
1 Int
n = (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
48) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
2 Int
n = (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
40) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
3 Int
n = (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
32) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
4 Int
n = (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
24) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
5 Int
n = (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
6 Int
n = (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR`  Int
8) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
7 Int
n = Int
n Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
255
  index Int
_ Int
_ = Int
0
  {-# INLINE index #-}

instance Lexicographic B.ByteString where
  extent :: ByteString -> Int
extent = ByteString -> Int
B.length
  {-# INLINE extent #-}
  size :: Proxy ByteString -> Int
size Proxy ByteString
_ = Int
257
  {-# INLINE size #-}
  index :: Int -> ByteString -> Int
index Int
i ByteString
b
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ByteString -> Int
B.length ByteString
b = Int
0
    | Bool
otherwise       = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int -> Word8
B.index ByteString
b Int
i) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  {-# INLINE index #-}

instance (Lexicographic a, Lexicographic b) => Lexicographic (a, b) where
  extent :: (a, b) -> Int
extent (a
a,b
b) = a -> Int
forall e. Lexicographic e => e -> Int
extent a
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ b -> Int
forall e. Lexicographic e => e -> Int
extent b
b
  {-# INLINE extent #-}
  size :: Proxy (a, b) -> Int
size Proxy (a, b)
_ = Proxy a -> Int
forall e. Lexicographic e => Proxy e -> Int
size (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Proxy b -> Int
forall e. Lexicographic e => Proxy e -> Int
size (Proxy b
forall k (t :: k). Proxy t
Proxy :: Proxy b)
  {-# INLINE size #-}
  index :: Int -> (a, b) -> Int
index Int
i (a
a,b
b)
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= a -> Int
forall e. Lexicographic e => e -> Int
extent a
a = Int -> b -> Int
forall e. Lexicographic e => Int -> e -> Int
index Int
i b
b
    | Bool
otherwise     = Int -> a -> Int
forall e. Lexicographic e => Int -> e -> Int
index Int
i a
a
  {-# INLINE index #-}

instance (Lexicographic a, Lexicographic b) => Lexicographic (Either a b) where
  extent :: Either a b -> Int
extent (Left  a
a) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ a -> Int
forall e. Lexicographic e => e -> Int
extent a
a
  extent (Right b
b) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ b -> Int
forall e. Lexicographic e => e -> Int
extent b
b
  {-# INLINE extent #-}
  size :: Proxy (Either a b) -> Int
size Proxy (Either a b)
_ = Proxy a -> Int
forall e. Lexicographic e => Proxy e -> Int
size (Proxy a
forall k (t :: k). Proxy t
Proxy :: Proxy a) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`max` Proxy b -> Int
forall e. Lexicographic e => Proxy e -> Int
size (Proxy b
forall k (t :: k). Proxy t
Proxy :: Proxy b)
  {-# INLINE size #-}
  index :: Int -> Either a b -> Int
index Int
0 (Left  a
_) = Int
0
  index Int
0 (Right b
_) = Int
1
  index Int
n (Left  a
a) = Int -> a -> Int
forall e. Lexicographic e => Int -> e -> Int
index (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) a
a
  index Int
n (Right b
b) = Int -> b -> Int
forall e. Lexicographic e => Int -> e -> Int
index (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) b
b
  {-# INLINE index #-}

-- | Given a representative of a stripe and an index number, this
-- function determines whether to stop sorting.
terminate :: Lexicographic e => e -> Int -> Bool
terminate :: e -> Int -> Bool
terminate e
e Int
i = Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= e -> Int
forall e. Lexicographic e => e -> Int
extent e
e
{-# INLINE terminate #-}

-- | Sorts an array using the default ordering. Both Lexicographic and
-- Ord are necessary because the algorithm falls back to insertion sort
-- for sufficiently small arrays.
sort :: forall e m v. (PrimMonad m, MVector v e, Lexicographic e, Ord e)
     => v (PrimState m) e -> m ()
sort :: v (PrimState m) e -> m ()
sort v (PrimState m) e
v = Comparison e
-> (e -> Int -> Bool)
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e
-> (e -> Int -> Bool)
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy Comparison e
forall a. Ord a => a -> a -> Ordering
compare e -> Int -> Bool
forall e. Lexicographic e => e -> Int -> Bool
terminate (Proxy e -> Int
forall e. Lexicographic e => Proxy e -> Int
size Proxy e
p) Int -> e -> Int
forall e. Lexicographic e => Int -> e -> Int
index v (PrimState m) e
v
 where p :: Proxy e
       p :: Proxy e
p = Proxy e
forall k (t :: k). Proxy t
Proxy
{-# INLINABLE sort #-}

-- | A fully parameterized version of the sorting algorithm. Again, this
-- function takes both radix information and a comparison, because the
-- algorithms falls back to insertion sort for small arrays.
sortBy :: (PrimMonad m, MVector v e)
       => Comparison e       -- ^ a comparison for the insertion sort flalback
       -> (e -> Int -> Bool) -- ^ determines whether a stripe is complete
       -> Int                -- ^ the number of buckets necessary
       -> (Int -> e -> Int)  -- ^ the big-endian radix function
       -> v (PrimState m) e  -- ^ the array to be sorted
       -> m ()
sortBy :: Comparison e
-> (e -> Int -> Bool)
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy Comparison e
cmp e -> Int -> Bool
stop Int
buckets Int -> e -> Int
radix v (PrimState m) e
v
  | v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise     = do MVector (PrimState m) Int
count <- Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new Int
buckets
                       MVector (PrimState m) Int
pile <- Int -> m (MVector (PrimState m) Int)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new Int
buckets
                       (e -> Int)
-> v (PrimState m) e -> MVector (PrimState m) Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(e -> Int)
-> v (PrimState m) e -> MVector (PrimState m) Int -> m ()
countLoop (Int -> e -> Int
radix Int
0) v (PrimState m) e
v MVector (PrimState m) Int
count
                       Comparison e
-> (e -> Int -> Bool)
-> (Int -> e -> Int)
-> MVector (PrimState m) Int
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e
-> (e -> Int -> Bool)
-> (Int -> e -> Int)
-> MVector (PrimState m) Int
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
flagLoop Comparison e
cmp e -> Int -> Bool
stop Int -> e -> Int
radix MVector (PrimState m) Int
count MVector (PrimState m) Int
pile v (PrimState m) e
v
{-# INLINE sortBy #-}

flagLoop :: (PrimMonad m, MVector v e)
         => Comparison e
         -> (e -> Int -> Bool)           -- number of passes
         -> (Int -> e -> Int)            -- radix function
         -> PV.MVector (PrimState m) Int -- auxiliary count array
         -> PV.MVector (PrimState m) Int -- auxiliary pile array
         -> v (PrimState m) e            -- source array
         -> m ()
flagLoop :: Comparison e
-> (e -> Int -> Bool)
-> (Int -> e -> Int)
-> MVector (PrimState m) Int
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
flagLoop Comparison e
cmp e -> Int -> Bool
stop Int -> e -> Int
radix MVector (PrimState m) Int
count MVector (PrimState m) Int
pile v (PrimState m) e
v = Int -> v (PrimState m) e -> m ()
go Int
0 v (PrimState m) e
v
 where

 go :: Int -> v (PrimState m) e -> m ()
go Int
pass v (PrimState m) e
v = do e
e <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
v Int
0
                Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (e -> Int -> Bool
stop e
e (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Int
pass Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> v (PrimState m) e -> m ()
go' Int
pass v (PrimState m) e
v

 go' :: Int -> v (PrimState m) e -> m ()
go' Int
pass v (PrimState m) e
v
   | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
threshold = Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> Int -> Int -> m ()
I.sortByBounds Comparison e
cmp v (PrimState m) e
v Int
0 Int
len
   | Bool
otherwise       = do MVector (PrimState m) Int -> MVector (PrimState m) Int -> m ()
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Int -> MVector (PrimState m) Int -> m ()
accumulate MVector (PrimState m) Int
count MVector (PrimState m) Int
pile
                          (e -> Int)
-> MVector (PrimState m) Int
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(e -> Int)
-> MVector (PrimState m) Int
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
permute (Int -> e -> Int
radix Int
pass) MVector (PrimState m) Int
count MVector (PrimState m) Int
pile v (PrimState m) e
v
                          Int -> m ()
recurse Int
0
  where
  len :: Int
len = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
v
  ppass :: Int
ppass = Int
pass Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

  recurse :: Int -> m ()
recurse Int
i
    | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len   = do Int
j <- (e -> Int)
-> (e -> Int)
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> Int
-> m Int
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(e -> Int)
-> (e -> Int)
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> Int
-> m Int
countStripe (Int -> e -> Int
radix Int
ppass) (Int -> e -> Int
radix Int
pass) MVector (PrimState m) Int
count v (PrimState m) e
v Int
i
                     Int -> v (PrimState m) e -> m ()
go Int
ppass (Int -> Int -> v (PrimState m) e -> v (PrimState m) e
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
i (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) v (PrimState m) e
v)
                     Int -> m ()
recurse Int
j
    | Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE flagLoop #-}

accumulate :: (PrimMonad m)
           => PV.MVector (PrimState m) Int
           -> PV.MVector (PrimState m) Int
           -> m ()
accumulate :: MVector (PrimState m) Int -> MVector (PrimState m) Int -> m ()
accumulate MVector (PrimState m) Int
count MVector (PrimState m) Int
pile = Int -> Int -> m ()
loop Int
0 Int
0
 where
 len :: Int
len = MVector (PrimState m) Int -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length MVector (PrimState m) Int
count

 loop :: Int -> Int -> m ()
loop Int
i Int
acc
   | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = do Int
ci <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead MVector (PrimState m) Int
count Int
i
                  let acc' :: Int
acc' = Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ci
                  MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector (PrimState m) Int
pile Int
i Int
acc
                  MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector (PrimState m) Int
count Int
i Int
acc'
                  Int -> Int -> m ()
loop (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
acc'
   | Bool
otherwise    = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE accumulate #-}

permute :: (PrimMonad m, MVector v e)
        => (e -> Int)                       -- radix function
        -> PV.MVector (PrimState m) Int     -- count array
        -> PV.MVector (PrimState m) Int     -- pile array
        -> v (PrimState m) e                -- source array
        -> m ()
permute :: (e -> Int)
-> MVector (PrimState m) Int
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> m ()
permute e -> Int
rdx MVector (PrimState m) Int
count MVector (PrimState m) Int
pile v (PrimState m) e
v = Int -> m ()
go Int
0
 where
 len :: Int
len = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
v

 go :: Int -> m ()
go Int
i
   | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len   = do e
e <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
v Int
i
                    let r :: Int
r = e -> Int
rdx e
e
                    Int
p <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead MVector (PrimState m) Int
pile Int
r
                    Int
m <- if Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
                            then MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead MVector (PrimState m) Int
count (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
                            else Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
0
                    case () of
                      -- if the current element is already in the right pile,
                      -- go to the end of the pile
                      ()
_ | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i Bool -> Bool -> Bool
&& Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
p  -> Int -> m ()
go Int
p
                      -- if the current element happens to be in the right
                      -- pile, bump the pile counter and go to the next element
                        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
p           -> MVector (PrimState m) Int -> Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector (PrimState m) Int
pile Int
r (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                      -- otherwise follow the chain
                        | Bool
otherwise        -> Int -> e -> Int -> m ()
follow Int
i e
e Int
p m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
   | Bool
otherwise = () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
 
 follow :: Int -> e -> Int -> m ()
follow Int
i e
e Int
j = do e
en <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
v Int
j
                   let r :: Int
r = e -> Int
rdx e
en
                   Int
p <- MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v Int) =>
v (PrimState m) Int -> Int -> m Int
inc MVector (PrimState m) Int
pile Int
r
                   if Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j
                      -- if the target happens to be in the right pile, don't move it.
                      then Int -> e -> Int -> m ()
follow Int
i e
e (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                      else v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
v Int
j e
e m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
p
                                             then v (PrimState m) e -> Int -> e -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
v Int
i e
en
                                             else Int -> e -> Int -> m ()
follow Int
i e
en Int
p
{-# INLINE permute #-}

countStripe :: (PrimMonad m, MVector v e)
            => (e -> Int)                   -- radix function
            -> (e -> Int)                   -- stripe function
            -> PV.MVector (PrimState m) Int -- count array
            -> v (PrimState m) e            -- source array
            -> Int                          -- starting position
            -> m Int                        -- end of stripe: [lo,hi)
countStripe :: (e -> Int)
-> (e -> Int)
-> MVector (PrimState m) Int
-> v (PrimState m) e
-> Int
-> m Int
countStripe e -> Int
rdx e -> Int
str MVector (PrimState m) Int
count v (PrimState m) e
v Int
lo = do MVector (PrimState m) Int -> Int -> m ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> a -> m ()
set MVector (PrimState m) Int
count Int
0
                                    e
e <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
v Int
lo
                                    Int -> e -> Int -> m Int
go (e -> Int
str e
e) e
e (Int
loInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
 where
 len :: Int
len = v (PrimState m) e -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
v

 go :: Int -> e -> Int -> m Int
go !Int
s e
e Int
i = MVector (PrimState m) Int -> Int -> m Int
forall (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v Int) =>
v (PrimState m) Int -> Int -> m Int
inc MVector (PrimState m) Int
count (e -> Int
rdx e
e) m Int -> m Int -> m Int
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>
            if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len
               then do e
en <- v (PrimState m) e -> Int -> m e
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
v Int
i
                       if e -> Int
str e
en Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
s
                          then Int -> e -> Int -> m Int
go Int
s e
en (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                          else Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i
                else Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
len
{-# INLINE countStripe #-}

threshold :: Int
threshold :: Int
threshold = Int
25