{-# LANGUAGE ScopedTypeVariables #-}

module Data.Record.Anon.Internal.Plugin.TC.EquivClasses (
    constructEquivClasses
  , canonicalize
  ) where

import Data.Bifunctor
import Data.Foldable (toList)
import Data.Graph (Graph, Vertex)
import Data.Map (Map)
import Data.Set (Set)

import qualified Data.Graph as Graph
import qualified Data.Map   as Map
import qualified Data.Set   as Set

-- | Given a set of equivalent pairs, map every value to canonical value
--
-- Example with two classes:
--
-- >>> constructEquivClasses [(1, 2), (4, 5), (2, 3)]
-- fromList [(1,1),(2,1),(3,1),(4,4),(5,4)]
--
-- Adding one element that connects both classes:
--
-- >>> constructEquivClasses [(1, 2), (4, 5), (2, 3), (3, 4)]
-- fromList [(1,1),(2,1),(3,1),(4,1),(5,1)]
constructEquivClasses :: forall a. Ord a => [(a, a)] -> Map a a
constructEquivClasses :: forall a. Ord a => [(a, a)] -> Map a a
constructEquivClasses [(a, a)]
equivs =
     forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
Map.unions forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ([a] -> Map a a
pickCanonical forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Vertex -> a
fromVertex forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> [a]
toList) forall a b. (a -> b) -> a -> b
$
       Graph -> Forest Vertex
Graph.components Graph
graph
  where
    allValues :: Set a
    allValues :: Set a
allValues = forall a. Ord a => [a] -> Set a
Set.fromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(a
x, a
y) -> [a
x, a
y]) [(a, a)]
equivs

    toVertex   :: a -> Vertex
    fromVertex :: Vertex -> a

    toVertex :: a -> Vertex
toVertex   a
a = forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault (forall a. HasCallStack => [Char] -> a
error [Char]
"toVertex: impossible")   a
a forall a b. (a -> b) -> a -> b
$
                     forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Set a -> [a]
Set.toList Set a
allValues) [Vertex
1..]
    fromVertex :: Vertex -> a
fromVertex Vertex
v = forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault (forall a. HasCallStack => [Char] -> a
error [Char]
"fromVertex: impossible") Vertex
v forall a b. (a -> b) -> a -> b
$
                     forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Vertex
1..] (forall a. Set a -> [a]
Set.toList Set a
allValues)

    graph :: Graph
    graph :: Graph
graph = Bounds -> [Bounds] -> Graph
Graph.buildG (Vertex
1, forall a. Set a -> Vertex
Set.size Set a
allValues) forall a b. (a -> b) -> a -> b
$
              forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap a -> Vertex
toVertex a -> Vertex
toVertex) [(a, a)]
equivs

    -- Given a previously established equivalence class, construct a mapping
    -- that maps each value to an (arbitrary) canonical value.
    pickCanonical :: [a] -> Map a a
    pickCanonical :: [a] -> Map a a
pickCanonical [a]
cls = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [a]
cls (forall a. a -> [a]
repeat (forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [a]
cls))

canonicalize :: Ord a => Map a a -> a -> a
canonicalize :: forall a. Ord a => Map a a -> a -> a
canonicalize Map a a
canon a
x = forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault a
x a
x Map a a
canon