{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Language.Symantic.Typing.Unify where

import Data.Map.Strict (Map)
import Data.Semigroup (Semigroup(..))
import Unsafe.Coerce (unsafeCoerce)
import qualified Data.Map.Strict as Map

import Language.Symantic.Grammar
import Language.Symantic.Typing.Variable
import Language.Symantic.Typing.Kind
import Language.Symantic.Typing.Type
import Language.Symantic.Typing.Show ()

-- * Type 'Subst'
-- | /Type variable substitution/.
--
-- WARNING: a 'Subst' MUST be without loops, and fully expanded.
newtype Subst src vs
 =      Subst (Map IndexVar (VT src vs))
deriving instance Source src => Show (Subst src vs)
instance Semigroup (Subst src vs) where
        (<>) = unionSubst
instance Monoid (Subst src vs) where
        mempty  = Subst Map.empty
        mappend = (<>)

-- | Unify two 'Subst's.
--
-- NOTE: the union is left-biased: in case of duplicate 'Var's,
-- it keeps the one from the first 'Subst' given.
--
-- NOTE: the first 'Subst' given is applied to the second (with 'subst'),
-- this way each 'Var' directly maps to an expanded 'Type',
-- so that, when using the resulting 'Subst',
-- there is no need to apply it multiple times
-- until there is no more substitution to be done.
unionSubst :: Subst src vs -> Subst src vs -> Subst src vs
unionSubst sx@(Subst x) (Subst y) = Subst $ x `Map.union` ((\(VT v r) -> VT v $ subst sx r) <$> y)

-- * Type 'VT'
-- | A 'Var' and a 'Type' existentialized over their type index.
data VT src vs = forall t. VT (Var src vs t) (Type src vs t)
deriving instance Source src => Show (VT src vs)

insertSubst :: Var src vs v -> Type src vs v -> Subst src vs -> Subst src vs
insertSubst v t (Subst s) = Subst $ Map.insert (indexVar v) (VT v t) s

lookupSubst :: Var src vs v -> Subst src vs -> Maybe (Type src vs v)
lookupSubst v (Subst s)
 | Just (VT v' t) <- Map.lookup (indexVar v) s
 , Just HRefl <- v `eqVarKi` v'
 = Just t
lookupSubst _v _m = Nothing

-- * Class 'Substable'
class Substable a where
        -- | Like 'substVar', but without the /occurence check/.
        substVarUnsafe ::
         src ~ SourceOf a =>
         vs ~ VarsOf a =>
         Var src vs v -> Type src vs v -> a -> a
        -- | Substitute all the 'Var's which have a match in given 'Subst'.
        subst ::
         src ~ SourceOf a =>
         vs ~ VarsOf a =>
         Subst src vs -> a -> a
instance Substable (Type src vs t) where
        substVarUnsafe _v _r t@TyConst{} = t
        substVarUnsafe v r (TyApp src f a) =
                TyApp src
                 (substVarUnsafe v r f)
                 (substVarUnsafe v r a)
        substVarUnsafe v r t@(TyVar _src _n vt) =
                case v `eqVarKi` vt of
                 Just HRefl -> r
                 Nothing -> t
        substVarUnsafe v r (TyFam src len fam as) =
                TyFam src len fam $ substVarUnsafe v r as

        subst _s t@TyConst{} = t
        subst s (TyApp src f a) = TyApp src (subst s f) (subst s a)
        subst (Subst s) t@(TyVar _src _n v) =
                case indexVar v `Map.lookup` s of
                 Nothing -> t
                 Just (VT vr r) ->
                        case v `eqVarKi` vr of
                         Nothing -> error "[BUG] subst: kind mismatch"
                         Just HRefl -> r
        subst s (TyFam src len fam as) = TyFam src len fam $ subst s as
instance Substable (Types src vs ts) where
        substVarUnsafe _v _r TypesZ = TypesZ
        substVarUnsafe v r (TypesS t ts) =
                substVarUnsafe v r t `TypesS`
                substVarUnsafe v r ts
        subst _s TypesZ = TypesZ
        subst s (TypesS t ts) = subst s t `TypesS` subst s ts

-- | Substitute the given 'Var' by the given 'Type',
-- returning 'Nothing' if this 'Type' contains
-- the 'Var' (occurence check).
substVar ::
 src ~ SourceOf a =>
 vs ~ VarsOf a =>
 Source src => VarOccursIn a => Substable a =>
 Var src vs v -> Type src vs v -> a -> Maybe a
substVar v r t =
        if v `varOccursIn` r
        then Nothing -- NOTE: occurence check
        else Just $ substVarUnsafe v r t

-- ** Type 'Error_Unify'
-- | Reasons why two 'Type's cannot be unified.
data Error_Unify src
 =   Error_Unify_Var_loop IndexVar (TypeVT src)
     -- ^ /occurence check/: a 'Var' is unified with a 'Type'
     --   which contains this same 'Var'.
 |   Error_Unify_Const_mismatch (TypeVT src) (TypeVT src)
     -- ^ Two 'TyConst's should be the same, but are different.
 |   Error_Unify_Kind_mismatch (KindK src) (KindK src)
     -- ^ Two 'Kind's should be the same, but are different.
 |   Error_Unify_Kind (Con_Kind src)
     -- ^ Two 'Kind's mismatch.
 |   Error_Unify_mismatch (TypeVT src) (TypeVT src)
     -- ^ Cannot unify those two 'Type's.

deriving instance Source src => Eq   (Error_Unify src)
deriving instance Source src => Show (Error_Unify src)

instance ErrorInj (Error_Unify src) (Error_Unify src) where
        errorInj = id
instance ErrorInj (Con_Kind src) (Error_Unify src) where
        errorInj = Error_Unify_Kind

-- | Return the left spine of a 'Type':
-- the root 'Type' and its 'Type' parameters,
-- from the left to the right.
spineTy ::
 forall src vs t.
 Source src =>
 SourceInj (TypeVT src) src =>
 Type src vs t ->
 (TypeT src vs, [TypeT src vs])
spineTy typ = go [] typ
        where
        go :: forall kx (x::kx). [TypeT src vs] -> Type src vs x -> (TypeT src vs, [TypeT src vs])
        go ctx (TyApp _ (TyApp _ (TyConst _ _ c) _q) t)
         | Just HRefl <- proj_ConstKi @(K (#>)) @(#>) c
         = go ctx t -- NOTE: skip the constraint @q@.
        go ctx (TyApp  _src f a) = go (TypeT (a `withSource` TypeVT typ) : ctx) f
        go ctx t = (TypeT (t `withSource` TypeVT typ), ctx)

{-
spineTy
 :: Type src ctx ss cs (t::k)
 -> (forall kx (x::kx) xs. Type src ctx ss cs x -> Types src ctx ss cs xs -> ret)
 -> ret
spineTy = go TypesZ
	where
	go :: Types src ctx ss cs hs
	   -> Type src ctx ss cs (t::k)
	   -> (forall kx (x::kx) xs. Type src ctx ss cs x -> Types src ctx ss cs xs -> ret)
	   -> ret
	go ctx (TyApp  _src f a) k = go (a `TypesS` ctx) f k
	go ctx (Term x _te)    k = go ctx x k
	go ctx (TyAny x)       k = go ctx x k
	go ctx t k = k t ctx
-}

-- | Return the /most general unification/ of two 'Type's, when it exists.
unifyType ::
 forall ki (x::ki) (y::ki) vs src.
 SourceInj (TypeVT src) src =>
 ErrorInj (Con_Kind src) (Error_Unify src) =>
 Subst src vs ->
 Type src vs (x::ki) ->
 Type src vs (y::ki) ->
 Either (Error_Unify src) (Subst src vs)
unifyType vs x y =
        let k = kindOfType x in
        case (spineTy x, spineTy y) of
         ((TypeT hx, px), (TypeT hy, py)) ->
                case (hx, hy) of
                 (TyVar _ _n vx, _) | Just Refl <- k `eqKind` kindOfVar vx -> goVar vs vx y
                 (_, TyVar _ _n vy) | Just Refl <- k `eqKind` kindOfVar vy -> goVar vs vy x
                 (TyConst _sx _lx cx, TyConst _sy _ly cy)
                        | Just HRefl <- cx `eqConstKi` cy -> goList vs px py
                        | otherwise -> Left $ Error_Unify_Const_mismatch (TypeVT hx) (TypeVT hy)
                 _ ->
                        case (x, y) of
                         (TyApp _ fx ax, TyApp _ fy ay) ->
                                goList vs
                                 [TypeT fx, TypeT ax]
                                 [TypeT fy, TypeT ay]
                         _ -> Left $ Error_Unify_mismatch (TypeVT x) (TypeVT y)
        where
        goVar ::
         forall k (a::k) (b::k) vs'.
         Subst src vs' -> Var src vs' a -> Type src vs' b ->
         Either (Error_Unify src) (Subst src vs')
        goVar vs' va b =
                case va `lookupSubst` vs' of
                 Just a -> unifyType vs' b a
                 Nothing ->
                        case vs' `subst` b of
                         TyVar _src _kb vb | Just HRefl <- va `eqVarKi` vb -> Right vs'
                         b' | va `varOccursIn` b' -> Left $ Error_Unify_Var_loop (indexVar va) (TypeVT b')
                            | Refl :: a :~: b <- unsafeCoerce Refl ->
                                Right $ insertSubst va b' mempty <> vs'
        goList ::
         forall vs'.
         Subst src vs' -> [TypeT src vs'] -> [TypeT src vs'] ->
         Either (Error_Unify src) (Subst src vs')
        goList vs' [] [] = Right vs'
        goList vs' (TypeT a:as) (TypeT b:bs) =
                when_EqKind (kindOfType a) (kindOfType b) $ \Refl ->
                unifyType vs' a b >>= \vs'' -> goList vs'' as bs
        goList _vs _a _b = error "[BUG] unifyType: kinds mismatch"