module Data.Tree.Knuth.Forest where
import Prelude hiding (foldr, elem)
import Data.Semigroup
import Data.Foldable hiding (elem)
import Data.Witherable
import qualified Data.Set.Class as Sets
import Control.Applicative
import Control.Monad
import Data.Data
import GHC.Generics
import Control.DeepSeq
import Test.QuickCheck
data KnuthForest a
  = Fork { kNode     :: a
         , kChildren :: KnuthForest a
         , kSiblings :: KnuthForest a
         }
  | Nil
  deriving (Show, Eq, Functor, Traversable, Generic, Data, Typeable)
instance NFData a => NFData (KnuthForest a)
instance Arbitrary a => Arbitrary (KnuthForest a) where
  arbitrary =
    oneof [ return Nil
          , liftA3 Fork arbitrary arbitrary arbitrary
          ]
instance Ord a => Ord (KnuthForest a) where
  compare (Fork x xc xs) (Fork y yc ys) =
    compare x y <> compare xs ys <> compare xc yc
  compare Nil Nil = EQ
  compare Nil _   = LT
  compare _ Nil   = GT
instance Applicative KnuthForest where
  pure x = Fork x Nil Nil
  Nil <*> _ = Nil
  _ <*> Nil = Nil
  (Fork f fc fs) <*> (Fork x xc xs) =
    Fork (f x) (fc <*> xc) (fs <*> xs)
instance Alternative KnuthForest where
  empty = Nil
  (<|>) = union
instance Monad KnuthForest where
  return = pure
  Nil            >>= _ = Nil
  (Fork x xc xs) >>= f = f x `union` (xs >>= f) `union` (xc >>= f)
instance MonadPlus KnuthForest where
  mzero = Nil
  mplus = union
instance Semigroup (KnuthForest a) where
  (<>) = union
instance Monoid (KnuthForest a) where
  mempty  = Nil
  mappend = union
instance Foldable KnuthForest where
  foldr _ acc Nil = acc
  foldr f acc (Fork x xc xs) =
    foldr f (foldr f (f x acc) xs) xc
instance Witherable KnuthForest where
  catMaybes Nil = Nil
  catMaybes (Fork mx xc xs) = case mx of
    Nothing -> Nil
    Just x  -> Fork x (catMaybes xc) (catMaybes xs)
instance Sets.HasUnion (KnuthForest a) where
  union = union
instance Eq a => Sets.HasIntersection (KnuthForest a) where
  intersection = intersection
instance Eq a => Sets.HasDifference (KnuthForest a) where
  difference = difference
instance Sets.HasSize (KnuthForest a) where
  size = size
instance Sets.HasEmpty (KnuthForest a) where
  empty = Nil
instance Sets.HasSingleton a (KnuthForest a) where
  singleton = singleton
instance Eq a => Sets.HasDelete a (KnuthForest a) where
  delete = delete
size :: KnuthForest a -> Int
size Nil = 0
size (Fork _ xc xs) = 1 + size xc + size xs
elem :: Eq a => a -> KnuthForest a -> Bool
elem _ Nil = False
elem x (Fork y yc ys) = x == y || elem x ys || elem x yc
isSubforestOf :: Eq a => KnuthForest a -> KnuthForest a -> Bool
isSubforestOf Nil _ = True
isSubforestOf xss yss@(Fork _ yc ys) =
  xss == yss || isSubforestOf xss ys || isSubforestOf xss yc
isSubforestOf _ Nil = False
isSubforestOf' :: Eq a => KnuthForest a -> KnuthForest a -> Bool
isSubforestOf' Nil _ = True
isSubforestOf' xss yss@(Fork _ yc ys) =
  isSubforestOf xss yc || isSubforestOf xss ys || xss == yss
isSubforestOf' _ Nil = False
isProperSubforestOf :: Eq a => KnuthForest a -> KnuthForest a -> Bool
isProperSubforestOf Nil _ = True
isProperSubforestOf xss (Fork _ yc _) = isSubforestOf xss yc
isProperSubforestOf _ Nil = False
isProperSubforestOf' :: Eq a => KnuthForest a -> KnuthForest a -> Bool
isProperSubforestOf' Nil _ = True
isProperSubforestOf' xss (Fork _ yc _) = isSubforestOf' xss yc
isProperSubforestOf' _ Nil = False
isSiblingOf :: Eq a => a -> KnuthForest a -> Bool
isSiblingOf _ Nil = False
isSiblingOf x (Fork y _ ys) = x == y || isSiblingOf x ys
isChildOf :: Eq a => a -> KnuthForest a -> Bool
isChildOf _ Nil = False
isChildOf x (Fork _ yc ys) = isSiblingOf x yc || isChildOf x ys
isDescendantOf :: Eq a => a -> KnuthForest a -> Bool
isDescendantOf _ Nil = False
isDescendantOf x (Fork y yc _) = x == y || isDescendantOf x yc
isProperDescendantOf :: Eq a => a -> KnuthForest a -> Bool
isProperDescendantOf _ Nil = False
isProperDescendantOf x (Fork _ yc _) = isDescendantOf x yc
singleton :: a -> KnuthForest a
singleton x = Fork x Nil Nil
delete :: Eq a => a -> KnuthForest a -> KnuthForest a
delete _ Nil = Nil
delete x (Fork y yc ys) | x == y = Nil
                        | otherwise = Fork y (delete x yc) (delete x ys)
union :: KnuthForest a -> KnuthForest a -> KnuthForest a
union Nil             y = y
union (Fork x xc Nil) y = Fork x xc y
union (Fork x xc xs)  y = Fork x xc $ union xs y
intersection :: Eq a => KnuthForest a -> KnuthForest a -> KnuthForest a
intersection Nil _ = Nil
intersection _ Nil = Nil
intersection (Fork x xc xs) (Fork y yc ys)
  | x == y = Fork y (intersection xc yc) (intersection xs ys)
  | otherwise = Nil
difference :: Eq a => KnuthForest a -> KnuthForest a -> KnuthForest a
difference Nil _ = Nil
difference x Nil = x
difference (Fork x xc xs) yss@(Fork y _ _)
  | x == y = Nil
  | otherwise = Fork x (difference xc yss) (difference xs yss)