-----------------------------------------------------------------------------
-- |
-- Module      :  Disco.Typecheck.Unify
-- Copyright   :  disco team and contributors
-- Maintainer  :  byorgey@gmail.com
--
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Unification.
--
-----------------------------------------------------------------------------

module Disco.Typecheck.Unify where

import           Unbound.Generics.LocallyNameless (Name, fv)

import           Control.Lens                     (anyOf)
import           Control.Monad.State
import qualified Data.Map                         as M
import           Data.Set                         (Set)
import qualified Data.Set                         as S

import           Disco.Subst
import           Disco.Types

-- XXX todo: might be better if unification took sorts into account
-- directly.  As it is, however, I think it works properly;
-- e.g. suppose we have a with sort {sub} and we unify it with Bool.
-- unify will just return a substitution [a |-> Bool].  But then when
-- we call extendSubst, and in particular substSortMap, the sort {sub}
-- will be applied to Bool and decomposed which will throw an error.

-- | Given a list of equations between types, return a substitution
--   which makes all the equations satisfied (or fail if it is not
--   possible).
--
--   This is not the most efficient way to implement unification but
--   it is simple.
unify :: TyDefCtx -> [(Type, Type)] -> Maybe S
unify :: TyDefCtx -> [(Type, Type)] -> Maybe S
unify = (BaseTy -> BaseTy -> Bool) -> TyDefCtx -> [(Type, Type)] -> Maybe S
unify' BaseTy -> BaseTy -> Bool
forall a. Eq a => a -> a -> Bool
(==)

-- | Given a list of equations between types, return a substitution
--   which makes all the equations equal *up to* identifying all base
--   types.  So, for example, Int = Nat weakly unifies but Int = (Int
--   -> Int) does not.  This is used to check whether subtyping
--   constraints are structurally sound before doing constraint
--   simplification/solving, to ensure termination.
weakUnify :: TyDefCtx -> [(Type, Type)] -> Maybe S
weakUnify :: TyDefCtx -> [(Type, Type)] -> Maybe S
weakUnify = (BaseTy -> BaseTy -> Bool) -> TyDefCtx -> [(Type, Type)] -> Maybe S
unify' (\BaseTy
_ BaseTy
_ -> Bool
True)

-- | Given a list of equations between types, return a substitution
--   which makes all the equations satisfied (or fail if it is not
--   possible), up to the given comparison on base types.
unify' :: (BaseTy -> BaseTy -> Bool) -> TyDefCtx
       -> [(Type, Type)] -> Maybe S
unify' :: (BaseTy -> BaseTy -> Bool) -> TyDefCtx -> [(Type, Type)] -> Maybe S
unify' BaseTy -> BaseTy -> Bool
baseEq TyDefCtx
tyDefns [(Type, Type)]
eqs = StateT (Set (Type, Type)) Maybe S -> Set (Type, Type) -> Maybe S
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go [(Type, Type)]
eqs) Set (Type, Type)
forall a. Set a
S.empty
  where
    go :: [(Type, Type)] -> StateT (Set (Type,Type)) Maybe S
    go :: [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go [] = S -> StateT (Set (Type, Type)) Maybe S
forall (m :: * -> *) a. Monad m => a -> m a
return S
forall a. Substitution a
idS
    go ((Type, Type)
e:[(Type, Type)]
es) = do
      Either S [(Type, Type)]
u <- (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne (Type, Type)
e
      case Either S [(Type, Type)]
u of
        Left S
sub    -> (S -> S -> S
forall a.
Subst a a =>
Substitution a -> Substitution a -> Substitution a
@@ S
sub) (S -> S)
-> StateT (Set (Type, Type)) Maybe S
-> StateT (Set (Type, Type)) Maybe S
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go (S -> [(Type, Type)] -> [(Type, Type)]
forall b a. Subst b a => Substitution b -> a -> a
applySubst S
sub [(Type, Type)]
es)
        Right [(Type, Type)]
newEs -> [(Type, Type)] -> StateT (Set (Type, Type)) Maybe S
go ([(Type, Type)]
newEs [(Type, Type)] -> [(Type, Type)] -> [(Type, Type)]
forall a. [a] -> [a] -> [a]
++ [(Type, Type)]
es)

    unifyOne :: (Type, Type) -> StateT (Set (Type,Type)) Maybe (Either S [(Type, Type)])
    unifyOne :: (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne (Type, Type)
pair = do
      Set (Type, Type)
seen <- StateT (Set (Type, Type)) Maybe (Set (Type, Type))
forall s (m :: * -> *). MonadState s m => m s
get
      case (Type, Type)
pair (Type, Type) -> Set (Type, Type) -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set (Type, Type)
seen of
        Bool
True  -> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ S -> Either S [(Type, Type)]
forall a b. a -> Either a b
Left S
forall a. Substitution a
idS
        Bool
False -> (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne' (Type, Type)
pair

    unifyOne' :: (Type, Type) -> StateT (Set (Type,Type)) Maybe (Either S [(Type, Type)])

    unifyOne' :: (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne' (Type
ty1, Type
ty2)
      | Type
ty1 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
ty2 = Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ S -> Either S [(Type, Type)]
forall a b. a -> Either a b
Left S
forall a. Substitution a
idS

    unifyOne' (TyVar Name Type
x, Type
ty2)
      | Name Type -> Type -> Bool
occurs Name Type
x Type
ty2 = StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero
      | Bool
otherwise    = Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ S -> Either S [(Type, Type)]
forall a b. a -> Either a b
Left (Name Type
x Name Type -> Type -> S
forall a. Name a -> a -> Substitution a
|-> Type
ty2)
    unifyOne' (Type
ty1, x :: Type
x@(TyVar Name Type
_))
      = (Type, Type)
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
unifyOne (Type
x, Type
ty1)

    -- At this point we know ty2 isn't the same skolem nor a unification variable.
    -- Skolems don't unify with anything.
    unifyOne' (TySkolem Name Type
_, Type
_) = StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    unifyOne' (Type
_, TySkolem Name Type
_) = StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero

    -- Unify two container types: unify the container descriptors as
    -- well as the type arguments
    unifyOne' p :: (Type, Type)
