{- -----------------------------------------------------------------------------
Copyright 2019-2021 Kevin P. Barry

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
----------------------------------------------------------------------------- -}

-- Author: Kevin P. Barry [ta0kira@gmail.com]

{-# LANGUAGE Safe #-}

module Types.Function (
  FunctionType(..),
  assignFunctionParams,
  checkFunctionConvert,
  validatateFunctionType,
) where

import Data.List (group,intercalate,sort)
import Control.Monad (when)
import qualified Data.Map as Map
import qualified Data.Set as Set

import Base.CompilerError
import Base.GeneralType
import Base.Positional
import Types.TypeInstance
import Types.Variance


data FunctionType =
  FunctionType {
    FunctionType -> Positional ValueType
ftArgs :: Positional ValueType,
    FunctionType -> Positional ValueType
ftReturns :: Positional ValueType,
    FunctionType -> Positional ParamName
ftParams :: Positional ParamName,
    FunctionType -> Positional [TypeFilter]
ftFilters :: Positional [TypeFilter]
  }
  deriving (FunctionType -> FunctionType -> Bool
(FunctionType -> FunctionType -> Bool)
-> (FunctionType -> FunctionType -> Bool) -> Eq FunctionType
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FunctionType -> FunctionType -> Bool
== :: FunctionType -> FunctionType -> Bool
$c/= :: FunctionType -> FunctionType -> Bool
/= :: FunctionType -> FunctionType -> Bool
Eq)

instance Show FunctionType where
  show :: FunctionType -> String
show (FunctionType Positional ValueType
as Positional ValueType
rs Positional ParamName
ps Positional [TypeFilter]
fa) =
    String
"<" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"," ((ParamName -> String) -> [ParamName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ParamName -> String
forall a. Show a => a -> String
show ([ParamName] -> [String]) -> [ParamName] -> [String]
forall a b. (a -> b) -> a -> b
$ Positional ParamName -> [ParamName]
forall a. Positional a -> [a]
pValues Positional ParamName
ps) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"> " String -> ShowS
forall a. [a] -> [a] -> [a]
++
    [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[String]] -> [String]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[String]] -> [String]) -> [[String]] -> [String]
forall a b. (a -> b) -> a -> b
$ ((ParamName, [TypeFilter]) -> [String])
-> [(ParamName, [TypeFilter])] -> [[String]]
forall a b. (a -> b) -> [a] -> [b]
map (ParamName, [TypeFilter]) -> [String]
forall {a} {a}. (Show a, Show a) => (a, [a]) -> [String]
showFilters ([(ParamName, [TypeFilter])] -> [[String]])
-> [(ParamName, [TypeFilter])] -> [[String]]
forall a b. (a -> b) -> a -> b
$ [ParamName] -> [[TypeFilter]] -> [(ParamName, [TypeFilter])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Positional ParamName -> [ParamName]
forall a. Positional a -> [a]
pValues Positional ParamName
ps) (Positional [TypeFilter] -> [[TypeFilter]]
forall a. Positional a -> [a]
pValues Positional [TypeFilter]
fa)) String -> ShowS
forall a. [a] -> [a] -> [a]
++
    String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"," ((ValueType -> String) -> [ValueType] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ValueType -> String
forall a. Show a => a -> String
show ([ValueType] -> [String]) -> [ValueType] -> [String]
forall a b. (a -> b) -> a -> b
$ Positional ValueType -> [ValueType]
forall a. Positional a -> [a]
pValues Positional ValueType
as) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") -> " String -> ShowS
forall a. [a] -> [a] -> [a]
++
    String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"," ((ValueType -> String) -> [ValueType] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ValueType -> String
forall a. Show a => a -> String
show ([ValueType] -> [String]) -> [ValueType] -> [String]
forall a b. (a -> b) -> a -> b
$ Positional ValueType -> [ValueType]
forall a. Positional a -> [a]
pValues Positional ValueType
rs) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
    where
      showFilters :: (a, [a]) -> [String]
showFilters (a
n,[a]
fs) = (a -> String) -> [a] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\a
f -> a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
f String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" ") [a]
fs

validatateFunctionType :: (CollectErrorsM m, TypeResolver r) =>
  r -> Set.Set ParamName -> ParamVariances -> FunctionType -> m ()
