{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TupleSections #-}

-- | Defunctionalization of typed, monomorphic Futhark programs without modules.
module Futhark.Internalise.Defunctionalise (transformProg) where

import qualified Control.Arrow as Arrow
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable
import Data.List (partition, sortOn, tails)
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.IR.Pretty ()
import qualified Futhark.Internalise.FreeVars as FV
import Futhark.MonadFreshNames
import Language.Futhark
import Language.Futhark.Traversals

-- | An expression or an extended 'Lambda' (with size parameters,
-- which AST lambdas do not support).
data ExtExp
  = ExtLambda [Pat] Exp StructRetType SrcLoc
  | ExtExp Exp
  deriving (Int -> ExtExp -> ShowS
[ExtExp] -> ShowS
ExtExp -> String
(Int -> ExtExp -> ShowS)
-> (ExtExp -> String) -> ([ExtExp] -> ShowS) -> Show ExtExp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ExtExp] -> ShowS
$cshowList :: [ExtExp] -> ShowS
show :: ExtExp -> String
$cshow :: ExtExp -> String
showsPrec :: Int -> ExtExp -> ShowS
$cshowsPrec :: Int -> ExtExp -> ShowS
Show)

-- | A static value stores additional information about the result of
-- defunctionalization of an expression, aside from the residual expression.
data StaticVal
  = Dynamic PatType
  | LambdaSV Pat StructRetType ExtExp Env
  | RecordSV [(Name, StaticVal)]
  | -- | The constructor that is actually present, plus
    -- the others that are not.
    SumSV Name [StaticVal] [(Name, [PatType])]
  | -- | The pair is the StaticVal and residual expression of this
    -- function as a whole, while the second StaticVal is its
    -- body. (Don't trust this too much, my understanding may have
    -- holes.)
    DynamicFun (Exp, StaticVal) StaticVal
  | IntrinsicSV
  deriving (Int -> StaticVal -> ShowS
[StaticVal] -> ShowS
StaticVal -> String
(Int -> StaticVal -> ShowS)
-> (StaticVal -> String)
-> ([StaticVal] -> ShowS)
-> Show StaticVal
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StaticVal] -> ShowS
$cshowList :: [StaticVal] -> ShowS
show :: StaticVal -> String
$cshow :: StaticVal -> String
showsPrec :: Int -> StaticVal -> ShowS
$cshowsPrec :: Int -> StaticVal -> ShowS
Show)

-- | The type is Just if this is a polymorphic binding that must be
-- instantiated.
data Binding = Binding (Maybe ([VName], StructType)) StaticVal
  deriving (Int -> Binding -> ShowS
[Binding] -> ShowS
Binding -> String
(Int -> Binding -> ShowS)
-> (Binding -> String) -> ([Binding] -> ShowS) -> Show Binding
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Binding] -> ShowS
$cshowList :: [Binding] -> ShowS
show :: Binding -> String
$cshow :: Binding -> String
showsPrec :: Int -> Binding -> ShowS
$cshowsPrec :: Int -> Binding -> ShowS
Show)

bindingSV :: Binding -> StaticVal
bindingSV :: Binding -> StaticVal
bindingSV (Binding Maybe ([VName], StructType)
_ StaticVal
sv) = StaticVal
sv

-- | Environment mapping variable names to their associated static
-- value.
type Env = M.Map VName Binding

localEnv :: Env -> DefM a -> DefM a
localEnv :: Env -> DefM a -> DefM a
localEnv Env
env = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ (Env -> Env) -> (Set VName, Env) -> (Set VName, Env)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
Arrow.second (Env
env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<>)

-- Even when using a "new" environment (for evaluating closures) we
-- still ram the global environment of DynamicFuns in there.
localNewEnv :: Env -> DefM a -> DefM a
localNewEnv :: Env -> DefM a -> DefM a
localNewEnv Env
env = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ \(Set VName
globals, Env
old_env) ->
  (Set VName
globals, (VName -> Binding -> Bool) -> Env -> Env
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\VName
k Binding
_ -> VName
k VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals) Env
old_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
env)

askEnv :: DefM Env
askEnv :: DefM Env
askEnv = ((Set VName, Env) -> Env) -> DefM Env
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Env
forall a b. (a, b) -> b
snd

isGlobal :: VName -> DefM a -> DefM a
isGlobal :: VName -> DefM a -> DefM a
isGlobal VName
v = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ (Set VName -> Set VName) -> (Set VName, Env) -> (Set VName, Env)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
Arrow.first (VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert VName
v)

replaceTypeSizes ::
  M.Map VName SizeSubst ->
  TypeBase (DimDecl VName) als ->
  TypeBase (DimDecl VName) als
replaceTypeSizes :: Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs = (DimDecl VName -> DimDecl VName)
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DimDecl VName -> DimDecl VName
onDim
  where
    onDim :: DimDecl VName -> DimDecl VName
onDim (NamedDim QualName VName
v) =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim QualName VName
v'
        Just (SubstConst Int
d) -> Int -> DimDecl VName
forall vn. Int -> DimDecl vn
ConstDim Int
d
        Maybe SizeSubst
Nothing -> QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim QualName VName
v
    onDim DimDecl VName
d = DimDecl VName
d

replaceStaticValSizes ::
  S.Set VName ->
  M.Map VName SizeSubst ->
  StaticVal ->
  StaticVal
replaceStaticValSizes :: Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv =
  case StaticVal
sv of
    StaticVal
_ | Map VName SizeSubst -> Bool
forall k a. Map k a -> Bool
M.null Map VName SizeSubst
orig_substs -> StaticVal
sv
    LambdaSV Pat
param (RetType [VName]
t_dims StructType
t) ExtExp
e Env
closure_env ->
      let substs :: Map VName SizeSubst
substs =
            (Map VName SizeSubst -> VName -> Map VName SizeSubst)
-> Map VName SizeSubst -> Set VName -> Map VName SizeSubst
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> Map VName SizeSubst -> Map VName SizeSubst)
-> Map VName SizeSubst -> VName -> Map VName SizeSubst
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName SizeSubst -> Map VName SizeSubst
forall k a. Ord k => k -> Map k a -> Map k a
M.delete) Map VName SizeSubst
orig_substs (Set VName -> Map VName SizeSubst)
-> Set VName -> Map VName SizeSubst
forall a b. (a -> b) -> a -> b
$
              [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList (Env -> [VName]
forall k a. Map k a -> [k]
M.keys Env
closure_env)
       in Pat -> RetTypeBase (DimDecl VName) () -> ExtExp -> Env -> StaticVal
LambdaSV
            (Map VName SizeSubst -> Pat -> Pat
forall x. ASTMappable x => Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs Pat
param)
            ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
t_dims (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs StructType
t))
            (Map VName SizeSubst -> ExtExp -> ExtExp
onExtExp Map VName SizeSubst
substs ExtExp
e)
            (Map VName SizeSubst -> Env -> Env
forall k.
Ord k =>
Map VName SizeSubst -> Map k Binding -> Map k Binding
onEnv Map VName SizeSubst
orig_substs Env
closure_env) --intentional
    Dynamic PatType
t ->
      PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> PatType -> PatType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
orig_substs PatType
t
    RecordSV [(Name, StaticVal)]
fs ->
      [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> (Name, StaticVal))
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> StaticVal) -> (Name, StaticVal) -> (Name, StaticVal)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs)) [(Name, StaticVal)]
fs
    SumSV Name
c [StaticVal]
svs [(Name, [PatType])]
ts ->
      Name -> [StaticVal] -> [(Name, [PatType])] -> StaticVal
SumSV Name
c ((StaticVal -> StaticVal) -> [StaticVal] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs) [StaticVal]
svs) ([(Name, [PatType])] -> StaticVal)
-> [(Name, [PatType])] -> StaticVal
forall a b. (a -> b) -> a -> b
$
        ((Name, [PatType]) -> (Name, [PatType]))
-> [(Name, [PatType])] -> [(Name, [PatType])]
forall a b. (a -> b) -> [a] -> [b]
map (([PatType] -> [PatType]) -> (Name, [PatType]) -> (Name, [PatType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([PatType] -> [PatType])
 -> (Name, [PatType]) -> (Name, [PatType]))
-> ([PatType] -> [PatType])
-> (Name, [PatType])
-> (Name, [PatType])
forall a b. (a -> b) -> a -> b
$ (PatType -> PatType) -> [PatType] -> [PatType]
forall a b. (a -> b) -> [a] -> [b]
map ((PatType -> PatType) -> [PatType] -> [PatType])
-> (PatType -> PatType) -> [PatType] -> [PatType]
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> PatType -> PatType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
orig_substs) [(Name, [PatType])]
ts
    DynamicFun (Exp
e, StaticVal
sv1) StaticVal
sv2 ->
      (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
orig_substs Exp
e, Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv1) (StaticVal -> StaticVal) -> StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$
        Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv2
    StaticVal
IntrinsicSV ->
      StaticVal
IntrinsicSV
  where
    tv :: Map VName SizeSubst -> ASTMapper m
tv Map VName SizeSubst
substs =
      ASTMapper m
forall (m :: * -> *). Monad m => ASTMapper m
identityMapper
        { mapOnPatType :: PatType -> m PatType
mapOnPatType = PatType -> m PatType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatType -> m PatType)
-> (PatType -> PatType) -> PatType -> m PatType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> PatType -> PatType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs,
          mapOnStructType :: StructType -> m StructType
mapOnStructType = StructType -> m StructType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StructType -> m StructType)
-> (StructType -> StructType) -> StructType -> m StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs,
          mapOnExp :: Exp -> m Exp
mapOnExp = Exp -> m Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> m Exp) -> (Exp -> Exp) -> Exp -> m Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs,
          mapOnName :: VName -> m VName
mapOnName = VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> m VName) -> (VName -> VName) -> VName -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs
        }

    onName :: Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs VName
v =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v'
        Maybe SizeSubst
_ -> VName
v

    onExp :: Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs (Var QualName VName
v Info PatType
t SrcLoc
loc) =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') ->
          QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
v' Info PatType
t SrcLoc
loc
        Just (SubstConst Int
d) ->
          PrimValue -> SrcLoc -> Exp
forall (f :: * -> *) vn. PrimValue -> SrcLoc -> ExpBase f vn
Literal (IntValue -> PrimValue
SignedValue (Int64 -> IntValue
Int64Value (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
d))) SrcLoc
loc
        Maybe SizeSubst
Nothing ->
          QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
v (Map VName SizeSubst -> PatType -> PatType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs (PatType -> PatType) -> Info PatType -> Info PatType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Info PatType
t) SrcLoc
loc
    onExp Map VName SizeSubst
substs (AppExp (Coerce Exp
e TypeDeclBase Info VName
tdecl SrcLoc
loc) (Info (AppRes PatType
t [VName]
ext))) =
      AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp -> TypeDeclBase Info VName -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn -> TypeDeclBase f vn -> SrcLoc -> AppExpBase f vn
Coerce (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e) TypeDeclBase Info VName
tdecl' SrcLoc
loc) (AppRes -> Info AppRes
forall a. a -> Info a
Info (PatType -> [VName] -> AppRes
AppRes (Map VName SizeSubst -> PatType -> PatType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs PatType
t) [VName]
ext))
      where
        tdecl' :: TypeDeclBase Info VName
tdecl' =
          TypeDecl :: forall (f :: * -> *) vn.
TypeExp vn -> f StructType -> TypeDeclBase f vn
TypeDecl
            { declaredType :: TypeExp VName
declaredType = Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs (TypeExp VName -> TypeExp VName) -> TypeExp VName -> TypeExp VName
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> TypeExp VName
forall (f :: * -> *) vn. TypeDeclBase f vn -> TypeExp vn
declaredType TypeDeclBase Info VName
tdecl,
              expandedType :: Info StructType
expandedType = Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs (StructType -> StructType) -> Info StructType -> Info StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
tdecl
            }
    onExp Map VName SizeSubst
substs Exp
e = Map VName SizeSubst -> Exp -> Exp
forall x. ASTMappable x => Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs Exp
e

    onTypeExpDim :: Map VName SizeSubst -> DimExp VName -> DimExp VName
onTypeExpDim Map VName SizeSubst
substs d :: DimExp VName
d@(DimExpNamed QualName VName
v SrcLoc
loc) =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') ->
          QualName VName -> SrcLoc -> DimExp VName
forall vn. QualName vn -> SrcLoc -> DimExp vn
DimExpNamed QualName VName
v' SrcLoc
loc
        Just (SubstConst Int
x) ->
          Int -> SrcLoc -> DimExp VName
forall vn. Int -> SrcLoc -> DimExp vn
DimExpConst Int
x SrcLoc
loc
        Maybe SizeSubst
Nothing ->
          DimExp VName
d
    onTypeExpDim Map VName SizeSubst
_ DimExp VName
d = DimExp VName
d

    onTypeArgExp :: Map VName SizeSubst -> TypeArgExp VName -> TypeArgExp VName
onTypeArgExp Map VName SizeSubst
substs (TypeArgExpDim DimExp VName
d SrcLoc
loc) =
      DimExp VName -> SrcLoc -> TypeArgExp VName
forall vn. DimExp vn -> SrcLoc -> TypeArgExp vn
TypeArgExpDim (Map VName SizeSubst -> DimExp VName -> DimExp VName
onTypeExpDim Map VName SizeSubst
substs DimExp VName
d) SrcLoc
loc
    onTypeArgExp Map VName SizeSubst
substs (TypeArgExpType TypeExp VName
te) =
      TypeExp VName -> TypeArgExp VName
forall vn. TypeExp vn -> TypeArgExp vn
TypeArgExpType (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs TypeExp VName
te)

    onTypeExp :: Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs (TEArray TypeExp VName
te DimExp VName
d SrcLoc
loc) =
      TypeExp VName -> DimExp VName -> SrcLoc -> TypeExp VName
forall vn. TypeExp vn -> DimExp vn -> SrcLoc -> TypeExp vn
TEArray (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs TypeExp VName
te) (Map VName SizeSubst -> DimExp VName -> DimExp VName
onTypeExpDim Map VName SizeSubst
substs DimExp VName
d) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEUnique TypeExp VName
t SrcLoc
loc) =
      TypeExp VName -> SrcLoc -> TypeExp VName
forall vn. TypeExp vn -> SrcLoc -> TypeExp vn
TEUnique (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs TypeExp VName
t) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEApply TypeExp VName
t1 TypeArgExp VName
t2 SrcLoc
loc) =
      TypeExp VName -> TypeArgExp VName -> SrcLoc -> TypeExp VName
forall vn. TypeExp vn -> TypeArgExp vn -> SrcLoc -> TypeExp vn
TEApply (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs TypeExp VName
t1) (Map VName SizeSubst -> TypeArgExp VName -> TypeArgExp VName
onTypeArgExp Map VName SizeSubst
substs TypeArgExp VName
t2) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEArrow Maybe VName
p TypeExp VName
t1 TypeExp VName
t2 SrcLoc
loc) =
      Maybe VName
-> TypeExp VName -> TypeExp VName -> SrcLoc -> TypeExp VName
forall vn.
Maybe vn -> TypeExp vn -> TypeExp vn -> SrcLoc -> TypeExp vn
TEArrow Maybe VName
p (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs TypeExp VName
t1) (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs TypeExp VName
t2) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TETuple [TypeExp VName]
ts SrcLoc
loc) =
      [TypeExp VName] -> SrcLoc -> TypeExp VName
forall vn. [TypeExp vn] -> SrcLoc -> TypeExp vn
TETuple ((TypeExp VName -> TypeExp VName)
-> [TypeExp VName] -> [TypeExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs) [TypeExp VName]
ts) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TERecord [(Name, TypeExp VName)]
ts SrcLoc
loc) =
      [(Name, TypeExp VName)] -> SrcLoc -> TypeExp VName
forall vn. [(Name, TypeExp vn)] -> SrcLoc -> TypeExp vn
TERecord (((Name, TypeExp VName) -> (Name, TypeExp VName))
-> [(Name, TypeExp VName)] -> [(Name, TypeExp VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((TypeExp VName -> TypeExp VName)
-> (Name, TypeExp VName) -> (Name, TypeExp VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TypeExp VName -> TypeExp VName)
 -> (Name, TypeExp VName) -> (Name, TypeExp VName))
-> (TypeExp VName -> TypeExp VName)
-> (Name, TypeExp VName)
-> (Name, TypeExp VName)
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs) [(Name, TypeExp VName)]
ts) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TESum [(Name, [TypeExp VName])]
ts SrcLoc
loc) =
      [(Name, [TypeExp VName])] -> SrcLoc -> TypeExp VName
forall vn. [(Name, [TypeExp vn])] -> SrcLoc -> TypeExp vn
TESum (((Name, [TypeExp VName]) -> (Name, [TypeExp VName]))
-> [(Name, [TypeExp VName])] -> [(Name, [TypeExp VName])]
forall a b. (a -> b) -> [a] -> [b]
map (([TypeExp VName] -> [TypeExp VName])
-> (Name, [TypeExp VName]) -> (Name, [TypeExp VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([TypeExp VName] -> [TypeExp VName])
 -> (Name, [TypeExp VName]) -> (Name, [TypeExp VName]))
-> ([TypeExp VName] -> [TypeExp VName])
-> (Name, [TypeExp VName])
-> (Name, [TypeExp VName])
forall a b. (a -> b) -> a -> b
$ (TypeExp VName -> TypeExp VName)
-> [TypeExp VName] -> [TypeExp VName]
forall a b. (a -> b) -> [a] -> [b]
map ((TypeExp VName -> TypeExp VName)
 -> [TypeExp VName] -> [TypeExp VName])
-> (TypeExp VName -> TypeExp VName)
-> [TypeExp VName]
-> [TypeExp VName]
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs) [(Name, [TypeExp VName])]
ts) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEDim [VName]
dims TypeExp VName
t SrcLoc
loc) =
      [VName] -> TypeExp VName -> SrcLoc -> TypeExp VName
forall vn. [vn] -> TypeExp vn -> SrcLoc -> TypeExp vn
TEDim [VName]
dims (Map VName SizeSubst -> TypeExp VName -> TypeExp VName
onTypeExp Map VName SizeSubst
substs TypeExp VName
t) SrcLoc
loc
    onTypeExp Map VName SizeSubst
_ (TEVar QualName VName
v SrcLoc
loc) =
      QualName VName -> SrcLoc -> TypeExp VName
forall vn. QualName vn -> SrcLoc -> TypeExp vn
TEVar QualName VName
v SrcLoc
loc

    onExtExp :: Map VName SizeSubst -> ExtExp -> ExtExp
onExtExp Map VName SizeSubst
substs (ExtExp Exp
e) =
      Exp -> ExtExp
ExtExp (Exp -> ExtExp) -> Exp -> ExtExp
forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e
    onExtExp Map VName SizeSubst
substs (ExtLambda [Pat]
params Exp
e (RetType [VName]
t_dims StructType
t) SrcLoc
loc) =
      [Pat] -> Exp -> RetTypeBase (DimDecl VName) () -> SrcLoc -> ExtExp
ExtLambda
        ((Pat -> Pat) -> [Pat] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SizeSubst -> Pat -> Pat
forall x. ASTMappable x => Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs) [Pat]
params)
        (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e)
        ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
t_dims (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs StructType
t))
        SrcLoc
loc

    onEnv :: Map VName SizeSubst -> Map k Binding -> Map k Binding
onEnv Map VName SizeSubst
substs =
      [(k, Binding)] -> Map k Binding
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
        ([(k, Binding)] -> Map k Binding)
-> (Map k Binding -> [(k, Binding)])
-> Map k Binding
-> Map k Binding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((k, Binding) -> (k, Binding)) -> [(k, Binding)] -> [(k, Binding)]
forall a b. (a -> b) -> [a] -> [b]
map ((Binding -> Binding) -> (k, Binding) -> (k, Binding)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Map VName SizeSubst -> Binding -> Binding
onBinding Map VName SizeSubst
substs))
        ([(k, Binding)] -> [(k, Binding)])
