{-# LANGUAGE TypeFamilies, FlexibleContexts, UndecidableInstances,
        StandaloneDeriving, FlexibleInstances, MultiParamTypeClasses #-}
module Text.XML.Expat.Chunked.Iterator (Iterator (..)) where

import Data.List.Class (List(..), ListItem(..), foldrL)

import Control.Applicative (Applicative(..))
import Control.Monad (MonadPlus(..), ap, liftM)
-- import Control.Monad.Cont.Class (MonadCont(..))
import Control.Monad.Error.Class (MonadError(..))
import Control.Monad.Reader.Class (MonadReader(..))
import Control.Monad.State.Class (MonadState(..))
import Control.Monad.Trans (MonadTrans(..), MonadIO(..))
import Data.Monoid (Monoid(..))


newtype Iterator m a = Iterator { runIterator :: m (ListItem (Iterator m) a) }

deriving instance (Eq (m (ListItem (Iterator m) a))) => Eq (Iterator m a)
deriving instance (Ord (m (ListItem (Iterator m) a))) => Ord (Iterator m a)
deriving instance (Read (m (ListItem (Iterator m) a))) => Read (Iterator m a)
deriving instance (Show (m (ListItem (Iterator m) a))) => Show (Iterator m a)

{-
instance (Monad m, MonadPlus (Iterator m)) => List (Iterator m) where
    type ItemM (Iterator m) = m
    runList = runIterator
    joinL m = Iterator $ do
        it <- m
        runIterator it
-}

-- for mappend, fmap, bind
foldrL' :: List l => (a -> l b -> l b) -> l b -> l a -> l b
foldrL' consFunc nilFunc =
  joinL . foldrL step (return nilFunc)
  where
    step x = return . consFunc x . joinL

-- like generic cons except using that one
-- would cause an infinite loop
cons :: Monad m => a -> Iterator m a -> Iterator m a
cons x = Iterator . return . Cons x

instance Monad m => Monoid (Iterator m a) where
  mempty = Iterator $ return Nil
  mappend = flip (foldrL' cons)

instance Monad m => Functor (Iterator m) where
  fmap func = foldrL' (cons . func) mempty

instance Monad m => Monad (Iterator m) where
  return = Iterator . return . (`Cons` mempty)
  a >>= b = foldrL' mappend mempty (fmap b a)

instance Monad m => Applicative (Iterator m) where
  pure = return
  (<*>) = ap

instance Monad m => MonadPlus (Iterator m) where
  mzero = mempty
  mplus = mappend

instance MonadTrans Iterator where
  lift = Iterator . liftM (`Cons` mempty)

instance Monad m => List (Iterator m) where
  type ItemM (Iterator m) = m
  runList = runIterator
  joinL = Iterator . (>>= runList)

-- YUCK:
-- I can't believe I'm doing this,
-- for compatability with mtl's Iterator.
-- I hate the O(N^2) code length auto-lifts. DRY!!

instance MonadIO m => MonadIO (Iterator m) where
  liftIO = lift . liftIO

{-
-- TODO: understand and verify this instance :)
instance MonadCont m => MonadCont (Iterator m) where
  callCC f =
    Iterator $ callCC thing
    where
      thing c = runIterator . f $ Iterator . c . (`Cons` mempty)
-}

instance MonadError e m => MonadError e (Iterator m) where
  throwError = lift . throwError
  catchError m = Iterator . catchError (runList m) . (runList .)

instance MonadReader s m => MonadReader s (Iterator m) where
  ask = lift ask
  local f = Iterator . local f . runList

instance MonadState s m => MonadState s (Iterator m) where
  get = lift get
  put = lift . put