validatateFunctionType :: forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> Set ParamName -> ParamVariances -> FunctionType -> m ()
validatateFunctionType r
r Set ParamName
params ParamVariances
vm (FunctionType Positional ValueType
as Positional ValueType
rs Positional ParamName
ps Positional [TypeFilter]
fa) = do
  ([ParamName] -> m ()) -> [[ParamName]] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ [ParamName] -> m ()
forall {m :: * -> *} {a}. (ErrorContextM m, Show a) => [a] -> m ()
checkCount ([[ParamName]] -> m ()) -> [[ParamName]] -> m ()
forall a b. (a -> b) -> a -> b
$ [ParamName] -> [[ParamName]]
forall a. Eq a => [a] -> [[a]]
group ([ParamName] -> [[ParamName]]) -> [ParamName] -> [[ParamName]]
forall a b. (a -> b) -> a -> b
$ [ParamName] -> [ParamName]
forall a. Ord a => [a] -> [a]
sort ([ParamName] -> [ParamName]) -> [ParamName] -> [ParamName]
forall a b. (a -> b) -> a -> b
$ Positional ParamName -> [ParamName]
forall a. Positional a -> [a]
pValues Positional ParamName
ps
  (ParamName -> m ()) -> [ParamName] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ ParamName -> m ()
forall {f :: * -> *}. ErrorContextM f => ParamName -> f ()
checkHides ([ParamName] -> m ()) -> [ParamName] -> m ()
forall a b. (a -> b) -> a -> b
$ Positional ParamName -> [ParamName]
forall a. Positional a -> [a]
pValues Positional ParamName
ps
  let allParams :: Set ParamName
allParams = Set ParamName -> Set ParamName -> Set ParamName
forall a. Ord a => Set a -> Set a -> Set a
Set.union Set ParamName
params ([ParamName] -> Set ParamName
forall a. Ord a => [a] -> Set a
Set.fromList ([ParamName] -> Set ParamName) -> [ParamName] -> Set ParamName
forall a b. (a -> b) -> a -> b
$ Positional ParamName -> [ParamName]
forall a. Positional a -> [a]
pValues Positional ParamName
ps)
  [(ParamName, TypeFilter)]