-> (Map k Binding -> [(k, Binding)])
-> Map k Binding
-> [(k, Binding)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map k Binding -> [(k, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList

    onBinding :: Map VName SizeSubst -> Binding -> Binding
onBinding Map VName SizeSubst
substs (Binding Maybe ([VName], StructType)
t StaticVal
bsv) =
      Maybe ([VName], StructType) -> StaticVal -> Binding
Binding
        ((StructType -> StructType)
-> ([VName], StructType) -> ([VName], StructType)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
substs) (([VName], StructType) -> ([VName], StructType))
-> Maybe ([VName], StructType) -> Maybe ([VName], StructType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ([VName], StructType)
t)
        (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
substs StaticVal
bsv)

    onAST :: ASTMappable x => M.Map VName SizeSubst -> x -> x
    onAST :: Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs = Identity x -> x
forall a. Identity a -> a
runIdentity (Identity x -> x) -> (x -> Identity x) -> x -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASTMapper Identity -> x -> Identity x
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap (Map VName SizeSubst -> ASTMapper Identity
forall (m :: * -> *). Monad m => Map VName SizeSubst -> ASTMapper m
tv Map VName SizeSubst
substs)

-- | Returns the defunctionalization environment restricted
-- to the given set of variable names and types.
restrictEnvTo :: FV.NameSet -> DefM Env
restrictEnvTo :: NameSet -> DefM Env
restrictEnvTo (FV.NameSet Map VName StructType
m) = ((Set VName, Env) -> Env) -> DefM Env
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Env
restrict
  where
    restrict :: (Set VName, Env) -> Env
restrict (Set VName
globals, Env
env) = (VName -> Binding -> Maybe Binding) -> Env -> Env
forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
M.mapMaybeWithKey VName -> Binding -> Maybe Binding
keep Env
env
      where
        keep :: VName -> Binding -> Maybe Binding
keep VName
k (Binding Maybe ([VName], StructType)
t StaticVal
sv) = do
          Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
k VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals
          Uniqueness
u <- StructType -> Uniqueness
forall shape as. TypeBase shape as -> Uniqueness
uniqueness (StructType -> Uniqueness) -> Maybe StructType -> Maybe Uniqueness
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName StructType -> Maybe StructType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k Map VName StructType
m
          Binding -> Maybe Binding
forall a. a -> Maybe a
Just (Binding -> Maybe Binding) -> Binding -> Maybe Binding
forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
t (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv
    restrict' :: Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
Nonunique (Dynamic PatType
t) =
      PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ PatType
t PatType -> Uniqueness -> PatType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
    restrict' Uniqueness
_ (Dynamic PatType
t) =
      PatType -> StaticVal
Dynamic PatType
t
    restrict' Uniqueness
u (LambdaSV Pat
pat RetTypeBase (DimDecl VName) ()
t ExtExp
e Env
env) =
      Pat -> RetTypeBase (DimDecl VName) () -> ExtExp -> Env -> StaticVal
LambdaSV Pat
pat RetTypeBase (DimDecl VName) ()
t ExtExp
e (Env -> StaticVal) -> Env -> StaticVal
forall a b. (a -> b) -> a -> b
$ (Binding -> Binding) -> Env -> Env
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Uniqueness -> Binding -> Binding
restrict'' Uniqueness
u) Env
env
    restrict' Uniqueness
u (RecordSV [(Name, StaticVal)]
fields) =
      [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> (Name, StaticVal))
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> StaticVal) -> (Name, StaticVal) -> (Name, StaticVal)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((StaticVal -> StaticVal)
 -> (Name, StaticVal) -> (Name, StaticVal))
-> (StaticVal -> StaticVal)
-> (Name, StaticVal)
-> (Name, StaticVal)
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u) [(Name, StaticVal)]
fields
    restrict' Uniqueness
u (SumSV Name
c [StaticVal]
svs [(Name, [PatType])]
fields) =
      Name -> [StaticVal] -> [(Name, [PatType])] -> StaticVal
SumSV Name
c ((StaticVal -> StaticVal) -> [StaticVal] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map (Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u) [StaticVal]
svs) [(Name, [PatType])]
fields
    restrict' Uniqueness
u (DynamicFun (Exp
e, StaticVal
sv1) StaticVal
sv2) =
      (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp
e, Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv1) (StaticVal -> StaticVal) -> StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv2
    restrict' Uniqueness
_ StaticVal
IntrinsicSV = StaticVal
IntrinsicSV
    restrict'' :: Uniqueness -> Binding -> Binding
restrict'' Uniqueness
u (Binding Maybe ([VName], StructType)
t StaticVal
sv) = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
t (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv

-- | Defunctionalization monad.  The Reader environment tracks both
-- the current Env as well as the set of globally defined dynamic
-- functions.  This is used to avoid unnecessarily large closure
-- environments.
newtype DefM a
  = DefM (ReaderT (S.Set VName, Env) (State ([ValBind], VNameSource)) a)
  deriving
    ( a -> DefM b -> DefM a
(a -> b) -> DefM a -> DefM b
(forall a b. (a -> b) -> DefM a -> DefM b)
-> (forall a b. a -> DefM b -> DefM a) -> Functor DefM
forall a b. a -> DefM b -> DefM a
forall a b. (a -> b) -> DefM a -> DefM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DefM b -> DefM a
$c<$ :: forall a b. a -> DefM b -> DefM a
fmap :: (a -> b) -> DefM a -> DefM b
$cfmap :: forall a b. (a -> b) -> DefM a -> DefM b
Functor,
      Functor DefM
a -> DefM a
Functor DefM
-> (forall a. a -> DefM a)
-> (forall a b. DefM (a -> b) -> DefM a -> DefM b)
-> (forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c)
-> (forall a b. DefM a -> DefM b -> DefM b)
-> (forall a b. DefM a -> DefM b -> DefM a)
-> Applicative DefM
DefM a -> DefM b -> DefM b
DefM a -> DefM b -> DefM a
DefM (a -> b) -> DefM a -> DefM b
(a -> b -> c) -> DefM a -> DefM b -> DefM c
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM (a -> b) -> DefM a -> DefM b
forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DefM a -> DefM b -> DefM a
$c<* :: forall a b. DefM a -> DefM b -> DefM a
*> :: DefM a -> DefM b -> DefM b
$c*> :: forall a b. DefM a -> DefM b -> DefM b
liftA2 :: (a -> b -> c) -> DefM a -> DefM b -> DefM c
$cliftA2 :: forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
<*> :: DefM (a -> b) -> DefM a -> DefM b
$c<*> :: forall a b. DefM (a -> b) -> DefM a -> DefM b
pure :: a -> DefM a
$cpure :: forall a. a -> DefM a
$cp1Applicative :: Functor DefM
Applicative,
      Applicative DefM
a -> DefM a
Applicative DefM
-> (forall a b. DefM a -> (a -> DefM b) -> DefM b)
-> (forall a b. DefM a -> DefM b -> DefM b)
-> (forall a. a -> DefM a)
-> Monad DefM
DefM a -> (a -> DefM b) -> DefM b
DefM a -> DefM b -> DefM b
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM a -> (a -> DefM b) -> DefM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DefM a
$creturn :: forall a. a -> DefM a
>> :: DefM a -> DefM b -> DefM b
$c>> :: forall a b. DefM a -> DefM b -> DefM b
>>= :: DefM a -> (a -> DefM b) -> DefM b
$c>>= :: forall a b. DefM a -> (a -> DefM b) -> DefM b
$cp1Monad :: Applicative DefM
Monad,
      MonadReader (S.Set VName, Env),
      MonadState ([ValBind], VNameSource)
    )

instance MonadFreshNames DefM where
  putNameSource :: VNameSource -> DefM ()
putNameSource VNameSource
src = (([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ())
-> (([ValBind], VNameSource) -> ([ValBind], VNameSource))
-> DefM ()
forall a b. (a -> b) -> a -> b
$ \([ValBind]
x, VNameSource
_) -> ([ValBind]
x, VNameSource
src)
  getNameSource :: DefM VNameSource
getNameSource = (([ValBind], VNameSource) -> VNameSource) -> DefM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ([ValBind], VNameSource) -> VNameSource
forall a b. (a, b) -> b
snd

-- | Run a computation in the defunctionalization monad. Returns the result of
-- the computation, a new name source, and a list of lifted function declations.
runDefM :: VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM :: VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM VNameSource
src (DefM ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
m) =
  let (a
x, ([ValBind]
vbs, VNameSource
src')) = State ([ValBind], VNameSource) a
-> ([ValBind], VNameSource) -> (a, ([ValBind], VNameSource))
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
-> (Set VName, Env) -> State ([ValBind], VNameSource) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
m (Set VName, Env)
forall a. Monoid a => a
mempty) ([ValBind]
forall a. Monoid a => a
mempty, VNameSource
src)
   in (a
x, VNameSource
src', [ValBind] -> [ValBind]
forall a. [a] -> [a]
reverse [ValBind]
vbs)

addValBind :: ValBind -> DefM ()
addValBind :: ValBind -> DefM ()
addValBind ValBind
vb = (([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((([ValBind], VNameSource) -> ([ValBind], VNameSource)) -> DefM ())
-> (([ValBind], VNameSource) -> ([ValBind], VNameSource))
-> DefM ()
forall a b. (a -> b) -> a -> b
$ ([ValBind] -> [ValBind])
-> ([ValBind], VNameSource) -> ([ValBind], VNameSource)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ValBind
vb ValBind -> [ValBind] -> [ValBind]
forall a. a -> [a] -> [a]
:)

-- | Looks up the associated static value for a given name in the environment.
lookupVar :: StructType -> VName -> DefM StaticVal
lookupVar :: StructType -> VName -> DefM StaticVal
lookupVar StructType
t VName
x = do
  Env
env <- DefM Env
askEnv
  case VName -> Env -> Maybe Binding
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Env
env of
    Just (Binding (Just ([VName]
dims, StructType
sv_t)) StaticVal
sv) -> do
      Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst
      Set VName
-> [VName]
-> StructType
-> StructType
-> StaticVal
-> DefM StaticVal
forall (m :: * -> *).
MonadFreshNames m =>
Set VName
-> [VName] -> StructType -> StructType -> StaticVal -> m StaticVal
instStaticVal Set VName
globals [VName]
dims StructType
t StructType
sv_t StaticVal
sv
    Just (Binding Maybe ([VName], StructType)
Nothing StaticVal
sv) ->
      StaticVal -> DefM StaticVal
forall (f :: * -> *) a. Applicative f => a -> f a
pure StaticVal
sv
    Maybe Binding
Nothing -- If the variable is unknown, it may refer to the 'intrinsics'
    -- module, which we will have to treat specially.
      | VName -> Int
baseTag VName
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag -> StaticVal -> DefM StaticVal
forall (m :: * -> *) a. Monad m => a -> m a
return StaticVal
IntrinsicSV
      | Bool
otherwise ->
        -- Anything not in scope is going to be an existential size.
        StaticVal -> DefM StaticVal
forall (m :: * -> *) a. Monad m => a -> m a
return (StaticVal -> DefM StaticVal) -> StaticVal -> DefM StaticVal
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

-- Like patternDimNames, but ignores sizes that are only found in
-- funtion types.
arraySizes :: StructType -> S.Set VName
arraySizes :: StructType -> Set VName
arraySizes (Scalar Arrow {}) = Set VName
forall a. Monoid a => a
mempty
arraySizes (Scalar (Record Map Name StructType
fields)) = (StructType -> Set VName) -> Map Name StructType -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes Map Name StructType
fields
arraySizes (Scalar (Sum Map Name [StructType]
cs)) = ([StructType] -> Set VName) -> Map Name [StructType] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((StructType -> Set VName) -> [StructType] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes) Map Name [StructType]
cs
arraySizes (Scalar (TypeVar ()
_ Uniqueness
_ TypeName
_ [TypeArg (DimDecl VName)]
targs)) =
  [Set VName] -> Set VName
forall a. Monoid a => [a] -> a
mconcat ([Set VName] -> Set VName) -> [Set VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ (TypeArg (DimDecl VName) -> Set VName)
-> [TypeArg (DimDecl VName)] -> [Set VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeArg (DimDecl VName) -> Set VName
f [TypeArg (DimDecl VName)]
targs
  where
    f :: TypeArg (DimDecl VName) -> Set VName
f (TypeArgDim (NamedDim QualName VName
d) SrcLoc
_) = VName -> Set VName
forall a. a -> Set a
S.singleton (VName -> Set VName) -> VName -> Set VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
    f TypeArgDim {} = Set VName
forall a. Monoid a => a
mempty
    f (TypeArgType StructType
t SrcLoc
_) = StructType -> Set VName
arraySizes StructType
t
arraySizes (Scalar Prim {}) = Set VName
forall a. Monoid a => a
mempty
arraySizes (Array ()
_ Uniqueness
_ ScalarTypeBase (DimDecl VName) ()
t ShapeDecl (DimDecl VName)
shape) =
  StructType -> Set VName
arraySizes (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar ScalarTypeBase (DimDecl VName) ()
t) Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (DimDecl VName -> Set VName) -> [DimDecl VName] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DimDecl VName -> Set VName
dimName (ShapeDecl (DimDecl VName) -> [DimDecl VName]
forall dim. ShapeDecl dim -> [dim]
shapeDims ShapeDecl (DimDecl VName)
shape)
  where
    dimName :: DimDecl VName -> S.Set VName
    dimName :: DimDecl VName -> Set VName
dimName (NamedDim QualName VName
qn) = VName -> Set VName
forall a. a -> Set a
S.singleton (VName -> Set VName) -> VName -> Set VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn
    dimName DimDecl VName
_ = Set VName
forall a. Monoid a => a
mempty

patternArraySizes :: Pat -> S.Set VName
patternArraySizes :: Pat -> Set VName
patternArraySizes = StructType -> Set VName
arraySizes (StructType -> Set VName)
-> (Pat -> StructType) -> Pat -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat -> StructType
patternStructType

data SizeSubst
  = SubstNamed (QualName VName)
  | SubstConst Int
  deriving (SizeSubst -> SizeSubst -> Bool
(SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool) -> Eq SizeSubst
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeSubst -> SizeSubst -> Bool
$c/= :: SizeSubst -> SizeSubst -> Bool
== :: SizeSubst -> SizeSubst -> Bool
$c== :: SizeSubst -> SizeSubst -> Bool
Eq, Eq SizeSubst
Eq SizeSubst
-> (SizeSubst -> SizeSubst -> Ordering)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> Bool)
-> (SizeSubst -> SizeSubst -> SizeSubst)
-> (SizeSubst -> SizeSubst -> SizeSubst)
-> Ord SizeSubst
SizeSubst -> SizeSubst -> Bool
SizeSubst -> SizeSubst -> Ordering
SizeSubst -> SizeSubst -> SizeSubst
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeSubst -> SizeSubst -> SizeSubst
$cmin :: SizeSubst -> SizeSubst -> SizeSubst
max :: SizeSubst -> SizeSubst -> SizeSubst
$cmax :: SizeSubst -> SizeSubst -> SizeSubst
>= :: SizeSubst -> SizeSubst -> Bool
$c>= :: SizeSubst -> SizeSubst -> Bool
> :: SizeSubst -> SizeSubst -> Bool
$c> :: SizeSubst -> SizeSubst -> Bool
<= :: SizeSubst -> SizeSubst -> Bool
$c<= :: SizeSubst -> SizeSubst -> Bool
< :: SizeSubst -> SizeSubst -> Bool
$c< :: SizeSubst -> SizeSubst -> Bool
compare :: SizeSubst -> SizeSubst -> Ordering
$ccompare :: SizeSubst -> SizeSubst -> Ordering
$cp1Ord :: Eq SizeSubst
Ord, Int -> SizeSubst -> ShowS
[SizeSubst] -> ShowS
SizeSubst -> String
(Int -> SizeSubst -> ShowS)
-> (SizeSubst -> String)
-> ([SizeSubst] -> ShowS)
-> Show SizeSubst
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SizeSubst] -> ShowS
$cshowList :: [SizeSubst] -> ShowS
show :: SizeSubst -> String
$cshow :: SizeSubst -> String
showsPrec :: Int -> SizeSubst -> ShowS
$cshowsPrec :: Int -> SizeSubst -> ShowS
Show)

dimMapping ::
  Monoid a =>
  TypeBase (DimDecl VName) a ->
  TypeBase (DimDecl VName) a ->
  M.Map VName SizeSubst
dimMapping :: TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName SizeSubst
dimMapping TypeBase (DimDecl VName) a
t1 TypeBase (DimDecl VName) a
t2 = State (Map VName SizeSubst) (TypeBase (DimDecl VName) a)
-> Map VName SizeSubst -> Map VName SizeSubst
forall s a. State s a -> s -> s
execState (([VName]
 -> DimDecl VName
 -> DimDecl VName
 -> StateT (Map VName SizeSubst) Identity (DimDecl VName))
-> TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a
-> State (Map VName SizeSubst) (TypeBase (DimDecl VName) a)
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
([VName] -> d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims [VName]
-> DimDecl VName
-> DimDecl VName
-> StateT (Map VName SizeSubst) Identity (DimDecl VName)
forall (t :: * -> *) (f :: * -> *) vn.
(Foldable t, MonadState (Map vn SizeSubst) f, Ord vn) =>
t VName -> DimDecl vn -> DimDecl VName -> f (DimDecl vn)
f TypeBase (DimDecl VName) a
t1 TypeBase (DimDecl VName) a
t2) Map VName SizeSubst
forall a. Monoid a => a
mempty
  where
    f :: t VName -> DimDecl vn -> DimDecl VName -> f (DimDecl vn)
f t VName
bound DimDecl vn
d1 (NamedDim QualName VName
d2)
      | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d2 VName -> t VName -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t VName
bound = DimDecl vn -> f (DimDecl vn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl vn
d1
    f t VName
_ (NamedDim QualName vn
d1) (NamedDim QualName VName
d2) = do
      (Map vn SizeSubst -> Map vn SizeSubst) -> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map vn SizeSubst -> Map vn SizeSubst) -> f ())
-> (Map vn SizeSubst -> Map vn SizeSubst) -> f ()
forall a b. (a -> b) -> a -> b
$ vn -> SizeSubst -> Map vn SizeSubst -> Map vn SizeSubst
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName vn -> vn
forall vn. QualName vn -> vn
qualLeaf QualName vn
d1) (SizeSubst -> Map vn SizeSubst -> Map vn SizeSubst)
-> SizeSubst -> Map vn SizeSubst -> Map vn SizeSubst
forall a b. (a -> b) -> a -> b
$ QualName VName -> SizeSubst
SubstNamed QualName VName
d2
      DimDecl vn -> f (DimDecl vn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl vn -> f (DimDecl vn)) -> DimDecl vn -> f (DimDecl vn)
forall a b. (a -> b) -> a -> b
$ QualName vn -> DimDecl vn
forall vn. QualName vn -> DimDecl vn
NamedDim QualName vn
d1
    f t VName
_ (NamedDim QualName vn
d1) (ConstDim Int
d2) = do
      (Map vn SizeSubst -> Map vn SizeSubst) -> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map vn SizeSubst -> Map vn SizeSubst) -> f ())
-> (Map vn SizeSubst -> Map vn SizeSubst) -> f ()
forall a b. (a -> b) -> a -> b
$ vn -> SizeSubst -> Map vn SizeSubst -> Map vn SizeSubst
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName vn -> vn
forall vn. QualName vn -> vn
qualLeaf QualName vn
d1) (SizeSubst -> Map vn SizeSubst -> Map vn SizeSubst)
-> SizeSubst -> Map vn SizeSubst -> Map vn SizeSubst
forall a b. (a -> b) -> a -> b
$ Int -> SizeSubst
SubstConst Int
d2
      DimDecl vn -> f (DimDecl vn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl vn -> f (DimDecl vn)) -> DimDecl vn -> f (DimDecl vn)
forall a b. (a -> b) -> a -> b
$ QualName vn -> DimDecl vn
forall vn. QualName vn -> DimDecl vn
NamedDim QualName vn
d1
    f t VName
_ DimDecl vn
d DimDecl VName
_ = DimDecl vn -> f (DimDecl vn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl vn
d

dimMapping' ::
  Monoid a =>
  TypeBase (DimDecl VName) a ->
  TypeBase (DimDecl VName) a ->
  M.Map VName VName
dimMapping' :: TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName VName
dimMapping' TypeBase (DimDecl VName) a
t1 TypeBase (DimDecl VName) a
t2 = (SizeSubst -> Maybe VName)
-> Map VName SizeSubst -> Map VName VName
forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe SizeSubst -> Maybe VName
f (Map VName SizeSubst -> Map VName VName)
-> Map VName SizeSubst -> Map VName VName
forall a b. (a -> b) -> a -> b
$ TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName SizeSubst
forall a.
Monoid a =>
TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName SizeSubst
dimMapping TypeBase (DimDecl VName) a
t1 TypeBase (DimDecl VName) a
t2
  where
    f :: SizeSubst -> Maybe VName
f (SubstNamed QualName VName
d) = VName -> Maybe VName
forall a. a -> Maybe a
Just (VName -> Maybe VName) -> VName -> Maybe VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
    f SizeSubst
_ = Maybe VName
forall a. Maybe a
Nothing

sizesToRename :: StaticVal -> S.Set VName
sizesToRename :: StaticVal -> Set VName
sizesToRename (DynamicFun (Exp
_, StaticVal
sv1) StaticVal
sv2) =
  StaticVal -> Set VName
sizesToRename StaticVal
sv1 Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> StaticVal -> Set VName
sizesToRename StaticVal
sv2
sizesToRename StaticVal
IntrinsicSV =
  Set VName
forall a. Monoid a => a
mempty
sizesToRename Dynamic {} =
  Set VName
forall a. Monoid a => a
mempty
sizesToRename (RecordSV [(Name, StaticVal)]
fs) =
  ((Name, StaticVal) -> Set VName)
-> [(Name, StaticVal)] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StaticVal -> Set VName
sizesToRename (StaticVal -> Set VName)
-> ((Name, StaticVal) -> StaticVal)
-> (Name, StaticVal)
-> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd) [(Name, StaticVal)]
fs
sizesToRename (SumSV Name
_ [StaticVal]
svs [(Name, [PatType])]
_) =
  (StaticVal -> Set VName) -> [StaticVal] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StaticVal -> Set VName
sizesToRename [StaticVal]
svs
sizesToRename (LambdaSV Pat
param RetTypeBase (DimDecl VName) ()
_ ExtExp
_ Env
_) =
  Pat -> Set VName
patternDimNames Pat
param
    Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (IdentBase Info VName -> VName)
-> Set (IdentBase Info VName) -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName ((IdentBase Info VName -> Bool)
-> Set (IdentBase Info VName) -> Set (IdentBase Info VName)
forall a. (a -> Bool) -> Set a -> Set a
S.filter IdentBase Info VName -> Bool
forall vn. IdentBase Info vn -> Bool
couldBeSize (Set (IdentBase Info VName) -> Set (IdentBase Info VName))
-> Set (IdentBase Info VName) -> Set (IdentBase Info VName)
forall a b. (a -> b) -> a -> b
$ Pat -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents Pat
param)
  where
    couldBeSize :: IdentBase Info vn -> Bool
couldBeSize IdentBase Info vn
ident =
      Info PatType -> PatType
forall a. Info a -> a
unInfo (IdentBase Info vn -> Info PatType
forall (f :: * -> *) vn. IdentBase f vn -> f PatType
identType IdentBase Info vn
ident) PatType -> PatType -> Bool
forall a. Eq a => a -> a -> Bool
== ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (IntType -> PrimType
Signed IntType
Int64))

-- When we instantiate a polymorphic StaticVal, we rename all the
-- sizes to avoid name conflicts later on.  This is a bit of a hack...
instStaticVal ::
  MonadFreshNames m =>
  S.Set VName ->
  [VName] ->
  StructType ->
  StructType ->
  StaticVal ->
  m StaticVal
instStaticVal :: Set VName
-> [VName] -> StructType -> StructType -> StaticVal -> m StaticVal
instStaticVal Set VName
globals [VName]
dims StructType
t StructType
sv_t StaticVal
sv = do
  Map VName SizeSubst
fresh_substs <- [VName] -> m (Map VName SizeSubst)
forall (f :: * -> *).
MonadFreshNames f =>
[VName] -> f (Map VName SizeSubst)
mkSubsts ([VName] -> m (Map VName SizeSubst))
-> [VName] -> m (Map VName SizeSubst)
forall a b. (a -> b) -> a -> b
$ Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> StaticVal -> Set VName
sizesToRename StaticVal
sv

  let dims' :: [VName]
dims' = (VName -> VName) -> [VName] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
fresh_substs) [VName]
dims
      isDim :: VName -> p -> Bool
isDim VName
k p
_ = VName
k VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
dims'
      dim_substs :: Map VName SizeSubst
dim_substs =
        (VName -> SizeSubst -> Bool)
-> Map VName SizeSubst -> Map VName SizeSubst
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey VName -> SizeSubst -> Bool
forall p. VName -> p -> Bool
isDim (Map VName SizeSubst -> Map VName SizeSubst)
-> Map VName SizeSubst -> Map VName SizeSubst
forall a b. (a -> b) -> a -> b
$ StructType -> StructType -> Map VName SizeSubst
forall a.
Monoid a =>
TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName SizeSubst
dimMapping (Map VName SizeSubst -> StructType -> StructType
forall als.
Map VName SizeSubst
-> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
replaceTypeSizes Map VName SizeSubst
fresh_substs StructType
sv_t) StructType
t
      replace :: SizeSubst -> SizeSubst
replace (SubstNamed QualName VName
k) = SizeSubst -> Maybe SizeSubst -> SizeSubst
forall a. a -> Maybe a -> a
fromMaybe (QualName VName -> SizeSubst
SubstNamed QualName VName
k) (Maybe SizeSubst -> SizeSubst) -> Maybe SizeSubst -> SizeSubst
forall a b. (a -> b) -> a -> b
$ VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
k) Map VName SizeSubst
dim_substs
      replace SizeSubst
