{-# LANGUAGE TypeFamilies #-}

-- | This module exports a type class covering representations of
-- function return types.
module Futhark.IR.RetType
  ( IsBodyType (..),
    IsRetType (..),
    expectedTypes,
  )
where

import Control.Monad.Identity
import Data.Map.Strict qualified as M
import Futhark.IR.Prop.Types
import Futhark.IR.Syntax.Core

-- | A type representing the return type of a body.  It should contain
-- at least the information contained in a list of 'ExtType's, but may
-- have more, notably an existential context.
class (Show rt, Eq rt, Ord rt, ExtTyped rt) => IsBodyType rt where
  -- | Construct a body type from a primitive type.
  primBodyType :: PrimType -> rt

instance IsBodyType ExtType where
  primBodyType :: PrimType -> ExtType
primBodyType = forall shape u. PrimType -> TypeBase shape u
Prim

-- | A type representing the return type of a function.  In practice,
-- a list of these will be used.  It should contain at least the
-- information contained in an 'ExtType', but may have more, notably
-- an existential context.
class (Show rt, Eq rt, Ord rt, DeclExtTyped rt) => IsRetType rt where
  -- | Contruct a return type from a primitive type.
  primRetType :: PrimType -> rt

  -- | Given a function return type, the parameters of the function,
  -- and the arguments for a concrete call, return the instantiated
  -- return type for the concrete call, if valid.
  applyRetType ::
    Typed dec =>
    [rt] ->
    [Param dec] ->
    [(SubExp, Type)] ->
    Maybe [rt]

-- | Given shape parameter names and types, produce the types of
-- arguments accepted.
expectedTypes :: Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes :: forall t. Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes [VName]
shapes [t]
value_ts [SubExp]
args = forall a b. (a -> b) -> [a] -> [b]
map (forall {u}. TypeBase Shape u -> TypeBase Shape u
correctDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Typed t => t -> Type
typeOf) [t]
value_ts
  where
    parammap :: M.Map VName SubExp
    parammap :: Map VName SubExp
parammap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
shapes [SubExp]
args

    correctDims :: TypeBase Shape u -> TypeBase Shape u
correctDims = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> SubExp
f)
      where
        f :: SubExp -> SubExp
f (Var VName
v)
          | Just SubExp
se <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
se
        f SubExp
se = SubExp
se

instance IsRetType DeclExtType where
  primRetType :: PrimType -> DeclExtType
primRetType = forall shape u. PrimType -> TypeBase shape u
Prim

  applyRetType :: forall dec.
Typed dec =>
[DeclExtType]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [DeclExtType]
applyRetType [DeclExtType]
extret [Param dec]
params [(SubExp, Type)]
args =
    if forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Type)]
args forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param dec]
params
      Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and
        ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf [Type]
argtypes forall a b. (a -> b) -> a -> b
$
            forall t. Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param dec]
params) [Param dec]
params forall a b. (a -> b) -> a -> b
$
              forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, Type)]
args
        )
      then forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {u}.
TypeBase (ShapeBase ExtSize) u -> TypeBase (ShapeBase ExtSize) u
correctExtDims [DeclExtType]
extret
      else forall a. Maybe a
Nothing
    where
      argtypes :: [Type]
argtypes = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(SubExp, Type)]
args

      parammap :: M.Map VName SubExp
      parammap :: Map VName SubExp
parammap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param dec]
params) (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, Type)]
args)

      correctExtDims :: TypeBase (ShapeBase ExtSize) u -> TypeBase (ShapeBase ExtSize) u
correctExtDims = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> SubExp
f)
        where
          f :: SubExp -> SubExp
f (Var VName
v)
            | Just SubExp
se <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
se
          f SubExp
se = SubExp
se