{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Inferno.Infer.Env
  ( Env (..),
    Namespace (..),
    TypeMetadata (..),
    closeOver,
    closeOverType,
    empty,
    lookup,
    lookupPinned,
    remove,
    extend,
    merge,
    mergeEnvs,
    singleton,
    keys,
    fromList,
    fromListModule,
    toList,
    normtype,
    normTC,
    fv,
    namespaceToIdent,
    generalize,
  )
where

import Data.Foldable (Foldable (foldl'))
import Data.List (nub)
import qualified Data.Map as Map
import qualified Data.Set as Set
import Inferno.Types.Syntax (ExtIdent)
import Inferno.Types.Type
  ( ImplType (..),
    InfernoType (..),
    Namespace (..),
    Substitutable (..),
    TCScheme (..),
    TV (..),
    TypeClass (..),
    TypeMetadata (..),
    namespaceToIdent,
  )
import Inferno.Types.VersionControl (VCObjectHash)
import Prelude hiding (lookup)

-------------------------------------------------------------------------------
-- Typing Environment
-------------------------------------------------------------------------------

data Env = TypeEnv
  { Env -> Map ExtIdent (TypeMetadata TCScheme)
types :: Map.Map ExtIdent (TypeMetadata TCScheme),
    Env -> Map VCObjectHash (TypeMetadata TCScheme)
pinnedTypes :: Map.Map VCObjectHash (TypeMetadata TCScheme)
  }
  deriving (Env -> Env -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Env -> Env -> Bool
$c/= :: Env -> Env -> Bool
== :: Env -> Env -> Bool
$c== :: Env -> Env -> Bool
Eq, Int -> Env -> ShowS
[Env] -> ShowS
Env -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Env] -> ShowS
$cshowList :: [Env] -> ShowS
show :: Env -> String
$cshow :: Env -> String
showsPrec :: Int -> Env -> ShowS
$cshowsPrec :: Int -> Env -> ShowS
Show)

instance Substitutable Env where
  apply :: Subst -> Env -> Env
apply Subst
s Env
env =
    Env
env
      { types :: Map ExtIdent (TypeMetadata TCScheme)
types = forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (\TypeMetadata TCScheme
meta -> TypeMetadata TCScheme
meta {ty :: TCScheme
ty = forall a. Substitutable a => Subst -> a -> a
apply Subst
s forall a b. (a -> b) -> a -> b
$ forall ty. TypeMetadata ty -> ty
ty TypeMetadata TCScheme
meta}) forall a b. (a -> b) -> a -> b
$ Env -> Map ExtIdent (TypeMetadata TCScheme)
types Env
env,
        pinnedTypes :: Map VCObjectHash (TypeMetadata TCScheme)
pinnedTypes = forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (\TypeMetadata TCScheme
meta -> TypeMetadata TCScheme
meta {ty :: TCScheme
ty = forall a. Substitutable a => Subst -> a -> a
apply Subst
s forall a b. (a -> b) -> a -> b
$ forall ty. TypeMetadata ty -> ty
ty TypeMetadata TCScheme
meta}) forall a b. (a -> b) -> a -> b
$ Env -> Map VCObjectHash (TypeMetadata TCScheme)
pinnedTypes Env
env
      }
  ftv :: Env -> Set TV
ftv Env
env =
    forall a. Substitutable a => a -> Set TV
ftv forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall ty. TypeMetadata ty -> ty
ty forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
Map.elems forall a b. (a -> b) -> a -> b
$ Env -> Map ExtIdent (TypeMetadata TCScheme)
types Env
env

-- pinnedTypes should not have any free variables!!
empty :: Env
empty :: Env
empty = Map ExtIdent (TypeMetadata TCScheme)
-> Map VCObjectHash (TypeMetadata TCScheme) -> Env
TypeEnv forall k a. Map k a
Map.empty forall k a. Map k a
Map.empty

extend :: Env -> (ExtIdent, TypeMetadata TCScheme) -> Env
extend :: Env -> (ExtIdent, TypeMetadata TCScheme) -> Env
extend Env
env (ExtIdent
x, TypeMetadata TCScheme
m) = Env
env {types :: Map ExtIdent (TypeMetadata TCScheme)
types = forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert ExtIdent
x TypeMetadata TCScheme
m (Env -> Map ExtIdent (TypeMetadata TCScheme)
types Env
env)}

remove :: Env -> ExtIdent -> Env
remove :: Env -> ExtIdent -> Env
remove Env
env ExtIdent
v = Env
env {types :: Map ExtIdent (TypeMetadata TCScheme)
types = forall k a. Ord k => k -> Map k a -> Map k a
Map.delete ExtIdent
v (Env -> Map ExtIdent (TypeMetadata TCScheme)
types Env
env)}

lookup :: ExtIdent -> Env -> Maybe (TypeMetadata TCScheme)
lookup :: ExtIdent -> Env -> Maybe (TypeMetadata TCScheme)
lookup ExtIdent
key Env
env = forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup ExtIdent
key (Env -> Map ExtIdent (TypeMetadata TCScheme)
types Env
env)

lookupPinned :: VCObjectHash -> Env -> Maybe (TypeMetadata TCScheme)
lookupPinned :: VCObjectHash -> Env -> Maybe (TypeMetadata TCScheme)
lookupPinned VCObjectHash
key Env
env = forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup VCObjectHash
key (Env -> Map VCObjectHash (TypeMetadata TCScheme)
pinnedTypes Env
env)

merge :: Env -> Env -> Env
merge :: Env -> Env -> Env
merge (TypeEnv Map ExtIdent (TypeMetadata TCScheme)
a Map VCObjectHash (TypeMetadata TCScheme)
b) (TypeEnv Map ExtIdent (TypeMetadata TCScheme)
a' Map VCObjectHash (TypeMetadata TCScheme)
b') =
  Map ExtIdent (TypeMetadata TCScheme)
-> Map VCObjectHash (TypeMetadata TCScheme) -> Env
TypeEnv (forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union Map ExtIdent (TypeMetadata TCScheme)
a Map ExtIdent (TypeMetadata TCScheme)
a') (forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union Map VCObjectHash (TypeMetadata TCScheme)
b Map VCObjectHash (TypeMetadata TCScheme)
b')

mergeEnvs :: [Env] -> Env
mergeEnvs :: [Env] -> Env
mergeEnvs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Env -> Env -> Env
merge Env
empty

singleton :: ExtIdent -> TypeMetadata TCScheme -> Env
singleton :: ExtIdent -> TypeMetadata TCScheme -> Env
singleton ExtIdent
x TypeMetadata TCScheme
m =
  TypeEnv
    { types :: Map ExtIdent (TypeMetadata TCScheme)
types = forall k a. k -> a -> Map k a
Map.singleton ExtIdent
x TypeMetadata TCScheme
m,
      pinnedTypes :: Map VCObjectHash (TypeMetadata TCScheme)
pinnedTypes = forall k a. Map k a
Map.empty
    }

keys :: Env -> [ExtIdent]
keys :: Env -> [ExtIdent]
keys = forall k a. Map k a -> [k]
Map.keys forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map ExtIdent (TypeMetadata TCScheme)
types

fromList :: [(ExtIdent, TypeMetadata TCScheme)] -> Env
fromList :: [(ExtIdent, TypeMetadata TCScheme)] -> Env
fromList [(ExtIdent, TypeMetadata TCScheme)]
xs =
  TypeEnv
    { types :: Map ExtIdent (TypeMetadata TCScheme)
types = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(ExtIdent, TypeMetadata TCScheme)]
xs,
      pinnedTypes :: Map VCObjectHash (TypeMetadata TCScheme)
pinnedTypes = forall k a. Map k a
Map.empty
    }

fromListModule :: [(VCObjectHash, TypeMetadata TCScheme)] -> Env
fromListModule :: [(VCObjectHash, TypeMetadata TCScheme)] -> Env
fromListModule [(VCObjectHash, TypeMetadata TCScheme)]
xs =
  TypeEnv
    { types :: Map ExtIdent (TypeMetadata TCScheme)
types = forall k a. Map k a
Map.empty,
      pinnedTypes :: Map VCObjectHash (TypeMetadata TCScheme)
pinnedTypes = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(VCObjectHash, TypeMetadata TCScheme)]
xs
    }

toList :: Env -> [(ExtIdent, TypeMetadata TCScheme)]
toList :: Env -> [(ExtIdent, TypeMetadata TCScheme)]
toList = forall k a. Map k a -> [(k, a)]
Map.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map ExtIdent (TypeMetadata TCScheme)
types

instance Semigroup Env where
  <> :: Env -> Env -> Env
(<>) = Env -> Env -> Env
merge

instance Monoid Env where
  mempty :: Env
mempty = Env
empty

normTC :: (InfernoType -> InfernoType) -> TypeClass -> TypeClass
normTC :: (InfernoType -> InfernoType) -> TypeClass -> TypeClass
normTC InfernoType -> InfernoType
nt (TypeClass Text
n [InfernoType]
tys) = Text -> [InfernoType] -> TypeClass
TypeClass Text
n (forall a b. (a -> b) -> [a] -> [b]
map InfernoType -> InfernoType
nt [InfernoType]
tys)

fv :: InfernoType -> [TV]
fv :: InfernoType -> [TV]
fv (TVar TV
a) = [TV
a]
fv (TArr InfernoType
a InfernoType
b) = InfernoType -> [TV]
fv InfernoType
a forall a. [a] -> [a] -> [a]
++ InfernoType -> [TV]
fv InfernoType
b
fv (TBase BaseType
_) = []
fv (TArray InfernoType
t) = InfernoType -> [TV]
fv InfernoType
t
fv (TSeries InfernoType
t) = InfernoType -> [TV]
fv InfernoType
t
fv (TOptional InfernoType
t) = InfernoType -> [TV]
fv InfernoType
t
fv (TTuple TList InfernoType
ts) = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a. [a] -> [a] -> [a]
(++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. InfernoType -> [TV]
fv) [] TList InfernoType
ts
fv (TRep InfernoType
t) = InfernoType -> [TV]
fv InfernoType
t

normtype :: Map.Map TV TV -> InfernoType -> InfernoType
normtype :: Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord (TArr InfernoType
a InfernoType
b) = InfernoType -> InfernoType -> InfernoType
TArr (Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord InfernoType
a) (Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord InfernoType
b)
normtype Map TV TV
_ (TBase BaseType
a) = BaseType -> InfernoType
TBase BaseType
a
normtype Map TV TV
ord (TArray InfernoType
a) = InfernoType -> InfernoType
TArray forall a b. (a -> b) -> a -> b
$ Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord InfernoType
a
normtype Map TV TV
ord (TSeries InfernoType
a) = InfernoType -> InfernoType
TSeries forall a b. (a -> b) -> a -> b
$ Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord InfernoType
a
normtype Map TV TV
ord (TOptional InfernoType
a) = InfernoType -> InfernoType
TOptional forall a b. (a -> b) -> a -> b
$ Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord InfernoType
a
normtype Map TV TV
ord (TTuple TList InfernoType
as) = TList InfernoType -> InfernoType
TTuple forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord) TList InfernoType
as
normtype Map TV TV
ord (TRep InfernoType
a) = InfernoType -> InfernoType
TRep forall a b. (a -> b) -> a -> b
$ Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ord InfernoType
a
normtype Map TV TV
ord (TVar TV
a) =
  case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TV
a Map TV TV
ord of
    Just TV
x -> TV -> InfernoType
TVar TV
x
    Maybe TV
Nothing -> TV -> InfernoType
TVar TV
a -- error $ "type variable " <> show a <> "not in signature"

normalize :: TCScheme -> TCScheme
normalize :: TCScheme -> TCScheme
normalize (ForallTC [TV]
_ Set TypeClass
tcs (ImplType Map ExtIdent InfernoType
impl InfernoType
body)) =
  [TV] -> Set TypeClass -> ImplType -> TCScheme
ForallTC
    (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(TV, TV)]
ord)
    (forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map ((InfernoType -> InfernoType) -> TypeClass -> TypeClass
normTC forall a b. (a -> b) -> a -> b
$ Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ordMap) Set TypeClass
tcs)
    forall a b. (a -> b) -> a -> b