k = SizeSubst
k
      substs :: Map VName SizeSubst
substs = (SizeSubst -> SizeSubst)
-> Map VName SizeSubst -> Map VName SizeSubst
forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeSubst -> SizeSubst
replace Map VName SizeSubst
fresh_substs Map VName SizeSubst -> Map VName SizeSubst -> Map VName SizeSubst
forall a. Semigroup a => a -> a -> a
<> Map VName SizeSubst
dim_substs

  StaticVal -> m StaticVal
forall (f :: * -> *) a. Applicative f => a -> f a
pure (StaticVal -> m StaticVal) -> StaticVal -> m StaticVal
forall a b. (a -> b) -> a -> b
$ Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
substs StaticVal
sv
  where
    mkSubsts :: [VName] -> f (Map VName SizeSubst)
mkSubsts [VName]
names =
      [(VName, SizeSubst)] -> Map VName SizeSubst
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SizeSubst)] -> Map VName SizeSubst)
-> ([VName] -> [(VName, SizeSubst)])
-> [VName]
-> Map VName SizeSubst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> [SizeSubst] -> [(VName, SizeSubst)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names ([SizeSubst] -> [(VName, SizeSubst)])
-> ([VName] -> [SizeSubst]) -> [VName] -> [(VName, SizeSubst)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SizeSubst) -> [VName] -> [SizeSubst]
forall a b. (a -> b) -> [a] -> [b]
map (QualName VName -> SizeSubst
SubstNamed (QualName VName -> SizeSubst)
-> (VName -> QualName VName) -> VName -> SizeSubst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> QualName VName
forall v. v -> QualName v
qualName)
        ([VName] -> Map VName SizeSubst)
-> f [VName] -> f (Map VName SizeSubst)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> f VName) -> [VName] -> f [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> f VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName [VName]
names

    onName :: Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs VName
v =
      case VName -> Map VName SizeSubst -> Maybe SizeSubst
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v'
        Maybe SizeSubst
_ -> VName
v

defuncFun ::
  [VName] ->
  [Pat] ->
  Exp ->
  StructRetType ->
  SrcLoc ->
  DefM (Exp, StaticVal)
defuncFun :: [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [VName]
tparams [Pat]
pats Exp
e0 RetTypeBase (DimDecl VName) ()
ret SrcLoc
loc = do
  -- Extract the first parameter of the lambda and "push" the
  -- remaining ones (if there are any) into the body of the lambda.
  let (Pat
pat, RetTypeBase (DimDecl VName) ()
ret', ExtExp
e0') = case [Pat]
pats of
        [] -> String -> (Pat, RetTypeBase (DimDecl VName) (), ExtExp)
forall a. HasCallStack => String -> a
error String
"Received a lambda with no parameters."
        [Pat
pat'] -> (Pat
pat', RetTypeBase (DimDecl VName) ()
ret, Exp -> ExtExp
ExtExp Exp
e0)
        (Pat
pat' : [Pat]
pats') ->
          ( Pat
pat',
            [VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (StructType -> RetTypeBase (DimDecl VName) ())
-> StructType -> RetTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ [StructType] -> RetTypeBase (DimDecl VName) () -> StructType
forall as dim.
Monoid as =>
[TypeBase dim as] -> RetTypeBase dim as -> TypeBase dim as
foldFunType ((Pat -> StructType) -> [Pat] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatType -> StructType) -> (Pat -> PatType) -> Pat -> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat -> PatType
patternType) [Pat]
pats') RetTypeBase (DimDecl VName) ()
ret,
            [Pat] -> Exp -> RetTypeBase (DimDecl VName) () -> SrcLoc -> ExtExp
ExtLambda [Pat]
pats' Exp
e0 RetTypeBase (DimDecl VName) ()
ret SrcLoc
loc
          )

  -- Construct a record literal that closes over the environment of
  -- the lambda.  Closed-over 'DynamicFun's are converted to their
  -- closure representation.
  let used :: NameSet
used =
        Exp -> NameSet
FV.freeVars ([Pat]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
e0 Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, RetTypeBase (DimDecl VName) ())
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
forall a. a -> Info a
Info (Aliasing
forall a. Monoid a => a
mempty, RetTypeBase (DimDecl VName) ()
ret)) SrcLoc
loc)
          NameSet -> Set VName -> NameSet
`FV.without` [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
tparams
  Env
used_env <- NameSet -> DefM Env
restrictEnvTo NameSet
used

  -- The closure parts that are sizes are proactively turned into size
  -- parameters.
  let sizes_of_arrays :: Set VName
sizes_of_arrays =
        (Binding -> Set VName) -> Env -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StructType -> Set VName
arraySizes (StructType -> Set VName)
-> (Binding -> StructType) -> Binding -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatType -> StructType)
-> (Binding -> PatType) -> Binding -> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StaticVal -> PatType
typeFromSV (StaticVal -> PatType)
-> (Binding -> StaticVal) -> Binding -> PatType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> StaticVal
bindingSV) Env
used_env
          Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Pat -> Set VName
patternArraySizes Pat
pat
      notSize :: VName -> Bool
notSize = Bool -> Bool
not (Bool -> Bool) -> (VName -> Bool) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes_of_arrays)
      ([FieldBase Info VName]
fields, Env
env) =
        ([(VName, Binding)] -> Env)
-> ([FieldBase Info VName], [(VName, Binding)])
-> ([FieldBase Info VName], Env)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [(VName, Binding)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (([FieldBase Info VName], [(VName, Binding)])
 -> ([FieldBase Info VName], Env))
-> ([(VName, Binding)]
    -> ([FieldBase Info VName], [(VName, Binding)]))
-> [(VName, Binding)]
-> ([FieldBase Info VName], Env)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FieldBase Info VName, (VName, Binding))]
-> ([FieldBase Info VName], [(VName, Binding)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(FieldBase Info VName, (VName, Binding))]
 -> ([FieldBase Info VName], [(VName, Binding)]))
-> ([(VName, Binding)]
    -> [(FieldBase Info VName, (VName, Binding))])
-> [(VName, Binding)]
-> ([FieldBase Info VName], [(VName, Binding)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Binding) -> (FieldBase Info VName, (VName, Binding)))
-> [(VName, Binding)] -> [(FieldBase Info VName, (VName, Binding))]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Binding) -> (FieldBase Info VName, (VName, Binding))
closureFromDynamicFun
          ([(VName, Binding)] -> [(FieldBase Info VName, (VName, Binding))])
-> ([(VName, Binding)] -> [(VName, Binding)])
-> [(VName, Binding)]
-> [(FieldBase Info VName, (VName, Binding))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, Binding) -> Bool)
-> [(VName, Binding)] -> [(VName, Binding)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Bool
notSize (VName -> Bool)
-> ((VName, Binding) -> VName) -> (VName, Binding) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Binding) -> VName
forall a b. (a, b) -> a
fst)
          ([(VName, Binding)] -> ([FieldBase Info VName], Env))