expanded <- ([[(ParamName, TypeFilter)]] -> [(ParamName, TypeFilter)])
-> m [[(ParamName, TypeFilter)]] -> m [(ParamName, TypeFilter)]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[(ParamName, TypeFilter)]] -> [(ParamName, TypeFilter)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (m [[(ParamName, TypeFilter)]] -> m [(ParamName, TypeFilter)])
-> m [[(ParamName, TypeFilter)]] -> m [(ParamName, TypeFilter)]
forall a b. (a -> b) -> a -> b
$ (ParamName -> [TypeFilter] -> m [(ParamName, TypeFilter)])
-> Positional ParamName
-> Positional [TypeFilter]
-> m [[(ParamName, TypeFilter)]]
forall a b (m :: * -> *) c.
(Show a, Show b, CollectErrorsM m) =>
(a -> b -> m c) -> Positional a -> Positional b -> m [c]
processPairs (\ParamName
n [TypeFilter]
fs -> [(ParamName, TypeFilter)] -> m [(ParamName, TypeFilter)]
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ([(ParamName, TypeFilter)] -> m [(ParamName, TypeFilter)])
-> [(ParamName, TypeFilter)] -> m [(ParamName, TypeFilter)]
forall a b. (a -> b) -> a -> b
$ [ParamName] -> [TypeFilter] -> [(ParamName, TypeFilter)]
forall a b. [a] -> [b] -> [(a, b)]
zip (ParamName -> [ParamName]
forall a. a -> [a]
repeat ParamName
n) [TypeFilter]
fs) Positional ParamName
ps Positional [TypeFilter]
fa
  ((ParamName, TypeFilter) -> m ())
-> [(ParamName, TypeFilter)] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ (Set ParamName -> (ParamName, TypeFilter) -> m ()
forall {m :: * -> *} {a}.
(CollectErrorsM m, Show a) =>
Set ParamName -> (a, TypeFilter) -> m ()
checkFilterType Set ParamName
allParams) [(ParamName, TypeFilter)]
expanded
  ((ParamName, TypeFilter) -> m ())
-> [(ParamName, TypeFilter)] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ (ParamName, TypeFilter) -> m ()
forall {m :: * -> *} {a}.
(CollectErrorsM m, Show a) =>
(a, TypeFilter) -> m ()
checkFilterVariance [(ParamName, TypeFilter)]
expanded
  (ValueType -> m ()) -> [ValueType] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ (Set ParamName -> ValueType -> m ()
forall {m :: * -> *}.
CollectErrorsM m =>
Set ParamName -> ValueType -> m ()
checkArg Set ParamName
allParams) ([ValueType] -> m ()) -> [ValueType] -> m ()
forall a b. (a -> b) -> a -> b
$ Positional ValueType -> [ValueType]
forall a. Positional a -> [a]
pValues Positional ValueType
as
  (ValueType -> m ()) -> [ValueType] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ (Set ParamName -> ValueType -> m ()
forall {m :: * -> *}.
CollectErrorsM m =>
Set ParamName -> ValueType -> m ()
checkReturn Set ParamName
allParams) ([ValueType] -> m ()) -> [ValueType] -> m ()
forall a b. (a -> b) -> a -> b
$ Positional ValueType -> [ValueType]
forall a. Positional a -> [a]
pValues Positional ValueType
rs
  where
    allVariances :: ParamVariances
allVariances = ParamVariances -> ParamVariances -> ParamVariances
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ParamVariances
vm ([(ParamName, Variance)] -> ParamVariances
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(ParamName, Variance)] -> ParamVariances)
-> [(ParamName, Variance)] -> ParamVariances
forall a b. (a -> b) -> a -> b
$ [ParamName] -> [Variance] -> [(ParamName, Variance)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Positional ParamName -> [ParamName]
forall a. Positional a -> [a]
pValues Positional ParamName
ps) (Variance -> [Variance]
forall a. a -> [a]
repeat Variance
Invariant))
    checkCount :: [a] -> m ()
checkCount xa :: [a]
xa@(a
x:a
_:[a]
_) =
      String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"Function parameter " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" occurs " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xa) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" times"
    checkCount [a]
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    checkHides :: ParamName -> f ()
checkHides ParamName
n =
      Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ParamName
n ParamName -> Set ParamName -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set ParamName
params) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$
        String -> f ()
forall a. String -> f a
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM (String -> f ()) -> String -> f ()
forall a b. (a -> b) -> a -> b
$ String
"Function parameter " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ParamName -> String
forall a. Show a => a -> String
show ParamName
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" hides a category-level parameter"
    checkFilterType :: Set ParamName -> (a, TypeFilter) -> m ()
checkFilterType Set ParamName
fa2 (a
n,TypeFilter
f) =
      r -> Set ParamName -> TypeFilter -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> Set ParamName -> TypeFilter -> m ()
validateTypeFilter r
r Set ParamName
fa2 TypeFilter
f m () -> String -> m ()
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
<?? (String
"In filter " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeFilter -> String
forall a. Show a => a -> String
show TypeFilter
f)
    checkFilterVariance :: (a, TypeFilter) -> m ()
checkFilterVariance (a
n,f :: TypeFilter
f@(TypeFilter FilterDirection
FilterRequires GeneralInstance
t)) =
      r -> ParamVariances -> Variance -> GeneralInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamVariances -> Variance -> GeneralInstance -> m ()
validateInstanceVariance r
r ParamVariances
allVariances Variance
Contravariant GeneralInstance
t m () -> String -> m ()
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
<??
        (String
"In filter " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeFilter -> String
forall a. Show a => a -> String
show TypeFilter
f)
    checkFilterVariance (a
n,f :: TypeFilter
f@(TypeFilter FilterDirection
FilterAllows GeneralInstance
t)) =
      r -> ParamVariances -> Variance -> GeneralInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamVariances -> Variance -> GeneralInstance -> m ()
validateInstanceVariance r
r ParamVariances
allVariances Variance
Covariant GeneralInstance
t m () -> String -> m ()
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
<??
        (String
"In filter " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeFilter -> String
forall a. Show a => a -> String
show TypeFilter
f)
    checkFilterVariance (a
n,f :: TypeFilter
f@(DefinesFilter DefinesInstance
t)) =
      r -> ParamVariances -> Variance -> DefinesInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamVariances -> Variance -> DefinesInstance -> m ()
