module Generics.MultiRec.HZip where
import Generics.MultiRec
import Control.Monad (liftM, liftM2, zipWithM)
class HZip phi f where
hzipM :: Monad m =>
(forall ix. El phi ix => phi ix -> r ix -> r' ix -> m (r'' ix)) ->
f r ix -> f r' ix -> m (f r'' ix)
instance El phi xi => HZip phi (I xi) where
hzipM f (I x) (I y) = liftM I (f proof x y)
instance Eq a => HZip phi (K a) where
hzipM f (K x) (K y) | x == y = return (K x)
| otherwise = fail "zip failed in K"
instance HZip phi U where
hzipM f U U = return U
instance (HZip phi a, HZip phi b) => HZip phi (a :+: b) where
hzipM f (L x) (L y) = liftM L (hzipM f x y)
hzipM f (R x) (R y) = liftM R (hzipM f x y)
hzipM f _ _ = fail "zip failed"
instance (HZip phi a, HZip phi b) => HZip phi (a :*: b) where
hzipM f (x1 :*: y1) (x2 :*: y2) = liftM2 (:*:) (hzipM f x1 x2) (hzipM f y1 y2)
instance HZip phi f => HZip phi (f :>: xi) where
hzipM f (Tag x) (Tag y) = liftM Tag (hzipM f x y)
instance HZip phi f => HZip phi (C c f) where
hzipM f (C x) (C y) = liftM C (hzipM f x y)
instance HZip phi f => HZip phi ([] :.: f) where
hzipM f (D x) (D y) = liftM D (zipWithM (hzipM f) x y)
hzip :: (HZip phi f, Monad m) =>
(forall ix. El phi ix => phi ix -> r ix -> s ix -> t ix) ->
phi ix -> f r ix -> f s ix -> m (f t ix)
hzip f p = hzipM (\w x y -> return (f w x y))
hzip' :: (HZip phi f) =>
(forall ix. El phi ix => phi ix -> r ix -> s ix -> t ix) ->
phi ix -> f r ix -> f s ix -> f t ix
hzip' f p a b = case hzip (\p x y -> f p x y) p a b of
Nothing -> error "generic zip failed"
Just res -> res
combine :: forall phi f r r' m ix. (Monad m, HZip phi f) =>
(forall ix. El phi ix => phi ix -> r ix -> r' ix -> m ()) ->
phi ix -> f r ix -> f r' ix -> m ()
combine f l x y = hzipM wrapf x y >> return ()
where
wrapf :: forall ix' b. El phi ix' => phi ix' -> r ix' -> r' ix' -> m (K0 () b)
wrapf ix x y = f ix x y >> return (K0 ())
geq :: (Fam phi, HZip phi (PF phi)) => phi ix -> ix -> ix -> Bool
geq ix x y = maybe False (const True) (geq' ix (I0 x) (I0 y))
geq' :: (Monad m, Fam phi, HZip phi (PF phi))
=> phi ix -> I0 ix -> I0 ix -> m ()
geq' p (I0 x) (I0 y) = combine geq' p (from p x) (from p y)