module Control.Monad.Free.Zip (zipFree, zipFree_) where

import Control.Monad.Free
import Control.Monad.State
import Data.Foldable
import Data.Traversable as T

zipFree :: (Traversable f, Eq (f ()), Monad m) => (a -> b -> m c) -> Free f a -> Free f b -> m (Free f c)
zipFree f (Pure a) (Pure b) = Pure `liftM` f a b
zipFree f m1@(Impure a) m2@(Impure b)
      | fmap (const ()) a == fmap (const ()) b = Impure `liftM` unsafeZipWithG (zipFree f) a b
zipFree _ _ _ = fail "zipFree: not the same structure"

zipFree_ :: (Traversable f, Eq (f ()), Monad m) => (a -> b -> m ()) -> Free f a -> Free f b -> m ()
zipFree_ f (Pure a) (Pure b) = f a b
zipFree_ f m1@(Impure a) m2@(Impure b)
      | fmap (const ()) a == fmap (const ()) b = zipWithM_ (zipFree_ f) (toList a) (toList b)
zipFree_ _ _ _ = fail "zipFree: not the same structure"


unsafeZipWithG :: (Traversable t1, Traversable t2, Monad m) => (a -> b -> m c) ->  t1 a -> t2 b -> m(t2 c)
unsafeZipWithG f t1 t2  = evalStateT (T.mapM zipG' t2) (toList t1)
       where zipG' y = do (x:xx) <- get
                          put xx
                          lift (f x y)