-> [(VName, Binding)] -> ([FieldBase Info VName], Env)
forall a b. (a -> b) -> a -> b
$ Env -> [(VName, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList Env
used_env

  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [FieldBase Info VName] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fields SrcLoc
loc,
      Pat -> RetTypeBase (DimDecl VName) () -> ExtExp -> Env -> StaticVal
LambdaSV Pat
pat RetTypeBase (DimDecl VName) ()
ret' ExtExp
e0' Env
env
    )
  where
    closureFromDynamicFun :: (VName, Binding) -> (FieldBase Info VName, (VName, Binding))
closureFromDynamicFun (VName
vn, Binding Maybe ([VName], StructType)
_ (DynamicFun (Exp
clsr_env, StaticVal
sv) StaticVal
_)) =
      let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
vn
       in ( Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
name Exp
clsr_env SrcLoc
forall a. Monoid a => a
mempty,
            (VName
vn, Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing StaticVal
sv)
          )
    closureFromDynamicFun (VName
vn, Binding Maybe ([VName], StructType)
_ StaticVal
sv) =
      let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
vn
          tp' :: PatType
tp' = StaticVal -> PatType
typeFromSV StaticVal
sv
       in ( Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit
              Name
name
              (QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
vn) (PatType -> Info PatType
forall a. a -> Info a
Info PatType
tp') SrcLoc
forall a. Monoid a => a
mempty)
              SrcLoc
forall a. Monoid a => a
mempty,
            (VName
vn, Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing StaticVal
sv)
          )

-- | Defunctionalization of an expression. Returns the residual expression and
-- the associated static value in the defunctionalization monad.
defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp e :: Exp
e@Literal {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp e :: Exp
e@IntLit {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp e :: Exp
e@FloatLit {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp e :: Exp
e@StringLit {} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp (Parens Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens Exp
e' SrcLoc
loc, StaticVal
sv)
defuncExp (QualParens (QualName VName, SrcLoc)
qn Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ((QualName VName, SrcLoc) -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
(QualName vn, SrcLoc) -> ExpBase f vn -> SrcLoc -> ExpBase f vn
QualParens (QualName VName, SrcLoc)
qn Exp
e' SrcLoc
loc, StaticVal
sv)
defuncExp (TupLit [Exp]
es SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- [(Exp, StaticVal)] -> ([Exp], [StaticVal])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Exp, StaticVal)] -> ([Exp], [StaticVal]))
-> DefM [(Exp, StaticVal)] -> DefM ([Exp], [StaticVal])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> DefM (Exp, StaticVal)) -> [Exp] -> DefM [(Exp, StaticVal)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Exp] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit [Exp]
es' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ [Name] -> [StaticVal] -> [(Name, StaticVal)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
tupleFieldNames [StaticVal]
svs)
defuncExp (RecordLit [FieldBase Info VName]
fs SrcLoc
loc) = do
  ([FieldBase Info VName]
fs', [(Name, StaticVal)]
names_svs) <- [(FieldBase Info VName, (Name, StaticVal))]
-> ([FieldBase Info VName], [(Name, StaticVal)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(FieldBase Info VName, (Name, StaticVal))]
 -> ([FieldBase Info VName], [(Name, StaticVal)]))
-> DefM [(FieldBase Info VName, (Name, StaticVal))]
-> DefM ([FieldBase Info VName], [(Name, StaticVal)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FieldBase Info VName
 -> DefM (FieldBase Info VName, (Name, StaticVal)))
-> [FieldBase Info VName]
-> DefM [(FieldBase Info VName, (Name, StaticVal))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField [FieldBase Info VName]
fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([FieldBase Info VName] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fs' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV [(Name, StaticVal)]
names_svs)
  where
    defuncField :: FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField (RecordFieldExplicit Name
vn Exp
e SrcLoc
loc') = do
      (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
      (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn Exp
e' SrcLoc
loc', (Name
vn, StaticVal
sv))
    defuncField (RecordFieldImplicit VName
vn (Info PatType
t) SrcLoc
loc') = do
      StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t) VName
vn
      case StaticVal
sv of
        -- If the implicit field refers to a dynamic function, we
        -- convert it to an explicit field with a record closing over
        -- the environment and bind the corresponding static value.
        DynamicFun (Exp
e, StaticVal
sv') StaticVal
_ ->
          let vn' :: Name
vn' = VName -> Name
baseName VName
vn
           in (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall (m :: * -> *) a. Monad m => a -> m a
return
                ( Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn' Exp
e SrcLoc
loc',
                  (Name
vn', StaticVal
sv')
                )
        -- The field may refer to a functional expression, so we get the
        -- type from the static value and not the one from the AST.
        StaticVal
_ ->
          let tp :: Info PatType
tp = PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv
           in (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Info PatType -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
vn -> f PatType -> SrcLoc -> FieldBase f vn
RecordFieldImplicit VName
vn Info PatType
tp SrcLoc
loc', (VName -> Name
baseName VName
vn, StaticVal
sv))
defuncExp (ArrayLit [Exp]
es t :: Info PatType
t@(Info PatType
t') SrcLoc
loc) = do
  [Exp]
es' <- (Exp -> DefM Exp) -> [Exp] -> DefM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' [Exp]
es
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Exp] -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
[ExpBase f vn] -> f PatType -> SrcLoc -> ExpBase f vn
ArrayLit [Exp]
es' Info PatType
t SrcLoc
loc, PatType -> StaticVal
Dynamic PatType
t')
defuncExp (AppExp (Range Exp
e1 Maybe Exp
me Inclusiveness Exp
incl SrcLoc
loc) Info AppRes
res) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  Maybe Exp
me' <- (Exp -> DefM Exp) -> Maybe Exp -> DefM (Maybe Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' Maybe Exp
me
  Inclusiveness Exp
incl' <- (Exp -> DefM Exp) -> Inclusiveness Exp -> DefM (Inclusiveness Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' Inclusiveness Exp
incl
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> Maybe Exp
-> Inclusiveness Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> Maybe (ExpBase f vn)
-> Inclusiveness (ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Range Exp
e1' Maybe Exp
me' Inclusiveness Exp
incl' SrcLoc
loc) Info AppRes
res,
      PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ AppRes -> PatType
appResType (AppRes -> PatType) -> AppRes -> PatType
forall a b. (a -> b) -> a -> b
$ Info AppRes -> AppRes
forall a. Info a -> a
unInfo Info AppRes
res
    )
defuncExp e :: Exp
e@(Var QualName VName
qn (Info PatType
t) SrcLoc
loc) = do
  StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t) (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
  case StaticVal
sv of
    -- If the variable refers to a dynamic function, we return its closure
    -- representation (i.e., a record expression capturing the free variables
    -- and a 'LambdaSV' static value) instead of the variable itself.
    DynamicFun (Exp, StaticVal)
closure StaticVal
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp, StaticVal)
closure
    -- Intrinsic functions used as variables are eta-expanded, so we
    -- can get rid of them.
    StaticVal
IntrinsicSV -> do
      ([Pat]
pats, Exp
body, RetTypeBase (DimDecl VName) ()
tp) <- PatType -> Exp -> DefM ([Pat], Exp, RetTypeBase (DimDecl VName) ())
etaExpand (Exp -> PatType
typeOf Exp
e) Exp
e
      Exp -> DefM (Exp, StaticVal)
defuncExp (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Pat]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
body Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, RetTypeBase (DimDecl VName) ())
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
forall a. a -> Info a
Info (Aliasing
forall a. Monoid a => a
mempty, RetTypeBase (DimDecl VName) ()
tp)) SrcLoc
forall a. Monoid a => a
mempty
    StaticVal
_ ->
      let tp :: PatType
tp = StaticVal -> PatType
typeFromSV StaticVal
sv
       in (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (PatType -> Info PatType
forall a. a -> Info a
Info PatType
tp) SrcLoc
loc, StaticVal
sv)
defuncExp (Ascript Exp
e0 TypeDeclBase Info VName
tydecl SrcLoc
loc)
  | PatType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> PatType
typeOf Exp
e0) = do
    (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
    (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TypeDeclBase Info VName -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> TypeDeclBase f vn -> SrcLoc -> ExpBase f vn
Ascript Exp
e0' TypeDeclBase Info VName
tydecl SrcLoc
loc, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
defuncExp (AppExp (Coerce Exp
e0 TypeDeclBase Info VName
tydecl SrcLoc
loc) Info AppRes
res)
  | PatType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> PatType
typeOf Exp
e0) = do
    (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
    (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp -> TypeDeclBase Info VName -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn -> TypeDeclBase f vn -> SrcLoc -> AppExpBase f vn
Coerce Exp
e0' TypeDeclBase Info VName
tydecl SrcLoc
loc) Info AppRes
res, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
defuncExp (AppExp (LetPat [SizeBinder VName]
sizes Pat
pat Exp
e1 Exp
e2 SrcLoc
loc) (Info (AppRes PatType
t [VName]
retext))) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env :: Env
env = Pat -> StaticVal -> Env
matchPatSV Pat
pat StaticVal
sv1
      pat' :: Pat
pat' = Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv1
  (Exp
e2', StaticVal
sv2) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  -- To maintain any sizes going out of scope, we need to compute the
  -- old size substitution induced by retext and also apply it to the
  -- newly computed body type.
  let mapping :: Map VName VName
mapping = PatType -> PatType -> Map VName VName
forall a.
Monoid a =>
TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName VName
dimMapping' (Exp -> PatType
typeOf Exp
e2) PatType
t
      subst :: VName -> VName
subst VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
mapping
      t' :: PatType
t' = (DimDecl VName -> DimDecl VName) -> PatType -> PatType
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((VName -> VName) -> DimDecl VName -> DimDecl VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> VName
subst) (PatType -> PatType) -> PatType -> PatType
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e2'
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp ([SizeBinder VName]
-> Pat -> Exp -> Exp -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat [SizeBinder VName]
sizes Pat
pat' Exp
e1' Exp
e2' SrcLoc
loc) (AppRes -> Info AppRes
forall a. a -> Info a
Info (PatType -> [VName] -> AppRes
AppRes PatType
t' [VName]
retext)), StaticVal
sv2)
defuncExp (AppExp (LetFun VName
vn ([TypeParamBase VName], [Pat], Maybe (TypeExp VName),
 Info (RetTypeBase (DimDecl VName) ()), Exp)
_ Exp
_ SrcLoc
_) Info AppRes
_) =
  String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ String
"defuncExp: Unexpected LetFun: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall v. IsName v => v -> String
prettyName VName
vn
defuncExp (AppExp (If Exp
e1 Exp
e2 Exp
e3 SrcLoc
loc) Info AppRes
res) = do
  (Exp
e1', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp
e3', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp -> Exp -> Exp -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn -> ExpBase f vn -> SrcLoc -> AppExpBase f vn
If Exp
e1' Exp
e2' Exp
e3' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
defuncExp e :: Exp
e@(AppExp (Apply f :: Exp
f@(Var QualName VName
f' Info PatType
_ SrcLoc
_) Exp
arg Info (Diet, Maybe VName)
d SrcLoc
loc) Info AppRes
res)
  | VName -> Int
baseTag (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
f') Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag,
    TupLit [Exp]
es SrcLoc
tuploc <- Exp
arg = do
    -- defuncSoacExp also works fine for non-SOACs.
    [Exp]
es' <- (Exp -> DefM Exp) -> [Exp] -> DefM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncSoacExp [Exp]
es
    (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
f ([Exp] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit [Exp]
es' SrcLoc
tuploc) Info (Diet, Maybe VName)
d SrcLoc
loc) Info AppRes
res,
        PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e
      )
defuncExp e :: Exp
e@(AppExp Apply {} Info AppRes
_) = Int -> Exp -> DefM (Exp, StaticVal)
defuncApply Int
0 Exp
e
defuncExp (Negate Exp
e0 SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Negate Exp
e0' SrcLoc
loc, StaticVal
sv)
defuncExp (Not Exp
e0 SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Not Exp
e0' SrcLoc
loc, StaticVal
sv)
defuncExp (Lambda [Pat]
pats Exp
e0 Maybe (TypeExp VName)
_ (Info (Aliasing
_, RetTypeBase (DimDecl VName) ()
ret)) SrcLoc
loc) =
  [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [] [Pat]
pats Exp
e0 RetTypeBase (DimDecl VName) ()
ret SrcLoc
loc
-- Operator sections are expected to be converted to lambda-expressions
-- by the monomorphizer, so they should no longer occur at this point.
defuncExp OpSection {} = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected operator section."
defuncExp OpSectionLeft {} = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected operator section."
defuncExp OpSectionRight {} = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected operator section."
defuncExp ProjectSection {} = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected projection section."
defuncExp IndexSection {} = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected projection section."
defuncExp (AppExp (DoLoop [VName]
sparams Pat
pat Exp
e1 LoopFormBase Info VName
form Exp
e3 SrcLoc
loc) Info AppRes
res) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env1 :: Env
env1 = Pat -> StaticVal -> Env
matchPatSV Pat
pat StaticVal
sv1
  (LoopFormBase Info VName
form', Env
env2) <- case LoopFormBase Info VName
form of
    For IdentBase Info VName
v Exp
e2 -> do
      Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
      (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall (m :: * -> *) a. Monad m => a -> m a
return (IdentBase Info VName -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
IdentBase f vn -> ExpBase f vn -> LoopFormBase f vn
For IdentBase Info VName
v Exp
e2', IdentBase Info VName -> Env
forall k. IdentBase Info k -> Map k Binding
envFromIdent IdentBase Info VName
v)
    ForIn Pat
pat2 Exp
e2 -> do
      Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
      (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall (m :: * -> *) a. Monad m => a -> m a
return (Pat -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
PatBase f vn -> ExpBase f vn -> LoopFormBase f vn
ForIn Pat
pat2 Exp
e2', Pat -> Env
envFromPat Pat
pat2)
    While Exp
e2 -> do
      Exp
e2' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env1 (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
e2
      (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While Exp
e2', Env
forall a. Monoid a => a
mempty)
  (Exp
e3', StaticVal
sv) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv (Env
env1 Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
env2) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp ([VName]
-> Pat
-> Exp
-> LoopFormBase Info VName
-> Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
[VName]
-> PatBase f vn
-> ExpBase f vn
-> LoopFormBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
DoLoop [VName]
sparams Pat
pat Exp
e1' LoopFormBase Info VName
form' Exp
e3' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
  where
    envFromIdent :: IdentBase Info k -> Map k Binding
envFromIdent (Ident k
vn (Info PatType
tp) SrcLoc
_) =
      k -> Binding -> Map k Binding
forall k a. k -> a -> Map k a
M.singleton k
vn (Binding -> Map k Binding) -> Binding -> Map k Binding
forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic PatType
tp
defuncExp e :: Exp
e@(AppExp BinOp {} Info AppRes
_) =
  String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ String
"defuncExp: unexpected binary operator: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
pretty Exp
e
defuncExp (Project Name
vn Exp
e0 tp :: Info PatType
tp@(Info PatType
tp') SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv0) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  case StaticVal
sv0 of
    RecordSV [(Name, StaticVal)]
svs -> case Name -> [(Name, StaticVal)] -> Maybe StaticVal
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
vn [(Name, StaticVal)]
svs of
      Just StaticVal
sv -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Exp -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) SrcLoc
loc, StaticVal
sv)
      Maybe StaticVal
Nothing -> String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"Invalid record projection."
    Dynamic PatType
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Exp -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' Info PatType
tp SrcLoc
loc, PatType -> StaticVal
Dynamic PatType
tp')
    StaticVal
_ -> String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ String
"Projection of an expression with static value " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv0
defuncExp (AppExp (LetWith IdentBase Info VName
id1 IdentBase Info VName
id2 SliceBase Info VName
idxs Exp
e1 Exp
body SrcLoc
loc) Info AppRes
res) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  SliceBase Info VName
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> SliceBase Info VName -> DefM (SliceBase Info VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  let id1_binding :: Binding
id1_binding = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Info PatType -> PatType
forall a. Info a -> a
unInfo (Info PatType -> PatType) -> Info PatType -> PatType
forall a b. (a -> b) -> a -> b
$ IdentBase Info VName -> Info PatType
forall (f :: * -> *) vn. IdentBase f vn -> f PatType
identType IdentBase Info VName
id1
  (Exp
body', StaticVal
sv) <-
    Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv (VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton (IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName IdentBase Info VName
id1) Binding
id1_binding) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
      Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (IdentBase Info VName
-> IdentBase Info VName
-> SliceBase Info VName
-> Exp
-> Exp
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
IdentBase f vn
-> IdentBase f vn
-> SliceBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetWith IdentBase Info VName
id1 IdentBase Info VName
id2 SliceBase Info VName
idxs' Exp
e1' Exp
body' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
defuncExp expr :: Exp
expr@(AppExp (Index Exp
e0 SliceBase Info VName
idxs SrcLoc
loc) Info AppRes
res) = do
  Exp
e0' <- Exp -> DefM Exp
defuncExp' Exp
e0
  SliceBase Info VName
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> SliceBase Info VName -> DefM (SliceBase Info VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp -> SliceBase Info VName -> SrcLoc -> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn -> SliceBase f vn -> SrcLoc -> AppExpBase f vn
Index Exp
e0' SliceBase Info VName
idxs' SrcLoc
loc) Info AppRes
res,
      PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
expr
    )
defuncExp (Update Exp
e1 SliceBase Info VName
idxs Exp
e2 SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  SliceBase Info VName
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> SliceBase Info VName -> DefM (SliceBase Info VName)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SliceBase Info VName -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> SliceBase f vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Update Exp
e1' SliceBase Info VName
idxs' Exp
e2' SrcLoc
loc, StaticVal
sv)

-- Note that we might change the type of the record field here.  This
-- is not permitted in the type checker due to problems with type
-- inference, but it actually works fine.
defuncExp (RecordUpdate Exp
e1 [Name]
fs Exp
e2 Info PatType
_ SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv2) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  let sv :: StaticVal
sv = StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv1 StaticVal
sv2 [Name]
fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Exp -> [Name] -> Exp -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [Name] -> ExpBase f vn -> f PatType -> SrcLoc -> ExpBase f vn
RecordUpdate Exp
e1' [Name]
fs Exp
e2' (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv1) SrcLoc
loc,
      StaticVal
sv
    )
  where
    staticField :: StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (RecordSV [(Name, StaticVal)]
svs) StaticVal
sv2 (Name
f : [Name]
fs') =
      case Name -> [(Name, StaticVal)] -> Maybe StaticVal
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
f [(Name, StaticVal)]
svs of
        Just StaticVal
sv ->
          [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$
            (Name
f, StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv StaticVal
sv2 [Name]
fs') (Name, StaticVal) -> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a. a -> [a] -> [a]
: ((Name, StaticVal) -> Bool)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= Name
f) (Name -> Bool)
-> ((Name, StaticVal) -> Name) -> (Name, StaticVal) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst) [(Name, StaticVal)]
svs
        Maybe StaticVal
Nothing -> String -> StaticVal
forall a. HasCallStack => String -> a
error String
"Invalid record projection."
    staticField (Dynamic t :: PatType
t@(Scalar Record {})) StaticVal
sv2 fs' :: [Name]
fs'@(Name
_ : [Name]
_) =
      StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (PatType -> StaticVal
svFromType PatType
t) StaticVal
sv2 [Name]
fs'
    staticField StaticVal
_ StaticVal
sv2 [Name]
_ = StaticVal
sv2
defuncExp (Assert Exp
e1 Exp
e2 Info String
desc SrcLoc
loc) = do
  (Exp
e1', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp -> Info String -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> ExpBase f vn -> f String -> SrcLoc -> ExpBase f vn
Assert Exp
e1' Exp
e2' Info String
desc SrcLoc
loc, StaticVal
sv)
defuncExp (Constr Name
name [Exp]
es (Info (Scalar (Sum Map Name [PatType]
all_fs))) SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- [(Exp, StaticVal)] -> ([Exp], [StaticVal])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Exp, StaticVal)] -> ([Exp], [StaticVal]))
-> DefM [(Exp, StaticVal)] -> DefM ([Exp], [StaticVal])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> DefM (Exp, StaticVal)) -> [Exp] -> DefM [(Exp, StaticVal)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  let sv :: StaticVal
sv =
        Name -> [StaticVal] -> [(Name, [PatType])] -> StaticVal
SumSV Name
name [StaticVal]
svs ([(Name, [PatType])] -> StaticVal)
-> [(Name, [PatType])] -> StaticVal
forall a b. (a -> b) -> a -> b
$
          Map Name [PatType] -> [(Name, [PatType])]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name [PatType] -> [(Name, [PatType])])
-> Map Name [PatType] -> [(Name, [PatType])]
forall a b. (a -> b) -> a -> b
$
            Name
name Name -> Map Name [PatType] -> Map Name [PatType]
forall k a. Ord k => k -> Map k a -> Map k a
`M.delete` ([PatType] -> [PatType])
-> Map Name [PatType] -> Map Name [PatType]
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((PatType -> PatType) -> [PatType] -> [PatType]
forall a b. (a -> b) -> [a] -> [b]
map PatType -> PatType
forall als.
Monoid als =>
TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType) Map Name [PatType]
all_fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> [Exp] -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> [ExpBase f vn] -> f PatType -> SrcLoc -> ExpBase f vn
Constr Name
name [Exp]
es' (PatType -> Info PatType
forall a. a -> Info a
Info (StaticVal -> PatType
typeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)
  where
    defuncType ::
      Monoid als =>
      TypeBase (DimDecl VName) als ->
      TypeBase (DimDecl VName) als
    defuncType :: TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType (Array als
as Uniqueness
u ScalarTypeBase (DimDecl VName) ()
t ShapeDecl (DimDecl VName)
shape) = als
-> Uniqueness
-> ScalarTypeBase (DimDecl VName) ()
-> ShapeDecl (DimDecl VName)
-> TypeBase (DimDecl VName) als
forall dim as.
as
-> Uniqueness
-> ScalarTypeBase dim ()
-> ShapeDecl dim
-> TypeBase dim as
Array als
as Uniqueness
u (ScalarTypeBase (DimDecl VName) ()
-> ScalarTypeBase (DimDecl VName) ()
forall als.
Monoid als =>
ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
defuncScalar ScalarTypeBase (DimDecl VName) ()
t) ShapeDecl (DimDecl VName)
shape
    defuncType (Scalar ScalarTypeBase (DimDecl VName) als
t) = ScalarTypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) als
 -> TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
-> TypeBase (DimDecl VName) als
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
forall als.
Monoid als =>
ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
defuncScalar ScalarTypeBase (DimDecl VName) als
t

    defuncScalar ::
      Monoid als =>
      ScalarTypeBase (DimDecl VName) als ->
      ScalarTypeBase (DimDecl VName) als
    defuncScalar :: ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
defuncScalar (Record Map Name (TypeBase (DimDecl VName) als)
fs) = Map Name (TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase (DimDecl VName) als)
 -> ScalarTypeBase (DimDecl VName) als)
-> Map Name (TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
forall a b. (a -> b) -> a -> b
$ (TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als)
-> Map Name (TypeBase (DimDecl VName) als)
-> Map Name (TypeBase (DimDecl VName) als)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
forall als.
Monoid als =>
TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType Map Name (TypeBase (DimDecl VName) als)
fs
    defuncScalar Arrow {} = Map Name (TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name (TypeBase (DimDecl VName) als)
forall a. Monoid a => a
mempty
    defuncScalar (Sum Map Name [TypeBase (DimDecl VName) als]
fs) = Map Name [TypeBase (DimDecl VName) als]
-> ScalarTypeBase (DimDecl VName) als
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [TypeBase (DimDecl VName) als]
 -> ScalarTypeBase (DimDecl VName) als)
-> Map Name [TypeBase (DimDecl VName) als]
-> ScalarTypeBase (DimDecl VName) als
forall a b. (a -> b) -> a -> b
$ ([TypeBase (DimDecl VName) als] -> [TypeBase (DimDecl VName) als])
-> Map Name [TypeBase (DimDecl VName) als]
-> Map Name [TypeBase (DimDecl VName) als]
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als)
-> [TypeBase (DimDecl VName) als] -> [TypeBase (DimDecl VName) als]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
forall als.
Monoid als =>
TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType) Map Name [TypeBase (DimDecl VName) als]
fs
    defuncScalar (Prim PrimType
t) = PrimType -> ScalarTypeBase (DimDecl VName) als
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
t
    defuncScalar (TypeVar als
as Uniqueness
u TypeName
tn [TypeArg (DimDecl VName)]
targs) = als
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) als
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar als
as Uniqueness
u TypeName
tn [TypeArg (DimDecl VName)]
targs
defuncExp (Constr Name
name [Exp]
_ (Info PatType
t) SrcLoc
loc) =
  String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
    String
"Constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
pretty Name
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" given type "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ PatType -> String
forall a. Pretty a => a -> String
pretty PatType
t
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" at "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc
defuncExp (AppExp (Match Exp
e NonEmpty (CaseBase Info VName)
cs SrcLoc
loc) Info AppRes
res) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  NonEmpty (CaseBase Info VName, StaticVal)
csPairs <- (CaseBase Info VName -> DefM (CaseBase Info VName, StaticVal))
-> NonEmpty (CaseBase Info VName)
-> DefM (NonEmpty (CaseBase Info VName, StaticVal))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (StaticVal
-> CaseBase Info VName -> DefM (CaseBase Info VName, StaticVal)
defuncCase StaticVal
sv) NonEmpty (CaseBase Info VName)
cs
  let cs' :: NonEmpty (CaseBase Info VName)
cs' = ((CaseBase Info VName, StaticVal) -> CaseBase Info VName)
-> NonEmpty (CaseBase Info VName, StaticVal)
-> NonEmpty (CaseBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CaseBase Info VName, StaticVal) -> CaseBase Info VName
forall a b. (a, b) -> a
fst NonEmpty (CaseBase Info VName, StaticVal)
csPairs
      sv' :: StaticVal
sv' = (CaseBase Info VName, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd ((CaseBase Info VName, StaticVal) -> StaticVal)
-> (CaseBase Info VName, StaticVal) -> StaticVal
forall a b. (a -> b) -> a -> b
$ NonEmpty (CaseBase Info VName, StaticVal)
-> (CaseBase Info VName, StaticVal)
forall a. NonEmpty a -> a
NE.head NonEmpty (CaseBase Info VName, StaticVal)
csPairs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> NonEmpty (CaseBase Info VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn) -> SrcLoc -> AppExpBase f vn
Match Exp
e' NonEmpty (CaseBase Info VName)
cs' SrcLoc
loc) Info AppRes
res, StaticVal
sv')
defuncExp (Attr AttrInfo VName
info Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AttrInfo VName -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
AttrInfo vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Attr AttrInfo VName
info Exp
e' SrcLoc
loc, StaticVal
sv)

