module Data.EulerTourTree
  ( EulerTourTree
    
  , empty
  , singleton
  , fromTree
    
  , toTree
    
  , root
  , member
  , size
    
  , cutEdge
  , splice
  , reroot
  ) where
import           Control.Applicative             hiding (empty)
import           Control.Applicative.Combinators
import           Control.Monad
import           Control.Monad.State.Lazy
import           Control.Monad.Trans.Maybe
import           Data.FingerTree                 hiding (empty, null, singleton)
import qualified Data.FingerTree                 as FingerTree
import           Data.Foldable
import           Data.List.Unique
import           Data.Maybe
import           Data.Monoid
import           Data.Set                        (Set)
import qualified Data.Set                        as Set
import           Data.Tree
import           Debug.Trace
searchM :: MonadPlus m => Measured v a => (v -> v -> Bool) -> FingerTree v a -> m (FingerTree v a, a, FingerTree v a)
searchM f tree = case FingerTree.search f tree of
  Position before element after -> return (before, element, after)
  _                             -> mzero
initSafe :: Measured v a => FingerTree v a -> FingerTree v a
initSafe tree = case viewr tree of
  result :> _ -> result
  _           -> tree
tailSafe :: Measured v a => FingerTree v a -> FingerTree v a
tailSafe tree = case viewl tree of
  _ :< result -> result
  _           -> tree
newtype EulerTourNode node = EulerTourNode node
deriving instance (Eq node) => Eq (EulerTourNode node)
deriving instance (Ord node) => Ord (EulerTourNode node)
deriving instance (Show node) => Show (EulerTourNode node)
data EulerTourMonoid node = EulerTourMonoid
  (First node)
  (Set (node, node))
  (Last node)
  (Set node)
  (Sum Int)
deriving instance Show node => Show (EulerTourMonoid node)
instance Ord node => Monoid (EulerTourMonoid node) where
  mempty = EulerTourMonoid mempty mempty mempty mempty mempty
  EulerTourMonoid a b c d e `mappend` EulerTourMonoid a' b' c' d' e' = result where
    result = EulerTourMonoid (a <> a') (b <> bMiddle <> b') (c <> c') (d <> d') (e <> e')
    bMiddle = fromMaybe mempty $ do
      l <- getLast c
      f <- getFirst a'
      return $ Set.singleton (min l f, max l f)
instance Ord node => Measured (EulerTourMonoid node) (EulerTourNode node) where
  measure (EulerTourNode node) = EulerTourMonoid (pure node) mempty (pure node) (Set.singleton node) (pure 1)
firstVertex :: MonadPlus m => Ord node => EulerTourMonoid node -> m node
firstVertex (EulerTourMonoid first _ _ _ _) = maybe mzero return $ getFirst first
allNodes :: Ord node => EulerTourMonoid node -> Set node
allNodes (EulerTourMonoid _ _ _ nodes _) = nodes
vertexMember :: Ord node => node -> EulerTourMonoid node -> Bool
vertexMember node (EulerTourMonoid _ _ _ nodes _) = Set.member node nodes
edgeMember :: Ord node => (node, node) -> EulerTourMonoid node -> Bool
edgeMember (u, v) (EulerTourMonoid _ edges _ _ _) = Set.member (min u v, max u v) edges
tourSize :: EulerTourMonoid node -> Int
tourSize (EulerTourMonoid _ _ _ nodes _) = Set.size nodes
data EulerTourTree node where
  EulerTourTree :: Ord node => FingerTree (EulerTourMonoid node) (EulerTourNode node) -> EulerTourTree node
instance Ord node => Measured (Set (node, node), Set node, Sum Int) (EulerTourTree node) where
  measure (EulerTourTree tree) = (edges, nodes, size) where
    EulerTourMonoid _ edges _ nodes size = measure tree
instance Foldable EulerTourTree where
  foldMap f etTree@(EulerTourTree _) = maybe mempty (foldMap f) $ toTree etTree
deriving instance Eq node => Eq (EulerTourTree node)
deriving instance Ord node => Ord (EulerTourTree node)
deriving instance Show node => Show (EulerTourTree node)
empty :: Ord node => EulerTourTree node
empty = EulerTourTree FingerTree.empty
singleton :: Ord node => node -> EulerTourTree node
singleton node = EulerTourTree $ FingerTree.singleton $ EulerTourNode node
fromTree :: MonadPlus m => Ord node => Tree node -> m (EulerTourTree node)
fromTree tree = do
  guard $ allUnique $ toList tree
  return $ EulerTourTree $ fromTree' tree
  where fromTree' (Node node forest) = EulerTourNode node <| mconcat (map ((\x -> x |> EulerTourNode node) . fromTree') forest)
type Parser node m = (MonadState (FingerTree (EulerTourMonoid node) (EulerTourNode node)) m, MonadPlus m)
toTree :: MonadPlus m => Ord node => EulerTourTree node -> m (Tree node)
toTree (EulerTourTree fingerTree) = evalStateT parser fingerTree where
  parser = do
    EulerTourNode node <- anyToken
    forest <- parser `endBy` try (token $ EulerTourNode node)
    return $ Node node forest
  anyToken :: Ord node => Parser node m => m (EulerTourNode node)
  anyToken = do
    tree <- get
    case viewl tree of
      node :< tree' -> put tree' >> return node
      _             -> mzero
  token x = do
    t <- anyToken
    guard (t == x)
    return t
  try f = do
    a <- get
    f <|> (put a >> mzero)
root :: MonadPlus m => Ord node => EulerTourTree node -> m node
root (EulerTourTree tree) = firstVertex $ measure tree
member :: Ord node => node -> EulerTourTree node -> Bool
member node (EulerTourTree tree) = vertexMember node $ measure tree
size :: Ord node => EulerTourTree node -> Int
size (EulerTourTree fingerTree) = tourSize $ measure fingerTree
cutEdge :: MonadPlus m
        => Ord node
        => EulerTourTree node  
        -> (node, node)        
        -> m (EulerTourTree node, EulerTourTree node)  
cutEdge (EulerTourTree tree) e@(a, b) = do
  (left, node, tree') <- searchM p1 tree
  (middle, node', right) <- searchM p2 $ node <| tree'
  let inside = node <| middle
      outside = left >< tailSafe right
  return (EulerTourTree inside, EulerTourTree outside)
  where p1 before after = edgeMember e before
        p2 before after = not (edgeMember e after)
splice :: MonadPlus m
       => Ord node
       => EulerTourTree node           
       -> node                         
       -> EulerTourTree node           
       -> m (EulerTourTree node)
splice (EulerTourTree inviteeTree) node (EulerTourTree hostTree) = do
  guard $ null $ Set.intersection (allNodes $ measure inviteeTree) (allNodes $ measure hostTree)
  (left, _, right) <- searchM p hostTree
  let inviteeTree' = if FingerTree.null inviteeTree then inviteeTree else EulerTourNode node <| inviteeTree
  return $ EulerTourTree $ left >< inviteeTree' >< (EulerTourNode node <| right)
  where p before after = vertexMember node before && not (vertexMember node after)
reroot :: MonadPlus m
       => Ord node
       => node                 
       -> EulerTourTree node   
       -> m (EulerTourTree node)
reroot node (EulerTourTree tree) = do
  (left, _, tree') <- searchM p1 tree
  (middle, _, right) <- searchM p2 (etNode <| tree')
  return $ EulerTourTree $ middle >< (etNode <| initSafe right) >< (left |> etNode)
  where p1 before after = vertexMember node before
        p2 before after = not (vertexMember node after)
        etNode = EulerTourNode node