module Data.Bag where
import Data.Monoid
import Data.Foldable
import Control.Monad

data BagTree a = Leaf a | Branch (BagTree a) (BagTree a) deriving (Show)
data Bag a = Bag (BagTree a) | Empty deriving (Show)

empty :: Bag a
empty = Empty

unit :: a -> Bag a
unit a = Bag $ Leaf a

merge :: Bag a -> Bag a -> Bag a
merge (Bag a) (Bag b) = Bag $ Branch a b
merge Empty b = b
merge a Empty = a

btconcat :: BagTree (BagTree a) -> BagTree a
btconcat (Leaf a) = a
btconcat (Branch a b) = Branch (btconcat a) (btconcat b)

bconcat :: Bag (Bag a) -> Bag a
bconcat Empty = Empty
bconcat (Bag a) =
    concat a
  where
    concat :: BagTree (Bag a) -> Bag a
    concat (Leaf x) = x
    concat (Branch a b) =
        let a' = concat a
            b' = concat b in
                case (a', b') of
                    (Bag a'', Bag b'') -> Bag $ Branch a'' b''
                    (Empty, b'') -> b''
                    (a'', Empty) -> a''

instance Monoid (Bag a) where
    mempty = empty
    mappend = merge

instance Foldable BagTree where
    foldMap f (Leaf x) = f x
    foldMap f (Branch a b) = (foldMap f a) `mappend` (foldMap f b)

instance Foldable Bag where
    foldMap _ Empty = mempty
    foldMap f (Bag t) = foldMap f t

instance Functor BagTree where
    fmap f (Leaf x) = Leaf $ f x
    fmap f (Branch a b) = Branch (fmap f a) (fmap f b)

instance Functor Bag where
    fmap _ Empty = empty
    fmap f (Bag t) = Bag $ fmap f t

instance Monad BagTree where
    return = Leaf
    m >>= f = btconcat $ fmap f m

instance Monad Bag where
    return = Bag . Leaf
    m >>= f = bconcat $ fmap f m

instance MonadPlus Bag where
    mzero = empty
    mplus = merge

instance Eq a => Eq (BagTree a) where
    a == b = (toList a) == (toList b)

instance Eq a => Eq (Bag a) where
    a == b = (toList a) == (toList b)