p@(TyCon (CContainer Atom
ctr1) [Type]
tys1, TyCon (CContainer Atom
ctr2) [Type]
tys2) = do
      (Set (Type, Type) -> Set (Type, Type))
-> StateT (Set (Type, Type)) Maybe ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
      Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> Either S [(Type, Type)]
forall a b. b -> Either a b
Right ((Atom -> Type
TyAtom Atom
ctr1, Atom -> Type
TyAtom Atom
ctr2) (Type, Type) -> [(Type, Type)] -> [(Type, Type)]
forall a. a -> [a] -> [a]
: [Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys1 [Type]
tys2)

    -- If one of the types to be unified is a user-defined type,
    -- unfold its definition before continuing the matching
    unifyOne' p :: (Type, Type)
p@(TyCon (CUser String
t) [Type]
tys1, Type
ty2) = do
      (Set (Type, Type) -> Set (Type, Type))
-> StateT (Set (Type, Type)) Maybe ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
      case String -> TyDefCtx -> Maybe TyDefBody
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
t TyDefCtx
tyDefns of
        Maybe TyDefBody
Nothing                 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero
        Just (TyDefBody [String]
_ [Type] -> Type
body) -> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> Either S [(Type, Type)]
forall a b. b -> Either a b
Right [([Type] -> Type
body [Type]
tys1, Type
ty2)]

    unifyOne' p :: (Type, Type)
p@(Type
ty1, TyCon (CUser String
t) [Type]
tys2) = do
      (Set (Type, Type) -> Set (Type, Type))
-> StateT (Set (Type, Type)) Maybe ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
      case String -> TyDefCtx -> Maybe TyDefBody
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
t TyDefCtx
tyDefns of
        Maybe TyDefBody
Nothing                 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero
        Just (TyDefBody [String]
_ [Type] -> Type
body) -> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> Either S [(Type, Type)]
forall a b. b -> Either a b
Right [(Type
ty1, [Type] -> Type
body [Type]
tys2)]

    -- Unify any other pair of type constructor applications: the type
    -- constructors must match exactly
    unifyOne' p :: (Type, Type)
p@(TyCon Con
c1 [Type]
tys1, TyCon Con
c2 [Type]
tys2)
      | Con
c1 Con -> Con -> Bool
forall a. Eq a => a -> a -> Bool
== Con
c2  = do
          (Set (Type, Type) -> Set (Type, Type))
-> StateT (Set (Type, Type)) Maybe ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Type, Type) -> Set (Type, Type) -> Set (Type, Type)
forall a. Ord a => a -> Set a -> Set a
S.insert (Type, Type)
p)
          Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ [(Type, Type)] -> Either S [(Type, Type)]
