{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE ScopedTypeVariables        #-}

module Generics.MultiRec.Transformations.ZipChildren where

import Generics.MultiRec hiding ( show, foldM )
import Control.Monad.State hiding ( foldM, mapM )
import Data.Foldable ( toList )

import Generics.MultiRec.CountIs
import Generics.MultiRec.Transformations.Path
import Generics.MultiRec.Transformations.Children
import Generics.MultiRec.Transformations.MemoTable

--------------------------------------------------------------------------------
-- Zip immediate children
--------------------------------------------------------------------------------

zipChildrenM :: (Monad m, Fam phi, ZipChildren phi (PF phi))
             => phi ix
             -> (forall t. (Lookup phi (Ixs phi) t, Children phi (PF phi) t, GetChildrenTable phi (Ixs phi) t,  Eq t) => phi t -> Path phi t ix -> t -> t -> m a)
             -> ix
             -> ix
             -> m [a]
{-
zipChildrenM p f a b = zipChildren p (\p w (I0 l) (I0 r) -> f p w l r)
                         (\z -> Push p z Empty)
                         (from p a) (from p b)
-}
zipChildrenM p f a b = zipChildren p (\p w (I0 l) (I0 r) -> f p w l r)
                         (\z -> Push (error "oops") z Empty)
                         (from p a) (from p b)

class ZipChildren phi (f :: (* -> *) -> * -> *) where
  zipChildren :: (Monad m)
              => phi ix
              -> (forall t. (Lookup phi (Ixs phi) t, Children phi (PF phi) t, GetChildrenTable phi (Ixs phi) t, Eq t)
                  => phi t -> Path phi t ix -> r t -> r t -> m a)
              -> (forall t. Dir f t ix -> Path phi t ix)
              -> f r ix
              -> f r ix
              -> m [a]

instance ( Lookup phi (Ixs phi) xi, El phi xi, Children phi (PF phi) xi
         , GetChildrenTable phi (Ixs phi) xi, Eq xi) => ZipChildren phi (I xi) where
  zipChildren _ f w (I l) (I r) = f proof (w CId) l r >>= \x -> return [x]

instance ZipChildren phi (K a) where
  zipChildren _ _ _ _ _ = return []

instance ZipChildren phi U where
  zipChildren _ _ _ _ _ = return []

instance (ZipChildren phi f, ZipChildren phi g) => ZipChildren phi (f :+: g) where
  zipChildren p f w (L l) (L r) = zipChildren p f (w . CL) l r
  zipChildren p f w (R l) (R r) = zipChildren p f (w . CR) l r

instance (ZipChildren phi f, ZipChildren phi g, CountIs g)
    => ZipChildren phi (f :*: g) where
  zipChildren p f w (l1 :*: l2) (r1 :*: r2) =
    liftM2 (++) (zipChildren p f (\z -> w (C1 z nullY)) l1 r1)
                (zipChildren p f (\z -> w (C2 nullX z)) l2 r2)
      where nullX = error "nullX" -- fmap (const ()) x
            nullY = error "nullY" -- fmap (const ()) y

instance (ZipChildren phi f, Constructor c) => ZipChildren phi (C c f) where
  zipChildren p f w (C l) (C r) = zipChildren p f (w . CC) l r

instance ZipChildren phi f => ZipChildren phi (f :>: ix) where
  zipChildren p f w (Tag l) (Tag r) = zipChildren p f (w . CTag) l r
{-
instance (Traversable t, ZipChildren phi f) => ZipChildren phi (t :.: f) where
  zipChildren p f w (D l) (D r) = liftM concat $ sequence $
                                  zipWith3 (\i -> zipChildren p f (w . TrvI i))
                                  [0..] (toList l) (toList r)
-}
instance (ZipChildren phi f) => ZipChildren phi (Maybe :.: f) where
  zipChildren p f w (D l) (D r) = liftM concat $ sequence $
                                  zipWith3 (\i -> zipChildren p f (w . CCM))
                                  [0..] (toList l) (toList r)

instance (ZipChildren phi f) => ZipChildren phi ([] :.: f) where
  zipChildren p f w (D l) (D r) = liftM concat $ sequence $
                                  zipWith3 (\i -> zipChildren p f (\x -> w (CCL (ll i) x lr)))
                                  [0..] l r
    where ll i = replicate i (error "oops3")
          lr = error "oops2"