-- | Hindley-Milner style type unification

module Hydra.Unification (
  Constraint,
  solveConstraints,
) where

import Hydra.Basics
import Hydra.Strip
import Hydra.Compute
import Hydra.Core
import Hydra.Lexical
import Hydra.Printing
import Hydra.Rewriting
import Hydra.Substitution
import Hydra.Tier1
import Hydra.Dsl.Types as Types
import Hydra.Lib.Io

import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Set as S


type Constraint = (Type, Type)

type Unifier = (Subst, [Constraint])

-- Note: type variables in Hydra are allowed to bind to type expressions which contain the variable;
--       i.e. type recursion by name is allowed.
bind :: Name -> Type -> Flow s Subst
bind :: forall s. Name -> Type -> Flow s Subst
bind Name
name Type
typ = do
  if Type
typ Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Type
TypeVariable Name
name
  then Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty
  else if Name -> Type -> Bool
variableOccursInType Name
name Type
typ
--     then fail $ "infinite type for " ++ unName name ++ ": " ++ show typ
    then Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty
    else Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Flow s Subst) -> Subst -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ Name -> Type -> Subst
forall k a. k -> a -> Map k a
M.singleton Name
name Type
typ

solveConstraints :: [Constraint] -> Flow s Subst
solveConstraints :: forall s. [Constraint] -> Flow s Subst
solveConstraints [Constraint]
cs = Unifier -> Flow s Subst
forall s. Unifier -> Flow s Subst
unificationSolver (Subst
forall k a. Map k a
M.empty, [Constraint]
cs)

unificationSolver :: Unifier -> Flow s Subst
unificationSolver :: forall s. Unifier -> Flow s Subst
unificationSolver (Subst
su, [Constraint]
cs) = case [Constraint]
cs of
  [] -> Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
su
  ((Type
t1, Type
t2):[Constraint]
rest) -> do
    Subst
su1  <- Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
t2
    Unifier -> Flow s Subst
forall s. Unifier -> Flow s Subst
unificationSolver (
      Subst -> Subst -> Subst
composeSubst Subst
su1 Subst
su,
      (\(Type
t1, Type
t2) -> (Subst -> Type -> Type
substituteInType Subst
su1 Type
t1, Subst -> Type -> Type
substituteInType Subst
su1 Type
t2)) (Constraint -> Constraint) -> [Constraint] -> [Constraint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Constraint]
rest)

