-- | Equational reasoning built on top of congruence closure.

{-# LANGUAGE TupleSections #-}
module Test.QuickSpec.Reasoning.NaiveEquationalReasoning where

import Test.QuickSpec.Term
import Test.QuickSpec.Equation
import Test.QuickSpec.Reasoning.CongruenceClosure(CC)
import qualified Test.QuickSpec.Reasoning.CongruenceClosure as CC
import Data.Map(Map)
import qualified Data.Map as Map
import Data.IntMap(IntMap)
import qualified Data.IntMap as IntMap
import Control.Monad
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State.Strict
import qualified Control.Monad.Trans.State.Strict as S
import Test.QuickSpec.Utils
import Test.QuickSpec.Utils.Typed
import Test.QuickSpec.Utils.Typeable
import Data.Ord
import Data.List

data Context = Context {
  rel :: CC.S,
  universe :: Map TypeRep Universe,
  maxDepth :: Int
  }

type Universe = IntMap [Int]

type EQ = ReaderT (Map TypeRep Universe, Int) CC

initial :: Int -> [Tagged Term] -> Context
initial d ts =
  let n = 1+maximum (0:concatMap (map index . symbols . erase) ts)
      (universe, rel) =
        CC.runCC (CC.initial n) $
          forM (partitionBy (witnessType . tag) ts) $ \xs@(x:_) ->
            fmap (witnessType (tag x),) (createUniverse (map erase xs))

  in Context rel (Map.fromList universe) d

createUniverse :: [Term] -> CC Universe
createUniverse ts = fmap IntMap.fromList (mapM createTerms tss)
  where tss = partitionBy depth ts
        createTerms ts@(t:_) = fmap (depth t,) (mapM flatten ts)

runEQ :: Context -> EQ a -> (a, Context)
runEQ ctx x = (y, ctx { rel = rel' })
  where (y, rel') = runState (runReaderT x (universe ctx, maxDepth ctx)) (rel ctx)

evalEQ :: Context -> EQ a -> a
evalEQ ctx x = fst (runEQ ctx x)

execEQ :: Context -> EQ a -> Context
execEQ ctx x = snd (runEQ ctx x)

liftCC :: CC a -> EQ a
liftCC x = ReaderT (const x)

(=?=) :: Term -> Term -> EQ Bool
t =?= u = liftCC $ do
  x <- flatten t
  y <- flatten u
  x CC.=?= y

unifiable :: Equation -> EQ Bool
unifiable (t :=: u) = t =?= u

(=:=) :: Term -> Term -> EQ Bool
t =:= u = do
  (ctx, d) <- ask
  b <- t =?= u
  unless b $
    forM_ (substs t ctx d ++ substs u ctx d) $ \s -> liftCC $ do
      t' <- subst s t
      u' <- subst s u
      t' CC.=:= u'
  return b

unify :: Equation -> EQ Bool
unify (t :=: u) = t =:= u

type Subst = Symbol -> Int

substs :: Term -> Map TypeRep Universe -> Int -> [Subst]
substs t univ d = map lookup (sequence (map choose vars))
  where vars = map (maximumBy (comparing snd)) .
               partitionBy fst .
               holes $ t

        choose (x, n) =
          let m = Map.findWithDefault
                  (error "Test.QuickSpec.Reasoning.NaiveEquationalReasoning.substs: empty universe")
                  (symbolType x) univ in
          [ (x, t)
          | d' <- [0..d-n],
            t <- IntMap.findWithDefault [] d' m ]

        lookup ss =
          let m = IntMap.fromList [ (index x, y) | (x, y) <- ss ]
          in \x -> IntMap.findWithDefault (index x) (index x) m

subst :: Subst -> Term -> CC Int
subst s (Var x) = return (s x)
subst s (Const x) = return (index x)
subst s (App f x) = do
  f' <- subst s f
  x' <- subst s x
  f' CC.$$ x'

flatten :: Term -> CC Int
flatten = subst index

get :: EQ CC.S
get = liftCC S.get

put :: CC.S -> EQ ()
put x = liftCC (S.put x)

rep :: Term -> EQ Int
rep x = liftCC (flatten x >>= CC.rep)