-- This file is part of the 'union-find-array' library. It is licensed
-- under an MIT license. See the accompanying 'LICENSE' file for details.
--
-- Authors: Bertram Felgenhauer

{-# LANGUAGE GeneralizedNewtypeDeriving, RankNTypes, FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
-- |
-- Monadic interface for creating a disjoint set data structure.
--
module Control.Monad.Union (
  UnionM,
  Union (..),
  MonadUnion (..),
  Node,
  run,
  run',
) where

import Control.Monad.Union.Class
import qualified Data.Union.ST as US
import Data.Union.Type (Node (..), Union (..))

import Prelude hiding (lookup)
import Control.Monad.State
import Control.Monad.ST
import Control.Monad.Fix
import Control.Applicative
import Control.Arrow (first)

data UState s l = UState {
    next   :: !Int,
    forest :: US.UnionST s l
}

-- | Union find monad.
newtype UnionM l a = U {
    runU :: (forall s . StateT (UState s l) (ST s) a)
}

instance Monad (UnionM l) where
    return x =  U (return x)
    f >>= b = U (runU f >>= runU . b)

instance Functor (UnionM l) where
    fmap = liftM

instance Applicative (UnionM l) where
    pure = return
    (<*>) = ap

instance MonadFix (UnionM l) where
    mfix a = U (mfix (runU . a))

-- | Run a union find computation.
run :: UnionM l a -> a
run a = runST $ do
    u <- US.new 1 undefined
    evalStateT (runU a) UState{ next = 0, forest = u }

-- | Run a union find computation; also return the final disjoint set forest
-- for querying.
run' :: UnionM l a -> (Union l, a)
run' a = runST $ do
    u <- US.new 1 undefined
    (a, s) <- runStateT (runU a) UState{ next = 0, forest = u }
    f <- US.unsafeFreeze (forest s)
    return (f, a)

instance MonadUnion l (UnionM l) where
    -- Add a new node, with a given label.
    new l = U $ do
        u <- get
        let size = US.size (forest u)
            n    = next u
        if (size <= next u) then do
            forest' <- lift $ US.grow (forest u) (2*size)
            lift $ US.annotate forest' n l
            put u{ forest = forest', next = n + 1 }
         else do
            lift $ US.annotate (forest u) n l
            put u{ next = n + 1 }
        return (Node n)

    -- Find the node representing a given node, and its label.
    lookup (Node n) = U $ do
        dsf <- gets forest
        first Node <$> lift (US.lookup dsf n)

    -- Merge two sets. The first argument is a function that takes the labels
    -- of the corresponding sets' representatives and computes a new label for
    -- the joined set. Returns Nothing if the given nodes are in the same set
    -- already.
    merge f (Node n) (Node m) = U $ do
        dsf <- gets forest
        lift $ US.merge dsf f n m

    -- Re-label a node.
    annotate (Node n) l = U $ do
        dsf <- gets forest
        lift $ US.annotate dsf n l

    -- Flatten the disjoint set forest for faster lookups.
    flatten = U $ do
        dsf <- gets forest
        lift $ US.flatten dsf