{-# LANGUAGE FlexibleInstances, TypeFamilies #-}
-- | This module exports a type class covering representations of
-- function return types.
module Futhark.IR.RetType
       (
         IsBodyType (..)
       , IsRetType (..)
       , expectedTypes
       )
       where

import qualified Data.Map.Strict as M

import Futhark.IR.Syntax.Core
import Futhark.IR.Prop.Types

-- | A type representing the return type of a body.  It should contain
-- at least the information contained in a list of 'ExtType's, but may
-- have more, notably an existential context.
class (Show rt, Eq rt, Ord rt, ExtTyped rt) => IsBodyType rt where
  -- | Construct a body type from a primitive type.
  primBodyType :: PrimType -> rt

instance IsBodyType ExtType where
  primBodyType :: PrimType -> ExtType
primBodyType = PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim

-- | A type representing the return type of a function.  In practice,
-- a list of these will be used.  It should contain at least the
-- information contained in an 'ExtType', but may have more, notably
-- an existential context.
class (Show rt, Eq rt, Ord rt, DeclExtTyped rt) => IsRetType rt where
  -- | Contruct a return type from a primitive type.
  primRetType :: PrimType -> rt

  -- | Given a function return type, the parameters of the function,
  -- and the arguments for a concrete call, return the instantiated
  -- return type for the concrete call, if valid.
  applyRetType :: Typed dec =>
                  [rt]
               -> [Param dec]
               -> [(SubExp, Type)]
               -> Maybe [rt]

-- | Given shape parameter names and value parameter types, produce the
-- types of arguments accepted.
expectedTypes :: Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes :: [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes [VName]
shapes [t]
value_ts [SubExp]
args = (t -> Type) -> [t] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
correctDims (Type -> Type) -> (t -> Type) -> t -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf) [t]
value_ts
    where parammap :: M.Map VName SubExp
          parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
shapes [SubExp]
args

          correctDims :: TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
correctDims TypeBase (ShapeBase SubExp) u
t =
            TypeBase (ShapeBase SubExp) u
t TypeBase (ShapeBase SubExp) u
-> ShapeBase SubExp -> TypeBase (ShapeBase SubExp) u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape`
            [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
correctDim ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase SubExp -> [SubExp]) -> ShapeBase SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) u -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) u
t)

          correctDim :: SubExp -> SubExp
correctDim (Constant PrimValue
v) = PrimValue -> SubExp
Constant PrimValue
v
          correctDim (Var VName
v)
            | Just SubExp
se <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
se
            | Bool
otherwise                       = VName -> SubExp
Var VName
v

instance IsRetType DeclExtType where
  primRetType :: PrimType -> DeclExtType
primRetType = PrimType -> DeclExtType
forall shape u. PrimType -> TypeBase shape u
Prim

  applyRetType :: [DeclExtType]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [DeclExtType]
applyRetType [DeclExtType]
extret [Param dec]
params [(SubExp, Type)]
args =
    if [(SubExp, Type)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Type)]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Param dec] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param dec]
params Bool -> Bool -> Bool
&&
       [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((Type -> Type -> Bool) -> [Type] -> [Type] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Type -> Bool
forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf [Type]
argtypes ([Type] -> [Bool]) -> [Type] -> [Bool]
forall a b. (a -> b) -> a -> b
$
            [VName] -> [Param dec] -> [SubExp] -> [Type]
forall t. Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params) [Param dec]
params ([SubExp] -> [Type]) -> [SubExp] -> [Type]
forall a b. (a -> b) -> a -> b
$ ((SubExp, Type) -> SubExp) -> [(SubExp, Type)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Type) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Type)]
args)
    then [DeclExtType] -> Maybe [DeclExtType]
forall a. a -> Maybe a
Just ([DeclExtType] -> Maybe [DeclExtType])
-> [DeclExtType] -> Maybe [DeclExtType]
forall a b. (a -> b) -> a -> b
$ (DeclExtType -> DeclExtType) -> [DeclExtType] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map DeclExtType -> DeclExtType
forall u.
TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
correctExtDims [DeclExtType]
extret
    else Maybe [DeclExtType]
forall a. Maybe a
Nothing
    where argtypes :: [Type]
argtypes = ((SubExp, Type) -> Type) -> [(SubExp, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Type) -> Type
forall a b. (a, b) -> b
snd [(SubExp, Type)]
args

          parammap :: M.Map VName SubExp
          parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params) (((SubExp, Type) -> SubExp) -> [(SubExp, Type)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Type) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Type)]
args)

          correctExtDims :: TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
correctExtDims TypeBase (ShapeBase (Ext SubExp)) u
t =
            TypeBase (ShapeBase (Ext SubExp)) u
t TypeBase (ShapeBase (Ext SubExp)) u
-> ShapeBase (Ext SubExp) -> TypeBase (ShapeBase (Ext SubExp)) u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape`
            [Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ((Ext SubExp -> Ext SubExp) -> [Ext SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> Ext SubExp
correctExtDim ([Ext SubExp] -> [Ext SubExp]) -> [Ext SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> ShapeBase (Ext SubExp) -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext SubExp)) u
t)

          correctExtDim :: Ext SubExp -> Ext SubExp
correctExtDim (Ext Int
i)  = Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
i
          correctExtDim (Free SubExp
d) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free (SubExp -> Ext SubExp) -> SubExp -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp
correctDim SubExp
d

          correctDim :: SubExp -> SubExp
correctDim (Constant PrimValue
v) = PrimValue -> SubExp
Constant PrimValue
v
          correctDim (Var VName
v)
            | Just SubExp
se <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
se
            | Bool
otherwise                       = VName -> SubExp
Var VName
v