{-# LANGUAGE CPP #-}
module Control.Monad.Free.Zip (zipFree, zipFree_) where

import Control.Monad.Free
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Foldable
import Data.Traversable as T
import Prelude hiding (fail)

zipFree
  :: (Traversable f, Eq (f ()), MonadFail m)
  => (Free f a -> Free f b -> m (Free f c))
  -> Free f a
  -> Free f b
  -> m (Free f c)
zipFree f (Impure a) (Impure b)
  | fmap (const ()) a == fmap (const ()) b = Impure `liftM` unsafeZipWithG f a b
zipFree _ _ _ = fail "zipFree: structure mistmatch"

zipFree_
  :: (Traversable f, Eq (f ()), MonadFail m)
  => (Free f a -> Free f b -> m ()) -> Free f a -> Free f b -> m ()
zipFree_ f (Impure a) (Impure b)
  | fmap (const ()) a == fmap (const ()) b = zipWithM_ f (toList a) (toList b)
zipFree_ _ _ _ = fail "zipFree_: structure mismatch"


unsafeZipWithG
  :: (Traversable t1, Traversable t2, Monad m, MonadFail m)
  => (a -> b -> m c) -> t1 a -> t2 b -> m (t2 c)
unsafeZipWithG f t1 t2  = evalStateT (T.mapM zipG' t2) (toList t1)
       where zipG' y = do (x:xx) <- get
                          put xx
                          lift (f x y)