validateDefinesVariance r
r ParamVariances
allVariances Variance
Contravariant DefinesInstance
t m () -> String -> m ()
forall (m :: * -> *) a. ErrorContextM m => m a -> String -> m a
<??
        (String
"In filter " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeFilter -> String
forall a. Show a => a -> String
show TypeFilter
f)
    checkFilterVariance (a
_,TypeFilter
ImmutableFilter) = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    checkArg :: Set ParamName -> ValueType -> m ()
checkArg Set ParamName
fa2 ta :: ValueType
ta@(ValueType StorageType
_ GeneralInstance
t) = (String
"In argument " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ValueType -> String
forall a. Show a => a -> String
show ValueType
ta) String -> m () -> m ()
forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
??> do
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ValueType -> Bool
isWeakValue ValueType
ta) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM String
"Weak values not allowed as argument types"
      r -> Set ParamName -> GeneralInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> Set ParamName -> GeneralInstance -> m ()
validateGeneralInstance r
r Set ParamName
fa2 GeneralInstance
t
      r -> ParamVariances -> Variance -> GeneralInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamVariances -> Variance -> GeneralInstance -> m ()
validateInstanceVariance r
r ParamVariances
allVariances Variance
Contravariant GeneralInstance
t
    checkReturn :: Set ParamName -> ValueType -> m ()
checkReturn Set ParamName
fa2 ta :: ValueType
ta@(ValueType StorageType
_ GeneralInstance
t) = (String
"In return " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ValueType -> String
forall a. Show a => a -> String
show ValueType
ta) String -> m () -> m ()
forall (m :: * -> *) a. ErrorContextM m => String -> m a -> m a
??> do
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ValueType -> Bool
isWeakValue ValueType
ta) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. ErrorContextM m => String -> m a
compilerErrorM String
"Weak values not allowed as return types"
      r -> Set ParamName -> GeneralInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> Set ParamName -> GeneralInstance -> m ()
validateGeneralInstance r
r Set ParamName
fa2 GeneralInstance
t
      r -> ParamVariances -> Variance -> GeneralInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamVariances -> Variance -> GeneralInstance -> m ()
validateInstanceVariance r
r ParamVariances
allVariances Variance
Covariant GeneralInstance
t

assignFunctionParams :: (CollectErrorsM m, TypeResolver r) =>
  r -> ParamFilters -> ParamValues -> Positional GeneralInstance ->
  FunctionType -> m FunctionType
assignFunctionParams :: forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r
-> ParamFilters
-> ParamValues
-> Positional GeneralInstance
-> FunctionType
-> m FunctionType
assignFunctionParams r
r ParamFilters
fm ParamValues
pm Positional GeneralInstance
ts (FunctionType Positional ValueType
as Positional ValueType
rs Positional ParamName
ps Positional [TypeFilter]
fa) = do
  (GeneralInstance -> m ()) -> [GeneralInstance] -> m ()
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m ()
mapCompilerM_ (r -> ParamFilters -> GeneralInstance -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamFilters -> GeneralInstance -> m ()
validateGeneralInstanceForCall r
r ParamFilters
fm) ([GeneralInstance] -> m ()) -> [GeneralInstance] -> m ()
forall a b. (a -> b) -> a -> b
$ Positional GeneralInstance -> [GeneralInstance]
forall a. Positional a -> [a]
pValues Positional GeneralInstance
ts
  ParamValues
assigned <- ([(ParamName, GeneralInstance)] -> ParamValues)
-> m [(ParamName, GeneralInstance)] -> m ParamValues
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(ParamName, GeneralInstance)] -> ParamValues
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList (m [(ParamName, GeneralInstance)] -> m ParamValues)
-> m [(ParamName, GeneralInstance)] -> m ParamValues
forall a b. (a -> b) -> a -> b
$ (ParamName -> GeneralInstance -> m (ParamName, GeneralInstance))
-> Positional ParamName
-> Positional GeneralInstance
-> m [(ParamName, GeneralInstance)]
forall a b (m :: * -> *) c.
(Show a, Show b, CollectErrorsM m) =>
(a -> b -> m c) -> Positional a -> Positional b -> m [c]
processPairs ParamName -> GeneralInstance -> m (ParamName, GeneralInstance)
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
alwaysPair Positional ParamName
ps Positional GeneralInstance
ts
  let pa :: ParamValues
