{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GADTs                 #-}

module Generics.MultiRec.HZip where

import Generics.MultiRec
import Control.Monad (liftM, liftM2, zipWithM)

class HZip phi f where
  hzipM :: Monad m =>
           (forall ix. El phi ix => phi ix -> r ix -> r' ix -> m (r'' ix)) ->
           f r ix -> f r' ix -> m (f r'' ix)

instance El phi xi => HZip phi (I xi) where
  hzipM f (I x) (I y) = liftM I (f proof x y)

instance Eq a => HZip phi (K a) where
  hzipM f (K x) (K y) | x == y    = return (K x)
                      | otherwise = fail "zip failed in K"

instance HZip phi U where
  hzipM f U U = return U

instance (HZip phi a, HZip phi b) => HZip phi (a :+: b) where
  hzipM f (L x) (L y) = liftM L (hzipM f x y)
  hzipM f (R x) (R y) = liftM R (hzipM f x y)
  hzipM f _     _     = fail "zip failed"

instance (HZip phi a, HZip phi b) => HZip phi (a :*: b) where
  hzipM f (x1 :*: y1) (x2 :*: y2) = liftM2 (:*:) (hzipM f x1 x2) (hzipM f y1 y2)

instance HZip phi f => HZip phi (f :>: xi) where
  hzipM f (Tag x) (Tag y) = liftM Tag (hzipM f x y)

instance HZip phi f => HZip phi (C c f) where
  hzipM f (C x) (C y) = liftM C (hzipM f x y)

instance HZip phi f => HZip phi ([] :.: f) where
  hzipM f (D x) (D y) = liftM D (zipWithM (hzipM f) x y)

-- | Monadic zip but argument is not monadic
hzip :: (HZip phi f, Monad m) =>
        (forall ix. El phi ix => phi ix -> r ix -> s ix -> t ix) ->
        phi ix -> f r ix -> f s ix -> m (f t ix)
hzip f p = hzipM (\w x y -> return (f w x y))

-- | Unsafe zip
hzip' :: (HZip phi f) =>
         (forall ix. El phi ix => phi ix -> r ix -> s ix -> t ix) ->
         phi ix -> f r ix -> f s ix -> f t ix
hzip' f p a b = case hzip (\p x y -> f p x y) p a b of
  Nothing  -> error "generic zip failed"
  Just res -> res

-- | Combine two structures monadically only
combine :: forall phi f r r' m ix. (Monad m, HZip phi f) =>
           (forall ix. El phi ix => phi ix -> r ix -> r' ix -> m ()) ->
           phi ix -> f r ix -> f r' ix -> m ()
combine f l x y = hzipM wrapf x y >> return ()
  where
    wrapf :: forall ix' b. El phi ix' => phi ix' -> r ix' -> r' ix' -> m (K0 () b)
    wrapf ix x y = f ix x y >> return (K0 ())

-- | Generic equality
geq :: (Fam phi, HZip phi (PF phi)) => phi ix -> ix -> ix -> Bool
geq ix x y = maybe False (const True) (geq' ix (I0 x) (I0 y))

-- | Monadic generic equality (just for the sake of the monad!)
geq' :: (Monad m, Fam phi, HZip phi (PF phi))
        => phi ix -> I0 ix -> I0 ix -> m ()
geq' p (I0 x) (I0 y) = combine geq' p (from p x) (from p y)