{-# LANGUAGE MultiParamTypeClasses,
             FlexibleInstances,
             GeneralizedNewtypeDeriving,
             NoMonomorphismRestriction #-}

module Data.FingerTree.PSQueue
  (Binding (..), PSQ, size, Data.FingerTree.PSQueue.null, Data.FingerTree.PSQueue.lookup,
   empty, singleton, alter, Data.FingerTree.PSQueue.delete, adjust, adjustWithKey,
   update, updateWithKey, toList, keys, fromList, fromAscList, minView,
   findMin, deleteMin, range, atMost, Data.FingerTree.PSQueue.foldr,
   Data.FingerTree.PSQueue.foldl) where

import qualified Data.Foldable as F
import qualified Data.FingerTree as FT
import Data.FingerTree (FingerTree, ViewL (..), (<|), (|>),
                        Measured (..), split, viewl, (><))
import Data.Monoid
import Data.Ord
import Data.List

data Prio k a = Prio (Binding k a) | PMax
  deriving (Eq, Ord, Show)

instance (Ord a) => Monoid (Prio k a)
  where mempty = PMax
        mappend PMax y = y
        mappend x PMax = x
        mappend a@(Prio (u :-> x)) b@(Prio (v :-> y)) = if x <= y then a else b

data Key a = NoKey | Key a
  deriving (Eq, Ord, Show)

instance Monoid (Key k)
  where mempty = NoKey
        mappend x NoKey = x
        mappend x y     = y

geqKey k NoKey = False
geqKey k (Key k') = k >= k'

data Binding k p = k :-> p
  deriving (Eq, Ord, Show)

data KPS k p = KPS { kpsKey  :: !(Key k),
                     kpsPrio :: !(Prio k p),
                     kpsSize :: !(Sum Int) }
  deriving (Show)

instance Eq k => Eq (KPS k p) where
  x == y = kpsKey x == kpsKey y

instance Ord k => Ord (KPS k p) where
  compare = comparing kpsKey

instance (Ord p) => Monoid (KPS k p)
  where mempty = KPS mempty mempty mempty
        mappend (KPS k p s) (KPS k' p' s') = KPS (mappend k k')
                                                 (mappend p p')
                                                 (mappend s s')


instance (Ord k, Ord p) => Measured (KPS k p) (Binding k p)
  where measure a@(k :-> p) = KPS (Key k) (Prio a) (Sum 1)

newtype PSQ k p = PSQ (FingerTree (KPS k p) (Binding k p))
  deriving (Eq, Ord, Show, Measured (KPS k p))

-- | O(1). The number of bindings in a queue.
size :: (Ord k, Ord p) => PSQ k p -> Int
size = getSum . kpsSize . measure

-- | O(1). Test if a queue is empty.
null :: (Ord k, Ord p) => PSQ k p -> Bool
null (PSQ q) = FT.null q

-- | O(log n). Determine if a key is in the queue, and its priority.
lookup :: (Ord k, Ord p) => k -> PSQ k p -> Maybe p
lookup k (PSQ q) = let (u,v) = split ((>= Key k) . kpsKey) q
                    in case viewl v of
                         EmptyL -> Nothing
                         (k' :-> p) :< v' | k == k'   -> Just p
                                          | otherwise -> Nothing

-- | O(1). The empty queue.
empty :: (Ord k, Ord p) => PSQ k p
empty = PSQ (FT.empty)

-- | O(1). Construct a queue with a single key/priority binding.
singleton :: (Ord k, Ord p) => k -> p -> PSQ k p
singleton k p = PSQ (FT.singleton (k :-> p))

-- | O(log n). Alters a priority search queue such that @lookup k (alter f k q) = f (lookup k q)@. This can be used to insert, delete,
-- or update a priority in a queue.
alter :: (Ord k, Ord p) => (Maybe p -> Maybe p) -> k -> PSQ k p -> PSQ k p
alter f k (PSQ q) =
  PSQ $ let (u,v) = split (geqKey k . kpsKey) q
         in case viewl v of
              EmptyL -> case f Nothing of
                           Nothing -> q
                           Just p  -> q |> (k :-> p)
              (k' :-> p') :< v'
                | k == k'   -> case f (Just p') of
                                 Nothing -> u >< v'
                                 Just p  -> u >< ((k :-> p') <| v')
                | otherwise -> case f Nothing of
                                 Nothing -> u >< v
                                 Just p  -> u >< ((k :-> p) <| v)

-- | O(log n). Delete a key from a queue.
delete :: (Ord k, Ord p) => k -> PSQ k p -> PSQ k p
delete = alter (const Nothing)

-- | O(log n). Adjust the priority of a key in the queue, provided that key exists.
adjust :: (Ord k, Ord p) => (p -> p) -> k -> PSQ k p -> PSQ k p
adjust f = alter (fmap f)

-- | O(log n). Adjust the priority of a key in the queue, provided that key exists, according to a function which additionally takes
-- the key as a parameter.
adjustWithKey :: (Ord k, Ord p) => (k -> p -> p) -> k -> PSQ k p -> PSQ k p
adjustWithKey f k = adjust (f k) k

-- | O(log n). Update or delete a priority in the queue, provided that key exists.
update :: (Ord k, Ord p) => (p -> Maybe p) -> k -> PSQ k p -> PSQ k p
update f = alter (>>= f) 

-- | O(log n). Update or delete a priority in the queue, provided that key exists, according to a function which additionally takes
-- the key as a parameter.
updateWithKey :: (Ord k, Ord p) => (k -> p -> Maybe p) -> k -> PSQ k p -> PSQ k p
updateWithKey f k = update (f k) k

-- | O(n). Flatten a queue into a list of bindings.
toList :: (Ord k, Ord p) => PSQ k p -> [Binding k p]
toList (PSQ q) = F.toList q

-- | O(n). Extract the list of keys of a queue.
keys :: (Ord k, Ord p) => PSQ k p -> [k]
keys = map (\(k :-> p) -> k) . toList

-- | O(n log n). Construct a queue from a list of bindings.
fromList :: (Ord k, Ord p) => [Binding k p] -> PSQ k p
fromList = PSQ . FT.fromList . sort

-- | O(n log n). Contstruct a queue from an already ascending list of bindings. Does not check that the list is sorted.
fromAscList :: (Ord k, Ord p) => [Binding k p] -> PSQ k p
fromAscList = PSQ . FT.fromList

-- | O(log n). Split a queue into the element with minimum priority, and the remainder.
minView :: (Ord k, Ord p) => PSQ k p -> Maybe (Binding k p, PSQ k p)
minView (PSQ q) =
  let minPrio = kpsPrio . measure $ q
      (u,v) = split ((== minPrio) . kpsPrio) q
   in case viewl v of
        EmptyL -> Nothing
        ((k :-> p) :< v') -> Just (k :-> p, PSQ (u >< v'))

-- | O(1). Find the binding with minimum priority in a queue.
findMin :: (Ord k, Ord p) => PSQ k p -> Maybe (Binding k p)
findMin q = case kpsPrio . measure $ q of
                    PMax   -> Nothing
                    Prio b -> Just b

-- | O(log n). Delete the key with minimum priority from a queue.
deleteMin :: (Ord k, Ord p) => PSQ k p -> PSQ k p
deleteMin q = maybe q id . fmap snd . minView $ q

-- | O(log n). The expression @range (l,u) q@ selects the keys k from q where @l <= k@ and @k <= u@.
range :: (Ord k, Ord p) => (k, k) -> PSQ k p -> PSQ k p
range (l,u) (PSQ q) = PSQ (fst . split ((> Key u) . kpsKey) . snd . split ((>= Key l) . kpsKey) $ q)

-- | O(r (log n)). Finds all the bindings in a queue whose priority is less than the given value.
atMost :: (Ord k, Ord p) => p -> PSQ k p -> [Binding k p]
atMost p (PSQ q) =
  let less (Prio (k :-> p')) = p' < p
      (u,v) = split (less . kpsPrio) q
   in case viewl v of
        EmptyL    -> []
        (b :< v') -> b : atMost p (PSQ v')

-- | Right fold over the list of bindings in a queue. 
foldr :: (Ord k, Ord p) => (Binding k p -> b -> b) -> b -> PSQ k p -> b
foldr f z = Prelude.foldr f z . toList

-- | Left fold over the list of bindings in a queue.
foldl :: (Ord k, Ord p) => (b -> Binding k p -> b) -> b -> PSQ k p -> b
foldl f z = Prelude.foldl f z . toList