{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Language.Symantic.Typing.Unify where

import Data.Map.Strict (Map)
import Data.Semigroup (Semigroup(..))
import Unsafe.Coerce (unsafeCoerce)
import qualified Data.Map.Strict as Map

import Language.Symantic.Grammar
import Language.Symantic.Typing.Variable
import Language.Symantic.Typing.Kind
import Language.Symantic.Typing.Type
import Language.Symantic.Typing.Show ()

-- * Type 'Subst'
-- | /Type variable substitution/.
--
-- WARNING: a 'Subst' MUST be without loops, and fully expanded.
newtype Subst src vs
 =      Subst (Map IndexVar (VT src vs))
deriving instance Source src => Show (Subst src vs)
instance Semigroup (Subst src vs) where
	(<>) = unionSubst
instance Monoid (Subst src vs) where
	mempty  = Subst Map.empty
	mappend = (<>)

-- | Unify two 'Subst's.
--
-- NOTE: the union is left-biased: in case of duplicate 'Var's,
-- it keeps the one from the first 'Subst' given.
--
-- NOTE: the first 'Subst' given is applied to the second (with 'subst'),
-- this way each 'Var' directly maps to an expanded 'Type',
-- so that, when using the resulting 'Subst',
-- there is no need to apply it multiple times
-- until there is no more substitution to be done.
unionSubst :: Subst src vs -> Subst src vs -> Subst src vs
unionSubst sx@(Subst x) (Subst y) = Subst $ x `Map.union` ((\(VT v r) -> VT v $ subst sx r) <$> y)

-- * Type 'VT'
-- | A 'Var' and a 'Type' existentialized over their type index.
data VT src vs = forall t. VT (Var src vs t) (Type src vs t)
deriving instance Source src => Show (VT src vs)

insertSubst :: Var src vs v -> Type src vs v -> Subst src vs -> Subst src vs
insertSubst v t (Subst s) = Subst $ Map.insert (indexVar v) (VT v t) s

lookupSubst :: Var src vs v -> Subst src vs -> Maybe (Type src vs v)
lookupSubst v (Subst s)
 | Just (VT v' t) <- Map.lookup (indexVar v) s
 , Just HRefl <- v `eqVarKi` v'
 = Just t
lookupSubst _v _m = Nothing

-- * Class 'Substable'
class Substable a where
	-- | Like 'substVar', but without the /occurence check/.
	substVarUnsafe ::
	 src ~ SourceOf a =>
	 vs ~ VarsOf a =>
	 Var src vs v -> Type src vs v -> a -> a
	-- | Substitute all the 'Var's which have a match in given 'Subst'.
	subst ::
	 src ~ SourceOf a =>
	 vs ~ VarsOf a =>
	 Subst src vs -> a -> a
instance Substable (Type src vs t) where
	substVarUnsafe _v _r t@TyConst{} = t
	substVarUnsafe v r (TyApp src f a) =
		TyApp src
		 (substVarUnsafe v r f)
		 (substVarUnsafe v r a)
	substVarUnsafe v r t@(TyVar _src _n vt) =
		case v `eqVarKi` vt of
		 Just HRefl -> r
		 Nothing -> t
	substVarUnsafe v r (TyFam src len fam as) =
		TyFam src len fam $ substVarUnsafe v r as
	
	subst _s t@TyConst{} = t
	subst s (TyApp src f a) = TyApp src (subst s f) (subst s a)
	subst (Subst s) t@(TyVar _src _n v) =
		case indexVar v `Map.lookup` s of
		 Nothing -> t
		 Just (VT vr r) ->
			case v `eqVarKi` vr of
			 Nothing -> error "[BUG] subst: kind mismatch"
			 Just HRefl -> r
	subst s (TyFam src len fam as) = TyFam src len fam $ subst s as
instance Substable (Types src vs ts) where
	substVarUnsafe _v _r TypesZ = TypesZ
	substVarUnsafe v r (TypesS t ts) =
		substVarUnsafe v r t `TypesS`
		substVarUnsafe v r ts
	subst _s TypesZ = TypesZ
	subst s (TypesS t ts) = subst s t `TypesS` subst s ts

-- | Substitute the given 'Var' by the given 'Type',
-- returning 'Nothing' if this 'Type' contains
-- the 'Var' (occurence check).
substVar ::
 src ~ SourceOf a =>
 vs ~ VarsOf a =>
 Source src => VarOccursIn a => Substable a =>
 Var src vs v -> Type src vs v -> a -> Maybe a
substVar v r t =
	if v `varOccursIn` r
	then Nothing -- NOTE: occurence check
	else Just $ substVarUnsafe v r t

-- ** Type 'Error_Unify'
-- | Reasons why two 'Type's cannot be unified.
data Error_Unify src
 =   Error_Unify_Var_loop IndexVar (TypeVT src)
     -- ^ /occurence check/: a 'Var' is unified with a 'Type'
     --   which contains this same 'Var'.
 |   Error_Unify_Const_mismatch (TypeVT src) (TypeVT src)
     -- ^ Two 'TyConst's should be the same, but are different.
 |   Error_Unify_Kind_mismatch (KindK src) (KindK src)
     -- ^ Two 'Kind's should be the same, but are different.
 |   Error_Unify_Kind (Con_Kind src)
     -- ^ Two 'Kind's mismatch.
 |   Error_Unify_mismatch (TypeVT src) (TypeVT src)
     -- ^ Cannot unify those two 'Type's.

deriving instance Source src => Eq   (Error_Unify src)
deriving instance Source src => Show (Error_Unify src)

instance ErrorInj (Error_Unify src) (Error_Unify src) where
	errorInj = id
instance ErrorInj (Con_Kind src) (Error_Unify src) where
	errorInj = Error_Unify_Kind

-- | Return the left spine of a 'Type':
-- the root 'Type' and its 'Type' parameters,
-- from the left to the right.
spineTy ::
 forall src vs t.
 Source src =>
 SourceInj (TypeVT src) src =>
 Type src vs t ->
 (TypeT src vs, [TypeT src vs])
spineTy typ = go [] typ
	where
	go :: forall x. [TypeT src vs] -> Type src vs x -> (TypeT src vs, [TypeT src vs])
	go ctx (TyApp _ (TyApp _ (TyConst _ _ c) _q) t)
	 | Just HRefl <- proj_ConstKi @(K (#>)) @(#>) c
	 = go ctx t -- NOTE: skip the constraint @q@.
	go ctx (TyApp  _src f a) = go (TypeT (a `withSource` TypeVT typ) : ctx) f
	go ctx t = (TypeT (t `withSource` TypeVT typ), ctx)

{-
spineTy
 :: Type src ctx ss cs (t::k)
 -> (forall kx (x::kx) xs. Type src ctx ss cs x -> Types src ctx ss cs xs -> ret)
 -> ret
spineTy = go TypesZ
	where
	go :: Types src ctx ss cs hs
	   -> Type src ctx ss cs (t::k)
	   -> (forall kx (x::kx) xs. Type src ctx ss cs x -> Types src ctx ss cs xs -> ret)
	   -> ret
	go ctx (TyApp  _src f a) k = go (a `TypesS` ctx) f k
	go ctx (Term x _te)    k = go ctx x k
	go ctx (TyAny x)       k = go ctx x k
	go ctx t k = k t ctx
-}

-- | Return the /most general unification/ of two 'Type's, when it exists.
unifyType ::
 forall ki (x::ki) (y::ki) vs src.
 SourceInj (TypeVT src) src =>
 ErrorInj (Con_Kind src) (Error_Unify src) =>
 Subst src vs ->
 Type src vs (x::ki) ->
 Type src vs (y::ki) ->
 Either (Error_Unify src) (Subst src vs)
unifyType vs x y =
	let k = kindOfType x in
	case (spineTy x, spineTy y) of
	 ((TypeT hx, px), (TypeT hy, py)) ->
		case (hx, hy) of
		 (TyVar _ _n vx, _) | Just Refl <- k `eqKind` kindOfVar vx -> goVar vs vx y
		 (_, TyVar _ _n vy) | Just Refl <- k `eqKind` kindOfVar vy -> goVar vs vy x
		 (TyConst _sx _lx cx, TyConst _sy _ly cy)
			| Just HRefl <- cx `eqConstKi` cy -> goList vs px py
			| otherwise -> Left $ Error_Unify_Const_mismatch (TypeVT hx) (TypeVT hy)
		 _ ->
			case (x, y) of
			 (TyApp _ fx ax, TyApp _ fy ay) ->
				goList vs
				 [TypeT fx, TypeT ax]
				 [TypeT fy, TypeT ay]
			 _ -> Left $ Error_Unify_mismatch (TypeVT x) (TypeVT y)
	where
	goVar ::
	 forall k (a::k) (b::k) vs'.
	 Subst src vs' -> Var src vs' a -> Type src vs' b ->
	 Either (Error_Unify src) (Subst src vs')
	goVar vs' va b =
		case va `lookupSubst` vs' of
		 Just a -> unifyType vs' b a
		 Nothing ->
			case vs' `subst` b of
			 TyVar _src _kb vb | Just HRefl <- va `eqVarKi` vb -> Right vs'
			 b' | va `varOccursIn` b' -> Left $ Error_Unify_Var_loop (indexVar va) (TypeVT b')
			    | Refl :: a :~: b <- unsafeCoerce Refl ->
				Right $ insertSubst va b' mempty <> vs'
	goList ::
	 forall vs'.
	 Subst src vs' -> [TypeT src vs'] -> [TypeT src vs'] ->
	 Either (Error_Unify src) (Subst src vs')
	goList vs' [] [] = Right vs'
	goList vs' (TypeT a:as) (TypeT b:bs) =
		when_EqKind (kindOfType a) (kindOfType b) $ \Refl ->
		unifyType vs' a b >>= \vs'' -> goList vs'' as bs
	goList _vs _a _b = error "[BUG] unifyType: kinds mismatch"