unify :: Type -> Type -> Flow s Subst
unify :: forall s. Type -> Type -> Flow s Subst
unify Type
ltyp Type
rtyp = do
--     withTrace ("unify " ++ show ltyp ++ " with " ++ show rtyp) $
     case (Type -> Type
stripType Type
ltyp, Type -> Type
stripType Type
rtyp) of
       -- Symmetric patterns
      (TypeApplication (ApplicationType Type
lhs1 Type
rhs1), TypeApplication (ApplicationType Type
lhs2 Type
rhs2)) ->
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type
lhs1, Type
rhs1] [Type
lhs2, Type
rhs2]
      (TypeFunction (FunctionType Type
dom1 Type
cod1), TypeFunction (FunctionType Type
dom2 Type
cod2)) ->
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type
dom1, Type
cod1] [Type
dom2, Type
cod2]
      (TypeList Type
lt1, TypeList Type
lt2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
lt1 Type
lt2
      (TypeLiteral LiteralType
lt1, TypeLiteral LiteralType
lt2) -> [Char] -> Bool -> Flow s Subst
forall {m :: * -> *} {k} {a}.
MonadFail m =>
[Char] -> Bool -> m (Map k a)
verify [Char]
"different literal types" (Bool -> Flow s Subst) -> Bool -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ LiteralType
lt1 LiteralType -> LiteralType -> Bool
forall a. Eq a => a -> a -> Bool
== LiteralType
lt2
      (TypeMap (MapType Type
k1 Type
v1), TypeMap (MapType Type
k2 Type
v2)) -> [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type
k1, Type
v1] [Type
k2, Type
v2]
      (TypeOptional Type
ot1, TypeOptional Type
ot2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
ot1 Type
ot2
      (TypeProduct [Type]
types1, TypeProduct [Type]
types2) -> [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type]
types1 [Type]
types2
      (TypeRecord RowType
rt1, TypeRecord RowType
rt2) -> do
        [Char] -> Bool -> Flow s (Map Any Any)
forall {m :: * -> *} {k} {a}.
MonadFail m =>
[Char] -> Bool -> m (Map k a)
verify [Char]
"different record type names" (RowType -> Name
rowTypeTypeName RowType
rt1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== RowType -> Name
rowTypeTypeName RowType
rt2)
        [Char] -> Bool -> Flow s (Map Any Any)
forall {m :: * -> *} {k} {a}.
MonadFail m =>
[Char] -> Bool -> m (Map k a)
verify [Char]
"different number of record fields" ([FieldType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length (RowType -> [FieldType]
rowTypeFields RowType
rt1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [FieldType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length (RowType -> [FieldType]
rowTypeFields RowType
rt2))
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany (FieldType -> Type
fieldTypeType (FieldType -> Type) -> [FieldType] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RowType -> [FieldType]
rowTypeFields RowType
rt1) (FieldType -> Type
fieldTypeType (FieldType -> Type) -> [FieldType] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RowType -> [FieldType]
rowTypeFields RowType
rt2)
      (TypeSet Type
st1, TypeSet Type
st2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
st1 Type
st2
      (TypeUnion RowType
rt1, TypeUnion RowType
rt2) -> [Char] -> Bool -> Flow s Subst
forall {m :: * -> *} {k} {a}.
MonadFail m =>
[Char] -> Bool -> m (Map k a)
verify [Char]
"different union type names" (RowType -> Name
rowTypeTypeName RowType
rt1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== RowType -> Name
rowTypeTypeName RowType
rt2)
      (TypeLambda (LambdaType (Name [Char]
v1) Type
body1), TypeLambda (LambdaType (Name [Char]
v2) Type
body2)) ->
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [[Char] -> Type
Types.var [Char]
v1, Type
body1] [[Char] -> Type
Types.var [Char]
v2, Type
body2]
      (TypeSum [Type]
types1, TypeSum [Type]
types2) -> [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type]
types1 [Type]
types2
      (TypeWrap WrappedType
n1, TypeWrap WrappedType
n2) -> [Char] -> Bool -> Flow s Subst
forall {m :: * -> *} {k} {a}.
MonadFail m =>
[Char] -> Bool -> m (Map k a)
verify [Char]
"different wrapper type names" (Bool -> Flow s Subst) -> Bool -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ WrappedType
n1 WrappedType -> WrappedType -> Bool
forall a. Eq a => a -> a -> Bool
== WrappedType
n2

      -- Asymmetric patterns
      (TypeVariable Name
v1, TypeVariable Name
v2) -> Name -> Name -> Flow s Subst
forall {s}. Name -> Name -> Flow s Subst
bindWeakest Name
v1 Name
v2
      (TypeVariable Name
v, Type
t2) -> Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v Type
t2
      (Type
t1, TypeVariable Name
v) -> Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v Type
t1

      -- TODO; temporary "slop", e.g. (record "RowType" ...) is allowed to unify with (wrap "RowType" @ "a")
      (TypeApplication (ApplicationType Type
lhs Type
rhs), Type
t2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
lhs Type
t2
      (Type
t1, TypeApplication (ApplicationType Type
lhs Type
rhs)) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
lhs
      (TypeLambda (LambdaType Name
_ Type
body), Type
t2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
body Type
t2
      (Type
t1, TypeLambda (LambdaType Name
_ Type
body)) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
body
      -- TODO; temporary "slop", e.g. (record "RowType" ...) is allowed to unify with (wrap "RowType")
      (TypeWrap WrappedType
_, Type
_) -> Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty -- TODO
      (Type
_, TypeWrap WrappedType
name) -> Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty -- TODO

      (Type
l, Type
r) -> [Char] -> Flow s Subst
forall a. [Char] -> Flow s a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail ([Char] -> Flow s Subst) -> [Char] -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ [Char]
"unification of " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeVariant -> [Char]
forall a. Show a => a -> [Char]
show (Type -> TypeVariant
typeVariant Type
l) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" with " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TypeVariant -> [Char]
forall a. Show a => a -> [Char]
show (Type -> TypeVariant
typeVariant Type
r) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
        [Char]
":\n  " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
showType Type
l [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
        [Char]
"\n  " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
showType Type
r
  where
    verify :: [Char] -> Bool -> m (Map k a)
verify [Char]
reason Bool
b = if Bool
b then Map k a -> m (Map k a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Map k a
forall k a. Map k a
M.empty else [Char] -> m (Map k a)
forall {m :: * -> *} {a}. MonadFail m => [Char] -> m a
failUnification [Char]
reason
    failUnification :: [Char] -> m a
failUnification [Char]
reason = [Char] -> m a
forall a. [Char] -> m a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail ([Char] -> m a) -> [Char] -> m a
forall a b. (a -> b) -> a -> b
$ [Char]
"could not unify types (reason: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
reason [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"):\n\t"
      [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
showType (Type -> Type
stripType Type
ltyp) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\n\t"
      [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
showType (Type -> Type
stripType Type
rtyp) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\n"
--     failUnification = fail $ "could not unify type " ++ describeType (stripType ltyp) ++ " with " ++ describeType (stripType rtyp)
    bindWeakest :: Name -> Name -> Flow s Subst
bindWeakest Name
v1 Name
v2 = if Name -> Bool
isWeak Name
v1
        then Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v1 (Name -> Type
TypeVariable Name
v2)
        else Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v2 (Name -> Type
TypeVariable Name
v1)
      where
        isWeak :: Name -> Bool
isWeak Name
v = [Char] -> Char
forall a. HasCallStack => [a] -> a
L.head (Name -> [Char]
unName Name
v) Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
't' -- TODO: use a convention like _xxx for temporarily variables, then normalize and replace them

unifyMany :: [Type] -> [Type] -> Flow s Subst
unifyMany :: forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [] [] = Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty
unifyMany (Type
t1 : [Type]
ts1) (Type
t2 : [Type]
ts2) =
  do Subst
su1 <- Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
t2
     Subst
su2 <- [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany (Subst -> Type -> Type
substituteInType Subst
su1 (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
ts1) (Subst -> Type -> Type
substituteInType Subst
su1 (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
ts2)
     Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Subst -> Subst
composeSubst Subst
su2 Subst
su1)
unifyMany [Type]
t1 [Type]
t2 = [Char] -> Flow s Subst
forall a. [Char] -> Flow s a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail ([Char] -> Flow s Subst) -> [Char] -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ [Char]
"unification mismatch between " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Type] -> [Char]
forall a. Show a => a -> [Char]
show [Type]
t1 [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" and " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Type] -> [Char]
forall a. Show a => a -> [Char]
show [Type]
t2

variableOccursInType :: Name -> Type -> Bool
variableOccursInType :: Name -> Type -> Bool
variableOccursInType Name
a Type
t = Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
a (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Set Name
freeVariablesInType Type
t