-- | Same as 'defuncExp', except it ignores the static value.
defuncExp' :: Exp -> DefM Exp
defuncExp' :: Exp -> DefM Exp
defuncExp' = ((Exp, StaticVal) -> Exp) -> DefM (Exp, StaticVal) -> DefM Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Exp, StaticVal) -> Exp
forall a b. (a, b) -> a
fst (DefM (Exp, StaticVal) -> DefM Exp)
-> (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> DefM (Exp, StaticVal)
defuncExp

defuncExtExp :: ExtExp -> DefM (Exp, StaticVal)
defuncExtExp :: ExtExp -> DefM (Exp, StaticVal)
defuncExtExp (ExtExp Exp
e) = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
defuncExtExp (ExtLambda [Pat]
pats Exp
e0 RetTypeBase (DimDecl VName) ()
ret SrcLoc
loc) =
  [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [] [Pat]
pats Exp
e0 RetTypeBase (DimDecl VName) ()
ret SrcLoc
loc

defuncCase :: StaticVal -> Case -> DefM (Case, StaticVal)
defuncCase :: StaticVal
-> CaseBase Info VName -> DefM (CaseBase Info VName, StaticVal)
defuncCase StaticVal
sv (CasePat Pat
p Exp
e SrcLoc
loc) = do
  let p' :: Pat
p' = Pat -> StaticVal -> Pat
updatePat Pat
p StaticVal
sv
      env :: Env
env = Pat -> StaticVal -> Env
matchPatSV Pat
p StaticVal
sv
  (Exp
e', StaticVal
sv') <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (CaseBase Info VName, StaticVal)
-> DefM (CaseBase Info VName, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Pat -> Exp -> SrcLoc -> CaseBase Info VName
forall (f :: * -> *) vn.
PatBase f vn -> ExpBase f vn -> SrcLoc -> CaseBase f vn
CasePat Pat
p' Exp
e' SrcLoc
loc, StaticVal
sv')

-- | Defunctionalize the function argument to a SOAC by eta-expanding if
-- necessary and then defunctionalizing the body of the introduced lambda.
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp e :: Exp
e@OpSection {} = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
defuncSoacExp e :: Exp
e@OpSectionLeft {} = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
defuncSoacExp e :: Exp
e@OpSectionRight {} = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
defuncSoacExp e :: Exp
e@ProjectSection {} = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
defuncSoacExp (Parens Exp
e SrcLoc
loc) =
  Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens (Exp -> SrcLoc -> Exp) -> DefM Exp -> DefM (SrcLoc -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM Exp
defuncSoacExp Exp
e DefM (SrcLoc -> Exp) -> DefM SrcLoc -> DefM Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> DefM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
defuncSoacExp (Lambda [Pat]
params Exp
e0 Maybe (TypeExp VName)
decl Info (Aliasing, RetTypeBase (DimDecl VName) ())
tp SrcLoc
loc) = do
  let env :: Env
env = (Pat -> Env) -> [Pat] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
params
  Exp
e0' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncSoacExp Exp
e0
  Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> DefM Exp) -> Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ [Pat]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
params Exp
e0' Maybe (TypeExp VName)
decl Info (Aliasing, RetTypeBase (DimDecl VName) ())
tp SrcLoc
loc
defuncSoacExp Exp
e
  | Scalar Arrow {} <- Exp -> PatType
typeOf Exp
e = do
    ([Pat]
pats, Exp
body, RetTypeBase (DimDecl VName) ()
tp) <- PatType -> Exp -> DefM ([Pat], Exp, RetTypeBase (DimDecl VName) ())
etaExpand (Exp -> PatType
typeOf Exp
e) Exp
e
    let env :: Env
env = (Pat -> Env) -> [Pat] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
pats
    Exp
body' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
body
    Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> DefM Exp) -> Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ [Pat]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
body' Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, RetTypeBase (DimDecl VName) ())
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
forall a. a -> Info a
Info (Aliasing
forall a. Monoid a => a
mempty, RetTypeBase (DimDecl VName) ()
tp)) SrcLoc
forall a. Monoid a => a
mempty
  | Bool
otherwise = Exp -> DefM Exp
defuncExp' Exp
e

etaExpand :: PatType -> Exp -> DefM ([Pat], Exp, StructRetType)
etaExpand :: PatType -> Exp -> DefM ([Pat], Exp, RetTypeBase (DimDecl VName) ())
etaExpand PatType
e_t Exp
e = do
  let ([(PName, PatType)]
ps, RetTypeBase (DimDecl VName) Aliasing
ret) = RetTypeBase (DimDecl VName) Aliasing
-> ([(PName, PatType)], RetTypeBase (DimDecl VName) Aliasing)
forall dim as.
RetTypeBase dim as
-> ([(PName, TypeBase dim as)], RetTypeBase dim as)
getType (RetTypeBase (DimDecl VName) Aliasing
 -> ([(PName, PatType)], RetTypeBase (DimDecl VName) Aliasing))
-> RetTypeBase (DimDecl VName) Aliasing
-> ([(PName, PatType)], RetTypeBase (DimDecl VName) Aliasing)
forall a b. (a -> b) -> a -> b
$ [VName] -> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
e_t
  ([Pat]
pats, [Exp]
vars) <- ([(Pat, Exp)] -> ([Pat], [Exp]))
-> DefM [(Pat, Exp)] -> DefM ([Pat], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Pat, Exp)] -> ([Pat], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip (DefM [(Pat, Exp)] -> DefM ([Pat], [Exp]))
-> (((PName, PatType) -> DefM (Pat, Exp)) -> DefM [(Pat, Exp)])
-> ((PName, PatType) -> DefM (Pat, Exp))
-> DefM ([Pat], [Exp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(PName, PatType)]
-> ((PName, PatType) -> DefM (Pat, Exp)) -> DefM [(Pat, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(PName, PatType)]
ps (((PName, PatType) -> DefM (Pat, Exp)) -> DefM ([Pat], [Exp]))
-> ((PName, PatType) -> DefM (Pat, Exp)) -> DefM ([Pat], [Exp])
forall a b. (a -> b) -> a -> b
$ \(PName
p, PatType
t) -> do
    VName
x <- case PName
p of
      Named VName
x -> VName -> DefM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
x
      PName
Unnamed -> String -> DefM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"x"
    (Pat, Exp) -> DefM (Pat, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( VName -> Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
x (PatType -> Info PatType
forall a. a -> Info a
Info PatType
t) SrcLoc
forall a. Monoid a => a
mempty,
        QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
x) (PatType -> Info PatType
forall a. a -> Info a
Info PatType
t) SrcLoc
forall a. Monoid a => a
mempty
      )
  let e' :: Exp
e' =
        (Exp -> (Exp, PatType, [PatType]) -> Exp)
-> Exp -> [(Exp, PatType, [PatType])] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
          ( \Exp
e1 (Exp
e2, PatType
t2, [PatType]
argtypes) ->
              AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
                (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
e1 Exp
e2 ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (PatType -> Diet
forall shape as. TypeBase shape as -> Diet
diet PatType
t2, Maybe VName
forall a. Maybe a
Nothing)) SrcLoc
forall a. Monoid a => a
mempty)
                (AppRes -> Info AppRes
forall a. a -> Info a
Info (PatType -> [VName] -> AppRes
AppRes ([PatType] -> RetTypeBase (DimDecl VName) Aliasing -> PatType
forall as dim.
Monoid as =>
[TypeBase dim as] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [PatType]
argtypes RetTypeBase (DimDecl VName) Aliasing
ret) []))
          )
          Exp
e
          ([(Exp, PatType, [PatType])] -> Exp)
-> [(Exp, PatType, [PatType])] -> Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> [PatType] -> [[PatType]] -> [(Exp, PatType, [PatType])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Exp]
vars (((PName, PatType) -> PatType) -> [(PName, PatType)] -> [PatType]
forall a b. (a -> b) -> [a] -> [b]
map (PName, PatType) -> PatType
forall a b. (a, b) -> b
snd [(PName, PatType)]
ps) (Int -> [[PatType]] -> [[PatType]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[PatType]] -> [[PatType]]) -> [[PatType]] -> [[PatType]]
forall a b. (a -> b) -> a -> b
$ [PatType] -> [[PatType]]
forall a. [a] -> [[a]]
tails ([PatType] -> [[PatType]]) -> [PatType] -> [[PatType]]
forall a b. (a -> b) -> a -> b
$ ((PName, PatType) -> PatType) -> [(PName, PatType)] -> [PatType]
forall a b. (a -> b) -> [a] -> [b]
map (PName, PatType) -> PatType
forall a b. (a, b) -> b
snd [(PName, PatType)]
ps)
  ([Pat], Exp, RetTypeBase (DimDecl VName) ())
-> DefM ([Pat], Exp, RetTypeBase (DimDecl VName) ())
forall (m :: * -> *) a. Monad m => a -> m a
return ([Pat]
pats, Exp
e', (Aliasing -> ())
-> RetTypeBase (DimDecl VName) Aliasing
-> RetTypeBase (DimDecl VName) ()
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (() -> Aliasing -> ()
forall a b. a -> b -> a
const ()) RetTypeBase (DimDecl VName) Aliasing
ret)
  where
    getType :: RetTypeBase dim as
-> ([(PName, TypeBase dim as)], RetTypeBase dim as)
getType (RetType [VName]
_ (Scalar (Arrow as
_ PName
p TypeBase dim as
t1 RetTypeBase dim as
t2))) =
      let ([(PName, TypeBase dim as)]
ps, RetTypeBase dim as
r) = RetTypeBase dim as
-> ([(PName, TypeBase dim as)], RetTypeBase dim as)
getType RetTypeBase dim as
t2 in ((PName
p, TypeBase dim as
t1) (PName, TypeBase dim as)
-> [(PName, TypeBase dim as)] -> [(PName, TypeBase dim as)]
forall a. a -> [a] -> [a]
: [(PName, TypeBase dim as)]
ps, RetTypeBase dim as
r)
    getType RetTypeBase dim as
t = ([], RetTypeBase dim as
t)

-- | Defunctionalize an indexing of a single array dimension.
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex (DimFix Exp
e1) = Exp -> DimIndexBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> DimIndexBase f vn
DimFix (Exp -> DimIndexBase Info VName)
-> ((Exp, StaticVal) -> Exp)
-> (Exp, StaticVal)
-> DimIndexBase Info VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp, StaticVal) -> Exp
forall a b. (a, b) -> a
fst ((Exp, StaticVal) -> DimIndexBase Info VName)
-> DefM (Exp, StaticVal) -> DefM (DimIndexBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
defuncDimIndex (DimSlice Maybe Exp
me1 Maybe Exp
me2 Maybe Exp
me3) =
  Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndexBase Info VName
forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice (Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp)
-> DefM (Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me1 DefM (Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp) -> DefM (Maybe Exp -> DimIndexBase Info VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me2 DefM (Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp) -> DefM (DimIndexBase Info VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me3
  where
    defunc' :: Maybe Exp -> DefM (Maybe Exp)
defunc' = (Exp -> DefM Exp) -> Maybe Exp -> DefM (Maybe Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp'

-- | Defunctionalize a let-bound function, while preserving parameters
-- that have order 0 types (i.e., non-functional).
defuncLet ::
  [VName] ->
  [Pat] ->
  Exp ->
  StructRetType ->
  DefM ([VName], [Pat], Exp, StaticVal)
defuncLet :: [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> DefM ([VName], [Pat], Exp, StaticVal)
defuncLet [VName]
dims ps :: [Pat]
ps@(Pat
pat : [Pat]
pats) Exp
body (RetType [VName]
ret_dims StructType
rettype)
  | Pat -> Bool
forall vn. PatBase Info vn -> Bool
patternOrderZero Pat
pat = do
    let bound_by_pat :: VName -> Bool
bound_by_pat = (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Pat -> Set VName
patternDimNames Pat
pat)
        -- Take care to not include more size parameters than necessary.
        ([VName]
pat_dims, [VName]
rest_dims) = (VName -> Bool) -> [VName] -> ([VName], [VName])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition VName -> Bool
bound_by_pat [VName]
dims
        env :: Env
env = Pat -> Env
envFromPat Pat
pat Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> [VName] -> Env
envFromDimNames [VName]
pat_dims
    ([VName]
rest_dims', [Pat]
pats', Exp
body', StaticVal
sv) <-
      Env
-> DefM ([VName], [Pat], Exp, StaticVal)
-> DefM ([VName], [Pat], Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM ([VName], [Pat], Exp, StaticVal)
 -> DefM ([VName], [Pat], Exp, StaticVal))
-> DefM ([VName], [Pat], Exp, StaticVal)
-> DefM ([VName], [Pat], Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> DefM ([VName], [Pat], Exp, StaticVal)
defuncLet [VName]
rest_dims [Pat]
pats Exp
body (RetTypeBase (DimDecl VName) ()
 -> DefM ([VName], [Pat], Exp, StaticVal))
-> RetTypeBase (DimDecl VName) ()
-> DefM ([VName], [Pat], Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype
    (Exp, StaticVal)
closure <- [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [VName]
dims [Pat]
ps Exp
body ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype) SrcLoc
forall a. Monoid a => a
mempty
    ([VName], [Pat], Exp, StaticVal)
-> DefM ([VName], [Pat], Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( [VName]
pat_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
rest_dims',
        Pat
pat Pat -> [Pat] -> [Pat]
forall a. a -> [a] -> [a]
: [Pat]
pats',
        Exp
body',
        (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
closure StaticVal
sv
      )
  | Bool
otherwise = do
    (Exp
e, StaticVal
sv) <- [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [VName]
dims [Pat]
ps Exp
body ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype) SrcLoc
forall a. Monoid a => a
mempty
    ([VName], [Pat], Exp, StaticVal)
-> DefM ([VName], [Pat], Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [], Exp
e, StaticVal
sv)
defuncLet [VName]
_ [] Exp
body (RetType [VName]
_ StructType
rettype) = do
  (Exp
body', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  ([VName], [Pat], Exp, StaticVal)
-> DefM ([VName], [Pat], Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [], Exp
body', StaticVal -> StructType -> StaticVal
forall as. StaticVal -> TypeBase (DimDecl VName) as -> StaticVal
imposeType StaticVal
sv StructType
rettype)
  where
    imposeType :: StaticVal -> TypeBase (DimDecl VName) as -> StaticVal
imposeType Dynamic {} TypeBase (DimDecl VName) as
t =
      PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ TypeBase (DimDecl VName) as -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase (DimDecl VName) as
t
    imposeType (RecordSV [(Name, StaticVal)]
fs1) (Scalar (Record Map Name (TypeBase (DimDecl VName) as)
fs2)) =
      [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ Map Name StaticVal -> [(Name, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name StaticVal -> [(Name, StaticVal)])
-> Map Name StaticVal -> [(Name, StaticVal)]
forall a b. (a -> b) -> a -> b
$ (StaticVal -> TypeBase (DimDecl VName) as -> StaticVal)
-> Map Name StaticVal
-> Map Name (TypeBase (DimDecl VName) as)
-> Map Name StaticVal
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith StaticVal -> TypeBase (DimDecl VName) as -> StaticVal
imposeType ([(Name, StaticVal)] -> Map Name StaticVal
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, StaticVal)]
fs1) Map Name (TypeBase (DimDecl VName) as)
fs2
    imposeType StaticVal
sv TypeBase (DimDecl VName) as
_ = StaticVal
sv

sizesForAll :: MonadFreshNames m => S.Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll :: Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
params = do
  ([Pat]
params', Set VName
sizes) <- StateT (Set VName) m [Pat] -> Set VName -> m ([Pat], Set VName)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((Pat -> StateT (Set VName) m Pat)
-> [Pat] -> StateT (Set VName) m [Pat]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ASTMapper (StateT (Set VName) m) -> Pat -> StateT (Set VName) m Pat
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper (StateT (Set VName) m)
tv) [Pat]
params) Set VName
forall a. Monoid a => a
mempty
  ([VName], [Pat]) -> m ([VName], [Pat])
forall (m :: * -> *) a. Monad m => a -> m a
return (Set VName -> [VName]
forall a. Set a -> [a]
S.toList Set VName
sizes, [Pat]
params')
  where
    bound :: Set VName
bound = Set VName
bound_sizes Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (Pat -> Set VName) -> [Pat] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set VName
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames [Pat]
params
    tv :: ASTMapper (StateT (Set VName) m)
tv = ASTMapper (StateT (Set VName) m)
forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnPatType :: PatType -> StateT (Set VName) m PatType
mapOnPatType = (DimDecl VName -> StateT (Set VName) m (DimDecl VName))
-> (Aliasing -> StateT (Set VName) m Aliasing)
-> PatType
-> StateT (Set VName) m PatType
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse DimDecl VName -> StateT (Set VName) m (DimDecl VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *).
(MonadState (Set VName) (t m), MonadTrans t, MonadFreshNames m) =>
DimDecl VName -> t m (DimDecl VName)
onDim Aliasing -> StateT (Set VName) m Aliasing
forall (f :: * -> *) a. Applicative f => a -> f a
pure}
    onDim :: DimDecl VName -> t m (DimDecl VName)
onDim (AnyDim (Just VName
v)) = do
      (Set VName -> Set VName) -> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set VName -> Set VName) -> t m ())
-> (Set VName -> Set VName) -> t m ()
forall a b. (a -> b) -> a -> b
$ VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert VName
v
      DimDecl VName -> t m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
v
    onDim (AnyDim Maybe VName
Nothing) = do
      VName
v <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"size"
      (Set VName -> Set VName) -> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set VName -> Set VName) -> t m ())
-> (Set VName -> Set VName) -> t m ()
forall a b. (a -> b) -> a -> b
$ VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert VName
v
      DimDecl VName -> t m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
v
    onDim (NamedDim QualName VName
d) = do
      Bool -> t m () -> t m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound) (t m () -> t m ()) -> t m () -> t m ()
forall a b. (a -> b) -> a -> b
$
        (Set VName -> Set VName) -> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set VName -> Set VName) -> t m ())
-> (Set VName -> Set VName) -> t m ()
forall a b. (a -> b) -> a -> b
$ VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert (VName -> Set VName -> Set VName)
-> VName -> Set VName -> Set VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
      DimDecl VName -> t m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim QualName VName
d
    onDim DimDecl VName
d = DimDecl VName -> t m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl VName
d

unRetType :: StructRetType -> StructType
unRetType :: RetTypeBase (DimDecl VName) () -> StructType
unRetType (RetType [] StructType
t) = StructType
t
unRetType (RetType [VName]
ext StructType
t) = (DimDecl VName -> DimDecl VName) -> StructType -> StructType
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DimDecl VName -> DimDecl VName
onDim StructType
t
  where
    onDim :: DimDecl VName -> DimDecl VName
onDim (NamedDim QualName VName
d) | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ext = Maybe VName -> DimDecl VName
forall vn. Maybe vn -> DimDecl vn
AnyDim Maybe VName
forall a. Maybe a
Nothing
    onDim DimDecl VName
d = DimDecl VName
d

-- | Defunctionalize an application expression at a given depth of application.
-- Calls to dynamic (first-order) functions are preserved at much as possible,
-- but a new lifted function is created if a dynamic function is only partially
-- applied.
defuncApply :: Int -> Exp -> DefM (Exp, StaticVal)
defuncApply :: Int -> Exp -> DefM (Exp, StaticVal)
defuncApply Int
depth e :: Exp
e@(AppExp (Apply Exp
e1 Exp
e2 Info (Diet, Maybe VName)
d SrcLoc
loc) t :: Info AppRes
t@(Info (AppRes PatType
ret [VName]
ext))) = do
  let ([PatType]
argtypes, PatType
_) = PatType -> ([PatType], PatType)
forall dim as.
TypeBase dim as -> ([TypeBase dim as], TypeBase dim as)
unfoldFunType PatType
ret
  (Exp
e1', StaticVal
sv1) <- Int -> Exp -> DefM (Exp, StaticVal)
defuncApply (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Exp
e1
  (Exp
e2', StaticVal
sv2) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  let e' :: Exp
e' = AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
e1' Exp
e2' Info (Diet, Maybe VName)
d SrcLoc
loc) Info AppRes
t
  case StaticVal
sv1 of
    LambdaSV Pat
pat RetTypeBase (DimDecl VName) ()
e0_t ExtExp
e0 Env
closure_env -> do
      let env' :: Env
env' = Pat -> StaticVal -> Env
matchPatSV Pat
pat StaticVal
sv2
          dims :: [VName]
dims = [VName]
forall a. Monoid a => a
mempty
      (Exp
e0', StaticVal
sv) <-
        Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localNewEnv (Env
env' Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
closure_env) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
          ExtExp -> DefM (Exp, StaticVal)
defuncExtExp ExtExp
e0

      let closure_pat :: Pat
closure_pat = [VName] -> Env -> Pat
buildEnvPat [VName]
dims Env
closure_env
          pat' :: Pat
pat' = Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv2

      Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst

      -- Lift lambda to top-level function definition.  We put in
      -- a lot of effort to try to infer the uniqueness attributes
      -- of the lifted function, but this is ultimately all a sham
      -- and a hack.  There is some piece we're missing.
      let params :: [Pat]
params = [Pat
closure_pat, Pat
pat']
          params_for_rettype :: [Pat]
params_for_rettype = [Pat]
params [Pat] -> [Pat] -> [Pat]
forall a. [a] -> [a] -> [a]
++ StaticVal -> [Pat]
svParams StaticVal
sv1 [Pat] -> [Pat] -> [Pat]
forall a. [a] -> [a] -> [a]
++ StaticVal -> [Pat]
svParams StaticVal
sv2
          svParams :: StaticVal -> [Pat]
svParams (LambdaSV Pat
sv_pat RetTypeBase (DimDecl VName) ()
_ ExtExp
_ Env
_) = [Pat
sv_pat]
          svParams StaticVal
_ = []
          rettype :: PatType
rettype = Env -> [Pat] -> StructType -> PatType -> PatType
buildRetType Env
closure_env [Pat]
params_for_rettype (RetTypeBase (DimDecl VName) () -> StructType
unRetType RetTypeBase (DimDecl VName) ()
e0_t) (PatType -> PatType) -> PatType -> PatType
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e0'

          already_bound :: Set VName
already_bound =
            Set VName
globals Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims
              Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (IdentBase Info VName -> VName)
-> Set (IdentBase Info VName) -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName ((Pat -> Set (IdentBase Info VName))
-> [Pat] -> Set (IdentBase Info VName)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
params)

          more_dims :: [VName]
more_dims =
            Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$
              (VName -> Bool) -> Set VName -> Set VName
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
already_bound) (Set VName -> Set VName) -> Set VName -> Set VName
forall a b. (a -> b) -> a -> b
$
                (Pat -> Set VName) -> [Pat] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set VName
patternArraySizes [Pat]
params

          -- Embed some information about the original function
          -- into the name of the lifted function, to make the
          -- result slightly more human-readable.
          liftedName :: t -> ExpBase f VName -> String
liftedName t
i (Var QualName VName
f f PatType
_ SrcLoc
_) =
            String
"defunc_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ t -> String
forall a. Show a => a -> String
show t
i String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
baseString (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
f)
          liftedName t
i (AppExp (Apply ExpBase f VName
f ExpBase f VName
_ f (Diet, Maybe VName)
_ SrcLoc
_) f AppRes
_) =
            t -> ExpBase f VName -> String
liftedName (t
i t -> t -> t
forall a. Num a => a -> a -> a
+ t
1) ExpBase f VName
f
          liftedName t
_ ExpBase f VName
_ = String
"defunc"

      -- Ensure that no parameter sizes are AnyDim.  The internaliser
      -- expects this.  This is easy, because they are all
      -- first-order.
      let bound_sizes :: Set VName
bound_sizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName]
dims [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
more_dims) Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
globals
      ([VName]
missing_dims, [Pat]
params') <- Set VName -> [Pat] -> DefM ([VName], [Pat])
forall (m :: * -> *).
MonadFreshNames m =>
Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
params

      VName
fname <- String -> DefM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString (String -> DefM VName) -> String -> DefM VName
forall a b. (a -> b) -> a -> b
$ Int -> Exp -> String
forall t (f :: * -> *).
(Show t, Num t) =>
t -> ExpBase f VName -> String
liftedName (Int
0 :: Int) Exp
e1
      VName
-> RetTypeBase (DimDecl VName) ()
-> [VName]
-> [Pat]
-> Exp
-> DefM ()
liftValDec
        VName
fname
        ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (StructType -> RetTypeBase (DimDecl VName) ())
-> StructType -> RetTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
rettype)
        ([VName]
dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
more_dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
missing_dims)
        [Pat]
params'
        Exp
e0'

      let t1 :: StructType
t1 = PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatType -> StructType) -> PatType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e1'
          t2 :: StructType
t2 = PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatType -> StructType) -> PatType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e2'
          fname' :: QualName VName