pa = ParamValues
pm ParamValues -> ParamValues -> ParamValues
forall k a. Ord k => Map k a -> Map k a -> Map k a
`Map.union` ParamValues
assigned
  Positional [TypeFilter]
fa' <- ([[TypeFilter]] -> Positional [TypeFilter])
-> m [[TypeFilter]] -> m (Positional [TypeFilter])
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[TypeFilter]] -> Positional [TypeFilter]
forall a. [a] -> Positional a
Positional (m [[TypeFilter]] -> m (Positional [TypeFilter]))
-> m [[TypeFilter]] -> m (Positional [TypeFilter])
forall a b. (a -> b) -> a -> b
$ ([TypeFilter] -> m [TypeFilter])
-> [[TypeFilter]] -> m [[TypeFilter]]
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m [b]
mapCompilerM (ParamValues -> [TypeFilter] -> m [TypeFilter]
forall {m :: * -> *}.
CollectErrorsM m =>
ParamValues -> [TypeFilter] -> m [TypeFilter]
assignFilters ParamValues
pa) (Positional [TypeFilter] -> [[TypeFilter]]
forall a. Positional a -> [a]
pValues Positional [TypeFilter]
fa)
  (GeneralInstance -> [TypeFilter] -> m ())
-> Positional GeneralInstance -> Positional [TypeFilter] -> m ()
forall a b (m :: * -> *) c.
(Show a, Show b, CollectErrorsM m) =>
(a -> b -> m c) -> Positional a -> Positional b -> m ()
processPairs_ (r -> ParamFilters -> GeneralInstance -> [TypeFilter] -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamFilters -> GeneralInstance -> [TypeFilter] -> m ()
validateAssignment r
r ParamFilters
fm) Positional GeneralInstance
ts Positional [TypeFilter]
fa'
  Positional ValueType
as' <- ([ValueType] -> Positional ValueType)
-> m [ValueType] -> m (Positional ValueType)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [ValueType] -> Positional ValueType
forall a. [a] -> Positional a
Positional (m [ValueType] -> m (Positional ValueType))
-> m [ValueType] -> m (Positional ValueType)
forall a b. (a -> b) -> a -> b
$
         (ValueType -> m ValueType) -> [ValueType] -> m [ValueType]
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m [b]
mapCompilerM ((ParamName -> m GeneralInstance) -> ValueType -> m ValueType
forall (m :: * -> *).
CollectErrorsM m =>
(ParamName -> m GeneralInstance) -> ValueType -> m ValueType
uncheckedSubValueType ((ParamName -> m GeneralInstance) -> ValueType -> m ValueType)
-> (ParamName -> m GeneralInstance) -> ValueType -> m ValueType
forall a b. (a -> b) -> a -> b
$ ParamValues -> ParamName -> m GeneralInstance
forall (m :: * -> *).
ErrorContextM m =>
ParamValues -> ParamName -> m GeneralInstance
getValueForParam ParamValues
pa) (Positional ValueType -> [ValueType]
forall a. Positional a -> [a]
pValues Positional ValueType
as)
  Positional ValueType
rs' <- ([ValueType] -> Positional ValueType)
-> m [ValueType] -> m (Positional ValueType)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [ValueType] -> Positional ValueType
forall a. [a] -> Positional a
Positional (m [ValueType] -> m (Positional ValueType))
-> m [ValueType] -> m (Positional ValueType)
forall a b. (a -> b) -> a -> b
$
         (ValueType -> m ValueType) -> [ValueType] -> m [ValueType]
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m [b]
mapCompilerM ((ParamName -> m GeneralInstance) -> ValueType -> m ValueType
forall (m :: * -> *).
CollectErrorsM m =>
(ParamName -> m GeneralInstance) -> ValueType -> m ValueType
uncheckedSubValueType ((ParamName -> m GeneralInstance) -> ValueType -> m ValueType)
-> (ParamName -> m GeneralInstance) -> ValueType -> m ValueType
forall a b. (a -> b) -> a -> b
$ ParamValues -> ParamName -> m GeneralInstance
forall (m :: * -> *).
ErrorContextM m =>
ParamValues -> ParamName -> m GeneralInstance
getValueForParam ParamValues
pa) (Positional ValueType -> [ValueType]
forall a. Positional a -> [a]
pValues Positional ValueType
rs)
  FunctionType -> m FunctionType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (FunctionType -> m FunctionType) -> FunctionType -> m FunctionType