$ Map ExtIdent InfernoType -> InfernoType -> ImplType
ImplType (forall a b k. (a -> b) -> Map k a -> Map k b
Map.map (Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ordMap) Map ExtIdent InfernoType
impl) (Map TV TV -> InfernoType -> InfernoType
normtype Map TV TV
ordMap InfernoType
body)
  where
    -- collect free variables from the body of the function first,
    -- then from any implicit type variables and finally from the typeclasses
    ftvs :: [TV]
ftvs = forall a. Eq a => [a] -> [a]
nub forall a b. (a -> b) -> a -> b
$ InfernoType -> [TV]
fv InfernoType
body forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (InfernoType -> [TV]
fv forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall k a. Map k a -> [(k, a)]
Map.toList Map ExtIdent InfernoType
impl) forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(TypeClass Text
_ [InfernoType]
tys) -> forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap InfernoType -> [TV]
fv [InfernoType]
tys) (forall a. Set a -> [a]
Set.toList Set TypeClass
tcs)
    ord :: [(TV, TV)]
ord = forall a b. [a] -> [b] -> [(a, b)]
zip [TV]
ftvs (forall a b. (a -> b) -> [a] -> [b]
map Int -> TV
TV [Int
0 ..])
    ordMap :: Map TV TV
ordMap = forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(TV, TV)]
ord

