{-# LANGUAGE MultiParamTypeClasses, TypeSynonymInstances, StandaloneDeriving, GeneralizedNewtypeDeriving #-}

module Language.Haskell.TH.Unification (subTerm, Term(..), UnifT, Explicit(..), unify, solveUnification) where

import Control.Monad
import Data.Map hiding (map)
import Control.Monad.State.Strict
import Control.Monad.Error

data Term f v a = App f (Term f v a) (Term f v a) | Atom a | Var v deriving (Eq, Show)
data Explicit f a = AppE f (Explicit f a) (Explicit f a) | AtomE a deriving (Eq, Show)
type Solution f v a = Map v (Explicit f a)

data Constraint f v a = Term f v a :==: Term f v a
type Constraints f v a = [Constraint f v a]

newtype UnifT f v a m x = UnifT (StateT (Constraints f v a) (ErrorT String m) x)
deriving instance (Monad m) => Monad (UnifT f v a m)
deriving instance (Monad m) => MonadState (Constraints f v a) (UnifT f v a m)

instance MonadTrans (UnifT f v a) where
	lift = UnifT . lift . lift

unify :: (Monad m) => Term f v a -> Term f v a -> UnifT f v a m ()
a `unify` b = modify ((a :==: b):)

runUnification :: (Ord v, Eq f, Eq a, Monad m) => UnifT f v a m x -> m (Either String (Constraints f v a))
runUnification (UnifT m) = runErrorT (execStateT m [])

solveUnification :: (Ord v, Eq f, Eq a, Monad m) => UnifT f v a m x -> m (Either String (x, Solution f v a))
solveUnification (UnifT m) = runErrorT (evalStateT m' [])
	where	m' = do	x <- m
			ans <- solve =<< get
			return (x, ans)

solve :: (Ord v, Eq f, Eq a, Monad m) => Constraints f v a -> m (Solution f v a)
solve (constr:constrs) = case constr of
	Var x :==: Var y
		| x == y	-> solve constrs
	Var x :==: t
		-> subSol x t `liftM` solve (substitute x t constrs)
	t :==: Var y
		-> subSol y t `liftM` solve (substitute y t constrs)
	Atom a :==: Atom b
		| a == b	-> solve constrs
		| otherwise	-> fail "Mismatched atoms"
	App f1 x1 y1 :==: App f2 x2 y2
		| f1 /= f2	-> fail "Mismatched functions"
		| otherwise	-> solve ([x1 :==: x2, y1 :==: y2] ++ constrs)
	_	-> fail "Function matched to atom"
solve [] = return empty

substitute :: (Ord v, Eq f, Eq a) => v -> Term f v a -> Constraints f v a -> Constraints f v a
substitute v t = map (\ (x :==: y) -> sub x :==: sub y) where
	sub (Var v')
		| v == v'	= t
	sub (App f x y) = App f (sub x) (sub y)
	sub t' = t'

subTerm :: Ord v => Solution f v a -> Term f v a -> Explicit f a
subTerm sol (Var v) = sol ! v
subTerm sol (App f x y) = AppE f (subTerm sol x) (subTerm sol y)
subTerm _ (Atom a) = AtomE a

subSol :: (Ord v, Eq f, Eq a) => v -> Term f v a -> Solution f v a -> Solution f v a
subSol v t sol = insert v (subTerm sol t) sol
	
-- test :: UnifT Char String String IO ()
-- test = do	App 'f' (App 'g' (Var "A") (Var "A")) (Var "A") `unify`
-- 			App 'f' (Var "B") (Atom "xyz")