-- | Hindley-Milner style type unification

module Hydra.Types.Unification (
  Constraint,
  solveConstraints,
) where

import Hydra.Kernel
import Hydra.Types.Substitution
import Hydra.Impl.Haskell.Dsl.Types as Types

import qualified Data.Map as M
import qualified Data.Set as S


type Constraint m = (Type m, Type m)

type Unifier m = (Subst m, [Constraint m])

bind :: (Eq m, Show m) => VariableType -> Type m -> GraphFlow m (Subst m)
bind :: forall m.
(Eq m, Show m) =>
VariableType -> Type m -> GraphFlow m (Subst m)
bind VariableType
a Type m
t | Type m
t forall a. Eq a => a -> a -> Bool
== forall m. VariableType -> Type m
TypeVariable VariableType
a = forall (m :: * -> *) a. Monad m => a -> m a
return forall k a. Map k a
M.empty
         | forall m. Show m => VariableType -> Type m -> Bool
variableOccursInType VariableType
a Type m
t = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"infinite type for ?" forall a. [a] -> [a] -> [a]
++ VariableType -> String
unVariableType VariableType
a forall a. [a] -> [a] -> [a]
++ String
": " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Type m
t
         | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VariableType
a Type m
t

solveConstraints :: (Eq m, Show m) => [Constraint m] -> GraphFlow m (Subst m)
solveConstraints :: forall m. (Eq m, Show m) => [Constraint m] -> GraphFlow m (Subst m)
solveConstraints [Constraint m]
cs = forall m. (Eq m, Show m) => Unifier m -> GraphFlow m (Subst m)
unificationSolver (forall k a. Map k a
M.empty, [Constraint m]
cs)

unificationSolver :: (Eq m, Show m) => Unifier m -> GraphFlow m (Subst m)
unificationSolver :: forall m. (Eq m, Show m) => Unifier m -> GraphFlow m (Subst m)
unificationSolver (Subst m
su, [Constraint m]
cs) = case [Constraint m]
cs of
  [] -> forall (m :: * -> *) a. Monad m => a -> m a
return Subst m
su
  ((Type m
t1, Type m
t2): [Constraint m]
cs0) -> do
    Subst m
su1  <- forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
t1 Type m
t2
    forall m. (Eq m, Show m) => Unifier m -> GraphFlow m (Subst m)
unificationSolver (
      forall m. Subst m -> Subst m -> Subst m
composeSubst Subst m
su1 Subst m
su,
      (\(Type m
t1, Type m
t2) -> (forall m. Map VariableType (Type m) -> Type m -> Type m
substituteInType Subst m
su1 Type m
t1, forall m. Map VariableType (Type m) -> Type m -> Type m
substituteInType Subst m
su1 Type m
t2)) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Constraint m]
cs0)

unify :: (Eq m, Show m) => Type m -> Type m -> GraphFlow m (Subst m)
unify :: forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
t1 Type m
t2 = if Type m
t1 forall a. Eq a => a -> a -> Bool
== Type m
t2
    then forall (m :: * -> *) a. Monad m => a -> m a
