{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, GeneralizedNewtypeDeriving, FlexibleContexts #-}
module Data.Monoid.Applicative 
    ( module Control.Applicative
    , module Data.Monoid.Reducer
    , Traversal(Traversal,getTraversal)
    , Alternate(Alternate,getAlternate)
    , TraversalWith(TraversalWith,getTraversalWith)
    ) where

import Control.Functor.Pointed (Pointed, point)
import Control.Applicative (Applicative, (*>), pure, Alternative, empty, (<|>), liftA2)
import Data.Monoid.Reducer

newtype Traversal f = Traversal { getTraversal :: f () } 

instance Applicative f => Monoid (Traversal f) where
    mempty = Traversal (pure ())
    Traversal a `mappend` Traversal b = Traversal (a *> b)

instance Applicative f => Reducer (f a) (Traversal f) where
    unit a = Traversal (a *> pure ())
    a `cons` Traversal b = Traversal (a *> b)
    Traversal a `snoc` b = Traversal (a *> b *> pure ())


{-# RULES "unitTraversal" unit = Traversal #-}
{-# RULES "snocTraversal" snoc = snocTraversal #-}
snocTraversal :: Reducer (f ()) (Traversal f) => Traversal f -> f () -> Traversal f
snocTraversal a = mappend a . Traversal

newtype Alternate f a = Alternate { getAlternate :: f a } 
    deriving (Eq,Ord,Show,Read,Functor,Applicative,Alternative)

instance Alternative f => Monoid (Alternate f a) where
    mempty = empty 
    Alternate a `mappend` Alternate b = Alternate (a <|> b) 

instance Alternative f => Reducer (f a) (Alternate f a) where
    unit = Alternate
    a `cons` Alternate b = Alternate (a <|> b) 
    Alternate a `snoc` b = Alternate (a <|> b)

instance Pointed f => Pointed (Alternate f) where
    point = Alternate . point

newtype TraversalWith f n = TraversalWith { getTraversalWith :: f n }

instance (Applicative f, Monoid n) => Monoid (TraversalWith f n) where
    mempty = TraversalWith (pure mempty)
    TraversalWith a `mappend` TraversalWith b = TraversalWith (liftA2 mappend a b)

instance (Applicative f, Monoid n) => Reducer (f n) (TraversalWith f n) where
    unit = TraversalWith

instance Functor f => Functor (TraversalWith f) where
    fmap f = TraversalWith . fmap f . getTraversalWith

instance Pointed f => Pointed (TraversalWith f) where
    point = TraversalWith . point