module Database.Schema.Migrations.CycleDetection
    ( hasCycle
    )
where

import Data.Graph.Inductive.Graph
    ( Graph(..)
    , Node
    , nodes
    , edges
    )

import Control.Monad.State ( State, evalState, gets, get, put )
import Control.Monad ( forM )

import Data.Maybe ( fromJust )
import Data.List ( findIndex )

data Mark = White | Gray | Black
type CycleDetectionState = [(Node, Mark)]

-- Cycle detection algorithm taken from http://www.cs.berkeley.edu/~kamil/teaching/sp03/041403.pdf
hasCycle :: Graph g => g a b -> Bool
hasCycle :: g a b -> Bool
hasCycle g a b
g = State CycleDetectionState Bool -> CycleDetectionState -> Bool
forall s a. State s a -> s -> a
evalState (g a b -> State CycleDetectionState Bool
forall (g :: * -> * -> *) a b.
Graph g =>
g a b -> State CycleDetectionState Bool
hasCycle' g a b
g) [(Node
n, Mark
White) | Node
n <- g a b -> [Node]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Node]
nodes g a b
g]

getMark :: Int -> State CycleDetectionState Mark
getMark :: Node -> State CycleDetectionState Mark
getMark Node
n = (CycleDetectionState -> Mark) -> State CycleDetectionState Mark
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Maybe Mark -> Mark
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Mark -> Mark)
-> (CycleDetectionState -> Maybe Mark)
-> CycleDetectionState
-> Mark
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Node -> CycleDetectionState -> Maybe Mark
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Node
n)

replace :: [a] -> Int -> a -> [a]
replace :: [a] -> Node -> a -> [a]
replace [a]
elems Node
index a
val
    | Node
index Node -> Node -> Bool
forall a. Ord a => a -> a -> Bool
> [a] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [a]
elems = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
"replacement index too large"
    | Bool
otherwise = (Node -> [a] -> [a]
forall a. Node -> [a] -> [a]
take Node
index [a]
elems) [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++
                  [a
val] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++
                  ([a] -> [a]
forall a. [a] -> [a]
reverse ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ Node -> [a] -> [a]
forall a. Node -> [a] -> [a]
take (([a] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [a]
elems) Node -> Node -> Node
forall a. Num a => a -> a -> a
- (Node
index Node -> Node -> Node
forall a. Num a => a -> a -> a
+ Node
1)) ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ [a] -> [a]
forall a. [a] -> [a]
reverse [a]
elems)

setMark :: Int -> Mark -> State CycleDetectionState ()
setMark :: Node -> Mark -> State CycleDetectionState ()
setMark Node
n Mark
mark = do
  CycleDetectionState
st <- StateT CycleDetectionState Identity CycleDetectionState
forall s (m :: * -> *). MonadState s m => m s
get
  let index :: Node
index = Maybe Node -> Node
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Node -> Node) -> Maybe Node -> Node
forall a b. (a -> b) -> a -> b
$ ((Node, Mark) -> Bool) -> CycleDetectionState -> Maybe Node
forall a. (a -> Bool) -> [a] -> Maybe Node
findIndex (\(Node
n', Mark
_) -> Node
n' Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
n) CycleDetectionState
st
  CycleDetectionState -> State CycleDetectionState ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (CycleDetectionState -> State CycleDetectionState ())
-> CycleDetectionState -> State CycleDetectionState ()
forall a b. (a -> b) -> a -> b
$ CycleDetectionState -> Node -> (Node, Mark) -> CycleDetectionState
forall a. [a] -> Node -> a -> [a]
replace CycleDetectionState
st Node
index (Node
n, Mark
mark)

hasCycle' :: Graph g => g a b -> State CycleDetectionState Bool
hasCycle' :: g a b -> State CycleDetectionState Bool
hasCycle' g a b
g = do
  [Bool]
result <- [Node]
-> (Node -> State CycleDetectionState Bool)
-> StateT CycleDetectionState Identity [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (g a b -> [Node]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [Node]
nodes g a b
g) ((Node -> State CycleDetectionState Bool)
 -> StateT CycleDetectionState Identity [Bool])
-> (Node -> State CycleDetectionState Bool)
-> StateT CycleDetectionState Identity [Bool]
forall a b. (a -> b) -> a -> b
$ \Node
n -> do
                   Mark
m <- Node -> State CycleDetectionState Mark
getMark Node
n
                   case Mark
m of
                     Mark
White -> g a b -> Node -> State CycleDetectionState Bool
forall (g :: * -> * -> *) a b.
Graph g =>
g a b -> Node -> State CycleDetectionState Bool
visit g a b
g Node
n
                     Mark
_ -> Bool -> State CycleDetectionState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  Bool -> State CycleDetectionState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> State CycleDetectionState Bool)
-> Bool -> State CycleDetectionState Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
result

visit :: Graph g => g a b -> Node -> State CycleDetectionState Bool
visit :: g a b -> Node -> State CycleDetectionState Bool
visit g a b
g Node
n = do
  Node -> Mark -> State CycleDetectionState ()
setMark Node
n Mark
Gray
  [Bool]
result <- [Node]
-> (Node -> State CycleDetectionState Bool)
-> StateT CycleDetectionState Identity [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [ Node
v | (Node
u,Node
v) <- g a b -> [(Node, Node)]
forall (gr :: * -> * -> *) a b.
Graph gr =>
gr a b -> [(Node, Node)]
edges g a b
g, Node
u Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
n ] ((Node -> State CycleDetectionState Bool)
 -> StateT CycleDetectionState Identity [Bool])
-> (Node -> State CycleDetectionState Bool)
-> StateT CycleDetectionState Identity [Bool]
forall a b. (a -> b) -> a -> b
$ \Node
node -> do
              Mark
m <- Node -> State CycleDetectionState Mark
getMark Node
node
              case Mark
m of
                Mark
Gray -> Bool -> State CycleDetectionState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                Mark
White -> g a b -> Node -> State CycleDetectionState Bool
forall (g :: * -> * -> *) a b.
Graph g =>
g a b -> Node -> State CycleDetectionState Bool
visit g a b
g Node
node
                Mark
_ -> Bool -> State CycleDetectionState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  case [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [Bool]
result of
    Bool
True -> Bool -> State CycleDetectionState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    Bool
False -> do
              Node -> Mark -> State CycleDetectionState ()
setMark Node
n Mark
Black
              Bool -> State CycleDetectionState Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False