forall a b. b -> Either a b
Right ([Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys1 [Type]
tys2)
      | Bool
otherwise = StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    unifyOne' (TyAtom (ABase BaseTy
b1), TyAtom (ABase BaseTy
b2))
      | BaseTy -> BaseTy -> Bool
baseEq BaseTy
b1 BaseTy
b2 = Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Either S [(Type, Type)]
 -> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)]))
-> Either S [(Type, Type)]
-> StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall a b. (a -> b) -> a -> b
$ S -> Either S [(Type, Type)]
forall a b. a -> Either a b
Left S
forall a. Substitution a
idS
      | Bool
otherwise    = StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero
    unifyOne' (Type, Type)
_ = StateT (Set (Type, Type)) Maybe (Either S [(Type, Type)])
forall (m :: * -> *) a. MonadPlus m => m a
mzero  -- Atom = Cons


equate :: TyDefCtx -> [Type] -> Maybe S
equate :: TyDefCtx -> [Type] -> Maybe S
equate TyDefCtx
tyDefns [Type]
tys = TyDefCtx -> [(Type, Type)] -> Maybe S
unify TyDefCtx
tyDefns [(Type, Type)]
eqns
  where
    eqns :: [(Type, Type)]
eqns = [Type] -> [Type] -> [(Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
tys ([Type] -> [Type]
forall a. [a] -> [a]
tail [Type]
tys)

occurs :: Name Type -> Type -> Bool
occurs :: Name Type -> Type -> Bool
occurs Name Type
x = Getting Any Type (Name Type) -> (Name Type -> Bool) -> Type -> Bool
forall s a. Getting Any s a -> (a -> Bool) -> s -> Bool
anyOf Getting Any Type (Name Type)
forall a (f :: * -> *) b.
(Alpha a, Typeable b, Contravariant f, Applicative f) =>
(Name b -> f (Name b)) -> a -> f a
fv (Name Type -> Name Type -> Bool
forall a. Eq a => a -> a -> Bool
==Name Type
x)


unifyAtoms :: TyDefCtx -> [Atom] -> Maybe (Substitution Atom)
unifyAtoms :: TyDefCtx -> [Atom] -> Maybe (Substitution Atom)
unifyAtoms TyDefCtx
tyDefns = (S -> Substitution Atom) -> Maybe S -> Maybe (Substitution Atom)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Type -> Atom) -> S -> Substitution Atom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Atom
fromTyAtom) (Maybe S -> Maybe (Substitution Atom))
-> ([Atom] -> Maybe S) -> [Atom] -> Maybe (Substitution Atom)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyDefCtx -> [Type] -> Maybe S
equate TyDefCtx
tyDefns ([Type] -> Maybe S) -> ([Atom] -> [Type]) -> [Atom] -> Maybe S
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Atom -> Type) -> [Atom] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Atom -> Type
TyAtom
  where
    fromTyAtom :: Type -> Atom
fromTyAtom (TyAtom Atom
a) = Atom
a
    fromTyAtom Type
_          = String -> Atom
forall a. HasCallStack => String -> a
error String
"fromTyAtom on non-TyAtom!"

unifyUAtoms :: TyDefCtx -> [UAtom] -> Maybe (Substitution UAtom)
unifyUAtoms :: TyDefCtx -> [UAtom] -> Maybe (Substitution UAtom)
unifyUAtoms TyDefCtx
tyDefns = (S -> Substitution UAtom) -> Maybe S -> Maybe (Substitution UAtom)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Type -> UAtom) -> S -> Substitution UAtom
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> UAtom
fromTyAtom) (Maybe S -> Maybe (Substitution UAtom))
-> ([UAtom] -> Maybe S) -> [UAtom] -> Maybe (Substitution UAtom)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyDefCtx -> [Type] -> Maybe S
equate TyDefCtx
tyDefns ([Type] -> Maybe S) -> ([UAtom] -> [Type]) -> [UAtom] -> Maybe S
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UAtom -> Type) -> [UAtom] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Atom -> Type
TyAtom (Atom -> Type) -> (UAtom -> Atom) -> UAtom -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UAtom -> Atom
uatomToAtom)
  where
    fromTyAtom :: Type -> UAtom
fromTyAtom (TyAtom (ABase BaseTy
b))    = BaseTy -> UAtom
UB BaseTy
b
    fromTyAtom (TyAtom (AVar (U Name Type
v))) = Name Type -> UAtom
UV Name Type
v
    fromTyAtom Type
_                     = String -> UAtom
forall a. HasCallStack => String -> a
error String
"fromTyAtom on wrong thing!"