fname' = VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname
          fname'' :: Exp
fname'' =
            QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var
              QualName VName
fname'
              ( PatType -> Info PatType
forall a. a -> Info a
Info
                  ( ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$
                      Aliasing
-> PName
-> PatType
-> RetTypeBase (DimDecl VName) Aliasing
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> RetTypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed (StructType -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t1) (RetTypeBase (DimDecl VName) Aliasing
 -> ScalarTypeBase (DimDecl VName) Aliasing)
-> RetTypeBase (DimDecl VName) Aliasing
-> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$
                        [VName] -> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (PatType -> RetTypeBase (DimDecl VName) Aliasing)
-> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$
                          ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$
                            Aliasing
-> PName
-> PatType
-> RetTypeBase (DimDecl VName) Aliasing
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> RetTypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed (StructType -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t2) (RetTypeBase (DimDecl VName) Aliasing
 -> ScalarTypeBase (DimDecl VName) Aliasing)
-> RetTypeBase (DimDecl VName) Aliasing
-> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$
                              [VName] -> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
rettype
                  )
              )
              SrcLoc
loc

          -- FIXME: what if this application returns both a function
          -- and a value?
          callret :: AppRes
callret
            | PatType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatType
ret = PatType -> [VName] -> AppRes
AppRes PatType
ret [VName]
ext
            | Bool
otherwise = PatType -> [VName] -> AppRes
AppRes PatType
rettype [VName]
ext

      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens
            ( AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
                ( Exp
-> Exp
-> Info (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase f vn
Apply
                    ( AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp
                        (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
fname'' Exp
e1' ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (Diet
Observe, Maybe VName
forall a. Maybe a
Nothing)) SrcLoc
loc)
                        ( AppRes -> Info AppRes
forall a. a -> Info a
Info (AppRes -> Info AppRes) -> AppRes -> Info AppRes
forall a b. (a -> b) -> a -> b
$
                            PatType -> [VName] -> AppRes
AppRes
                              ( ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$
                                  Aliasing
-> PName
-> PatType
-> RetTypeBase (DimDecl VName) Aliasing
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> RetTypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed (StructType -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t2) (RetTypeBase (DimDecl VName) Aliasing
 -> ScalarTypeBase (DimDecl VName) Aliasing)
-> RetTypeBase (DimDecl VName) Aliasing
-> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$
                                    [VName] -> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
rettype
                              )
                              []
                        )
                    )
                    Exp
e2'
                    Info (Diet, Maybe VName)
d
                    SrcLoc
loc
                )
                (AppRes -> Info AppRes
forall a. a -> Info a
Info AppRes
callret)
            )
            SrcLoc
forall a. Monoid a => a
mempty,
          StaticVal
sv
        )

    -- If e1 is a dynamic function, we just leave the application in place,
    -- but we update the types since it may be partially applied or return
    -- a higher-order term.
    DynamicFun (Exp, StaticVal)
_ StaticVal
sv -> do
      let ([PatType]
argtypes', PatType
rettype) = StaticVal -> [PatType] -> ([PatType], PatType)
dynamicFunType StaticVal
sv [PatType]
argtypes
          restype :: PatType
restype = [PatType] -> RetTypeBase (DimDecl VName) Aliasing -> PatType
forall as dim.
Monoid as =>
[TypeBase dim as] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [PatType]
argtypes' ([VName] -> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
rettype) PatType -> Aliasing -> PatType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` PatType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatType
ret
          -- FIXME: what if this application returns both a function
          -- and a value?
          callret :: AppRes
callret
            | PatType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatType
ret = PatType -> [VName] -> AppRes
AppRes PatType
ret [VName]
ext
            | Bool
otherwise = PatType -> [VName] -> AppRes
AppRes PatType
restype [VName]
ext
          apply_e :: Exp
apply_e = AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
e1' Exp
e2' Info (Diet, Maybe VName)
d SrcLoc
loc) (AppRes -> Info AppRes
forall a. a -> Info a
Info AppRes
callret)
      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
apply_e, StaticVal
sv)
    -- Propagate the 'IntrinsicsSV' until we reach the outermost application,
    -- where we construct a dynamic static value with the appropriate type.
    StaticVal
IntrinsicSV
      | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
        -- If the intrinsic is fully applied, then we are done.
        -- Otherwise we need to eta-expand it and recursively
        -- defunctionalise. XXX: might it be better to simply
        -- eta-expand immediately any time we encounter a
        -- non-fully-applied intrinsic?
        if [PatType] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatType]
argtypes
          then (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e', PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
          else do
            ([Pat]
pats, Exp
body, RetTypeBase (DimDecl VName) ()
tp) <- PatType -> Exp -> DefM ([Pat], Exp, RetTypeBase (DimDecl VName) ())
etaExpand (Exp -> PatType
typeOf Exp
e') Exp
e'
            Exp -> DefM (Exp, StaticVal)
defuncExp (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Pat]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, RetTypeBase (DimDecl VName) ())
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
body Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, RetTypeBase (DimDecl VName) ())
-> Info (Aliasing, RetTypeBase (DimDecl VName) ())
forall a. a -> Info a
Info (Aliasing
forall a. Monoid a => a
mempty, RetTypeBase (DimDecl VName) ()
tp)) SrcLoc
forall a. Monoid a => a
mempty
      | Bool
otherwise -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e', StaticVal
IntrinsicSV)
    StaticVal
_ ->
      String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
        String
"Application of an expression\n"
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
pretty Exp
e1
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nthat is neither a static lambda "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"nor a dynamic function, but has static value:\n"
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv1
defuncApply Int
depth e :: Exp
e@(Var QualName VName
qn (Info PatType
t) SrcLoc
loc) = do
  let ([PatType]
argtypes, PatType
_) = PatType -> ([PatType], PatType)
forall dim as.
TypeBase dim as -> ([TypeBase dim as], TypeBase dim as)
unfoldFunType PatType
t
  StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t) (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)

  case StaticVal
sv of
    DynamicFun (Exp, StaticVal)
_ StaticVal
_
      | StaticVal -> Int -> Bool
fullyApplied StaticVal
sv Int
depth -> do
        -- We still need to update the types in case the dynamic
        -- function returns a higher-order term.
        let ([PatType]
argtypes', PatType
rettype) = StaticVal -> [PatType] -> ([PatType], PatType)
dynamicFunType StaticVal
sv [PatType]
argtypes
        (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (PatType -> Info PatType
forall a. a -> Info a
Info ([PatType] -> RetTypeBase (DimDecl VName) Aliasing -> PatType
forall as dim.
Monoid as =>
[TypeBase dim as] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [PatType]
argtypes' (RetTypeBase (DimDecl VName) Aliasing -> PatType)
-> RetTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$ [VName] -> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
rettype)) SrcLoc
loc, StaticVal
sv)
      | Bool
otherwise -> do
        VName
fname <- String -> DefM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> DefM VName) -> String -> DefM VName
forall a b. (a -> b) -> a -> b
$ String
"dyn_" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> VName -> String
baseString (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
        let ([Pat]
pats, Exp
e0, StaticVal
sv') = String -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun (QualName VName -> String
forall a. Pretty a => a -> String
pretty QualName VName
qn) StaticVal
sv Int
depth
            ([PatType]
argtypes', PatType
rettype) = StaticVal -> [PatType] -> ([PatType], PatType)
dynamicFunType StaticVal
sv' [PatType]
argtypes
            dims' :: [VName]
dims' = [VName]
forall a. Monoid a => a
mempty

        -- Ensure that no parameter sizes are AnyDim.  The internaliser
        -- expects this.  This is easy, because they are all
        -- first-order.
        Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst
        let bound_sizes :: Set VName
bound_sizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims' Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
globals
        ([VName]
missing_dims, [Pat]
pats') <- Set VName -> [Pat] -> DefM ([VName], [Pat])
forall (m :: * -> *).
MonadFreshNames m =>
Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
pats

        VName
-> RetTypeBase (DimDecl VName) ()
-> [VName]
-> [Pat]
-> Exp
-> DefM ()
liftValDec VName
fname ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (StructType -> RetTypeBase (DimDecl VName) ())
-> StructType -> RetTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
rettype) ([VName]
dims' [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
missing_dims) [Pat]
pats' Exp
e0
        (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var
              (VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname)
              (PatType -> Info PatType
forall a. a -> Info a
Info ([PatType] -> RetTypeBase (DimDecl VName) Aliasing -> PatType
forall as dim.
Monoid as =>
[TypeBase dim as] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [PatType]
argtypes' (RetTypeBase (DimDecl VName) Aliasing -> PatType)
-> RetTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$ [VName] -> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (PatType -> RetTypeBase (DimDecl VName) Aliasing)
-> PatType -> RetTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ PatType -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct PatType
rettype))
              SrcLoc
loc,
            StaticVal
sv'
          )
    StaticVal
IntrinsicSV -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, StaticVal
IntrinsicSV)
    StaticVal
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (PatType -> Info PatType
forall a. a -> Info a
Info (StaticVal -> PatType
typeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)
defuncApply Int
depth (Parens Exp
e SrcLoc
_) = Int -> Exp -> DefM (Exp, StaticVal)
defuncApply Int
depth Exp
e
defuncApply Int
_ Exp
expr = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
expr

-- | Check if a 'StaticVal' and a given application depth corresponds
-- to a fully applied dynamic function.
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) Int
depth
  | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bool
False
  | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = StaticVal -> Int -> Bool
fullyApplied StaticVal
sv (Int
depth Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
fullyApplied StaticVal
_ Int
_ = Bool
True

-- | Converts a dynamic function 'StaticVal' into a list of
-- dimensions, a list of parameters, a function body, and the
-- appropriate static value for applying the function at the given
-- depth of partial application.
liftDynFun :: String -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun :: String -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun String
_ (DynamicFun (Exp
e, StaticVal
sv) StaticVal
_) Int
0 = ([], Exp
e, StaticVal
sv)
liftDynFun String
s (DynamicFun clsr :: (Exp, StaticVal)
clsr@(Exp
_, LambdaSV Pat
pat RetTypeBase (DimDecl VName) ()
_ ExtExp
_ Env
_) StaticVal
sv) Int
d
  | Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =
    let ([Pat]
pats, Exp
e', StaticVal
sv') = String -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun String
s StaticVal
sv (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
     in (Pat
pat Pat -> [Pat] -> [Pat]
forall a. a -> [a] -> [a]
: [Pat]
pats, Exp
e', (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
clsr StaticVal
sv')
liftDynFun String
s StaticVal
sv Int
d =
  String -> ([Pat], Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> ([Pat], Exp, StaticVal))
-> String -> ([Pat], Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$
    String
s
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" Tried to lift a StaticVal "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> ShowS
forall a. Int -> [a] -> [a]
take Int
100 (StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv)
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but expected a dynamic function.\n"
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
d

-- | Converts a pattern to an environment that binds the individual names of the
-- pattern to their corresponding types wrapped in a 'Dynamic' static value.
envFromPat :: Pat -> Env
envFromPat :: Pat -> Env
envFromPat Pat
pat = case Pat
pat of
  TuplePat [Pat]
ps SrcLoc
_ -> (Pat -> Env) -> [Pat] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
ps
  RecordPat [(Name, Pat)]
fs SrcLoc
_ -> ((Name, Pat) -> Env) -> [(Name, Pat)] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat -> Env
envFromPat (Pat -> Env) -> ((Name, Pat) -> Pat) -> (Name, Pat) -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Pat) -> Pat
forall a b. (a, b) -> b
snd) [(Name, Pat)]
fs
  PatParens Pat
p SrcLoc
_ -> Pat -> Env
envFromPat Pat
p
  PatAttr AttrInfo VName
_ Pat
p SrcLoc
_ -> Pat -> Env
envFromPat Pat
p
  Id VName
vn (Info PatType
t) SrcLoc
_ -> VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (Binding -> Env) -> Binding -> Env
forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic PatType
t
  Wildcard Info PatType
_ SrcLoc
_ -> Env
forall a. Monoid a => a
mempty
  PatAscription Pat
p TypeDeclBase Info VName
_ SrcLoc
_ -> Pat -> Env
envFromPat Pat
p
  PatLit {} -> Env
forall a. Monoid a => a
mempty
  PatConstr Name
_ Info PatType
_ [Pat]
ps SrcLoc
_ -> (Pat -> Env) -> [Pat] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
ps

envFromDimNames :: [VName] -> Env
envFromDimNames :: [VName] -> Env
envFromDimNames = [(VName, Binding)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Binding)] -> Env)
-> ([VName] -> [(VName, Binding)]) -> [VName] -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> [Binding] -> [(VName, Binding)])
-> [Binding] -> [VName] -> [(VName, Binding)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [VName] -> [Binding] -> [(VName, Binding)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Binding -> [Binding]
forall a. a -> [a]
repeat Binding
d)
  where
    d :: Binding
d = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

-- | Create a new top-level value declaration with the given function name,
-- return type, list of parameters, and body expression.
liftValDec :: VName -> StructRetType -> [VName] -> [Pat] -> Exp -> DefM ()
liftValDec :: VName
-> RetTypeBase (DimDecl VName) ()
-> [VName]
-> [Pat]
-> Exp
-> DefM ()
liftValDec VName
fname (RetType [VName]
ret_dims StructType
ret) [VName]
dims [Pat]
pats Exp
body = ValBind -> DefM ()
addValBind ValBind
dec
  where
    dims' :: [TypeParamBase VName]
dims' = (VName -> TypeParamBase VName) -> [VName] -> [TypeParamBase VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SrcLoc -> TypeParamBase VName
forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` SrcLoc
forall a. Monoid a => a
mempty) [VName]
dims
    -- FIXME: this pass is still not correctly size-preserving, so
    -- forget those return sizes that we forgot to propagate along
    -- the way.  Hopefully the internaliser is conservative and
    -- will insert reshapes...
    bound_here :: Set VName
bound_here = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (IdentBase Info VName -> VName)
-> Set (IdentBase Info VName) -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName ((Pat -> Set (IdentBase Info VName))
-> [Pat] -> Set (IdentBase Info VName)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
pats)
    mkExt :: VName -> Maybe VName
mkExt VName
v
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound_here = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
    mkExt VName
_ = Maybe VName
forall a. Maybe a
Nothing
    rettype_st :: RetTypeBase (DimDecl VName) ()
