-- | Variable substitution and normalization of type expressions
module Hydra.Substitution where

import Hydra.Core
import Hydra.Mantle
import Hydra.Rewriting
import Hydra.Tier1
import Hydra.Dsl.Types as Types

import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Maybe as Y


type Subst = M.Map Name (Type)

composeSubst :: Subst -> Subst -> Subst
composeSubst :: Subst -> Subst -> Subst
composeSubst Subst
s1 Subst
s2 = Subst -> Subst -> Subst
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Subst
s1 (Subst -> Subst) -> Subst -> Subst
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> Subst -> Subst
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Subst -> Type -> Type
substituteInType Subst
s1) Subst
s2

normalVariables :: [Name]
normalVariables :: [Name]
normalVariables = Int -> Name
normalVariable (Int -> Name) -> [Int] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
0..]

-- | Type variable naming convention follows Haskell: t0, t1, etc.
normalVariable :: Int -> Name
normalVariable :: Int -> Name
normalVariable Int
i = String -> Name
Name (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"t" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i

normalizeScheme :: TypeScheme -> TypeScheme
normalizeScheme :: TypeScheme -> TypeScheme
normalizeScheme ts :: TypeScheme
ts@(TypeScheme [Name]
_ Type
body) = [Name] -> Type -> TypeScheme
TypeScheme (((Name, Name) -> Name) -> [(Name, Name)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Name, Name) -> Name
forall a b. (a, b) -> b
snd [(Name, Name)]
ord) (Type -> Type
normalizeType Type
body)
  where
    ord :: [(Name, Name)]
ord = [Name] -> [Name] -> [(Name, Name)]
forall a b. [a] -> [b] -> [(a, b)]
L.zip (Set Name -> [Name]
forall a. Set a -> [a]
S.toList (Set Name -> [Name]) -> Set Name -> [Name]
forall a b. (a -> b) -> a -> b
$ Type -> Set Name
freeVariablesInType Type
body) [Name]
normalVariables

    normalizeFieldType :: FieldType -> FieldType
normalizeFieldType (FieldType Name
fname Type
typ) = Name -> Type -> FieldType
FieldType Name
fname (Type -> FieldType) -> Type -> FieldType
forall a b. (a -> b) -> a -> b
$ Type -> Type
normalizeType Type
typ

    normalizeType :: Type -> Type
normalizeType Type
typ = case Type
typ of
      TypeApplication (ApplicationType Type
lhs Type
rhs) -> ApplicationType -> Type
TypeApplication (Type -> Type -> ApplicationType
ApplicationType (Type -> Type
normalizeType Type
lhs) (Type -> Type
normalizeType Type
rhs))
      TypeAnnotated (AnnotatedType Type
t Map String Term
ann) -> AnnotatedType -> Type
TypeAnnotated (Type -> Map String Term -> AnnotatedType
AnnotatedType (Type -> Type
normalizeType Type
t) Map String Term
ann)
      TypeFunction (FunctionType Type
dom Type
cod) -> Type -> Type -> Type
function (Type -> Type
normalizeType Type
dom) (Type -> Type
normalizeType Type
cod)
      TypeList Type
t -> Type -> Type
list (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
normalizeType Type
t
      TypeLiteral LiteralType
_ -> Type
typ
      TypeMap (MapType Type
kt Type
vt) -> Type -> Type -> Type
Types.map (Type -> Type
normalizeType Type
kt) (Type -> Type
normalizeType Type
vt)
      TypeOptional Type
t -> Type -> Type
optional (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
normalizeType Type
t
      TypeProduct [Type]
types -> [Type] -> Type
TypeProduct (Type -> Type
normalizeType (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
types)
      TypeRecord (RowType Name
n Maybe Name
e [FieldType]
fields) -> RowType -> Type
TypeRecord (RowType -> Type) -> RowType -> Type
forall a b. (a -> b) -> a -> b
$ Name -> Maybe Name -> [FieldType] -> RowType
RowType Name
n Maybe Name
e (FieldType -> FieldType
normalizeFieldType (FieldType -> FieldType) -> [FieldType] -> [FieldType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [FieldType]
fields)
      TypeSet Type
t -> Type -> Type
set (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
normalizeType Type
t
      TypeSum [Type]
types -> [Type] -> Type
TypeSum (Type -> Type
normalizeType (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
types)
      TypeUnion (RowType Name
n Maybe Name
e [FieldType]
fields) -> RowType -> Type
TypeUnion (RowType -> Type) -> RowType -> Type
forall a b. (a -> b) -> a -> b
$ Name -> Maybe Name -> [FieldType] -> RowType
RowType Name
n Maybe Name
e (FieldType -> FieldType
normalizeFieldType (FieldType -> FieldType) -> [FieldType] -> [FieldType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [FieldType]
fields)
      TypeLambda (LambdaType (Name String
v) Type
t) -> LambdaType -> Type
TypeLambda (Name -> Type -> LambdaType
LambdaType (String -> Name
Name String
v) (Type -> LambdaType) -> Type -> LambdaType
forall a b. (a -> b) -> a -> b
$ Type -> Type
normalizeType Type
t)
      TypeVariable Name
v -> case Name -> [(Name, Name)] -> Maybe Name
forall a b. Eq a => a -> [(a, b)] -> Maybe b
Prelude.lookup Name
v [(Name, Name)]
ord of
        Just (Name String
v1) -> String -> Type
var String
v1
        Maybe Name
Nothing -> String -> Type
forall a. HasCallStack => String -> a
error (String -> Type) -> String -> Type
forall a b. (a -> b) -> a -> b
$ String
"type variable " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" not in signature of type scheme: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TypeScheme -> String
forall a. Show a => a -> String
show TypeScheme
ts
      TypeWrap WrappedType
_ -> Type
typ

substituteInScheme :: M.Map Name (Type) -> TypeScheme -> TypeScheme
substituteInScheme :: Subst -> TypeScheme -> TypeScheme
substituteInScheme Subst
s (TypeScheme [Name]
as Type
t) = [Name] -> Type -> TypeScheme
TypeScheme [Name]
as (Type -> TypeScheme) -> Type -> TypeScheme
forall a b. (a -> b) -> a -> b
$ Subst -> Type -> Type
substituteInType Subst
s' Type
t
  where
    s' :: Subst
s' = (Name -> Subst -> Subst) -> Subst -> [Name] -> Subst
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
L.foldr Name -> Subst -> Subst
forall k a. Ord k => k -> Map k a -> Map k a
M.delete Subst
s [Name]
as

substituteInType :: M.Map Name (Type) -> Type -> Type
substituteInType :: Subst -> Type -> Type
substituteInType Subst
s Type
typ = case Type
typ of
    TypeApplication (ApplicationType Type
lhs Type
rhs) -> ApplicationType -> Type
TypeApplication (Type -> Type -> ApplicationType
ApplicationType (Type -> Type
subst Type
lhs) (Type -> Type
subst Type
rhs))
    TypeAnnotated (AnnotatedType Type
t Map String Term
ann) -> AnnotatedType -> Type
TypeAnnotated (Type -> Map String Term -> AnnotatedType
AnnotatedType (Type -> Type
subst Type
t) Map String Term
ann)
    TypeFunction (FunctionType Type
dom Type
cod) -> Type -> Type -> Type
function (Type -> Type
subst Type
dom) (Type -> Type
subst Type
cod)
    TypeList Type
t -> Type -> Type
list (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
subst Type
t
    TypeLiteral LiteralType
_ -> Type
typ
    TypeMap (MapType Type
kt Type
vt) -> Type -> Type -> Type
Types.map (Type -> Type
subst Type
kt) (Type -> Type
subst Type
vt)
    TypeOptional Type
t -> Type -> Type
optional (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
subst Type
t
    TypeProduct [Type]
types -> [Type] -> Type
TypeProduct (Type -> Type
subst (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
types)
    TypeRecord (RowType Name
n Maybe Name
e [FieldType]
fields) -> RowType -> Type
TypeRecord (RowType -> Type) -> RowType -> Type
forall a b. (a -> b) -> a -> b
$ Name -> Maybe Name -> [FieldType] -> RowType
RowType Name
n Maybe Name
e (FieldType -> FieldType
substField (FieldType -> FieldType) -> [FieldType] -> [FieldType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [FieldType]
fields)
    TypeSet Type
t -> Type -> Type
set (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
subst Type
t
    TypeSum [Type]
types -> [Type] -> Type
TypeSum (Type -> Type
subst (Type -> Type) -> [Type] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type]
types)
    TypeUnion (RowType Name
n Maybe Name
e [FieldType]
fields) -> RowType -> Type
TypeUnion (RowType -> Type) -> RowType -> Type
forall a b. (a -> b) -> a -> b
$ Name -> Maybe Name -> [FieldType] -> RowType
RowType Name
n Maybe Name
e (FieldType -> FieldType
substField (FieldType -> FieldType) -> [FieldType] -> [FieldType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [FieldType]
fields)
    TypeLambda (LambdaType var :: Name
var@(Name String
v) Type
body) -> if Maybe Type -> Bool
forall a. Maybe a -> Bool
Y.isNothing (Name -> Subst -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
var Subst
s)
      then LambdaType -> Type
TypeLambda (Name -> Type -> LambdaType
LambdaType (String -> Name
Name String
v) (Type -> Type
subst Type
body))
      else Type
typ
    TypeVariable Name
a -> Type -> Name -> Subst -> Type
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Type
typ Name
a Subst
s
    TypeWrap WrappedType
_ -> Type
typ -- because we do not allow names to be bound to types with free variables
  where
    subst :: Type -> Type
subst = Subst -> Type -> Type
substituteInType Subst
s
    substField :: FieldType -> FieldType
substField (FieldType Name
fname Type
t) = Name -> Type -> FieldType
FieldType Name
fname (Type -> FieldType) -> Type -> FieldType
forall a b. (a -> b) -> a -> b
$ Type -> Type
subst Type
t