{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

{-# OPTIONS_GHC -Wall #-}

-- | Sort primitive arrays with a stable sorting algorithm. All functions
-- in this module are marked as @INLINABLE@, so they will specialize
-- when used in a monomorphic setting.
module Data.Primitive.Sort
  ( -- * Immutable
    sort
  , sortUnique
  , sortTagged
  , sortUniqueTagged
    -- * Mutable
  , sortMutable
  , sortUniqueMutable
  , sortTaggedMutable
  , sortUniqueTaggedMutable
  ) where

import Control.Monad.ST
import Control.Applicative
import GHC.ST (ST(..))
import GHC.IO (IO(..))
import GHC.Int (Int(..))
import Control.Monad
import GHC.Prim
import Control.Concurrent (getNumCapabilities)
import Data.Primitive.Contiguous (Contiguous,Mutable,Element)
import qualified Data.Primitive.Contiguous as C

-- | Sort an immutable array. Duplicate elements are preserved.
--
-- >>> sort ([5,6,7,9,5,4,5,7] :: Array Int)
-- fromListN 8 [4,5,5,5,6,7,7,9]
sort :: (Contiguous arr, Element arr a, Ord a)
  => arr a
  -> arr a
{-# INLINABLE sort #-}
sort !src = runST $ do
  let len = C.size src
  dst <- C.new (C.size src)
  C.copy dst 0 src 0 len
  res <- sortMutable dst
  C.unsafeFreeze res

-- | Sort a tagged immutable array. Each element from the @keys@ array is
-- paired up with an element from the @values@ array at the matching
-- index. The sort permutes the @values@ array so that a value end up
-- in the same position as its corresponding key. The two argument array
-- should be of the same length, but if one is shorter than the other,
-- the longer one will be truncated so that the lengths match.
--
-- >>> sortTagged ([5,6,7,5,5,7] :: Array Int) ([1,2,3,4,5,6] :: Array Int)
-- (fromListN 6 [5,5,5,6,7,7],fromListN 6 [1,4,5,2,3,6])
--
-- Since the sort is stable, the values corresponding to a key that
-- appears multiple times have their original order preserved.
sortTagged :: forall k v karr varr. (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => karr k -- ^ keys
  -> varr v -- ^ values
  -> (karr k,varr v)
{-# INLINABLE sortTagged #-}
sortTagged !src !srcTags = runST $ do
  let len = min (C.size src) (C.size srcTags)
  dst <- C.new len
  C.copy dst 0 src 0 len
  dstTags <- C.new len
  C.copy dstTags 0 srcTags 0 len
  (res,resTags) <- sortTaggedMutableN len dst dstTags
  res' <- C.unsafeFreeze res
  resTags' <- C.unsafeFreeze resTags
  return (res',resTags')

-- | Sort a tagged immutable array. Only a single copy of each
-- duplicate key is preserved, along with the last value from @values@
-- that corresponded to it. The two argument arrays
-- should be of the same length, but if one is shorter than the other,
-- the longer one will be truncated so that the lengths match.
--
-- >>> sortUniqueTagged ([5,6,7,5,5,7] :: Array Int) ([1,2,3,4,5,6] :: Array Int)
-- (fromListN 3 [5,6,7],fromListN 3 [5,2,6])
sortUniqueTagged :: forall k v karr varr. (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => karr k -- ^ keys
  -> varr v -- ^ values
  -> (karr k,varr v)
{-# INLINABLE sortUniqueTagged #-}
sortUniqueTagged !src !srcTags = runST $ do
  let len = min (C.size src) (C.size srcTags)
  dst <- C.new len
  C.copy dst 0 src 0 len
  dstTags <- C.new len
  C.copy dstTags 0 srcTags 0 len
  (res0,resTags0) <- sortTaggedMutableN len dst dstTags
  (res1,resTags1) <- uniqueTaggedMutableN len res0 resTags0
  res' <- C.unsafeFreeze res1
  resTags' <- C.unsafeFreeze resTags1
  return (res',resTags')

-- | Sort the mutable array. This operation preserves duplicate
-- elements. The argument may either be modified in-place, or another
-- array may be allocated and returned. The argument
-- may not be reused after being passed to this function.
sortMutable :: (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a
  -> ST s (Mutable arr s a)
{-# INLINABLE sortMutable #-}
sortMutable !dst = do
  len <- C.sizeMutable dst
  if len < threshold
    then insertionSortRange dst 0 len
    else do
      work <- C.new len
      C.copyMutable work 0 dst 0 len 
      caps <- unsafeEmbedIO getNumCapabilities
      let minElemsPerThread = 20000
          maxThreads = unsafeQuot len minElemsPerThread
          preThreads = min caps maxThreads
          threads = if preThreads == 1 then 1 else preThreads * 8
      -- I cannot understand why, but GHC's runtime does better
      -- when we let this schedule 8 times as many threads as
      -- we have capabilities. However, we only get this benefit
      -- when we actually have more than one capability.
      splitMergeParallel dst work threads 0 len
  return dst

-- | Sort an array of a key type @k@, rearranging the values of
-- type @v@ according to the element they correspond to in the
-- key array. The argument arrays may not be reused after they
-- are passed to the function.
sortTaggedMutable :: (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k
  -> Mutable varr s v
  -> ST s (Mutable karr s k, Mutable varr s v)
{-# INLINABLE sortTaggedMutable #-}
sortTaggedMutable !dst0 !dstTags0 = do
  (!dst,!dstTags,!len) <- alignArrays dst0 dstTags0
  sortTaggedMutableN len dst dstTags

alignArrays :: (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k
  -> Mutable varr s v
  -> ST s (Mutable karr s k, Mutable varr s v,Int)
{-# INLINABLE alignArrays #-}
alignArrays dst0 dstTags0 = do
  lenDst <- C.sizeMutable dst0
  lenDstTags <- C.sizeMutable dstTags0
  -- This cleans up mismatched lengths.
  if lenDst == lenDstTags
    then return (dst0,dstTags0,lenDst)
    else if lenDst < lenDstTags
      then do
        dstTags <- C.resize dstTags0 lenDst
        return (dst0,dstTags,lenDst)
      else do
        dst <- C.resize dst0 lenDstTags
        return (dst,dstTags0,lenDstTags)

sortUniqueTaggedMutable :: (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k -- ^ keys
  -> Mutable varr s v -- ^ values
  -> ST s (Mutable karr s k, Mutable varr s v)
{-# INLINABLE sortUniqueTaggedMutable #-}
sortUniqueTaggedMutable dst0 dstTags0 = do
  (!dst1,!dstTags1,!len) <- alignArrays dst0 dstTags0
  (!dst2,!dstTags2) <- sortTaggedMutableN len dst1 dstTags1
  uniqueTaggedMutableN len dst2 dstTags2

sortTaggedMutableN :: (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Int
  -> Mutable karr s k
  -> Mutable varr s v
  -> ST s (Mutable karr s k, Mutable varr s v)
{-# INLINABLE sortTaggedMutableN #-}
sortTaggedMutableN !len !dst !dstTags = if len < thresholdTagged
  then do
    insertionSortTaggedRange dst dstTags 0 len
    return (dst,dstTags)
  else do
    work <- C.cloneMutable dst 0 len 
    workTags <- C.cloneMutable dstTags 0 len 
    caps <- unsafeEmbedIO getNumCapabilities
    let minElemsPerThread = 20000
        maxThreads = unsafeQuot len minElemsPerThread
        preThreads = min caps maxThreads
        threads = if preThreads == 1 then 1 else preThreads * 8
    splitMergeParallelTagged dst work dstTags workTags threads 0 len
    return (dst,dstTags)

-- | Sort an immutable array. Only a single copy of each duplicated
-- element is preserved.
--
-- >>> sortUnique ([5,6,7,9,5,4,5,7] :: Array Int)
-- fromListN 5 [4,5,6,7,9]
sortUnique :: (Contiguous arr, Element arr a, Ord a)
  => arr a -> arr a
{-# INLINABLE sortUnique #-}
sortUnique src = runST $ do
  let len = C.size src
  dst <- C.new len
  C.copy dst 0 src 0 len
  res <- sortUniqueMutable dst
  C.unsafeFreeze res

-- | Sort an immutable array. Only a single copy of each duplicated
-- element is preserved. This operation may run in-place, or it may
-- need to allocate a new array, so the argument may not be reused
-- after this function is applied to it. 
sortUniqueMutable :: (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a
  -> ST s (Mutable arr s a)
{-# INLINABLE sortUniqueMutable #-}
sortUniqueMutable marr = do
  res <- sortMutable marr
  uniqueMutable res

-- | Discards adjacent equal elements from an array. This operation
-- may run in-place, or it may need to allocate a new array, so the
-- argument may not be reused after this function is applied to it.
uniqueMutable :: forall arr s a. (Contiguous arr, Element arr a, Eq a)
  => Mutable arr s a -> ST s (Mutable arr s a)
{-# INLINABLE uniqueMutable #-}
uniqueMutable !marr = do
  !len <- C.sizeMutable marr
  if len > 1
    then do
      !a0 <- C.read marr 0
      let findFirstDuplicate :: a -> Int -> ST s Int
          findFirstDuplicate !prev !ix = if ix < len
            then do
              a <- C.read marr ix
              if a == prev
                then return ix
                else findFirstDuplicate a (ix + 1)
            else return ix
      dupIx <- findFirstDuplicate a0 1
      if dupIx == len
        then return marr
        else do
          let deduplicate :: a -> Int -> Int -> ST s Int
              deduplicate !prev !srcIx !dstIx = if srcIx < len
                then do
                  a <- C.read marr srcIx
                  if a == prev
                    then deduplicate a (srcIx + 1) dstIx
                    else do
                      C.write marr dstIx a
                      deduplicate a (srcIx + 1) (dstIx + 1)
                else return dstIx
          !a <- C.read marr dupIx
          !reducedLen <- deduplicate a (dupIx + 1) dupIx
          C.resize marr reducedLen
    else return marr

uniqueTaggedMutableN :: forall karr varr s k v. (Contiguous karr, Element karr k, Eq k, Contiguous varr, Element varr v)
  => Int
  -> Mutable karr s k
  -> Mutable varr s v
  -> ST s (Mutable karr s k, Mutable varr s v)
{-# INLINABLE uniqueTaggedMutableN #-}
uniqueTaggedMutableN !len !marr !marrTags = if len > 1
  then do
    !a0 <- C.read marr 0
    let findFirstDuplicate :: k -> Int -> ST s Int
        findFirstDuplicate !prev !ix = if ix < len
          then do
            a <- C.read marr ix
            if a == prev
              then return ix
              else findFirstDuplicate a (ix + 1)
          else return ix
    dupIx <- findFirstDuplicate a0 1
    if dupIx == len
      then return (marr,marrTags)
      else do
        C.read marrTags dupIx >>= C.write marrTags (dupIx - 1)
        let deduplicate :: k -> Int -> Int -> ST s Int
            deduplicate !prev !srcIx !dstIx = if srcIx < len
              then do
                a <- C.read marr srcIx
                if a == prev
                  then do
                    C.read marrTags srcIx >>= C.write marrTags (dstIx - 1)
                    deduplicate a (srcIx + 1) dstIx
                  else do
                    C.read marrTags srcIx >>= C.write marrTags dstIx
                    C.write marr dstIx a
                    deduplicate a (srcIx + 1) (dstIx + 1)
              else return dstIx
        !a <- C.read marr dupIx
        !reducedLen <- deduplicate a (dupIx + 1) dupIx
        liftA2 (,) (C.resize marr reducedLen) (C.resize marrTags reducedLen)
  else return (marr,marrTags)

unsafeEmbedIO :: IO a -> ST s a
unsafeEmbedIO (IO f) = ST (unsafeCoerce# f)

half :: Int -> Int
half x = unsafeQuot x 2

splitMergeParallel :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a -- source and destination
  -> Mutable arr s a -- work array
  -> Int -- spark limit, should be power of two
  -> Int -- start
  -> Int -- end
  -> ST s ()
{-# INLINABLE splitMergeParallel #-}
splitMergeParallel !arr !work !level !start !end = if level > 1
  then if end - start < threshold
    then insertionSortRange arr start end
    else do
      let !mid = unsafeQuot (end + start) 2
          !levelDown = half level
      tandem 
        (splitMergeParallel work arr levelDown start mid)
        (splitMergeParallel work arr levelDown mid end)
      mergeParallel work arr level start mid end
  else splitMerge arr work start end

splitMergeParallelTagged :: forall karr varr s k v. (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k -- source and destination
  -> Mutable karr s k -- work array
  -> Mutable varr s v -- source and destination tags
  -> Mutable varr s v -- work tags
  -> Int -- spark limit, should be power of two
  -> Int -- start
  -> Int -- end
  -> ST s ()
{-# INLINABLE splitMergeParallelTagged #-}
splitMergeParallelTagged !arr !work !arrTags !workTags !level !start !end = if level > 1
  then do
    let !mid = unsafeQuot (end + start) 2
        !levelDown = half level
    tandem 
      (splitMergeParallelTagged work arr workTags arrTags levelDown start mid)
      (splitMergeParallelTagged work arr workTags arrTags levelDown mid end)
    mergeParallelTagged work arr workTags arrTags level start mid end
  else splitMergeTagged arr work arrTags workTags start end

splitMerge :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a -- source and destination
  -> Mutable arr s a -- work array
  -> Int -- start
  -> Int -- end
  -> ST s ()
{-# INLINABLE splitMerge #-}
splitMerge !arr !work !start !end = if end - start < 2
  then return ()
  else if end - start > threshold
    then do
      let !mid = unsafeQuot (end + start) 2
      splitMerge work arr start mid
      splitMerge work arr mid end
      mergeNonContiguous work arr start mid mid end start
    else insertionSortRange arr start end

splitMergeTagged :: (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k -- source and destination
  -> Mutable karr s k -- work array
  -> Mutable varr s v
  -> Mutable varr s v
  -> Int -- start
  -> Int -- end
  -> ST s ()
{-# INLINABLE splitMergeTagged #-}
splitMergeTagged !arr !work !arrTags !workTags !start !end = if end - start < 2
  then return ()
  else if end - start > thresholdTagged
    then do
      let !mid = unsafeQuot (end + start) 2
      splitMergeTagged work arr workTags arrTags start mid
      splitMergeTagged work arr workTags arrTags mid end
      mergeNonContiguousTagged work arr workTags arrTags start mid mid end start
    else insertionSortTaggedRange arr arrTags start end

-- Precondition: threads is greater than 0
mergeParallel :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a -- source
  -> Mutable arr s a -- dest
  -> Int -- threads
  -> Int -- start
  -> Int -- middle
  -> Int -- end
  -> ST s ()
{-# INLINABLE mergeParallel #-}
mergeParallel !src !dst !threads !start !mid !end = do
  !lock <- newLock
  let go :: Int -- previous A end
         -> Int -- previous B end
         -> Int -- how many chunk have we already iterated over
         -> ST s Int
      go !prevEndA !prevEndB !ix = 
        if | prevEndA == mid && prevEndB == end -> return ix
           | prevEndA == mid -> do
               forkST_ $ do
                 let !startA = mid
                     !endA = mid
                     !startB = prevEndB
                     !endB = end
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguous src dst startA endA startB endB startDst
                 putLock lock
               go mid end (ix + 1)
           | prevEndB == end -> do
               forkST_ $ do
                 let !startA = prevEndA
                     !endA = mid
                     !startB = end
                     !endB = end
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguous src dst startA endA startB endB startDst
                 putLock lock
               go mid end (ix + 1)
           | ix == threads - 1 -> do
               forkST_ $ do
                 let !startA = prevEndA
                     !endA = mid
                     !startB = prevEndB
                     !endB = end
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguous src dst startA endA startB endB startDst
                 putLock lock
               return (ix + 1)
           | otherwise -> do
               -- We use the left half for this lookup. We could instead
               -- use both halves and take the median.
               !endElem <- C.read src (start + chunk * (ix + 1))
               !endA <- findIndexOfGtElem src (endElem :: a) prevEndA mid
               !endB <- findIndexOfGtElem src endElem prevEndB end
               forkST_ $ do
                 let !startA = prevEndA
                     !startB = prevEndB
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguous src dst startA endA startB endB startDst
                 putLock lock
               go endA endB (ix + 1)
  !endElem <- C.read src (start + chunk) 
  !endA <- findIndexOfGtElem src (endElem :: a) start mid
  !endB <- findIndexOfGtElem src endElem mid end
  forkST_ $ do
    let !startA = start
        !startB = mid
        !startDst = (startA - start) + (startB - mid) + start
    mergeNonContiguous src dst startA endA startB endB startDst
    putLock lock
  total <- go endA endB 1
  replicateM_ total (takeLock lock)
  where
  !chunk = unsafeQuot (end - start) threads

-- Precondition: threads is greater than 0
-- This function is just a copy of mergeParallel but with
-- the tags arrays passed to mergeNonContiguousTagged
mergeParallelTagged :: forall karr varr s k v. (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k -- source
  -> Mutable karr s k -- dest
  -> Mutable varr s v -- source tags
  -> Mutable varr s v -- dest tags
  -> Int -- threads
  -> Int -- start
  -> Int -- middle
  -> Int -- end
  -> ST s ()
{-# INLINABLE mergeParallelTagged #-}
mergeParallelTagged !src !dst !srcTags !dstTags !threads !start !mid !end = do
  !lock <- newLock
  let go :: Int -- previous A end
         -> Int -- previous B end
         -> Int -- how many chunk have we already iterated over
         -> ST s Int
      go !prevEndA !prevEndB !ix = 
        if | prevEndA == mid && prevEndB == end -> return ix
           | prevEndA == mid -> do
               forkST_ $ do
                 let !startA = mid
                     !endA = mid
                     !startB = prevEndB
                     !endB = end
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguousTagged src dst srcTags dstTags startA endA startB endB startDst
                 putLock lock
               go mid end (ix + 1)
           | prevEndB == end -> do
               forkST_ $ do
                 let !startA = prevEndA
                     !endA = mid
                     !startB = end
                     !endB = end
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguousTagged src dst srcTags dstTags startA endA startB endB startDst
                 putLock lock
               go mid end (ix + 1)
           | ix == threads - 1 -> do
               forkST_ $ do
                 let !startA = prevEndA
                     !endA = mid
                     !startB = prevEndB
                     !endB = end
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguousTagged src dst srcTags dstTags startA endA startB endB startDst
                 putLock lock
               return (ix + 1)
           | otherwise -> do
               -- We use the left half for this lookup. We could instead
               -- use both halves and take the median.
               !endElem <- C.read src (start + chunk * (ix + 1))
               !endA <- findIndexOfGtElem src (endElem :: k) prevEndA mid
               !endB <- findIndexOfGtElem src endElem prevEndB end
               forkST_ $ do
                 let !startA = prevEndA
                     !startB = prevEndB
                     !startDst = (startA - start) + (startB - mid) + start
                 mergeNonContiguousTagged src dst srcTags dstTags startA endA startB endB startDst
                 putLock lock
               go endA endB (ix + 1)
  !endElem <- C.read src (start + chunk) 
  !endA <- findIndexOfGtElem src (endElem :: k) start mid
  !endB <- findIndexOfGtElem src endElem mid end
  forkST_ $ do
    let !startA = start
        !startB = mid
        !startDst = (startA - start) + (startB - mid) + start
    mergeNonContiguousTagged src dst srcTags dstTags startA endA startB endB startDst
    putLock lock
  total <- go endA endB 1
  replicateM_ total (takeLock lock)
  where
  !chunk = unsafeQuot (end - start) threads

unsafeQuot :: Int -> Int -> Int
unsafeQuot (I# a) (I# b) = I# (quotInt# a b)

-- If the needle is bigger than everything in the slice
-- of the array, this returns the end index (which is out
-- of bounds). Callers of this function should be able
-- to handle that.
findIndexOfGtElem :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a -> a -> Int -> Int -> ST s Int
{-# INLINABLE findIndexOfGtElem #-}
findIndexOfGtElem !v !needle !start !end = go start end
  where
  go :: Int -> Int -> ST s Int
  go !lo !hi = if lo < hi
    then do
      let !mid = lo + half (hi - lo)
      !val <- C.read v mid
      if | val == needle -> gallopToGtIndex v needle (mid + 1) hi
         | val < needle -> go (mid + 1) hi
         | otherwise -> go lo mid
    else return lo

-- | TODO: should probably turn this into a real galloping search
gallopToGtIndex :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a -> a -> Int -> Int -> ST s Int
{-# INLINABLE gallopToGtIndex #-}
gallopToGtIndex !v !val !start !end = go start
  where
  go :: Int -> ST s Int
  go !ix = if ix < end
    then do
      !a <- C.read v ix
      if a > val
        then return ix
        else go (ix + 1)
    else return end

-- stepA assumes that we previously incremented ixA.
-- Consequently, we do not need to check that ixB
-- is still in bounds. As a precondition, both
-- indices are guarenteed to start in bounds.
mergeNonContiguous :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a -- source
  -> Mutable arr s a -- dest
  -> Int -- start A
  -> Int -- end A
  -> Int -- start B
  -> Int -- end B
  -> Int -- start destination
  -> ST s ()
{-# INLINABLE mergeNonContiguous #-}
mergeNonContiguous !src !dst !startA !endA !startB !endB !startDst =
  if startB < endB
    then stepA startA startB startDst
    else if startA < endA
      then stepB startA startB startDst
      else return ()
  where
  continue :: Int -> Int -> Int -> ST s ()
  continue ixA ixB ixDst = do
    !a <- C.read src ixA
    !b <- C.read src ixB
    if (a :: a) <= b
      then do
        C.write dst ixDst a
        stepA (ixA + 1) ixB (ixDst + 1)
      else do
        C.write dst ixDst b
        stepB ixA (ixB + 1) (ixDst + 1)
  stepB :: Int -> Int -> Int -> ST s ()
  stepB !ixA !ixB !ixDst = if ixB < endB
    then continue ixA ixB ixDst
    else finishA ixA ixDst
  stepA :: Int -> Int -> Int -> ST s ()
  stepA !ixA !ixB !ixDst = if ixA < endA
    then continue ixA ixB ixDst
    else finishB ixB ixDst
  finishB :: Int -> Int -> ST s ()
  finishB !ixB !ixDst = C.copyMutable dst ixDst src ixB (endB - ixB)
  finishA :: Int -> Int -> ST s ()
  finishA !ixA !ixDst = C.copyMutable dst ixDst src ixA (endA - ixA)

mergeNonContiguousTagged :: forall karr varr k v s. (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k -- source
  -> Mutable karr s k -- dest
  -> Mutable varr s v -- source tags
  -> Mutable varr s v -- dest tags
  -> Int -- start A
  -> Int -- end A
  -> Int -- start B
  -> Int -- end B
  -> Int -- start destination
  -> ST s ()
{-# INLINABLE mergeNonContiguousTagged #-}
mergeNonContiguousTagged !src !dst !srcTags !dstTags !startA !endA !startB !endB !startDst =
  if startB < endB
    then stepA startA startB startDst
    else if startA < endA
      then stepB startA startB startDst
      else return ()
  where
  continue :: Int -> Int -> Int -> ST s ()
  continue ixA ixB ixDst = do
    !a <- C.read src ixA
    !b <- C.read src ixB
    if a <= b
      then do
        C.write dst ixDst a
        (C.read srcTags ixA :: ST s v) >>= C.write dstTags ixDst
        stepA (ixA + 1) ixB (ixDst + 1)
      else do
        C.write dst ixDst b
        (C.read srcTags ixB :: ST s v) >>= C.write dstTags ixDst
        stepB ixA (ixB + 1) (ixDst + 1)
  stepB :: Int -> Int -> Int -> ST s ()
  stepB !ixA !ixB !ixDst = if ixB < endB
    then continue ixA ixB ixDst
    else finishA ixA ixDst
  stepA :: Int -> Int -> Int -> ST s ()
  stepA !ixA !ixB !ixDst = if ixA < endA
    then continue ixA ixB ixDst
    else finishB ixB ixDst
  finishB :: Int -> Int -> ST s ()
  finishB !ixB !ixDst = do
    C.copyMutable dst ixDst src ixB (endB - ixB)
    C.copyMutable dstTags ixDst srcTags ixB (endB - ixB)
  finishA :: Int -> Int -> ST s ()
  finishA !ixA !ixDst = do
    C.copyMutable dst ixDst src ixA (endA - ixA)
    C.copyMutable dstTags ixDst srcTags ixA (endA - ixA)

threshold :: Int
threshold = 16

thresholdTagged :: Int
thresholdTagged = 16

insertionSortRange :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a
  -> Int -- start
  -> Int -- end
  -> ST s ()
{-# INLINABLE insertionSortRange #-}
insertionSortRange !arr !start !end = go start
  where
  go :: Int -> ST s ()
  go !ix = if ix < end
    then do
      !a <- C.read arr ix
      insertElement arr (a :: a) start ix
      go (ix + 1)
    else return ()
    
insertElement :: forall arr s a. (Contiguous arr, Element arr a, Ord a)
  => Mutable arr s a
  -> a
  -> Int
  -> Int
  -> ST s ()
{-# INLINABLE insertElement #-}
insertElement !arr !a !start !end = go end
  where
  go :: Int -> ST s ()
  go !ix = if ix > start
    then do
      !b <- C.read arr (ix - 1)
      if b <= a
        then do
          C.copyMutable arr (ix + 1) arr ix (end - ix)
          C.write arr ix a
        else go (ix - 1)
    else do
      C.copyMutable arr (ix + 1) arr ix (end - ix)
      C.write arr ix a

insertionSortTaggedRange :: forall karr varr s k v. (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k
  -> Mutable varr s v
  -> Int -- start
  -> Int -- end
  -> ST s ()
{-# INLINABLE insertionSortTaggedRange #-}
insertionSortTaggedRange !karr !varr !start !end = go start
  where
  go :: Int -> ST s ()
  go !ix = if ix < end
    then do
      !a <- C.read karr ix
      !v <- C.read varr ix
      insertElementTagged karr varr a v start ix
      go (ix + 1)
    else return ()
    
insertElementTagged :: forall karr varr s k v. (Contiguous karr, Element karr k, Ord k, Contiguous varr, Element varr v)
  => Mutable karr s k
  -> Mutable varr s v
  -> k
  -> v
  -> Int
  -> Int
  -> ST s ()
{-# INLINABLE insertElementTagged #-}
insertElementTagged !karr !varr !a !v !start !end = go end
  where
  go :: Int -> ST s ()
  go !ix = if ix > start
    then do
      !b <- C.read karr (ix - 1)
      if b <= a
        then do
          C.copyMutable karr (ix + 1) karr ix (end - ix)
          C.write karr ix a
          C.copyMutable varr (ix + 1) varr ix (end - ix)
          C.write varr ix v
        else go (ix - 1)
    else do
      C.copyMutable karr (ix + 1) karr ix (end - ix)
      C.write karr ix a
      C.copyMutable varr (ix + 1) varr ix (end - ix)
      C.write varr ix v


forkST_ :: ST s a -> ST s ()
forkST_ action = ST $ \s1 -> case forkST# action s1 of
  (# s2, _ #) -> (# s2, () #)

forkST# :: a -> State# s -> (# State# s, ThreadId# #)
forkST# = unsafeCoerce# fork#

data Lock s = Lock (MVar# s ())

newLock :: ST s (Lock s)
newLock = ST $ \s1 -> case newMVar# s1 of
  (# s2, v #) -> (# s2, Lock v #)

takeLock :: Lock s -> ST s ()
takeLock (Lock mvar#) = ST $ \ s# -> takeMVar# mvar# s#

putLock  :: Lock s -> ST s ()
putLock (Lock mvar#) = ST $ \ s# ->
  case putMVar# mvar# () s# of
    s2# -> (# s2#, () #)

-- | Execute the first computation on the main thread and
--   the second one on another thread in parallel. Blocks
--   until both are finished.
tandem :: ST s () -> ST s () -> ST s ()
tandem a b = do
  lock <- newLock
  forkST_ (b >> putLock lock)
  a
  takeLock lock

-- $setup
--
-- These are to make doctest work correctly.
--
-- >>> :set -XOverloadedLists
-- >>> import Data.Primitive.Array (Array)
--