{-# LANGUAGE ExistentialQuantification, MultiParamTypeClasses,
       FunctionalDependencies, FlexibleInstances, UndecidableInstances #-}

-----------------------------------------------------------------------------

-- Copyright 2018, Ideas project team. This file is distributed under the

-- terms of the Apache License 2.0. For more information, see the files

-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.

-----------------------------------------------------------------------------

-- |

-- Maintainer  :  bastiaan.heeren@ou.nl

-- Stability   :  provisional

-- Portability :  portable (depends on ghc)

--

-----------------------------------------------------------------------------



module Ideas.Common.Rewriting.RewriteRule

   ( -- * Supporting type class

     Different(..)

     -- * Rewrite rules and specs

   , RewriteRule, ruleSpecTerm, RuleSpec(..)

     -- * Compiling rewrite rules

   , makeRewriteRule, termRewriteRule, RuleBuilder(..)

     -- * Using rewrite rules

   , showRewriteRule

   , metaInRewriteRule, renumberRewriteRule

   , symbolMatcher, symbolBuilder

   ) where



import Data.Maybe

import Ideas.Common.Classes

import Ideas.Common.Environment

import Ideas.Common.Id

import Ideas.Common.Rewriting.Substitution

import Ideas.Common.Rewriting.Term

import Ideas.Common.Rewriting.Unification

import Ideas.Common.View hiding (match)

import Ideas.Utils.Uniplate (descend)

import qualified Data.IntSet as IS

import qualified Data.Map as M



------------------------------------------------------

-- Rewrite rules and specs



infixl 1 :~>



data RuleSpec a = a :~> a deriving Show



instance Functor RuleSpec where

   fmap f (a :~> b) = f a :~> f b



data RewriteRule a = R

   { ruleId         :: Id

   , ruleSpecTerm   :: RuleSpec Term

   , ruleShow       :: a -> String

   , ruleTermView   :: View Term a

   , ruleMatchers   :: M.Map Symbol SymbolMatch

   , ruleBuilders   :: M.Map Symbol ([Term] -> Term)

   }



instance Show (RewriteRule a) where

   show = showId



instance HasId (RewriteRule a) where

   getId = ruleId

   changeId f r = r {ruleId = f (ruleId r)}



------------------------------------------------------

-- Compiling a rewrite rule



class Different a where

   different :: (a, a)



instance Different a => Different [a] where

   different = ([], [fst different])



instance Different Int where

   different = (0, 1)



instance Different Integer where

   different = (0, 1)



instance Different Double where

   different = (0, 1)



instance Different Float where

   different = (0, 1)



instance Different Rational where

   different = (0, 1)



instance Different Bool where

   different = (True, False)



instance Different Char where

   different = ('a', 'b')



instance Different Term where

   different = (TNum 0, TNum 1)



instance (Different a, Different b) => Different (a, b) where

   different =

      let (a1, a2) = different

          (b1, b2) = different

      in ((a1, b1), (a2, b2))



instance (Different a, Different b, Different c) => Different (a, b, c) where

   different =

      let (a1, a2) = different

          (b1, b2) = different

          (c1, c2) = different

      in ((a1, b1, c1), (a2, b2, c2))



class (IsTerm a, Show a) => RuleBuilder t a | t -> a where

   buildRuleSpec  :: Int -> t -> RuleSpec Term



instance (IsTerm a, Show a) => RuleBuilder (RuleSpec a) a where

   buildRuleSpec  = const $ fmap toTerm



instance (Different a, RuleBuilder t b) => RuleBuilder (a -> t) b where

   buildRuleSpec i f = buildFunction i (buildRuleSpec (i+1) . f)



buildFunction :: Different a => Int -> (a -> RuleSpec Term) -> RuleSpec Term

buildFunction n f = fzip (fill n) ((f *** f) different)

 where

   fzip g (a :~> b, c :~> d) = g a c :~> g b d



fill :: Int -> Term -> Term -> Term

fill i = rec

 where

   rec (TCon s xs) (TCon t ys) | s == t && length xs == length ys =

      TCon s (zipWith rec xs ys)

   rec (TList xs) (TList ys) | length xs == length ys =

      TList (zipWith rec xs ys)

   rec a b

      | a == b    = a

      | otherwise = TMeta i



buildSpec :: M.Map Symbol SymbolMatch

          -> M.Map Symbol ([Term] -> Term)

          -> RuleSpec Term -> Term -> [(Term, [Term])]

buildSpec sm sb (lhs :~> rhs) a = do

   (sub, ml, mr) <- matchExtended sm lhs a

   let sym = maybe (error "buildSpec") fst (getFunction lhs)

       extLeft  = maybe id (binary sym) ml

       extRight = maybe id (flip (binary sym)) mr

       new  = useBuilders $ extLeft $ extRight $ sub |-> rhs

       args = mapMaybe (`lookupVar` sub) $ IS.toList $ dom sub

   return (new, args)

 where

   useBuilders

      | M.null sb = id

      | otherwise = rec

    where

      rec (TCon s xs) =

         fromMaybe (TCon s) (M.lookup s sb) (map rec xs)

      rec term = term



makeRewriteRule :: (IsId n, RuleBuilder f a) => n -> f -> RewriteRule a

makeRewriteRule s f =

   R (newId s) (buildRuleSpec 0 f) show termView M.empty M.empty



termRewriteRule :: (IsId n, IsTerm a, Show a) => n -> RuleSpec Term -> RewriteRule a

termRewriteRule s spec =

   R (newId s) spec show termView M.empty M.empty



symbolMatcher :: Symbol -> SymbolMatch -> RewriteRule a -> RewriteRule a

symbolMatcher s f r = r {ruleMatchers = M.insert s f (ruleMatchers r)}



symbolBuilder :: Symbol -> ([Term] -> Term) -> RewriteRule a -> RewriteRule a

symbolBuilder s f r = r {ruleBuilders = M.insert s f (ruleBuilders r)}



------------------------------------------------------

-- Using a rewrite rule



instance Apply RewriteRule where

   applyAll r = map fst . applyRewriteRule r



applyRewriteRule :: RewriteRule a -> a -> [(a, Environment)]

applyRewriteRule r a = do

   let builder = buildSpec (ruleMatchers r) (ruleBuilders r) (ruleSpecTerm r)

       term    = toTermRR r a

   (out, xs) <- builder term

   let env    = mconcat (zipWith make xs [1::Int ..])

       make t = flip singleBinding t . makeRef . show

   b <- fromTermRR r out

   return (b, env)



-----------------------------------------------------------

-- Pretty-print a rewriteRule



showRewriteRule :: Bool -> RewriteRule a -> Maybe String

showRewriteRule sound r = do

   x <- fromTermRR r (sub |-> a)

   y <- fromTermRR r (sub |-> b)

   let op = if sound then "~>" else "/~>"

   return (ruleShow r x ++ " " ++ op ++ " " ++ ruleShow r y)

 where

   a :~> b = ruleSpecTerm r

   vs  = IS.toList (metaVarSet a `IS.union` metaVarSet b)

   sub = listToSubst $ zip vs [ TVar [c] | c <- ['a' ..] ]



------------------------------------------------------



-- some helpers

metaInRewriteRule :: RewriteRule a -> [Int]

metaInRewriteRule r = metaVars a ++ metaVars b

 where a :~> b = ruleSpecTerm r



renumberRewriteRule :: Int -> RewriteRule a -> RewriteRule a

renumberRewriteRule n r = r {ruleSpecTerm = fmap f (ruleSpecTerm r)}

 where

   f (TMeta i) = TMeta (i+n)

   f term      = descend f term



toTermRR :: RewriteRule a -> a -> Term

toTermRR = build . ruleTermView



fromTermRR :: Monad m => RewriteRule a -> Term -> m a

fromTermRR = matchM . ruleTermView