forall a b. (a -> b) -> a -> b
$ Positional ValueType
-> Positional ValueType
-> Positional ParamName
-> Positional [TypeFilter]
-> FunctionType
FunctionType Positional ValueType
as' Positional ValueType
rs' ([ParamName] -> Positional ParamName
forall a. [a] -> Positional a
Positional []) ([[TypeFilter]] -> Positional [TypeFilter]
forall a. [a] -> Positional a
Positional [])
  where
    assignFilters :: ParamValues -> [TypeFilter] -> m [TypeFilter]
assignFilters ParamValues
fm2 [TypeFilter]
fs = (TypeFilter -> m TypeFilter) -> [TypeFilter] -> m [TypeFilter]
forall (m :: * -> *) a b.
CollectErrorsM m =>
(a -> m b) -> [a] -> m [b]
mapCompilerM ((ParamName -> m GeneralInstance) -> TypeFilter -> m TypeFilter
forall (m :: * -> *).
CollectErrorsM m =>
(ParamName -> m GeneralInstance) -> TypeFilter -> m TypeFilter
uncheckedSubFilter ((ParamName -> m GeneralInstance) -> TypeFilter -> m TypeFilter)
-> (ParamName -> m GeneralInstance) -> TypeFilter -> m TypeFilter
forall a b. (a -> b) -> a -> b
$ ParamValues -> ParamName -> m GeneralInstance
forall (m :: * -> *).
ErrorContextM m =>
ParamValues -> ParamName -> m GeneralInstance
getValueForParam ParamValues
fm2) [TypeFilter]
fs

checkFunctionConvert :: (CollectErrorsM m, TypeResolver r) =>
  r -> ParamFilters -> ParamValues -> FunctionType -> FunctionType -> m ()
checkFunctionConvert :: forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r
-> ParamFilters
-> ParamValues
-> FunctionType
-> FunctionType
-> m ()
checkFunctionConvert r
r ParamFilters
fm ParamValues
pm (FunctionType Positional ValueType
as1 Positional ValueType
rs1 Positional ParamName
ps1 Positional [TypeFilter]
fa1) FunctionType
ff2 = do
  ParamFilters
mapped <- ([(ParamName, [TypeFilter])] -> ParamFilters)
-> m [(ParamName, [TypeFilter])] -> m ParamFilters
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(ParamName, [TypeFilter])] -> ParamFilters
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList (m [(ParamName, [TypeFilter])] -> m ParamFilters)
-> m [(ParamName, [TypeFilter])] -> m ParamFilters
forall a b. (a -> b) -> a -> b
$ (ParamName -> [TypeFilter] -> m (ParamName, [TypeFilter]))
-> Positional ParamName
-> Positional [TypeFilter]
-> m [(ParamName, [TypeFilter])]
forall a b (m :: * -> *) c.
(Show a, Show b, CollectErrorsM m) =>
(a -> b -> m c) -> Positional a -> Positional b -> m [c]
processPairs ParamName -> [TypeFilter] -> m (ParamName, [TypeFilter])
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
alwaysPair Positional ParamName
ps1 Positional [TypeFilter]
fa1
  let fm' :: ParamFilters
fm' = ParamFilters -> ParamFilters -> ParamFilters
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ParamFilters
fm ParamFilters
mapped
  let asTypes :: Positional GeneralInstance
