-----------------------------------------------------------------------------
-- Copyright 2016, Ideas project team. This file is distributed under the
-- terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Ideas.Common.Rewriting.Unification
   ( unify, match, matchExtended, matchList
   , Match, SymbolMatch
   , unificationTests
   ) where

import Control.Monad
import Data.Maybe
import Ideas.Common.Rewriting.AC (pairingsMatchA)
import Ideas.Common.Rewriting.Substitution
import Ideas.Common.Rewriting.Term
import Ideas.Utils.TestSuite
import qualified Data.Map as M

-----------------------------------------------------------
-- Unification (in both ways)

unify :: Term -> Term -> Maybe Substitution
unify term1 term2 =
   case (term1, term2) of
      (TMeta i, TMeta j) | i == j ->
         return emptySubst
      (TMeta i, _) | not (i `hasMetaVar` term2) ->
         return (singletonSubst i term2)
      (_, TMeta j) | not (j `hasMetaVar` term1) ->
         return (singletonSubst j term1)
      (TCon s xs, TCon t ys) | s == t ->
         rec xs ys
      (TList xs, TList ys) ->
         rec xs ys
      _ | term1 == term2 ->
         return emptySubst
      _ -> Nothing
 where
   rec [] [] = return emptySubst
   rec (x:xs) (y:ys) = do
      s1 <- unify x y
      s2 <- rec (map (s1 |->) xs) (map (s1 |->) ys)
      return (s2 @@ s1)
   rec _ _ = fail "match: no unifier"

match :: MonadPlus m => Term -> Term -> m Substitution
match term1 term2 =
   case (term1, term2) of
      (TMeta i, TMeta j) | i == j ->
         return emptySubst
      (TMeta i, _) | not (i `hasMetaVar` term2) ->
         return (singletonSubst i term2)
      (_, TMeta _) ->
         fail "match: no unifier"
      (TCon s xs, TCon t ys) | s == t ->
         rec xs ys
      (TList xs, TList ys) ->
         rec xs ys
      _ | term1 == term2 ->
         return emptySubst
      _ -> fail "match: no unifier"
 where
   rec [] [] = return emptySubst
   rec (x:xs) (y:ys) = do
      s1 <- match x y
      s2 <- rec (map (s1 |->) xs) ys
      guard (composable s1 s2)
      return (s1 @@ s2)
   rec _ _ = fail "match: no unifier"

-----------------------------------------------------------
-- Matching (or: one-way unification)

type Match a = a -> a -> [Substitution]
type SymbolMatch = Match Term -> [Term] -> Term -> [Substitution]

-- If the top-level symbol (of the left-hand side) is an associative binary
-- operator, extend both sides optionally with a meta-variable.
matchExtended :: M.Map Symbol SymbolMatch -> Term -> Term -> [(Substitution, Maybe Term, Maybe Term)]
matchExtended sm x y =
   [ (sub, lookupVar mvLeft sub, lookupVar mvRight sub)
   | f   <- extensions
   , sub <- matchA sm (f x) y
   ]
 where
   mvLeft     = nextMetaVar x
   mvRight    = mvLeft + 1
   extensions =
      case x of
         TCon s [_, _] | isAssociative s ->
            let extLeft  = binary s (TMeta mvLeft)
                extRight = flip (binary s) (TMeta mvRight)
            in [ f . g | f <- [id, extLeft], g <- [id, extRight] ]
         _ -> [id]

-- second term should not have meta variables
matchA :: M.Map Symbol SymbolMatch -> Match Term
matchA sm = rec
 where
   rec (TMeta i) y =
      return (singletonSubst i y)
   rec (TList xs) (TList ys) =
      matchList rec xs ys
   rec x y =
      case getFunction x of
         Just (s, as) ->
            case M.lookup s sm of
               Just f -> f rec as y
               Nothing
                  | isAssociative s -> associativeMatch s rec as y
                  | otherwise       -> defaultMatch rec x y
         _ -> defaultMatch rec x y

defaultMatch :: Match Term -> Match Term
defaultMatch f x y =
   case (x, y) of
      (TCon s xs, TCon t ys) -> do
         guard (s == t)
         matchList f xs ys
      (TList xs, TList ys) ->
         matchList f xs ys
      _ -> do
         guard (x == y)
         return emptySubst

matchList :: Match Term -> Match [Term]
matchList f as bs =
   case safeZipWith f as bs of
      Just ms -> products ms
      Nothing -> fail "matchList: lengths differ"

safeZipWith :: (a -> b -> c) -> [a] -> [b] -> Maybe [c]
safeZipWith f = rec
 where
   rec []     []     = Just []
   rec (a:as) (b:bs) = fmap (f a b:) (rec as bs)
   rec _      _      = Nothing

products :: [[Substitution]] -> [Substitution]
products = foldr op [emptySubst]
 where
   op xs ys = catMaybes [ x @+@ y | x <- xs, y <- ys ]

associativeMatch :: Symbol -> SymbolMatch
associativeMatch s f as b =
   pairingsMatchA make (collects as []) (collect b []) >>= products
 where
   make :: Term -> [Term] -> [Substitution]
   make (TMeta i) xs = [singletonSubst i (construct xs)]
   make x [y]        = f x y
   make _ _          = []

   collects     = foldr ((.) . collect) id
   collect term = maybe (term:) collects (isFunction s term)

   construct xs
      | null xs   = error "associativeMatch: empty list"
      | otherwise = foldr1 (binary s) xs

-----------------------------------------------------------
--- * Test unification properties

unificationTests :: TestSuite
unificationTests = suite "Unification"
   [ useProperty "unify" $ \a b ->
        case unify a b of
           Just s  -> (s |-> a) == (s |-> b)
           Nothing -> True
   , useProperty "unify-succeed" $ \a s ->
        let b = s |-> a in
        case unify a b of
           Just s2 -> (s2 |-> a) == (s2 |-> b)
           Nothing -> False
   , useProperty "match" $ \a b ->
        case match a b of
           Just s  -> (s |-> a) == b
           Nothing -> True
   , useProperty "match-succeed" $ \a s ->
        let b = s |-> a in
        case match a (s |-> a) of
           Just s2 -> (s2 |-> a) == b
           Nothing -> True
   ]