rettype_st = [VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType ((VName -> Maybe VName) -> [VName] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe VName -> Maybe VName
mkExt (Set VName -> [VName]
forall a. Set a -> [a]
S.toList (StructType -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames StructType
ret)) [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
ret_dims) StructType
ret

    (RetTypeBase (DimDecl VName) ()
valbind_t, [VName]
valbind_ext) =
      case [Pat]
pats of
        [] -> ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (StructType -> RetTypeBase (DimDecl VName) ())
-> StructType -> RetTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ RetTypeBase (DimDecl VName) () -> StructType
forall dim as. RetTypeBase dim as -> TypeBase dim as
retType RetTypeBase (DimDecl VName) ()
rettype_st, RetTypeBase (DimDecl VName) () -> [VName]
forall dim as. RetTypeBase dim as -> [VName]
retDims RetTypeBase (DimDecl VName) ()
rettype_st)
        [Pat]
_ -> (RetTypeBase (DimDecl VName) ()
rettype_st, [])
    dec :: ValBind
dec =
      ValBind :: forall (f :: * -> *) vn.
Maybe (f EntryPoint)
-> vn
-> Maybe (TypeExp vn)
-> f (RetTypeBase (DimDecl VName) (), [VName])
-> [TypeParamBase vn]
-> [PatBase f vn]
-> ExpBase f vn
-> Maybe DocComment
-> [AttrInfo vn]
-> SrcLoc
-> ValBindBase f vn
ValBind
        { valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = Maybe (Info EntryPoint)
forall a. Maybe a
Nothing,
          valBindName :: VName
valBindName = VName
fname,
          valBindRetDecl :: Maybe (TypeExp VName)
valBindRetDecl = Maybe (TypeExp VName)
forall a. Maybe a
Nothing,
          valBindRetType :: Info (RetTypeBase (DimDecl VName) (), [VName])
valBindRetType = (RetTypeBase (DimDecl VName) (), [VName])
-> Info (RetTypeBase (DimDecl VName) (), [VName])
forall a. a -> Info a
Info (RetTypeBase (DimDecl VName) ()
valbind_t, [VName]
valbind_ext),
          valBindTypeParams :: [TypeParamBase VName]
valBindTypeParams = [TypeParamBase VName]
dims',
          valBindParams :: [Pat]
valBindParams = [Pat]
pats,
          valBindBody :: Exp
valBindBody = Exp
body,
          valBindDoc :: Maybe DocComment
valBindDoc = Maybe DocComment
forall a. Maybe a
Nothing,
          valBindAttrs :: [AttrInfo VName]
valBindAttrs = [AttrInfo VName]
forall a. Monoid a => a
mempty,
          valBindLocation :: SrcLoc
valBindLocation = SrcLoc
forall a. Monoid a => a
mempty
        }

-- | Given a closure environment, construct a record pattern that
-- binds the closed over variables.  Insert wildcard for any patterns
-- that would otherwise clash with size parameters.
buildEnvPat :: [VName] -> Env -> Pat
buildEnvPat :: [VName] -> Env -> Pat
buildEnvPat [VName]
sizes Env
env = [(Name, Pat)] -> SrcLoc -> Pat
forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat (((VName, Binding) -> (Name, Pat))
-> [(VName, Binding)] -> [(Name, Pat)]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Binding) -> (Name, Pat)
buildField ([(VName, Binding)] -> [(Name, Pat)])
-> [(VName, Binding)] -> [(Name, Pat)]
forall a b. (a -> b) -> a -> b
$ Env -> [(VName, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList Env
env) SrcLoc
forall a. Monoid a => a
mempty
  where
    buildField :: (VName, Binding) -> (Name, Pat)
buildField (VName
vn, Binding Maybe ([VName], StructType)
_ StaticVal
sv) =
      ( String -> Name
nameFromString (VName -> String
forall a. Pretty a => a -> String
pretty VName
vn),
        if VName
vn VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
sizes
          then Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) SrcLoc
forall a. Monoid a => a
mempty
          else VName -> Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
vn (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) SrcLoc
forall a. Monoid a => a
mempty
      )

-- | Given a closure environment pattern and the type of a term,
-- construct the type of that term, where uniqueness is set to
-- `Nonunique` for those arrays that are bound in the environment or
-- pattern (except if they are unique there).  This ensures that a
-- lifted function can create unique arrays as long as they do not
-- alias any of its parameters.  XXX: it is not clear that this is a
-- sufficient property, unfortunately.
buildRetType :: Env -> [Pat] -> StructType -> PatType -> PatType
buildRetType :: Env -> [Pat] -> StructType -> PatType -> PatType
buildRetType Env
env [Pat]
pats = StructType -> PatType -> PatType
forall (t :: * -> *) shape as.
(Foldable t, Monoid (t Alias)) =>
TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb
  where
    bound :: Set VName
bound =
      [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList (Env -> [VName]
forall k a. Map k a -> [k]
M.keys Env
env) Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (IdentBase Info VName -> VName)
-> Set (IdentBase Info VName) -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName ((Pat -> Set (IdentBase Info VName))
-> [Pat] -> Set (IdentBase Info VName)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
pats)
    boundAsUnique :: VName -> Bool
boundAsUnique VName
v =
      Bool
-> (IdentBase Info VName -> Bool)
-> Maybe (IdentBase Info VName)
-> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (PatType -> Bool
forall dim as. TypeBase dim as -> Bool
unique (PatType -> Bool)
-> (IdentBase Info VName -> PatType)
-> IdentBase Info VName
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Info PatType -> PatType
forall a. Info a -> a
unInfo (Info PatType -> PatType)
-> (IdentBase Info VName -> Info PatType)
-> IdentBase Info VName
-> PatType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdentBase Info VName -> Info PatType
forall (f :: * -> *) vn. IdentBase f vn -> f PatType
identType) (Maybe (IdentBase Info VName) -> Bool)
-> Maybe (IdentBase Info VName) -> Bool
forall a b. (a -> b) -> a -> b
$
        (IdentBase Info VName -> Bool)
-> [IdentBase Info VName] -> Maybe (IdentBase Info VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> (IdentBase Info VName -> VName) -> IdentBase Info VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName) ([IdentBase Info VName] -> Maybe (IdentBase Info VName))
-> [IdentBase Info VName] -> Maybe (IdentBase Info VName)
forall a b. (a -> b) -> a -> b
$ Set (IdentBase Info VName) -> [IdentBase Info VName]
forall a. Set a -> [a]
S.toList (Set (IdentBase Info VName) -> [IdentBase Info VName])
-> Set (IdentBase Info VName) -> [IdentBase Info VName]
forall a b. (a -> b) -> a -> b
$ (Pat -> Set (IdentBase Info VName))
-> [Pat] -> Set (IdentBase Info VName)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
pats
    problematic :: VName -> Bool
problematic VName
v = (VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound) Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Bool
boundAsUnique VName
v)
    comb :: TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb (Scalar (Record Map Name (TypeBase shape as)
fs_annot)) (Scalar (Record Map Name (TypeBase shape (t Alias))
fs_got)) =
      ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias)
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase shape (t Alias))
 -> ScalarTypeBase shape (t Alias))