asTypes = [GeneralInstance] -> Positional GeneralInstance
forall a. [a] -> Positional a
Positional ([GeneralInstance] -> Positional GeneralInstance)
-> [GeneralInstance] -> Positional GeneralInstance
forall a b. (a -> b) -> a -> b
$ (ParamName -> GeneralInstance) -> [ParamName] -> [GeneralInstance]
forall a b. (a -> b) -> [a] -> [b]
map (TypeInstanceOrParam -> GeneralInstance
forall a. (Eq a, Ord a) => a -> GeneralType a
singleType (TypeInstanceOrParam -> GeneralInstance)
-> (ParamName -> TypeInstanceOrParam)
-> ParamName
-> GeneralInstance
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> ParamName -> TypeInstanceOrParam
JustParamName Bool
False) ([ParamName] -> [GeneralInstance])
-> [ParamName] -> [GeneralInstance]
forall a b. (a -> b) -> a -> b
$ Positional ParamName -> [ParamName]
forall a. Positional a -> [a]
pValues Positional ParamName
ps1
  -- Substitute params from ff2 into ff1.
  (FunctionType Positional ValueType
as2 Positional ValueType
rs2 Positional ParamName
_ Positional [TypeFilter]
_) <- r
-> ParamFilters
-> ParamValues
-> Positional GeneralInstance
-> FunctionType
-> m FunctionType
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r
-> ParamFilters
-> ParamValues
-> Positional GeneralInstance
-> FunctionType
-> m FunctionType
assignFunctionParams r
r ParamFilters
fm' ParamValues
pm Positional GeneralInstance
asTypes FunctionType
ff2
  [(ParamName, [TypeFilter])]
fixed <- (ParamName -> [TypeFilter] -> m (ParamName, [TypeFilter]))
-> Positional ParamName
-> Positional [TypeFilter]
-> m [(ParamName, [TypeFilter])]
forall a b (m :: * -> *) c.
(Show a, Show b, CollectErrorsM m) =>
(a -> b -> m c) -> Positional a -> Positional b -> m [c]
processPairs ParamName -> [TypeFilter] -> m (ParamName, [TypeFilter])
forall (m :: * -> *) a b. Monad m => a -> b -> m (a, b)
alwaysPair Positional ParamName
ps1 Positional [TypeFilter]
fa1
  let fm'' :: ParamFilters
fm'' = ParamFilters -> ParamFilters -> ParamFilters
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union ParamFilters
fm ([(ParamName, [TypeFilter])] -> ParamFilters
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(ParamName, [TypeFilter])]
fixed)
  (ValueType -> ValueType -> m ())
-> Positional ValueType -> Positional ValueType -> m ()
forall a b (m :: * -> *) c.
(Show a, Show b, CollectErrorsM m) =>
(a -> b -> m c) -> Positional a -> Positional b -> m ()
processPairs_ (ParamFilters -> ValueType -> ValueType -> m ()
forall {m :: * -> *}.
CollectErrorsM m =>
ParamFilters -> ValueType -> ValueType -> m ()
validateArg ParamFilters
fm'') Positional ValueType
as1 Positional ValueType
as2
  (ValueType -> ValueType -> m ())
-> Positional ValueType -> Positional ValueType -> m ()
forall a b (m :: * -> *) c.
(Show a, Show b, CollectErrorsM m) =>
(a -> b -> m c) -> Positional a -> Positional b -> m ()
processPairs_ (ParamFilters -> ValueType -> ValueType -> m ()
forall {m :: * -> *}.
CollectErrorsM m =>
ParamFilters -> ValueType -> ValueType -> m ()
validateReturn ParamFilters
fm'') Positional ValueType
rs1 Positional ValueType
rs2
  where
    validateArg :: ParamFilters -> ValueType -> ValueType -> m ()
validateArg ParamFilters
fm2 ValueType
a1 ValueType
a2 = r -> ParamFilters -> ValueType -> ValueType -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamFilters -> ValueType -> ValueType -> m ()
checkValueAssignment r
r ParamFilters
fm2 ValueType
a1 ValueType
a2
    validateReturn :: ParamFilters -> ValueType -> ValueType -> m ()
validateReturn ParamFilters
fm2 ValueType
r1 ValueType
r2 = r -> ParamFilters -> ValueType -> ValueType -> m ()
forall (m :: * -> *) r.
(CollectErrorsM m, TypeResolver r) =>
r -> ParamFilters -> ValueType -> ValueType -> m ()
checkValueAssignment r
r ParamFilters
fm2 ValueType
r2 ValueType
r1