-- | Equational reasoning built on top of congruence closure.

{-# LANGUAGE CPP, TupleSections #-}
module Test.QuickSpec.Reasoning.NaiveEquationalReasoning where

#include "errors.h"

import Test.QuickSpec.Term
import Test.QuickSpec.Equation
import Test.QuickSpec.Reasoning.CongruenceClosure(CC)
import qualified Test.QuickSpec.Reasoning.CongruenceClosure as CC
import Data.Map(Map)
import qualified Data.Map as Map
import Data.IntMap(IntMap)
import qualified Data.IntMap as IntMap
import Control.Monad
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Strict
import qualified Control.Monad.Trans.State.Strict as S
import Test.QuickSpec.Utils
import Test.QuickSpec.Utils.Typed
import Test.QuickSpec.Utils.Typeable
import Data.Ord
import Data.List

data Context = Context {
  rel :: CC.S,
  maxDepth :: Int,
  universe :: IntMap Universe
  }

type Universe = IntMap [Int]

type EQ = ReaderT (Int, IntMap Universe) CC

initial :: Int -> [Symbol] -> [Tagged Term] -> Context
initial d syms ts =
  let n = 1+maximum (0:map index syms)
      (universe, rel) =
        CC.runCC (CC.initial n) $
          forM (partitionBy (witnessType . tag) ts) $ \xs@(x:_) ->
            fmap (witnessType (tag x),) (createUniverse (map erase xs))
      univMap = Map.fromList universe

  in Context rel d . IntMap.fromList $ [
    (index sym,
     Map.findWithDefault IntMap.empty (symbolType sym) univMap)
    | sym <- syms ]

createUniverse :: [Term] -> CC Universe
createUniverse ts = fmap IntMap.fromList (mapM createTerms tss)
  where tss = partitionBy depth ts
        createTerms ts@(t:_) = fmap (depth t,) (mapM flatten ts)

runEQ :: Context -> EQ a -> (a, Context)
runEQ ctx x = (y, ctx { rel = rel' })
  where (y, rel') = runState (runReaderT x (maxDepth ctx, universe ctx)) (rel ctx)

evalEQ :: Context -> EQ a -> a
evalEQ ctx x = fst (runEQ ctx x)

execEQ :: Context -> EQ a -> Context
execEQ ctx x = snd (runEQ ctx x)

liftCC :: CC a -> EQ a
liftCC x = ReaderT (const x)

(=?=) :: Term -> Term -> EQ Bool
t =?= u = liftCC $ do
  x <- flatten t
  y <- flatten u
  x CC.=?= y

equal :: Equation -> EQ Bool
equal (t :=: u) = t =?= u

(=:=) :: Term -> Term -> EQ Bool
t =:= u = unify (t :=: u)

unify :: Equation -> EQ Bool
unify (t :=: u) = do
  (d, ctx) <- ask
  b <- t =?= u
  unless b $
    forM_ (substs t ctx d ++ substs u ctx d) $ \s -> liftCC $ do
      t' <- subst s t
      u' <- subst s u
      t' CC.=:= u'
  return b

type Subst = Symbol -> Int

substs :: Term -> IntMap Universe -> Int -> [Subst]
substs t univ d = map lookup (sequence (map choose vars))
  where vars = map (maximumBy (comparing snd)) .
               partitionBy fst .
               holes $ t

        choose (x, n) =
          let m = IntMap.findWithDefault (ERROR "empty universe")
                  (index x) univ in
          [ (x, t)
          | d' <- [0..d-n],
            t <- IntMap.findWithDefault [] d' m ]

        lookup ss =
          let m = IntMap.fromList [ (index x, y) | (x, y) <- ss ]
          in \x -> IntMap.findWithDefault (index x) (index x) m

subst :: Subst -> Term -> CC Int
subst s (Var x) = return (s x)
subst s (Const x) = return (index x)
subst s (App f x) = do
  f' <- subst s f
  x' <- subst s x
  f' CC.$$ x'

flatten :: Term -> CC Int
flatten = subst index

get :: EQ CC.S
get = liftCC S.get

put :: CC.S -> EQ ()
put x = liftCC (S.put x)

rep :: Term -> EQ Int
rep x = liftCC (flatten x >>= CC.rep)