{-# LANGUAGE ScopedTypeVariables #-}

module Data.Record.Anon.Internal.Plugin.TC.EquivClasses (
    constructEquivClasses
  , canonicalize
  ) where

import Data.Bifunctor
import Data.Foldable (toList)
import Data.Graph (Graph, Vertex)
import Data.Map (Map)
import Data.Set (Set)

import qualified Data.Graph as Graph
import qualified Data.Map   as Map
import qualified Data.Set   as Set

-- | Given a set of equivalent pairs, map every value to canonical value
--
-- Example with two classes:
--
-- >>> constructEquivClasses [(1, 2), (4, 5), (2, 3)]
-- fromList [(1,1),(2,1),(3,1),(4,4),(5,4)]
--
-- Adding one element that connects both classes:
--
-- >>> constructEquivClasses [(1, 2), (4, 5), (2, 3), (3, 4)]
-- fromList [(1,1),(2,1),(3,1),(4,1),(5,1)]
constructEquivClasses :: forall a. Ord a => [(a, a)] -> Map a a
constructEquivClasses :: forall a. Ord a => [(a, a)] -> Map a a
constructEquivClasses [(a, a)]
equivs =
     [Map a a] -> Map a a
forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
f (Map k a) -> Map k a
Map.unions ([Map a a] -> Map a a) -> [Map a a] -> Map a a
forall a b. (a -> b) -> a -> b
$ (Tree Vertex -> Map a a) -> [Tree Vertex] -> [Map a a]
forall a b. (a -> b) -> [a] -> [b]
map ([a] -> Map a a
pickCanonical ([a] -> Map a a) -> (Tree Vertex -> [a]) -> Tree Vertex -> Map a a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vertex -> a) -> [Vertex] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Vertex -> a
fromVertex ([Vertex] -> [a])
-> (Tree Vertex -> [Vertex]) -> Tree Vertex -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tree Vertex -> [Vertex]
forall a. Tree a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList) ([Tree Vertex] -> [Map a a]) -> [Tree Vertex] -> [Map a a]
forall a b. (a -> b) -> a -> b
$
       Graph -> [Tree Vertex]
Graph.components Graph
graph
  where
    allValues :: Set a
    allValues :: Set a
allValues = [a] -> Set a
forall a. Ord a => [a] -> Set a
Set.fromList ([a] -> Set a) -> [a] -> Set a
forall a b. (a -> b) -> a -> b
$ ((a, a) -> [a]) -> [(a, a)] -> [a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(a
x, a
y) -> [a
x, a
y]) [(a, a)]
equivs

    toVertex   :: a -> Vertex
    fromVertex :: Vertex -> a

    toVertex :: a -> Vertex
toVertex   a
a = Vertex -> a -> Map a Vertex -> Vertex
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault ([Char] -> Vertex
forall a. HasCallStack => [Char] -> a
error [Char]
"toVertex: impossible")   a
a (Map a Vertex -> Vertex) -> Map a Vertex -> Vertex
forall a b. (a -> b) -> a -> b
$
                     [(a, Vertex)] -> Map a Vertex
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(a, Vertex)] -> Map a Vertex) -> [(a, Vertex)] -> Map a Vertex
forall a b. (a -> b) -> a -> b
$ [a] -> [Vertex] -> [(a, Vertex)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Set a -> [a]
forall a. Set a -> [a]
Set.toList Set a
allValues) [Vertex
1..]
    fromVertex :: Vertex -> a
fromVertex Vertex
v = a -> Vertex -> Map Vertex a -> a
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault ([Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"fromVertex: impossible") Vertex
v (Map Vertex a -> a) -> Map Vertex a -> a
forall a b. (a -> b) -> a -> b
$
                     [(Vertex, a)] -> Map Vertex a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Vertex, a)] -> Map Vertex a) -> [(Vertex, a)] -> Map Vertex a
forall a b. (a -> b) -> a -> b
$ [Vertex] -> [a] -> [(Vertex, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Vertex
1..] (Set a -> [a]
forall a. Set a -> [a]
Set.toList Set a
allValues)

    graph :: Graph
    graph :: Graph
graph = Bounds -> [Bounds] -> Graph
Graph.buildG (Vertex
1, Set a -> Vertex
forall a. Set a -> Vertex
Set.size Set a
allValues) ([Bounds] -> Graph) -> [Bounds] -> Graph
forall a b. (a -> b) -> a -> b
$
              ((a, a) -> Bounds) -> [(a, a)] -> [Bounds]
forall a b. (a -> b) -> [a] -> [b]
map ((a -> Vertex) -> (a -> Vertex) -> (a, a) -> Bounds
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap a -> Vertex
toVertex a -> Vertex
toVertex) [(a, a)]
equivs

    -- Given a previously established equivalence class, construct a mapping
    -- that maps each value to an (arbitrary) canonical value.
    pickCanonical :: [a] -> Map a a
    pickCanonical :: [a] -> Map a a
pickCanonical [a]
cls = [(a, a)] -> Map a a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(a, a)] -> Map a a) -> [(a, a)] -> Map a a
forall a b. (a -> b) -> a -> b
$ [a] -> [a] -> [(a, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [a]
cls (a -> [a]
forall a. a -> [a]
repeat ([a] -> a
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [a]
cls))

canonicalize :: Ord a => Map a a -> a -> a
canonicalize :: forall a. Ord a => Map a a -> a -> a
canonicalize Map a a
canon a
x = a -> a -> Map a a -> a
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault a
x a
x Map a a
canon