{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | This module provides a variant of a mutable binary min heap that is used elsewhere to implement
--   Dijkstra\'s algorithm. It is unlikely that there are other uses for this specific
--   implementation. The binary heap in this module uses the standard array-as-binary-heap
--   approach where the 0-index item in the array is unused, the 1-index item is the
--   root, and the @n@ element has its left child at @2n@ and its right child at @2n + 1@.
--   The following additions (which are uncommon) have been made:
--
--   * This heap only supports 'Int' elements but is polymorphic in the priority data type.
--   * When the heap is initialized, it is given an 'Int'. This represents and exclusive upper bound
--     on the allowed elements. For example, if you pass 40, then you can only push 0 through 39 as
--     elements.
--   * Most of the functions in this module take an extra 'Int' (right after the 'RawHeap' argument).
--     This 'Int' tells us the number of items currently in the heap. In some cases, this argument
--     is not even used, but it is present so that LiquidHaskell can provide extra bound-checking assurances.
--     For example, if we initialize the heap with @new 90@, and then push three elements, we do not
--     want to be able to read the 83rd element in the 'rawHeapPriorities' and 'rawHeapElements'
--     arrays. Even though the this index is technically in bounds, the element stored there is not
--     actually in the heap. Many places where this bounding number is passed around to a function
--     should be eliminated by the inliner and do not affect runtime performance. The @ModelD@ module
--     provides a much more usable heap implementation where the currenty heap size is stored in a 'MutVar'.
--     It is not implemented as a 'MutVar' in here because LiquidHaskell cannot (to my knowledge)
--     use mutable values for meaningful proofs.
--   * This heap implements decrease-key as a part of 'push'. If you push an already existing element
--     onto the heap, the priority of the existing one and the priority of the one you are attempting
--     to push will be combined with the 'Monoid' instance. (Note: this could definitely be weakened
--     to 'Semigroup'). At the moment, only bubble up is attempted after this operation, so if this
--     causes the priority to increase, the heap becomes invalid (but not in a way that causes
--     segfaults).
--
--   As a result of the last constraint, the 'Monoid' instance and 'Ord' instance of the priority type
--   must obey these additional laws:
--
--   > mappend a b ≤ a
--   > mappend a b ≤ b
--   > mempty ≥ c
--
--   In more colloquial terms, the monoidal append of two priorities must be less than or equal
--   to the smaller of the two. Additionally, 'mempty' must be the largest priority.

module Data.Heap.Mutable.ModelC where

import           Control.Monad
import           Control.Monad.Primitive
import           Data.Bits                   (unsafeShiftL, unsafeShiftR)
import           Data.Coerce
import           Data.Primitive.Array
import qualified Data.Primitive.Array        as A
import           Data.Primitive.ByteArray
import           Data.Primitive.MutVar
import           Data.Primitive.Types        (sizeOf#)
import           Data.Vector                 (MVector, Vector)
import qualified Data.Vector                 as V
import qualified Data.Vector.Mutable         as MV
import           Data.Vector.Unboxed         (Unbox)
import qualified Data.Vector.Unboxed         as U
import qualified Data.Vector.Unboxed.Mutable as MU
import           Data.Word
import           Debug.Trace
import           GHC.Types                   (Int (..))


{-@ type Positive = {n:Int | n > 0} @-}

{-@ data RawHeap s p = RawHeap
      { rawHeapBound         :: Nat
      , rawHeapPriorities    :: (MutableArray s p)
      , rawHeapElements      :: (MutableByteArray s)
      , rawHeapInvertedIndex :: (MutableByteArray s)
      }
@-}
data RawHeap s p = RawHeap
  { rawHeapBound         :: !Int -- ^ This bound is exclusive
  , rawHeapPriorities    :: !(MutableArray s p) -- ^ Binary tree of priorities
  , rawHeapElements      :: !(MutableByteArray s) -- ^ Binary tree of elements
  , rawHeapInvertedIndex :: !(MutableByteArray s) -- ^ Lookup binary tree index by element, used for increase and decrease priority
  }

{-@ assume readHeapPriority :: PrimMonad m
      => h:RawHeap (PrimState m) p -> bound:{bound:Nat | bound <= rawHeapBound h}
      -> ix:{v:Nat | v > 0 && v <= bound } -> m p @-}
readHeapPriority :: PrimMonad m => RawHeap (PrimState m) p -> Int -> Int -> m p
readHeapPriority (RawHeap _ priorities _ _) _ ix = readArray priorities ix

{-@ assume readHeapElement :: PrimMonad m => h:RawHeap (PrimState m) p
      -> bound:{bound:Nat | bound <= rawHeapBound h}
      -> ix:{v:Positive | v <= bound }
      -> m {e:Nat | e < rawHeapBound h} @-}
readHeapElement :: PrimMonad m => RawHeap (PrimState m) p -> Int -> Int -> m Int
readHeapElement (RawHeap _ _ elements _) _ ix = readByteArray elements ix

{-@ assume writeHeapPriority :: PrimMonad m
      => h:RawHeap (PrimState m) p -> bound:{bound:Nat | bound <= rawHeapBound h}
      -> ix:{v:Nat | v > 0 && v <= bound } -> p -> m () @-}
writeHeapPriority :: PrimMonad m => RawHeap (PrimState m) p -> Int -> Int -> p -> m ()
writeHeapPriority (RawHeap _ priorities _ _) _ = writeArray priorities

{-@ assume writeHeapElement :: PrimMonad m
      => h:RawHeap (PrimState m) p
      -> bound:{bound:Nat | bound <= rawHeapBound h}
      -> ix:{v:Positive | v <= bound }
      -> e:{v:Nat | v < rawHeapBound h }
      -> m () @-}
writeHeapElement :: PrimMonad m => RawHeap (PrimState m) p -> Int -> Int -> Int -> m ()
writeHeapElement (RawHeap _ _ elements _) _ = writeByteArray elements

{-@ assume readHeapInvertedIndex :: PrimMonad m
      => h:RawHeap (PrimState m) p
      -> bound:{bound:Nat | bound <= rawHeapBound h }
      -> ix:{v:Nat | v < rawHeapBound h }
      -> m {x:Positive|x < bound}
@-}
readHeapInvertedIndex :: PrimMonad m => RawHeap (PrimState m) p -> Int -> Int -> m Int
readHeapInvertedIndex (RawHeap _ _ _ invIndex) _ ix = readByteArray invIndex ix

{-@ assume writeHeapInvertedIndex :: PrimMonad m
      => h:RawHeap (PrimState m) p
      -> bound:{bound:Nat | bound <= rawHeapBound h }
      -> e:{v:Nat | v < rawHeapBound h }
      -> ix:{v:Nat | v <= bound }
      -> m ()
@-}
writeHeapInvertedIndex :: PrimMonad m => RawHeap (PrimState m) p -> Int -> Int -> Int -> m ()
writeHeapInvertedIndex (RawHeap _ _ _ invIndex) _ = writeByteArray invIndex

{-@ swapHeap :: (Ord p, Monoid p, PrimMonad m)
      => h:RawHeap (PrimState m) p
      -> bound:{bound:Nat | bound <= rawHeapBound h}
      -> ix1:{ix1:Positive | ix1 <= bound}
      -> ix2:{ix2:Positive | ix2 <= bound}
      -> m () @-}
swapHeap :: PrimMonad m => RawHeap (PrimState m) p -> Int -> Int -> Int -> m ()
swapHeap h bound ix1 ix2 = do
  a <- readHeapElement h bound ix1
  b <- readHeapElement h bound ix2
  writeHeapElement h bound ix1 b
  writeHeapElement h bound ix2 a
  c <- readHeapPriority h bound ix1
  d <- readHeapPriority h bound ix2
  writeHeapPriority h bound ix1 d
  writeHeapPriority h bound ix2 c
  writeHeapInvertedIndex h bound a ix2
  writeHeapInvertedIndex h bound b ix1

{-@ pop :: (PrimMonad m, Ord p)
      => h:RawHeap (PrimState m) p
      -> bound:{bound:Nat | bound <= rawHeapBound h}
      -> m ({k:Nat | if bound = 0 then k = 0 else k = bound - 1},Maybe (p,Int))
@-}
pop :: (PrimMonad m, Ord p) => RawHeap (PrimState m) p -> Int -> m (Int,Maybe (p,Int))
pop h currentSize = if currentSize > 0
  then do
    let newSize = currentSize - 1
    priority <- readHeapPriority h currentSize 1
    element <- readHeapElement h currentSize 1
    writeHeapInvertedIndex h currentSize element 0
    if (newSize > 0)
      then do
        lastPriority <- readHeapPriority h currentSize currentSize
        lastElement <- readHeapElement h currentSize currentSize
        writeHeapPriority h currentSize 1 lastPriority
        writeHeapElement h currentSize 1 lastElement
        writeHeapInvertedIndex h newSize lastElement 1
        bubbleDown h newSize
      else return ()
    return (newSize, Just (priority,element))
  else return (currentSize,Nothing)

{-@ bubbleDown :: (Ord p, PrimMonad m)
      => h:RawHeap (PrimState m) p
      -> bound:{bound:Positive | bound <= rawHeapBound h}
      -> m ()
@-}
bubbleDown :: forall p m. (Ord p, PrimMonad m)
  => RawHeap (PrimState m) p
  -> Int
  -> m ()
bubbleDown h currentSizeX = go currentSizeX 1 where
  {-@ go :: bnd:{bnd:Positive | bnd <= rawHeapBound h} -> ix:Positive -> m () / [bnd - ix] @-}
  go :: Int -> Int -> m ()
  go !currentSize !ix = do
    let leftChildIx = ix + ix
        rightChildIx = leftChildIx + 1
    if rightChildIx > currentSize
      then if leftChildIx == currentSize
        then do
          let childIx = leftChildIx
          myPriority <- readHeapPriority h currentSize ix
          childPriority <- readHeapPriority h currentSize childIx
          if childPriority < myPriority
            then do
              myElement <- readHeapElement h currentSize ix
              childElement <- readHeapElement h currentSize childIx
              swapHeap h currentSize ix childIx
              -- go childIx -- not needed here bc we know there will not be further children
            else return ()
        else return ()
      else do
        myPriority <- readHeapPriority h currentSize ix
        leftChildPriority <- readHeapPriority h currentSize leftChildIx
        rightChildPriority <- readHeapPriority h currentSize rightChildIx
        let (childIx,childPriority) = if leftChildPriority < rightChildPriority
              then (leftChildIx,leftChildPriority)
              else (rightChildIx,rightChildPriority)
        if childPriority < myPriority
          then do
            myElement <- readHeapElement h currentSize ix
            childElement <- readHeapElement h currentSize childIx
            swapHeap h currentSize ix childIx
            go currentSize childIx
          else return ()

{-@ unsafePush :: (Ord p, Monoid p, PrimMonad m)
               => p -> e:Nat -> oldSize:Nat
               -> {h:RawHeap (PrimState m) p | e < rawHeapBound h && oldSize <= rawHeapBound h}
               -> m {newSize:Nat | newSize > 0} @-}
unsafePush :: forall m p k. (Ord p, Monoid p, PrimMonad m)
  => p -> Int -> Int -> RawHeap (PrimState m) p -> m Int
unsafePush priority element currentSize h@(RawHeap bound _ _ _) = do
  existingElemIndex <- readHeapInvertedIndex h currentSize element
  if existingElemIndex == 0
    then if currentSize < rawHeapBound h
      then do
        let newSize = currentSize + 1
        appendElem priority element newSize h
        return newSize
      else error "unsafePush: This cannot ever happen (2)"
    else if currentSize > 0
      then do
        combineElem priority element currentSize existingElemIndex h
        return currentSize
      else error "unsafePush: This cannot ever happen (1)"

{-@ appendElem :: (Ord p, Monoid p, PrimMonad m)
               => p -> e:Nat -> newSize:{newSize:Nat|newSize > 0}
               -> {h:RawHeap (PrimState m) p | e < rawHeapBound h && newSize <= rawHeapBound h}
               -> m () @-}
appendElem :: (Ord p, PrimMonad m)
  => p -> Int -> Int -> RawHeap (PrimState m) p -> m ()
appendElem priority element currentSize h = do
  writeHeapPriority h currentSize currentSize priority
  writeHeapElement h currentSize currentSize element
  writeHeapInvertedIndex h currentSize element currentSize
  bubbleUp currentSize currentSize h

{-@ combineElem :: (Ord p, Monoid p, PrimMonad m)
                => p -> e:Nat -> sz:Positive -> ix:{ix:Positive | ix <= sz}
                -> {h:RawHeap (PrimState m) p | e < rawHeapBound h && sz <= rawHeapBound h}
                -> m () @-}
combineElem :: (Monoid p, Ord p, PrimMonad m)
  => p -> Int -> Int -> Int -> RawHeap (PrimState m) p -> m ()
combineElem priority element currentSize existingIndex h = do
  existingPriority <- readHeapPriority h currentSize existingIndex
  let newPriority = mappend priority existingPriority
  writeHeapPriority h currentSize existingIndex newPriority
  bubbleUp currentSize existingIndex h

{-@ bubbleUp :: (Ord p, PrimMonad m)
             => sz:Positive
             -> ix:{ix:Positive|ix <= sz}
             -> {h:RawHeap (PrimState m) p | sz <= rawHeapBound h}
             -> m () @-}
bubbleUp :: (Ord p, PrimMonad m)
         => Int
         -> Int
         -> RawHeap (PrimState m) p
         -> m ()
bubbleUp currentSize startIx h = go startIx where
  go !ix = do
    let parentIx = ix `div` 2 -- getParentIndex ix, make this more efficient, use shifting
    if parentIx > 0
      then do
        myPriority <- readHeapPriority h currentSize ix
        parentPriority <- readHeapPriority h currentSize parentIx
        if myPriority < parentPriority
          then do
            swapHeap h currentSize ix parentIx
            go parentIx
          else return ()
      else return ()

{-@ new :: (PrimMonad m, Monoid p) => bound:Nat
      -> m ({n:Int| n = 0},{h:RawHeap (PrimState m) p | rawHeapBound h = bound}) @-}
new :: (PrimMonad m, Monoid p) => Int -> m (Int,RawHeap (PrimState m) p)
new bound = do
  let boundPlusOne = bound + 1
  priorities <- newArray boundPlusOne mempty
  elements <- newByteArray (boundPlusOne * (I# (sizeOf# (undefined :: Int))))
  invertedIndex <- newByteArray (bound * (I# (sizeOf# (undefined :: Int))))
  setByteArray invertedIndex 0 bound (0 :: Int)
  return (0,RawHeap bound priorities elements invertedIndex)