-- | Hindley-Milner style type unification

module Hydra.Unification (
  solveConstraints
) where

import Hydra.Basics
import Hydra.Strip
import Hydra.Compute
import Hydra.Core
import Hydra.Lexical
import Hydra.Mantle
import Hydra.Printing
import Hydra.Rewriting
import Hydra.Inference.Substitution
import Hydra.Tier1
import Hydra.Dsl.Types as Types
import Hydra.Lib.Io

import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Set as S


-- Note: type variables in Hydra are allowed to bind to type expressions which contain the variable;
--       i.e. type recursion by name is allowed.
bind :: Name -> Type -> Flow s Subst
bind :: forall s. Name -> Type -> Flow s Subst
bind Name
name Type
typ = do
  if Type
typ Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Type
TypeVariable Name
name
  then Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty
  else if Name -> Type -> Bool
variableOccursInType Name
name Type
typ
--     then fail $ "infinite type for " ++ unName name ++ ": " ++ show typ
    then Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty
    else Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Flow s Subst) -> Subst -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ Name -> Type -> Subst
forall k a. k -> a -> Map k a
M.singleton Name
name Type
typ

solveConstraints :: [TypeConstraint] -> Flow s Subst
solveConstraints :: forall s. [TypeConstraint] -> Flow s Subst
solveConstraints [TypeConstraint]
cs = Subst -> [TypeConstraint] -> Flow s Subst
forall s. Subst -> [TypeConstraint] -> Flow s Subst
unificationSolver Subst
forall k a. Map k a
M.empty [TypeConstraint]
cs

unificationSolver :: Subst -> [TypeConstraint] -> Flow s Subst
unificationSolver :: forall s. Subst -> [TypeConstraint] -> Flow s Subst
unificationSolver Subst
su [TypeConstraint]
cs = case [TypeConstraint]
cs of
  [] -> Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
su
  ((TypeConstraint Type
t1 Type
t2 Maybe String
_):[TypeConstraint]
rest) -> do
    Subst
su1  <- Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
t2
    Subst -> [TypeConstraint] -> Flow s Subst
forall s. Subst -> [TypeConstraint] -> Flow s Subst
unificationSolver
      (Subst -> Subst -> Subst
composeSubst Subst
su1 Subst
su)
      ((\(TypeConstraint Type
t1 Type
t2 Maybe String
ctx) -> (Type -> Type -> Maybe String -> TypeConstraint
TypeConstraint (Subst -> Type -> Type
substituteInType Subst
su1 Type
t1) (Subst -> Type -> Type
substituteInType Subst
su1 Type
t2) Maybe String
ctx)) (TypeConstraint -> TypeConstraint)
-> [TypeConstraint] -> [TypeConstraint]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TypeConstraint]
rest)