-> Map Name (TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ (TypeBase shape as
 -> TypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> Map Name (TypeBase shape as)
-> Map Name (TypeBase shape (t Alias))
-> Map Name (TypeBase shape (t Alias))
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb Map Name (TypeBase shape as)
fs_annot Map Name (TypeBase shape (t Alias))
fs_got
    comb (Scalar (Sum Map Name [TypeBase shape as]
cs_annot)) (Scalar (Sum Map Name [TypeBase shape (t Alias)]
cs_got)) =
      ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ Map Name [TypeBase shape (t Alias)]
-> ScalarTypeBase shape (t Alias)
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [TypeBase shape (t Alias)]
 -> ScalarTypeBase shape (t Alias))
-> Map Name [TypeBase shape (t Alias)]
-> ScalarTypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ ([TypeBase shape as]
 -> [TypeBase shape (t Alias)] -> [TypeBase shape (t Alias)])
-> Map Name [TypeBase shape as]
-> Map Name [TypeBase shape (t Alias)]
-> Map Name [TypeBase shape (t Alias)]
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith ((TypeBase shape as
 -> TypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> [TypeBase shape as]
-> [TypeBase shape (t Alias)]
-> [TypeBase shape (t Alias)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb) Map Name [TypeBase shape as]
cs_annot Map Name [TypeBase shape (t Alias)]
cs_got
    comb (Scalar Arrow {}) TypeBase shape (t Alias)
t =
      TypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall (t :: * -> *) dim.
(Foldable t, Monoid (t Alias)) =>
TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend TypeBase shape (t Alias)
t
    comb TypeBase shape as
got TypeBase shape (t Alias)
et =
      TypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall (t :: * -> *) dim.
(Foldable t, Monoid (t Alias)) =>
TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend (TypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ TypeBase shape as -> TypeBase shape Aliasing
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase shape as
got TypeBase shape Aliasing -> t Alias -> TypeBase shape (t Alias)
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` TypeBase shape (t Alias) -> t Alias
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase shape (t Alias)
et

    descend :: TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend t :: TypeBase dim (t Alias)
t@Array {}
      | (Alias -> Bool) -> t Alias -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Bool
problematic (VName -> Bool) -> (Alias -> VName) -> Alias -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alias -> VName
aliasVar) (TypeBase dim (t Alias) -> t Alias
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim (t Alias)
t) = TypeBase dim (t Alias)
t TypeBase dim (t Alias) -> Uniqueness -> TypeBase dim (t Alias)
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
    descend (Scalar (Record Map Name (TypeBase dim (t Alias))
t)) = ScalarTypeBase dim (t Alias) -> TypeBase dim (t Alias)
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim (t Alias) -> TypeBase dim (t Alias))
-> ScalarTypeBase dim (t Alias) -> TypeBase dim (t Alias)
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase dim (t Alias)) -> ScalarTypeBase dim (t Alias)
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase dim (t Alias)) -> ScalarTypeBase dim (t Alias))
-> Map Name (TypeBase dim (t Alias))
-> ScalarTypeBase dim (t Alias)
forall a b. (a -> b) -> a -> b
$ (TypeBase dim (t Alias) -> TypeBase dim (t Alias))
-> Map Name (TypeBase dim (t Alias))
-> Map Name (TypeBase dim (t Alias))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend Map Name (TypeBase dim (t Alias))
t
    descend TypeBase dim (t Alias)
t = TypeBase dim (t Alias)
t

-- | Compute the corresponding type for the *representation* of a
-- given static value (not the original possibly higher-order value).
typeFromSV :: StaticVal -> PatType
typeFromSV :: StaticVal -> PatType
typeFromSV (Dynamic PatType
tp) =
  PatType
tp
typeFromSV (LambdaSV Pat
_ RetTypeBase (DimDecl VName) ()
_ ExtExp
_ Env
env) =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ([(Name, PatType)] -> ScalarTypeBase (DimDecl VName) Aliasing)
-> [(Name, PatType)]
-> PatType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name PatType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name PatType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> ([(Name, PatType)] -> Map Name PatType)
-> [(Name, PatType)]
-> ScalarTypeBase (DimDecl VName) Aliasing
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Name, PatType)] -> Map Name PatType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, PatType)] -> PatType) -> [(Name, PatType)] -> PatType
forall a b. (a -> b) -> a -> b
$
    ((VName, Binding) -> (Name, PatType))
-> [(VName, Binding)] -> [(Name, PatType)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Name)
-> (Binding -> PatType) -> (VName, Binding) -> (Name, PatType)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty) (StaticVal -> PatType
typeFromSV (StaticVal -> PatType)
-> (Binding -> StaticVal) -> Binding -> PatType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> StaticVal
bindingSV)) ([(VName, Binding)] -> [(Name, PatType)])
-> [(VName, Binding)] -> [(Name, PatType)]
forall a b. (a -> b) -> a -> b
$
      Env -> [(VName, Binding)]
forall k a. Map k a -> [(k, a)]
M.toList Env
env
typeFromSV (RecordSV [(Name, StaticVal)]
ls) =
  let ts :: [(Name, PatType)]
ts = ((Name, StaticVal) -> (Name, PatType))
-> [(Name, StaticVal)] -> [(Name, PatType)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> PatType) -> (Name, StaticVal) -> (Name, PatType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StaticVal -> PatType
typeFromSV) [(Name, StaticVal)]
ls
   in ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$ Map Name PatType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name PatType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Map Name PatType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ [(Name, PatType)] -> Map Name PatType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, PatType)]
ts
typeFromSV (DynamicFun (Exp
_, StaticVal
sv) StaticVal
_) =
  StaticVal -> PatType
typeFromSV StaticVal
sv
typeFromSV (SumSV Name
name [StaticVal]
svs [(Name, [PatType])]
fields) =
  let svs' :: [PatType]
svs' = (StaticVal -> PatType) -> [StaticVal] -> [PatType]
forall a b. (a -> b) -> [a] -> [b]
map StaticVal -> PatType
typeFromSV [StaticVal]
svs
   in ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$ Map Name [PatType] -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [PatType] -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Map Name [PatType] -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ Name -> [PatType] -> Map Name [PatType] -> Map Name [PatType]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name [PatType]
svs' (Map Name [PatType] -> Map Name [PatType])
-> Map Name [PatType] -> Map Name [PatType]
forall a b. (a -> b) -> a -> b
$ [(Name, [PatType])] -> Map Name [PatType]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, [PatType])]
fields
typeFromSV StaticVal
IntrinsicSV =
  String -> PatType
forall a. HasCallStack => String -> a
error String
"Tried to get the type from the static value of an intrinsic."

-- | Construct the type for a fully-applied dynamic function from its
-- static value and the original types of its arguments.
dynamicFunType :: StaticVal -> [PatType] -> ([PatType], PatType)
dynamicFunType :: StaticVal -> [PatType] -> ([PatType], PatType)
dynamicFunType (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) (PatType
p : [PatType]
ps) =
  let ([PatType]
ps', PatType
ret) = StaticVal -> [PatType] -> ([PatType], PatType)
dynamicFunType StaticVal
sv [PatType]
ps in (PatType
p PatType -> [PatType] -> [PatType]
forall a. a -> [a] -> [a]
: [PatType]
ps', PatType
ret)
dynamicFunType StaticVal
sv [PatType]
_ = ([], StaticVal -> PatType
typeFromSV StaticVal
sv)

-- | Match a pattern with its static value. Returns an environment with
-- the identifier components of the pattern mapped to the corresponding
-- subcomponents of the static value.
matchPatSV :: PatBase Info VName -> StaticVal -> Env
matchPatSV :: Pat -> StaticVal -> Env
matchPatSV (TuplePat [Pat]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls) =
  [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pat -> (Name, StaticVal) -> Env)
-> [Pat] -> [(Name, StaticVal)] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Pat
p (Name
_, StaticVal
sv) -> Pat -> StaticVal -> Env
matchPatSV Pat
p StaticVal
sv) [Pat]
ps [(Name, StaticVal)]
ls
matchPatSV (RecordPat [(Name, Pat)]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls)
  | [(Name, Pat)]
ps' <- ((Name, Pat) -> Name) -> [(Name, Pat)] -> [(Name, Pat)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, Pat) -> Name
forall a b. (a, b) -> a
fst [(Name, Pat)]
ps,
    [(Name, StaticVal)]
ls' <- ((Name, StaticVal) -> Name)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls,
    ((Name, Pat) -> Name) -> [(Name, Pat)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Pat) -> Name
forall a b. (a, b) -> a
fst [(Name, Pat)]
ps' [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== ((Name, StaticVal) -> Name) -> [(Name, StaticVal)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls' =
    [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ ((Name, Pat) -> (Name, StaticVal) -> Env)
-> [(Name, Pat)] -> [(Name, StaticVal)] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(Name
_, Pat
p) (Name
_, StaticVal
sv) -> Pat -> StaticVal -> Env
matchPatSV Pat
p StaticVal
sv) [(Name, Pat)]
ps' [(Name, StaticVal)]
ls'
matchPatSV (PatParens Pat
pat SrcLoc
_) StaticVal
sv = Pat -> StaticVal -> Env
matchPatSV Pat
pat StaticVal
sv
matchPatSV (PatAttr AttrInfo VName
_ Pat
pat SrcLoc
_) StaticVal
sv = Pat -> StaticVal -> Env
matchPatSV Pat
pat StaticVal
sv
matchPatSV (Id VName
vn (Info PatType
t) SrcLoc
_) StaticVal
sv =
  -- When matching a pattern with a zero-order STaticVal, the type of
  -- the pattern wins out.  This is important when matching a
  -- nonunique pattern with a unique value.
  if StaticVal -> Bool
orderZeroSV StaticVal
sv
    then Env
dim_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic PatType
t)
    else Env
dim_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing StaticVal
sv)
  where
    dim_env :: Env
dim_env =
      [(VName, Binding)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Binding)] -> Env) -> [(VName, Binding)] -> Env
forall a b. (a -> b) -> a -> b
$ (VName -> (VName, Binding)) -> [VName] -> [(VName, Binding)]
forall a b. (a -> b) -> [a] -> [b]
map (,Binding
i64) ([VName] -> [(VName, Binding)]) -> [VName] -> [(VName, Binding)]
forall a b. (a -> b) -> a -> b
$ Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ PatType -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames PatType
t
    i64 :: Binding
i64 = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
forall a. Maybe a
Nothing (StaticVal -> Binding) -> StaticVal -> Binding
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic (PatType -> StaticVal) -> PatType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64
matchPatSV (Wildcard Info PatType
_ SrcLoc
_) StaticVal
_ = Env
forall a. Monoid a => a
mempty
matchPatSV (PatAscription Pat
pat TypeDeclBase Info VName
_ SrcLoc
_) StaticVal
sv = Pat -> StaticVal -> Env
matchPatSV Pat
pat StaticVal
sv
matchPatSV PatLit {} StaticVal
_ = Env
forall a. Monoid a => a
mempty
matchPatSV (PatConstr Name
c1 Info PatType
_ [Pat]
ps SrcLoc
_) (SumSV Name
c2 [StaticVal]
ls [(Name, [PatType])]
fs)
  | Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2 =
    [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pat -> StaticVal -> Env) -> [Pat] -> [StaticVal] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat -> StaticVal -> Env
matchPatSV [Pat]
ps [StaticVal]
ls
  | Just [PatType]
ts <- Name -> [(Name, [PatType])] -> Maybe [PatType]
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
c1 [(Name, [PatType])]
fs =
    [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pat -> StaticVal -> Env) -> [Pat] -> [StaticVal] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat -> StaticVal -> Env
matchPatSV [Pat]
ps ([StaticVal] -> [Env]) -> [StaticVal] -> [Env]
forall a b. (a -> b) -> a -> b
$ (PatType -> StaticVal) -> [PatType] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map PatType -> StaticVal
svFromType [PatType]
ts
  | Bool
otherwise =
    String -> Env
forall a. HasCallStack => String -> a
error (String -> Env) -> String -> Env
forall a b. (a -> b) -> a -> b
$ String
"matchPatSV: missing constructor in type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
pretty Name
c1
matchPatSV (PatConstr Name
c1 Info PatType
_ [Pat]
ps SrcLoc
_) (Dynamic (Scalar (Sum Map Name [PatType]
fs)))
  | Just [PatType]
ts <- Name -> Map Name [PatType] -> Maybe [PatType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c1 Map Name [PatType]
fs =
    [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pat -> StaticVal -> Env) -> [Pat] -> [StaticVal] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat -> StaticVal -> Env
matchPatSV [Pat]
ps ([StaticVal] -> [Env]) -> [StaticVal] -> [Env]
forall a b. (a -> b) -> a -> b
$ (PatType -> StaticVal) -> [PatType] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map PatType -> StaticVal
svFromType [PatType]
ts
  | Bool
otherwise =
    String -> Env
forall a. HasCallStack => String -> a
error (String -> Env) -> String -> Env
forall a b. (a -> b) -> a -> b
$ String
"matchPatSV: missing constructor in type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
pretty Name
c1
matchPatSV Pat
pat (Dynamic PatType
t) = Pat -> StaticVal -> Env
matchPatSV Pat
pat (StaticVal -> Env) -> StaticVal -> Env
forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
svFromType PatType
t
matchPatSV Pat
pat StaticVal
sv =
  String -> Env
forall a. HasCallStack => String -> a
error (String -> Env) -> String -> Env
forall a b. (a -> b) -> a -> b
$
    String
"Tried to match pattern " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat -> String
forall a. Pretty a => a -> String
pretty Pat
pat
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" with static value "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."

orderZeroSV :: StaticVal -> Bool
orderZeroSV :: StaticVal -> Bool
orderZeroSV Dynamic {} = Bool
True
orderZeroSV (RecordSV [(Name, StaticVal)]
fields) = ((Name, StaticVal) -> Bool) -> [(Name, StaticVal)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (StaticVal -> Bool
orderZeroSV (StaticVal -> Bool)
-> ((Name, StaticVal) -> StaticVal) -> (Name, StaticVal) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd) [(Name, StaticVal)]
fields
orderZeroSV StaticVal
_ = Bool
False

-- | Given a pattern and the static value for the defunctionalized argument,
-- update the pattern to reflect the changes in the types.
updatePat :: Pat -> StaticVal -> Pat
updatePat :: Pat -> StaticVal -> Pat
updatePat (TuplePat [Pat]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs) =
  [Pat] -> SrcLoc -> Pat
forall (f :: * -> *) vn. [PatBase f vn] -> SrcLoc -> PatBase f vn
TuplePat ((Pat -> StaticVal -> Pat) -> [Pat] -> [StaticVal] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat -> StaticVal -> Pat
updatePat [Pat]
ps ([StaticVal] -> [Pat]) -> [StaticVal] -> [Pat]
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> StaticVal)
-> [(Name, StaticVal)] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map (Name, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd [(Name, StaticVal)]
svs) SrcLoc
loc
updatePat (RecordPat [(Name, Pat)]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs)
  | [(Name, Pat)]
ps' <- ((Name, Pat) -> Name) -> [(Name, Pat)] -> [(Name, Pat)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, Pat) -> Name
forall a b. (a, b) -> a
fst [(Name, Pat)]
ps,
    [(Name, StaticVal)]
svs' <- ((Name, StaticVal) -> Name)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
svs =
    [(Name, Pat)] -> SrcLoc -> Pat
forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat
      (((Name, Pat) -> (Name, StaticVal) -> (Name, Pat))
-> [(Name, Pat)] -> [(Name, StaticVal)] -> [(Name, Pat)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(Name
n, Pat
p) (Name
_, StaticVal
sv) -> (Name
n, Pat -> StaticVal -> Pat
updatePat Pat
p StaticVal
sv)) [(Name, Pat)]
ps' [(Name, StaticVal)]
svs')
      SrcLoc
loc
updatePat (PatParens Pat
pat SrcLoc
loc) StaticVal
sv =
  Pat -> SrcLoc -> Pat
forall (f :: * -> *) vn. PatBase f vn -> SrcLoc -> PatBase f vn
PatParens (Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv) SrcLoc
loc
updatePat (PatAttr AttrInfo VName
attr Pat
pat SrcLoc
loc) StaticVal
sv =
  AttrInfo VName -> Pat -> SrcLoc -> Pat
forall (f :: * -> *) vn.
AttrInfo vn -> PatBase f vn -> SrcLoc -> PatBase f vn
PatAttr AttrInfo VName
attr (Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv) SrcLoc
loc
updatePat (Id VName
vn (Info PatType
tp) SrcLoc
loc) StaticVal
sv =
  VName -> Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
vn (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ PatType -> PatType -> PatType
forall dim as.
TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb PatType
tp (StaticVal -> PatType
typeFromSV StaticVal
sv PatType -> Uniqueness -> PatType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique)) SrcLoc
loc
  where
    -- Preserve any original zeroth-order types.
    comb :: TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb (Scalar Arrow {}) TypeBase dim as
t2 = TypeBase dim as
t2
    comb (Scalar (Record Map Name (TypeBase dim as)
m1)) (Scalar (Record Map Name (TypeBase dim as)
m2)) =
      ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase dim as) -> ScalarTypeBase dim as
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase dim as) -> ScalarTypeBase dim as)
-> Map Name (TypeBase dim as) -> ScalarTypeBase dim as
forall a b. (a -> b) -> a -> b
$ (TypeBase dim as -> TypeBase dim as -> TypeBase dim as)
-> Map Name (TypeBase dim as)
-> Map Name (TypeBase dim as)
-> Map Name (TypeBase dim as)
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb Map Name (TypeBase dim as)
m1 Map Name (TypeBase dim as)
m2
    comb (Scalar (Sum Map Name [TypeBase dim as]
m1)) (Scalar (Sum Map Name [TypeBase dim as]
m2)) =
      ScalarTypeBase dim as -> TypeBase dim as
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim as -> TypeBase dim as)
-> ScalarTypeBase dim as -> TypeBase dim as
forall a b. (a -> b) -> a -> b
$ Map Name [TypeBase dim as] -> ScalarTypeBase dim as
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [TypeBase dim as] -> ScalarTypeBase dim as)
-> Map Name [TypeBase dim as] -> ScalarTypeBase dim as
forall a b. (a -> b) -> a -> b
$ ([TypeBase dim as] -> [TypeBase dim as] -> [TypeBase dim as])
-> Map Name [TypeBase dim as]
-> Map Name [TypeBase dim as]
-> Map Name [TypeBase dim as]
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith ((TypeBase dim as -> TypeBase dim as -> TypeBase dim as)
-> [TypeBase dim as] -> [TypeBase dim as] -> [TypeBase dim as]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb) Map Name [TypeBase dim as]
m1 Map Name [TypeBase dim as]
m2
    comb TypeBase dim as
t1 TypeBase dim as
_ = TypeBase dim as
t1 -- t1 must be array or prim.
updatePat pat :: Pat
pat@(Wildcard (Info PatType
tp) SrcLoc
loc) StaticVal
sv
  | PatType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatType
tp = Pat
pat
  | Bool
otherwise = Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) SrcLoc
loc
updatePat (PatAscription Pat
pat TypeDeclBase Info VName
tydecl SrcLoc
loc) StaticVal
sv
  | StructType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (StructType -> Bool)
-> (Info StructType -> StructType) -> Info StructType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> Bool) -> Info StructType -> Bool
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
tydecl =
    Pat -> TypeDeclBase Info VName -> SrcLoc -> Pat
forall (f :: * -> *) vn.
PatBase f vn -> TypeDeclBase f vn -> SrcLoc -> PatBase f vn
PatAscription (Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv) TypeDeclBase Info VName
tydecl SrcLoc
loc
  | Bool
otherwise = Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv
updatePat p :: Pat
p@PatLit {} StaticVal
_ = Pat
p
updatePat pat :: Pat
pat@(PatConstr Name
c1 (Info PatType
t) [Pat]
ps SrcLoc
loc) sv :: StaticVal
sv@(SumSV Name
_ [StaticVal]
svs [(Name, [PatType])]
_)
  | PatType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatType
t = Pat
pat
  | Bool
otherwise = Name -> Info PatType -> [Pat] -> SrcLoc -> Pat
forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
c1 (PatType -> Info PatType
forall a. a -> Info a
Info PatType
t') [Pat]
ps' SrcLoc
loc
  where
    t' :: PatType
t' = StaticVal -> PatType
typeFromSV StaticVal
sv PatType -> Uniqueness -> PatType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
    ps' :: [Pat]
ps' = (Pat -> StaticVal -> Pat) -> [Pat] -> [StaticVal] -> [Pat]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat -> StaticVal -> Pat
updatePat [Pat]
ps [StaticVal]
svs
updatePat (PatConstr Name
c1 Info PatType
_ [Pat]
ps SrcLoc
loc) (Dynamic PatType
t) =
  Name -> Info PatType -> [Pat] -> SrcLoc -> Pat
forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
c1 (PatType -> Info PatType
forall a. a -> Info a
Info PatType
t) [Pat]
ps SrcLoc
loc
updatePat Pat
pat (Dynamic PatType
t) = Pat -> StaticVal -> Pat
updatePat Pat
pat (PatType -> StaticVal
svFromType PatType
t)
updatePat Pat
pat StaticVal
sv =
  String -> Pat
forall a. HasCallStack => String -> a
error (String -> Pat) -> String -> Pat
forall a b. (a -> b) -> a -> b
$
    String
"Tried to update pattern " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat -> String
forall a. Pretty a => a -> String
pretty Pat
pat
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"to reflect the static value "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv

-- | Convert a record (or tuple) type to a record static value. This is used for
-- "unwrapping" tuples and records that are nested in 'Dynamic' static values.
svFromType :: PatType -> StaticVal
svFromType :: PatType -> StaticVal
svFromType (Scalar (Record Map Name PatType
fs)) = [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> (Map Name StaticVal -> [(Name, StaticVal)])
-> Map Name StaticVal
-> StaticVal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name StaticVal -> [(Name, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name StaticVal -> StaticVal)
-> Map Name StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$ (PatType -> StaticVal) -> Map Name PatType -> Map Name StaticVal
forall a b k. (a -> b) -> Map k a -> Map k b
M.map PatType -> StaticVal
svFromType Map Name PatType
fs
svFromType PatType
t = PatType -> StaticVal
Dynamic PatType
t

-- | Defunctionalize a top-level value binding. Returns the
-- transformed result as well as an environment that binds the name of
-- the value binding to the static value of the transformed body.  The
-- boolean is true if the function is a 'DynamicFun'.
defuncValBind :: ValBind -> DefM (ValBind, Env, Bool)
-- Eta-expand entry points with a functional return type.
defuncValBind :: ValBind -> DefM (ValBind, Env, Bool)
defuncValBind (ValBind Maybe (Info EntryPoint)
entry VName
name Maybe (TypeExp VName)
_ (Info (RetType [VName]
_ StructType
rettype, [VName]
retext)) [TypeParamBase VName]
tparams [Pat]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
attrs SrcLoc
loc)
  | Scalar Arrow {} <- StructType
rettype = do
    ([Pat]
body_pats, Exp
body', RetTypeBase (DimDecl VName) ()
rettype') <- PatType -> Exp -> DefM ([Pat], Exp, RetTypeBase (DimDecl VName) ())
etaExpand (StructType -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
rettype) Exp
body
    ValBind -> DefM (ValBind, Env, Bool)
defuncValBind (ValBind -> DefM (ValBind, Env, Bool))
-> ValBind -> DefM (ValBind, Env, Bool)
forall a b. (a -> b) -> a -> b
$
      Maybe (Info EntryPoint)
-> VName
-> Maybe (TypeExp VName)
-> Info (RetTypeBase (DimDecl VName) (), [VName])
-> [TypeParamBase VName]
-> [Pat]
-> Exp
-> Maybe DocComment
-> [AttrInfo VName]
-> SrcLoc
-> ValBind
forall (f :: * -> *) vn.
Maybe (f EntryPoint)
-> vn
-> Maybe (TypeExp vn)
-> f (RetTypeBase (DimDecl VName) (), [VName])
-> [TypeParamBase vn]
-> [PatBase f vn]
-> ExpBase f vn
-> Maybe DocComment
-> [AttrInfo vn]
-> SrcLoc
-> ValBindBase f vn
ValBind
        Maybe (Info EntryPoint)
entry
        VName
name
        Maybe (TypeExp VName)
forall a. Maybe a
Nothing
        ((RetTypeBase (DimDecl VName) (), [VName])
-> Info (RetTypeBase (DimDecl VName) (), [VName])
forall a. a -> Info a
Info (RetTypeBase (DimDecl VName) ()
rettype', [VName]
retext))
        [TypeParamBase VName]
tparams
        ([Pat]
params [Pat] -> [Pat] -> [Pat]
forall a. Semigroup a => a -> a -> a
<> [Pat]
body_pats)
        Exp
body'
        Maybe DocComment
forall a. Maybe a
Nothing
        [AttrInfo VName]
attrs
        SrcLoc
loc
defuncValBind valbind :: ValBind
valbind@(ValBind Maybe (Info EntryPoint)
_ VName
name Maybe (TypeExp VName)
retdecl (Info (RetType [VName]
ret_dims StructType
rettype, [VName]
retext)) [TypeParamBase VName]
tparams [Pat]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
_ SrcLoc
_) = do
  Bool -> DefM () -> DefM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((TypeParamBase VName -> Bool) -> [TypeParamBase VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any TypeParamBase VName -> Bool
forall vn. TypeParamBase vn -> Bool
isTypeParam [TypeParamBase VName]
tparams) (DefM () -> DefM ()) -> DefM () -> DefM ()
forall a b. (a -> b) -> a -> b
$
    String -> DefM ()
forall a. HasCallStack => String -> a
error (String -> DefM ()) -> String -> DefM ()
forall a b. (a -> b) -> a -> b
$
      VName -> String
forall v. IsName v => v -> String
prettyName VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" has type parameters, "
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"but the defunctionaliser expects a monomorphic input program."
  ([VName]
tparams', [Pat]
params', Exp
body', StaticVal
sv) <-
    [VName]
-> [Pat]
-> Exp
-> RetTypeBase (DimDecl VName) ()
-> DefM ([VName], [Pat], Exp, StaticVal)
defuncLet ((TypeParamBase VName -> VName) -> [TypeParamBase VName] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParamBase VName -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParamBase VName]
tparams) [Pat]
params Exp
body (RetTypeBase (DimDecl VName) ()
 -> DefM ([VName], [Pat], Exp, StaticVal))
-> RetTypeBase (DimDecl VName) ()
-> DefM ([VName], [Pat], Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype
  Set VName
globals <- ((Set VName, Env) -> Set VName) -> DefM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Set VName
forall a b. (a, b) -> a
fst
  let bound_sizes :: Set VName
bound_sizes = (Pat -> Set VName) -> [Pat] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set VName
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames [Pat]
params' Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
tparams' Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
globals
      rettype' :: StructType
rettype' =
        -- FIXME: dubious that we cannot assume that all sizes in the
        -- body are in scope.  This is because when we insert
        -- applications of lifted functions, we don't properly update
        -- the types in the return type annotation.
        StructType -> StructType -> StructType
forall as dim.
(Monoid as, ArrayDim dim) =>
TypeBase dim as -> TypeBase dim as -> TypeBase dim as
combineTypeShapes StructType
rettype (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ (DimDecl VName -> DimDecl VName) -> StructType -> StructType
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Set VName -> DimDecl VName -> DimDecl VName
forall vn. Ord vn => Set vn -> DimDecl vn -> DimDecl vn
anyDimIfNotBound Set VName
bound_sizes) (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatType -> StructType) -> PatType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
body'
  ([VName]
missing_dims, [Pat]
params'') <- Set VName -> [Pat] -> DefM ([VName], [Pat])
forall (m :: * -> *).
MonadFreshNames m =>
Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
params'

  (ValBind, Env, Bool) -> DefM (ValBind, Env, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( ValBind
valbind
        { valBindRetDecl :: Maybe (TypeExp VName)
valBindRetDecl = Maybe (TypeExp VName)
retdecl,
          valBindRetType :: Info (RetTypeBase (DimDecl VName) (), [VName])
valBindRetType =
            (RetTypeBase (DimDecl VName) (), [VName])
-> Info (RetTypeBase (DimDecl VName) (), [VName])
forall a. a -> Info a
Info
              ( if [Pat] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Pat]
params'
                  then
                    ( [VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (StructType -> RetTypeBase (DimDecl VName) ())
-> StructType -> RetTypeBase (DimDecl VName) ()
forall a b. (a -> b) -> a -> b
$ StructType
rettype' StructType -> Uniqueness -> StructType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique,
                      [VName]
retext
                    )
                  else ([VName] -> StructType -> RetTypeBase (DimDecl VName) ()
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype', [VName]
retext)
              ),
          valBindTypeParams :: [TypeParamBase VName]
valBindTypeParams =
            (VName -> TypeParamBase VName) -> [VName] -> [TypeParamBase VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SrcLoc -> TypeParamBase VName
forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` SrcLoc
forall a. Monoid a => a
mempty) ([VName] -> [TypeParamBase VName])
-> [VName] -> [TypeParamBase VName]
forall a b. (a -> b) -> a -> b
$ [VName]
tparams' [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
missing_dims,
          valBindParams :: [Pat]
valBindParams = [Pat]
params'',
          valBindBody :: Exp
valBindBody = Exp
body'
        },
      VName -> Binding -> Env
forall k a. k -> a -> Map k a
M.singleton VName
name (Binding -> Env) -> Binding -> Env
forall a b. (a -> b) -> a -> b
$
        Maybe ([VName], StructType) -> StaticVal -> Binding
Binding
          (([VName], StructType) -> Maybe ([VName], StructType)
forall a. a -> Maybe a
Just (([TypeParamBase VName] -> [VName])
-> ([TypeParamBase VName], StructType) -> ([VName], StructType)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((TypeParamBase VName -> VName) -> [TypeParamBase VName] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParamBase VName -> VName
forall vn. TypeParamBase vn -> vn
typeParamName) (ValBind -> ([TypeParamBase VName], StructType)
valBindTypeScheme ValBind
valbind)))
          StaticVal
sv,
      case StaticVal
sv of
        DynamicFun {} -> Bool
True
        Dynamic {} -> Bool
True
        StaticVal
_ -> Bool
False
    )
  where
    anyDimIfNotBound :: Set vn -> DimDecl vn -> DimDecl vn
anyDimIfNotBound Set vn
bound_sizes (NamedDim QualName vn
v)
      | QualName vn -> vn
forall vn. QualName vn -> vn
qualLeaf QualName vn
v vn -> Set vn -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set vn
bound_sizes = Maybe vn -> DimDecl vn
forall vn. Maybe vn -> DimDecl vn
AnyDim (Maybe vn -> DimDecl vn) -> Maybe vn -> DimDecl vn
forall a b. (a -> b) -> a -> b
$ vn -> Maybe vn
forall a. a -> Maybe a
Just (vn -> Maybe vn) -> vn -> Maybe vn
forall a b. (a -> b) -> a -> b
$ QualName vn -> vn
forall vn. QualName vn -> vn
qualLeaf QualName vn
v
    anyDimIfNotBound Set vn
_ DimDecl vn
d = DimDecl vn
d

-- | Defunctionalize a list of top-level declarations.
defuncVals :: [ValBind] -> DefM ()
defuncVals :: [ValBind] -> DefM ()
defuncVals [] = () -> DefM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defuncVals (ValBind
valbind : [ValBind]
ds) = do
  (ValBind
valbind', Env
env, Bool
dyn) <- ValBind -> DefM (ValBind, Env, Bool)
defuncValBind ValBind
valbind
  ValBind -> DefM ()
addValBind ValBind
valbind'
  Env -> DefM () -> DefM ()
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM () -> DefM ()) -> DefM () -> DefM ()
forall a b. (a -> b) -> a -> b
$
    if Bool
dyn
      then VName -> DefM () -> DefM ()
forall a. VName -> DefM a -> DefM a
isGlobal (ValBind -> VName
forall (f :: * -> *) vn. ValBindBase f vn -> vn
valBindName ValBind
valbind') (DefM () -> DefM ()) -> DefM () -> DefM ()
forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM ()
defuncVals [ValBind]
ds
      else [ValBind] -> DefM ()
defuncVals [ValBind]
ds

{-# NOINLINE transformProg #-}

-- | Transform a list of top-level value bindings. May produce new
-- lifted function definitions, which are placed in front of the
-- resulting list of declarations.
transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg :: [ValBind] -> m [ValBind]
transformProg [ValBind]
decs = (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ([ValBind], VNameSource)) -> m [ValBind])
-> (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
  let ((), VNameSource
namesrc', [ValBind]
decs') = VNameSource -> DefM () -> ((), VNameSource, [ValBind])
forall a. VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM VNameSource
namesrc (DefM () -> ((), VNameSource, [ValBind]))
-> DefM () -> ((), VNameSource, [ValBind])
forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM ()
defuncVals [ValBind]
decs
   in ([ValBind]
decs', VNameSource
namesrc')