return forall k a. Map k a
M.empty
    else case (Type m
t1, Type m
t2) of
      -- Temporary; type parameters are ignored
      (TypeApplication (ApplicationType Type m
lhs Type m
rhs), Type m
t2) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
lhs Type m
t2
      (Type m
t1, TypeApplication (ApplicationType Type m
lhs Type m
rhs)) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
t1 Type m
lhs

      (TypeAnnotated (Annotated Type m
at m
_), Type m
_) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
at Type m
t2
      (Type m
_, TypeAnnotated (Annotated Type m
at m
_)) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
t1 Type m
at
      (TypeElement Type m
et1, TypeElement Type m
et2) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
et1 Type m
et2
      (TypeFunction (FunctionType Type m
dom Type m
cod), TypeFunction (FunctionType Type m
t3 Type m
t4)) -> forall m.
(Eq m, Show m) =>
[Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany [Type m
dom, Type m
cod] [Type m
t3, Type m
t4]
      (TypeList Type m
lt1, TypeList Type m
lt2) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
lt1 Type m
lt2
      (TypeMap (MapType Type m
k1 Type m
v1), TypeMap (MapType Type m
k2 Type m
v2)) -> forall m.
(Eq m, Show m) =>
[Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany [Type m
k1, Type m
v1] [Type m
k2, Type m
v2]
      (TypeOptional Type m
ot1, TypeOptional Type m
ot2) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
ot1 Type m
ot2
      (TypeProduct [Type m]
types1, TypeProduct [Type m]
types2) -> forall m.
(Eq m, Show m) =>
[Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany [Type m]
types1 [Type m]
types2
      (TypeRecord RowType m
rt1, TypeRecord RowType m
rt2) -> forall {k} {a}. Bool -> Flow (Context m) (Map k a)
verify (forall m. RowType m -> Name
rowTypeTypeName RowType m
rt1 forall a. Eq a => a -> a -> Bool
== forall m. RowType m -> Name
rowTypeTypeName RowType m
rt2)
      (TypeSet Type m
st1, TypeSet Type m
st2) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
st1 Type m
st2
      (TypeUnion RowType m
rt1, TypeUnion RowType m
rt2) -> forall {k} {a}. Bool -> Flow (Context m) (Map k a)
verify (forall m. RowType m -> Name
rowTypeTypeName RowType m
rt1 forall a. Eq a => a -> a -> Bool
== forall m. RowType m -> Name
rowTypeTypeName RowType m
rt2)
      (TypeLambda (LambdaType (VariableType String
v1) Type m
body1), TypeLambda (LambdaType (VariableType String
v2) Type m
body2)) -> forall m.
(Eq m, Show m) =>
[Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany
        [forall m. String -> Type m
Types.variable String
v1, Type m
body1] [forall m. String -> Type m
Types.variable String
v2, Type m
body2]
      (TypeSum [Type m]
types1, TypeSum [Type m]
types2) -> forall m.
(Eq m, Show m) =>
[Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany [Type m]
types1 [Type m]
types2
      (TypeVariable VariableType
v, Type m
_) -> forall m.
(Eq m, Show m) =>
VariableType -> Type m -> GraphFlow m (Subst m)
bind VariableType
v Type m
t2
      (Type m
_, TypeVariable VariableType
v) -> forall m.
(Eq m, Show m) =>
VariableType -> Type m -> GraphFlow m (Subst m)
bind VariableType
v Type m
t1
      (TypeNominal Name
n1, TypeNominal Name
n2) -> if Name
n1 forall a. Eq a => a -> a -> Bool
== Name
n2
        then forall (m :: * -> *) a. Monad m => a -> m a
return forall k a. Map k a
M.empty
        else forall {a}. Flow (Context m) a
failUnification
      (TypeNominal Name
_, Type m
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall k a. Map k a
M.empty -- TODO
      (Type m
_, TypeNominal Name
name) -> forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify (forall m. Name -> Type m
Types.nominal Name
name) Type m
t1
      (Type m
l, Type m
r) -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"unexpected unification of " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall m. Type m -> TypeVariant
typeVariant Type m
l) forall a. [a] -> [a] -> [a]
++ String
" with " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall m. Type m -> TypeVariant
typeVariant Type m
r) forall a. [a] -> [a] -> [a]
++
        String
":\n  " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Type m
l forall a. [a] -> [a] -> [a]
++ String
"\n  " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Type m
r
  where
    verify :: Bool -> Flow (Context m) (Map k a)
verify Bool
b = if Bool
b then forall (m :: * -> *) a. Monad m => a -> m a
return forall k a. Map k a
M.empty else forall {a}. Flow (Context m) a
failUnification
    failUnification :: Flow (Context m) a
failUnification = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"could not unify type " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Type m
t1 forall a. [a] -> [a] -> [a]
++ String
" with " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Type m
t2

unifyMany :: (Eq m, Show m) => [Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany :: forall m.
(Eq m, Show m) =>
[Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany [] [] = forall (m :: * -> *) a. Monad m => a -> m a
return forall k a. Map k a
M.empty
unifyMany (Type m
t1 : [Type m]
ts1) (Type m
t2 : [Type m]
ts2) =
  do Subst m
su1 <- forall m.
(Eq m, Show m) =>
Type m -> Type m -> GraphFlow m (Subst m)
unify Type m
t1 Type m
t2
     Subst m
su2 <- forall m.
(Eq m, Show m) =>
[Type m] -> [Type m] -> GraphFlow m (Subst m)
unifyMany (forall m. Map VariableType (Type m) -> Type m -> Type m
substituteInType Subst m
su1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type m]
ts1) (forall m. Map VariableType (Type m) -> Type m -> Type m
substituteInType Subst m
su1 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type m]
ts2)
     forall (m :: * -> *) a. Monad m => a -> m a
return (forall m. Subst m -> Subst m -> Subst m
composeSubst Subst m
su2 Subst m
su1)
unifyMany [Type m]
t1 [Type m]
t2 = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"unification mismatch between " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [Type m]
t1 forall a. [a] -> [a] -> [a]
++ String
" and " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [Type m]
t2

variableOccursInType ::  Show m => VariableType -> Type m -> Bool
variableOccursInType :: forall m. Show m => VariableType -> Type m -> Bool
variableOccursInType VariableType
a Type m
t = forall a. Ord a => a -> Set a -> Bool
S.member VariableType
a forall a b. (a -> b) -> a -> b
$ forall m. Type m -> Set VariableType
freeVariablesInType Type m
t