unify :: Type -> Type -> Flow s Subst
unify :: forall s. Type -> Type -> Flow s Subst
unify Type
ltyp Type
rtyp = do
--     withTrace ("unify " ++ show ltyp ++ " with " ++ show rtyp) $
     case (Type -> Type
stripType Type
ltyp, Type -> Type
stripType Type
rtyp) of
       -- Symmetric patterns
      (TypeApplication (ApplicationType Type
lhs1 Type
rhs1), TypeApplication (ApplicationType Type
lhs2 Type
rhs2)) ->
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type
lhs1, Type
rhs1] [Type
lhs2, Type
rhs2]
      (TypeFunction (FunctionType Type
dom1 Type
cod1), TypeFunction (FunctionType Type
dom2 Type
cod2)) ->
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type
dom1, Type
cod1] [Type
dom2, Type
cod2]
      (TypeList Type
lt1, TypeList Type
lt2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
lt1 Type
lt2
      (TypeLiteral LiteralType
lt1, TypeLiteral LiteralType
lt2) -> String -> Bool -> Flow s Subst
forall {m :: * -> *} {k} {a}.
MonadFail m =>
String -> Bool -> m (Map k a)
verify String
"different literal types" (Bool -> Flow s Subst) -> Bool -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ LiteralType
lt1 LiteralType -> LiteralType -> Bool
forall a. Eq a => a -> a -> Bool
== LiteralType
lt2
      (TypeMap (MapType Type
k1 Type
v1), TypeMap (MapType Type
k2 Type
v2)) -> [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type
k1, Type
v1] [Type
k2, Type
v2]
      (TypeOptional Type
ot1, TypeOptional Type
ot2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
ot1 Type
ot2
      (TypeProduct [Type]
types1, TypeProduct [Type]
types2) -> [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type]
types1 [Type]
types2
      (TypeRecord RowType
rt1, TypeRecord RowType
rt2) -> do
        String -> Bool -> Flow s (Map Any Any)
forall {m :: * -> *} {k} {a}.
MonadFail m =>
String -> Bool -> m (Map k a)
verify String
"different record type names" (RowType -> Name
rowTypeTypeName RowType
rt1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== RowType -> Name
rowTypeTypeName RowType
rt2)
        String -> Bool -> Flow s (Map Any Any)
forall {m :: * -> *} {k} {a}.
MonadFail m =>
String -> Bool -> m (Map k a)
verify String
"different number of record fields" ([FieldType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length (RowType -> [FieldType]
rowTypeFields RowType
rt1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [FieldType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length (RowType -> [FieldType]
rowTypeFields RowType
rt2))
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany (FieldType -> Type
fieldTypeType (FieldType -> Type) -> [FieldType] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RowType -> [FieldType]
rowTypeFields RowType
rt1) (FieldType -> Type
fieldTypeType (FieldType -> Type) -> [FieldType] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RowType -> [FieldType]
rowTypeFields RowType
rt2)
      (TypeSet Type
st1, TypeSet Type
st2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
st1 Type
st2
      (TypeUnion RowType
rt1, TypeUnion RowType
rt2) -> String -> Bool -> Flow s Subst
forall {m :: * -> *} {k} {a}.
MonadFail m =>
String -> Bool -> m (Map k a)
verify String
"different union type names" (RowType -> Name
rowTypeTypeName RowType
rt1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== RowType -> Name
rowTypeTypeName RowType
rt2)
      (TypeLambda (LambdaType (Name String
v1) Type
body1), TypeLambda (LambdaType (Name String
v2) Type
body2)) ->
        [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [String -> Type
Types.var String
v1, Type
body1] [String -> Type
Types.var String
v2, Type
body2]
      (TypeSum [Type]
types1, TypeSum [Type]
types2) -> [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [Type]
types1 [Type]
types2
      (TypeWrap WrappedType
n1, TypeWrap WrappedType
n2) -> String -> Bool -> Flow s Subst
forall {m :: * -> *} {k} {a}.
MonadFail m =>
String -> Bool -> m (Map k a)
verify String
"different wrapper type names" (Bool -> Flow s Subst) -> Bool -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ WrappedType
n1 WrappedType -> WrappedType -> Bool
forall a. Eq a => a -> a -> Bool
== WrappedType
n2

      -- Asymmetric patterns
      (TypeVariable Name
v1, TypeVariable Name
v2) -> Name -> Name -> Flow s Subst
forall {s}. Name -> Name -> Flow s Subst
bindWeakest Name
v1 Name
v2
      (TypeVariable Name
v, Type
t2) -> Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v Type
t2
      (Type
t1, TypeVariable Name
v) -> Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v Type
t1

      -- TODO; temporary "slop", e.g. (record "RowType" ...) is allowed to unify with (wrap "RowType" @ "a")
      (TypeApplication (ApplicationType Type
lhs Type
rhs), Type
t2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
lhs Type
t2
      (Type
t1, TypeApplication (ApplicationType Type
lhs Type
rhs)) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
lhs
      (TypeLambda (LambdaType Name
_ Type
body), Type
t2) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
body Type
t2
      (Type
t1, TypeLambda (LambdaType Name
_ Type
body)) -> Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
body
      -- TODO; temporary "slop", e.g. (record "RowType" ...) is allowed to unify with (wrap "RowType")
      (TypeWrap WrappedType
_, Type
_) -> Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty -- TODO
      (Type
_, TypeWrap WrappedType
name) -> Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty -- TODO

      (Type
l, Type
r) -> String -> Flow s Subst
forall a. String -> Flow s a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Flow s Subst) -> String -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ String
"unification of " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TypeVariant -> String
forall a. Show a => a -> String
show (Type -> TypeVariant
typeVariant Type
l) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" with " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TypeVariant -> String
forall a. Show a => a -> String
show (Type -> TypeVariant
typeVariant Type
r) String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
":\n  " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
showType Type
l String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
"\n  " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
showType Type
r
  where
    verify :: String -> Bool -> m (Map k a)
verify String
reason Bool
b = if Bool
b then Map k a -> m (Map k a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Map k a
forall k a. Map k a
M.empty else String -> m (Map k a)
forall {m :: * -> *} {a}. MonadFail m => String -> m a
failUnification String
reason
    failUnification :: String -> m a
failUnification String
reason = String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m a) -> String -> m a
forall a b. (a -> b) -> a -> b
$ String
"could not unify types (reason: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
reason String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"):\n\t"
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
showType (Type -> Type
stripType Type
ltyp) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n\t"
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ Type -> String
showType (Type -> Type
stripType Type
rtyp) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
--     failUnification = fail $ "could not unify type " ++ describeType (stripType ltyp) ++ " with " ++ describeType (stripType rtyp)
    bindWeakest :: Name -> Name -> Flow s Subst
bindWeakest Name
v1 Name
v2 = if Name -> Bool
isWeak Name
v1
        then Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v1 (Name -> Type
TypeVariable Name
v2)
        else Name -> Type -> Flow s Subst
forall s. Name -> Type -> Flow s Subst
bind Name
v2 (Name -> Type
TypeVariable Name
v1)
      where
        isWeak :: Name -> Bool
isWeak Name
v = String -> Char
forall a. HasCallStack => [a] -> a
L.head (Name -> String
unName Name
v) Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
't' -- TODO: use a convention like _xxx for temporarily variables, then normalize and replace them

unifyMany :: [Type] -> [Type] -> Flow s Subst
unifyMany :: forall s. [Type] -> [Type] -> Flow s Subst
unifyMany [] [] = Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst
forall k a. Map k a
M.empty
unifyMany (Type
t1 : [Type]
ts1) (Type
t2 : [Type]
ts2) =
  do Subst
su1 <- Type -> Type -> Flow s Subst
forall s. Type -> Type -> Flow s Subst
unify Type
t1 Type
t2
     Subst
su2 <- [Type] -> [Type] -> Flow s Subst
forall s. [Type] -> [Type] -> Flow s Subst
unifyMany (Subst -> Type -> Type
substituteInType Subst
su1 (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
ts1) (Subst -> Type -> Type
substituteInType Subst
su1 (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
ts2)
     Subst -> Flow s Subst
forall a. a -> Flow s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Subst -> Subst -> Subst
composeSubst Subst
su2 Subst
su1)
unifyMany [Type]
t1 [Type]
t2 = String -> Flow s Subst
forall a. String -> Flow s a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Flow s Subst) -> String -> Flow s Subst
forall a b. (a -> b) -> a -> b
$ String
"unification mismatch between " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Show a => a -> String
show [Type]
t1 String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" and " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Show a => a -> String
show [Type]
t2

variableOccursInType :: Name -> Type -> Bool
variableOccursInType :: Name -> Type -> Bool
variableOccursInType Name
a Type
t = Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
a (Set Name -> Bool) -> Set Name -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Set Name
freeVariablesInType Type
t