{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GADTs                 #-}

module Generics.MultiRec.Rewriting.Machinery where

import Generics.MultiRec
import Generics.MultiRec.HZip
import Generics.MultiRec.Rewriting.Rules
import Generics.MultiRec.Any

import qualified Data.Map as M
import Control.Monad.State

-----------------------------------------------------------------------------
-- Class synonym for shorter names
-----------------------------------------------------------------------------
class (Fam phi, EqS phi, HZip phi (PF phi), HFunctor phi (PF phi))
      => Rewrite phi

-----------------------------------------------------------------------------
-- Actual rewriting
-----------------------------------------------------------------------------
rewriteM :: Rewrite phi => Rule phi a -> a -> Maybe a
rewriteM (Rule p (lhs :~> rhs)) term = 
  match p lhs term >>= return . (\s -> inst s p rhs)

match :: (Monad m, Rewrite phi) => 
         phi ix -> Scheme phi ix -> ix -> m (Subst phi)
match p pat term = execStateT (matchM p pat (I0 term)) M.empty

matchM :: (Monad m, Rewrite phi) 
          => phi ix -> Scheme phi ix -> I0 ix -> StateT (Subst phi) m ()
matchM p scheme (I0 e) = case scheme of
  HIn (L (K var)) -> do 
    subst <- get
    case M.lookup var subst of
      Nothing     -> put (M.insert var (Any p e) subst)
      Just exTerm -> checkEqual p e exTerm
  HIn (R r) -> combine matchM p r (from p e)

checkEqual :: (Monad m, Rewrite phi)
           => phi ix -> ix -> Any phi -> m ()
checkEqual p e (Any p' e') = case eqS p p' of
  Nothing   -> fail "checkEqual"
  Just Refl -> geq' p (I0 e) (I0 e')

inst :: Rewrite phi =>
        Subst phi -> phi ix -> Scheme phi ix -> ix
inst s ix p
  = case p of
     HIn (L (K x)) ->
        case M.lookup x s of
          Just (Any ix' e)
            -> case eqS ix ix' of
                 Just Refl -> e
                 Nothing -> error "Coerce error in inst"
     HIn (R r) -> to ix $ hmap (\ix' -> I0 . inst s ix') ix r

type Subst phi = M.Map Metavar (Any phi)