{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RankNTypes #-}
module Data.DisjointSet.Int.Monadic (
DisjointIntSetMonadic(DisjointIntSetMonadicFixed, DisjointIntSetMonadicVariable),
DisjointIntSet(DisjointIntSet),
newDisjointIntSetFixed,
newDisjointIntSetVariable,
union,
find,
count,
findAndCount,
numSets,
size,
nextInSet,
setToList,
unsafeFreeze,
freeze,
runDisjointIntSet
)
where
import qualified Data.DisjointSet.Int.Monadic.Impl as Impl
import Data.DisjointSet.Int.Monadic.Impl (MVectorT)
import Prelude (
Int, (+), (-), (*), negate,
($),
return,
Monad, (>>),
Bool(True, False),
(<), (>=), (==), (>), (/=), (<=),
mapM_,
max,
pred, succ,
undefined,
minBound,
div,
Maybe,
Foldable,
(.),
(>>=),
(<$>),
Show
)
import Data.Vector.Unboxed.Mutable (
MVector,
new,
unsafeRead, unsafeWrite,
unsafeGrow
)
import qualified Data.Vector.Unboxed
import qualified Data.Vector.Unboxed.Mutable
import Data.Vector.Unboxed (
Vector
)
import Control.Monad.Primitive (
PrimState, PrimMonad
)
import Control.Monad.ST (
ST, runST
)
import Control.Monad.Ref (
MonadRef, Ref, newRef,
readRef, writeRef, modifyRef'
)
import Control.Monad (
when
)
data Size = Variable | Fixed
data DisjointIntSetMonadic m where
DisjointIntSetMonadicFixed :: (Monad m) => (MVectorT m) -> (MVectorT m) -> (Ref m Int) -> Int -> DisjointIntSetMonadic m
DisjointIntSetMonadicVariable :: (Monad m) => (Ref m (MVectorT m)) -> (Ref m (MVectorT m)) -> (Ref m Int) -> (Ref m Int) -> DisjointIntSetMonadic m
data DisjointIntSet = DisjointIntSet (Vector Int) (Vector Int) Int Int deriving Show
type MonadT m = (MonadRef m, PrimMonad m)
runMonadicIntSetFunc ::(MonadT m) => DisjointIntSetMonadic m -> ((?v :: MVectorT m, ?set_v :: MVectorT m, ?numElems :: Int, ?numSets :: Int, ?numSets_ref :: Ref m Int) => m b) -> m b
runMonadicIntSetFunc (DisjointIntSetMonadicVariable v_ref set_v_ref numSets_ref numElems_ref) f = do
v <- readRef v_ref
set_v <- readRef set_v_ref
numElems <- readRef numElems_ref
numSets <- readRef numSets_ref
let
?v = v
?set_v = set_v
?numElems = numElems
?numSets = numSets
?numSets_ref = numSets_ref
in f
runMonadicIntSetFunc (DisjointIntSetMonadicFixed v set_v numSets_ref numElems) f =
do
numSets <- readRef numSets_ref
let
?v = v
?set_v = set_v
?numElems = numElems
?numSets = numSets
?numSets_ref = numSets_ref
in f
union :: (MonadT m) => DisjointIntSetMonadic m -> Int -> Int -> m Bool
union x@(DisjointIntSetMonadicFixed _ _ _ _) i1 i2 = runMonadicIntSetFunc x (Impl.union (i1, i2))
union x@(DisjointIntSetMonadicVariable v_ref set_v_ref numSets_ref numElems_ref) i1 i2 = do
let
?v_ref = v_ref
?set_v_ref = set_v_ref
?numSets_ref = numSets_ref
?numElems_ref = numElems_ref
in do
Impl.resize (max i1 i2)
runMonadicIntSetFunc x (Impl.union (i1,i2))
find :: (MonadT m) => DisjointIntSetMonadic m -> Int -> m Int
find x i = runMonadicIntSetFunc x (Impl.find i)
count :: (MonadT m) => DisjointIntSetMonadic m -> Int -> m Int
count x i = runMonadicIntSetFunc x (Impl.count i)
findAndCount :: (MonadT m) => DisjointIntSetMonadic m -> Int -> m (Int, Int)
findAndCount x i = runMonadicIntSetFunc x (Impl.findAndCount i)
numSets :: (MonadT m) => DisjointIntSetMonadic m -> m Int
numSets (DisjointIntSetMonadicFixed _ _ numSets_ref _) = readRef numSets_ref
numSets (DisjointIntSetMonadicVariable _ _ numSets_ref _) = readRef numSets_ref
size :: (MonadT m) => DisjointIntSetMonadic m -> m Int
size (DisjointIntSetMonadicFixed _ _ _ numElems) = return numElems
size (DisjointIntSetMonadicVariable _ _ _ numElems_ref) = readRef numElems_ref
newDisjointIntSetFixed :: (PrimMonad m, MonadRef m) => Int -> m (DisjointIntSetMonadic m)
newDisjointIntSetFixed size = let ?size = size; ?array_size = size in do
v <- Impl.new_count
set_v <- Impl.new_set
n <- newRef size
return (DisjointIntSetMonadicFixed v set_v n size)
newDisjointIntSetVariable :: (PrimMonad m, MonadRef m) => m (DisjointIntSetMonadic m)
newDisjointIntSetVariable = let ?size = 0; ?array_size = 1024 in do
v <- Impl.new_count
set_v <- Impl.new_set
v_ref <- newRef v
set_v_ref <- newRef set_v
zero <- newRef 0
n <- newRef 0
return (DisjointIntSetMonadicVariable v_ref set_v_ref zero n)
nextInSet :: (MonadT m) => DisjointIntSetMonadic m -> Int -> m Int
nextInSet x i = runMonadicIntSetFunc x (Impl.nextInSet i)
setToList :: forall m. (MonadT m) => DisjointIntSetMonadic m -> Int -> m [Int]
setToList x i = do
n <- count x i
go n i where
go :: Int -> Int -> m [Int]
go n i = case n of
0 -> return []
_ -> do
next_i <- nextInSet x i
rest <- go (n - 1) next_i
return (i:rest)
unsafeFreeze :: (MonadT m) => DisjointIntSetMonadic m -> m DisjointIntSet
unsafeFreeze (DisjointIntSetMonadicFixed v set_v numSets_ref numElems) = do
v_frozen <- Data.Vector.Unboxed.unsafeFreeze v
set_v_frozen <- Data.Vector.Unboxed.unsafeFreeze set_v
numSets <- readRef numSets_ref
return (DisjointIntSet v_frozen set_v_frozen numSets numElems)
unsafeFreeze (DisjointIntSetMonadicVariable v_ref set_v_ref numSets_ref numElems_ref) = do
numSets <- readRef numSets_ref
numElems <- readRef numElems_ref
v <- readRef v_ref
set_v <- readRef set_v_ref
v_frozen <- Data.Vector.Unboxed.unsafeFreeze v
set_v_frozen <- Data.Vector.Unboxed.unsafeFreeze set_v
return (DisjointIntSet v_frozen set_v_frozen numSets numElems)
freeze :: (MonadT m) => DisjointIntSetMonadic m -> m DisjointIntSet
freeze (DisjointIntSetMonadicFixed v set_v numSets_ref numElems) = do
v_frozen <- Data.Vector.Unboxed.freeze (Data.Vector.Unboxed.Mutable.slice 0 numElems v)
set_v_frozen <- Data.Vector.Unboxed.freeze (Data.Vector.Unboxed.Mutable.slice 0 numElems set_v)
numSets <- readRef numSets_ref
return (DisjointIntSet v_frozen set_v_frozen numSets numElems)
freeze (DisjointIntSetMonadicVariable v_ref set_v_ref numSets_ref numElems_ref) = do
numSets <- readRef numSets_ref
numElems <- readRef numElems_ref
v <- readRef v_ref
set_v <- readRef set_v_ref
v_frozen <- Data.Vector.Unboxed.freeze (Data.Vector.Unboxed.Mutable.slice 0 numElems v)
set_v_frozen <- Data.Vector.Unboxed.freeze (Data.Vector.Unboxed.Mutable.slice 0 numElems set_v)
return (DisjointIntSet v_frozen set_v_frozen numSets numElems)
runDisjointIntSet :: (forall s. ST s (DisjointIntSetMonadic (ST s))) -> DisjointIntSet
runDisjointIntSet actions = runST (actions >>= unsafeFreeze)
runPure :: (MonadT m) => DisjointIntSetMonadic m -> (DisjointIntSet -> a) -> m a
runPure x f = f <$> (unsafeFreeze x)
runMonad :: (MonadT m) => DisjointIntSetMonadic m -> (DisjointIntSet -> m a) -> m a
runMonad x f = (unsafeFreeze x) >>= f