{-
Copyright (C) Stilo International plc, 2019
This source code is unpublished proprietary information of Stilo
Corporation. The copyright notice above does not evidence any
actual or intended publication of such source code.
This file may not be redistributed as source, either by itself,
or as part of software derived from this file.
-}
{-# LANGUAGE GADTs, FlexibleContexts, FlexibleInstances, MultiParamTypeClasses, ScopedTypeVariables, UndecidableInstances #-}
module Data.PriorityQueue (PQueue, Branching, Pruned,
branchable, prune, pruneAbove, pruneAlternativesAbove, mapWithCost, filter, mapMaybe, foldPeers,
canonical, pruneSubsets, strip, stripCommon, stripCost,
cost, leastCost, withCost) where
import Control.Applicative (Applicative(..), Alternative(..))
import Data.Coerce (coerce)
import Data.Foldable (Foldable(fold))
import Data.Monoid (Monoid(mempty, mappend), Alt(Alt, getAlt))
import Data.Semigroup (Semigroup((<>)))
import Prelude hiding (filter)
data Branching
data Pruned
data PQueue t c a = Costly !c (PQueue t c a)
| Free !(Ground a) (PQueue t c a)
| Empty String
deriving Show
data Ground a = Leaf a
| Peer !(Ground a) !(Ground a)
deriving Show
instance Foldable Ground where
foldMap f (Leaf a) = f a
foldMap f (Peer g h) = foldMap f g <> foldMap f h
instance Functor Ground where
fmap f (Leaf a) = Leaf (f a)
fmap f (Peer g h) = Peer (fmap f g) (fmap f h)
instance Applicative Ground where
Leaf f <*> g = f <$> g
Peer g1 g2 <*> h = Peer (g1 <*> h) (g2 <*> h)
pure = Leaf
instance Foldable (PQueue t c) where
foldMap f (Costly _ q) = foldMap f q
foldMap f (Free a q) = foldMap f a <> foldMap f q
foldMap f Empty{} = mempty
instance Functor (PQueue t c) where
fmap f (Costly c q) = Costly c (fmap f q)
fmap f (Free a q) = Free (fmap f a) (fmap f q)
fmap _ (Empty msg) = Empty msg
instance (Alternative (PQueue t c), Semigroup c) => Applicative (PQueue t c) where
Costly c1 q1 <*> Costly c2 q2 = Costly (c1 <> c2) (q1 <*> q2)
Costly c q1 <*> q2 = Costly c (q1 <*> q2)
q1 <*> Costly c q2 = Costly c (q1 <*> q2)
Free f q1 <*> Free a q2 = Free (f <*> a) (mapPeers (f <*>) q2 <|> mapPeers (<*> a) q1 <|> q1 <*> q2)
where mapPeers f (Free g q) = Free (f g) (mapPeers f q)
mapPeers f (Costly c q) = Costly c (mapPeers f q)
mapPeers _ (Empty msg) = Empty msg
Empty msg <*> _ = Empty msg
_ <*> Empty msg = Empty msg
pure a = Free (Leaf a) (Empty "")
{-# INLINABLE (<*>) #-}
instance (Num c, Ord c, Semigroup c) => Alternative (PQueue Branching c) where
Costly c1 q1 <|> Costly c2 q2 = {-# SCC "AltB.compare" #-}
case compare c1 c2
of LT -> Costly c1 (q1 <|> Costly (c2 - c1) q2)
GT -> Costly c2 (Costly (c1 - c2) q1 <|> q2)
EQ -> Costly c1 (q1 <|> q2)
Free a q1 <|> Free b q2 = Free (Peer a b) (q1 <|> q2)
Free a q1 <|> q2 = Free a (q1 <|> q2)
q1 <|> Free a q2 = Free a (q1 <|> q2)
Empty{} <|> pq = pq
pq <|> Empty{} = pq
empty = Empty "empty"
{-# INLINABLE (<|>) #-}
instance (Num c, Ord c, Semigroup c) => Alternative (PQueue Pruned c) where
Costly c1 q1 <|> Costly c2 q2 = {-# SCC "AltP.compare" #-}
case compare c1 c2
of LT -> Costly c1 (q1 <|> Costly (c2 - c1) q2)
GT -> Costly c2 (Costly (c1 - c2) q1 <|> q2)
EQ -> Costly c1 (q1 <|> q2)
Free a _ <|> _ = Free a (Empty "")
_ <|> Free a _ = Free a (Empty "")
Empty{} <|> pq = pq
pq <|> Empty{} = pq
empty = Empty "empty"
{-# INLINABLE (<|>) #-}
instance (Semigroup c, Alternative (PQueue t c)) => Monad (PQueue t c) where
Costly c q >>= f = Costly c (q >>= f)
Free a q >>= f = getAlt (foldMap (Alt . f) a) <|> (q >>= f)
Empty msg >>= _ = Empty msg
{-# INLINABLE (>>=) #-}
instance (Semigroup c, Alternative (PQueue t c)) => MonadFail (PQueue t c) where
fail = Empty
-- | @withCost k@ adds a penalty of k to each value in the queue.
withCost :: (Semigroup c, Num c, Ord c) => c -> PQueue t c a -> PQueue t c a
withCost 0 q = q
withCost c q | c <= 0 = error "The cost must be non-negative!"
| otherwise = Costly c q
{-# INLINE withCost #-}
-- | Fold together all stored values that share the same priority.
foldPeers :: (a -> a -> a) -> PQueue t c a -> PQueue t c a
foldPeers _ q@Empty{} = q
foldPeers f (Costly c q) = Costly c (foldPeers f q)
foldPeers f (Free g q) = Free (Leaf a'') q''
where (a'', q'') = case foldPeers f q
of Free (Leaf b) q' -> (f a' b, q')
q' -> (a', q')
a' = foldGroundPeers f g
foldGroundPeers :: (a -> a -> a) -> Ground a -> a
foldGroundPeers _ (Leaf a) = a
foldGroundPeers f (Peer l r) = f (foldGroundPeers f l) (foldGroundPeers f r)
-- | Imposes the given cost on the current computation branch.
-- > cost k = withCost k (pure ())
cost :: (Semigroup c, Num c, Ord c) => c -> PQueue Branching c ()
cost 0 = pure ()
cost k | k > 0 = Costly k (pure ())
-- | Relax the 'Pruned' phantom constraint, allowing the queue to become 'Branching'.
branchable :: PQueue Pruned c a -> PQueue t c a
branchable = coerce
-- | Prune away all stored values more expensive than the given cost.
pruneAbove :: (Semigroup c, Num c, Ord c) => c -> PQueue t c a -> PQueue t c a
pruneAbove k _
| k < 0 = Empty "pruned"
pruneAbove k (Costly c q)
| k' < 0 = Empty "pruned"
| otherwise = Costly c (pruneAbove k' q)
where k' = k - c
pruneAbove k (Free a q) = Free a (pruneAbove k q)
pruneAbove _ q@Empty{} = q
{-# INLINABLE pruneAbove #-}
-- | Prune away all stored values more expensive than the given cost and a less expensive alternative value.
pruneAlternativesAbove :: (Semigroup c, Num c, Ord c) => c -> PQueue t c a -> PQueue t c a
pruneAlternativesAbove k q
| k <= 0 = q
pruneAlternativesAbove k (Costly c q) = Costly c (pruneAlternativesAbove (k - c) q)
pruneAlternativesAbove k (Free a q) = Free a (pruneAbove k q)
pruneAlternativesAbove _ q@Empty{} = q
{-# INLINABLE pruneAlternativesAbove #-}
-- | Prune away all stored values except the one with the least penalty, making the queue 'Pruned'.
prune :: PQueue t c a -> PQueue Pruned c a
prune (Costly c q) = Costly c (prune q)
prune (Free a q) = Free (Leaf $ leftmost a) (Empty "")
where leftmost :: Ground a -> a
leftmost (Leaf a) = a
leftmost (Peer l r) = leftmost l
prune (Empty msg) = Empty msg
-- | Minimize the queue structure. This operation forces the entire spine of the queue and its every level.
canonical :: Semigroup c => PQueue t c a -> PQueue t c a
canonical (Costly c1 (Costly c2 q)) = canonical (Costly (c1 <> c2) q)
canonical (Costly c q) = Costly c (canonical q)
canonical (Free a q) = Free a (canonical q)
canonical q@Empty{} = q
-- | Filter away from the queue the values that the argument function maps to `False`
filter :: (a -> Bool) -> PQueue t c a -> PQueue t c a
filter f (Costly c q) = Costly c (filter f q)
filter f (Free g q) = maybe id Free (filterGround g) (filter f q)
where filterGround g@(Leaf a) = if f a then Just g else Nothing
filterGround (Peer g1 g2) = case (filterGround g1, filterGround g2)
of (Just g1', Just g2') -> Just (Peer g1' g2')
(Just g', Nothing) -> Just g'
(Nothing, Just g') -> Just g'
(Nothing, Nothing) -> Nothing
filter _ q@Empty{} = q
-- | Map and filter away from the queue the values that the argument function maps to `Nothing'
mapMaybe :: (a -> Maybe b) -> PQueue t c a -> PQueue t c b
mapMaybe f (Costly c q) = Costly c (mapMaybe f q)
mapMaybe f (Free g q) = maybe id Free (filterGround g) (mapMaybe f q)
where filterGround g@(Leaf a) = Leaf <$> f a
filterGround (Peer g1 g2) = case (filterGround g1, filterGround g2)
of (Just g1', Just g2') -> Just (Peer g1' g2')
(Just g', Nothing) -> Just g'
(Nothing, Just g') -> Just g'
(Nothing, Nothing) -> Nothing
mapMaybe _ (Empty msg) = Empty msg
-- | Assuming the stored values belong to a cancellative monoid, prune away all extraneous values and factors using the
-- supplied function that calculates the sum and difference of the two values, if there is any difference, and the monoid null.
-- > fold (pruneSubsets plusDiff mempty pq) == fold pq
-- > where plusDiff u a
-- > | gcd u a == u = Nothing
-- > | d <- a - gcd u a = Just (u <> d, d)
pruneSubsets :: (a -> b -> Maybe (a, b)) -> a -> PQueue t c b -> PQueue t c b
pruneSubsets unionDiff set (Costly c q) = Costly c (pruneSubsets unionDiff set q)
pruneSubsets unionDiff set (Free g q) =
case pruneGroundSubsets unionDiff set g
of Nothing -> pruneSubsets unionDiff set q
Just (set', g') -> Free g' (pruneSubsets unionDiff set' q)
pruneSubsets _ _ q@Empty{} = q
pruneGroundSubsets :: (a -> b -> Maybe (a, b)) -> a -> Ground b -> Maybe (a, Ground b)
pruneGroundSubsets unionDiff set (Leaf l) = case unionDiff set l
of Nothing -> Nothing
Just (set', l') -> Just (set', Leaf l')
pruneGroundSubsets unionDiff set (Peer g1 g2) =
case pruneGroundSubsets unionDiff set g1
of Nothing -> pruneGroundSubsets unionDiff set g2
Just (set', g1') -> case pruneGroundSubsets unionDiff set' g2
of Nothing -> Just (set', g1')
Just (set'', g2') -> Just (set'', Peer g1' g2')
-- | Returns the pair of the GCD of all the penalties and the penalties without the GCD
-- > gcd <*> rest == f
-- > where (gcd, rest) = stripCommon f
stripCommon :: (Ord c, Num c, Functor f, Foldable f, Alternative (PQueue t c)) =>
f (PQueue t c a) -> (PQueue Pruned c (a -> a), f (PQueue t c a))
stripCommon f = (common, strip common <$> f)
where common = const id <$> prune (getAlt $ foldMap Alt f)
-- | Subtract the first argument cost GCD from the cost of every value in the second argument
strip :: (Ord c, Num c) => PQueue Pruned c a -> PQueue t c b -> PQueue t c b
strip (Costly c q1) q2 = stripCost c (strip q1 q2)
strip _ q = q
-- | Subtract the given cost from the cost of every value in the queue
stripCost :: (Ord c, Num c) => c -> PQueue t c a -> PQueue t c a
stripCost c (Costly c' q)
| c < c' = Costly (c' - c) q
| c > c' = stripCost (c - c') q
| otherwise = q
stripCost _ q@Empty{} = q
-- stripCost c q = error ("stripCost " <> show c <> " " <> show (() <$ q))
-- | Returns 'Just' the minimal cost present in the queue, 'Nothing' if the queue is empty.
leastCost :: Monoid c => PQueue t c a -> Maybe c
leastCost (Costly c q) = (c <>) <$> leastCost q
leastCost Free{} = Just mempty
leastCost Empty{} = Nothing
-- | Maps each item contained in the queue, supplying the item's cost as first argument
mapWithCost :: Monoid c => (c -> a -> b) -> PQueue t c a -> PQueue t c b
mapWithCost f (Costly c q) = Costly c (mapWithCost (f . (c <>)) q)
mapWithCost f (Free a q) = Free (f mempty <$> a) (mapWithCost f q)
mapWithCost _ (Empty msg) = Empty msg