generalize :: Set.Set TypeClass -> ImplType -> TCScheme
generalize :: Set TypeClass -> ImplType -> TCScheme
generalize Set TypeClass
tcs ImplType
t = [TV] -> Set TypeClass -> ImplType -> TCScheme
ForallTC [TV]
as Set TypeClass
tcs ImplType
t
  where
    as :: [TV]
as = forall a. Set a -> [a]
Set.toList forall a b. (a -> b) -> a -> b
$ ((forall a. Substitutable a => a -> Set TV
ftv ImplType
t) forall a. Ord a => Set a -> Set a -> Set a
`Set.union` (forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
Set.elems forall a b. (a -> b) -> a -> b
$ forall b a. Ord b => (a -> b) -> Set a -> Set b
Set.map forall a. Substitutable a => a -> Set TV
ftv Set TypeClass
tcs))

-- | Canonicalize and return the polymorphic toplevel type.
closeOver :: Set.Set TypeClass -> ImplType -> TCScheme
closeOver :: Set TypeClass -> ImplType -> TCScheme
closeOver Set TypeClass
tcs ImplType
t = TCScheme -> TCScheme
normalize forall a b. (a -> b) -> a -> b
$ Set TypeClass -> ImplType -> TCScheme
generalize Set TypeClass
tcs ImplType
t

closeOverType :: InfernoType -> TCScheme
closeOverType :: InfernoType -> TCScheme
closeOverType = TCScheme -> TCScheme
normalize forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set TypeClass -> ImplType -> TCScheme
generalize forall a. Set a
Set.empty forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map ExtIdent InfernoType -> InfernoType -> ImplType
ImplType forall k a. Map k a
Map.empty