-- | Defunctionalization of typed, monomorphic Futhark programs without modules.
module Futhark.Internalise.Defunctionalise (transformProg) where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable
import Data.List (partition, sortOn)
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.IR.Pretty ()
import Futhark.MonadFreshNames
import Futhark.Util (mapAccumLM)
import Language.Futhark
import Language.Futhark.Traversals
import Language.Futhark.TypeChecker.Types (Subst (..), applySubst)

-- | A static value stores additional information about the result of
-- defunctionalization of an expression, aside from the residual expression.
data StaticVal
  = Dynamic PatType
  | -- | The Env is the lexical closure of the lambda.
    LambdaSV Pat StructRetType Exp Env
  | RecordSV [(Name, StaticVal)]
  | -- | The constructor that is actually present, plus
    -- the others that are not.
    SumSV Name [StaticVal] [(Name, [PatType])]
  | -- | The pair is the StaticVal and residual expression of this
    -- function as a whole, while the second StaticVal is its
    -- body. (Don't trust this too much, my understanding may have
    -- holes.)
    DynamicFun (Exp, StaticVal) StaticVal
  | IntrinsicSV
  | HoleSV PatType SrcLoc
  deriving (Int -> StaticVal -> ShowS
[StaticVal] -> ShowS
StaticVal -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [StaticVal] -> ShowS
$cshowList :: [StaticVal] -> ShowS
show :: StaticVal -> [Char]
$cshow :: StaticVal -> [Char]
showsPrec :: Int -> StaticVal -> ShowS
$cshowsPrec :: Int -> StaticVal -> ShowS
Show)

-- | The type is Just if this is a polymorphic binding that must be
-- instantiated.
data Binding = Binding (Maybe ([VName], StructType)) StaticVal
  deriving (Int -> Binding -> ShowS
[Binding] -> ShowS
Binding -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Binding] -> ShowS
$cshowList :: [Binding] -> ShowS
show :: Binding -> [Char]
$cshow :: Binding -> [Char]
showsPrec :: Int -> Binding -> ShowS
$cshowsPrec :: Int -> Binding -> ShowS
Show)

bindingSV :: Binding -> StaticVal
bindingSV :: Binding -> StaticVal
bindingSV (Binding Maybe ([VName], StructType)
_ StaticVal
sv) = StaticVal
sv

-- | Environment mapping variable names to their associated static
-- value.
type Env = M.Map VName Binding

localEnv :: Env -> DefM a -> DefM a
localEnv :: forall a. Env -> DefM a -> DefM a
localEnv Env
env = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Env
env <>)

-- Even when using a "new" environment (for evaluating closures) we
-- still ram the global environment of DynamicFuns in there.
localNewEnv :: Env -> DefM a -> DefM a
localNewEnv :: forall a. Env -> DefM a -> DefM a
localNewEnv Env
env = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \(Set VName
globals, Env
old_env) ->
  (Set VName
globals, forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\VName
k Binding
_ -> VName
k forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals) Env
old_env forall a. Semigroup a => a -> a -> a
<> Env
env)

askEnv :: DefM Env
askEnv :: DefM Env
askEnv = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a, b) -> b
snd

areGlobal :: [VName] -> DefM a -> DefM a
areGlobal :: forall a. [VName] -> DefM a -> DefM a
areGlobal [VName]
vs = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a. Ord a => [a] -> Set a
S.fromList [VName]
vs <>)

replaceTypeSizes ::
  M.Map VName SizeSubst ->
  TypeBase Size als ->
  TypeBase Size als
replaceTypeSizes :: forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Size -> Size
onDim
  where
    onDim :: Size -> Size
onDim (NamedSize QualName VName
v) =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> QualName VName -> Size
NamedSize QualName VName
v'
        Just (SubstConst Int64
d) -> Int64 -> Size
ConstSize Int64
d
        Maybe SizeSubst
Nothing -> QualName VName -> Size
NamedSize QualName VName
v
    onDim Size
d = Size
d

replaceStaticValSizes ::
  S.Set VName ->
  M.Map VName SizeSubst ->
  StaticVal ->
  StaticVal
replaceStaticValSizes :: Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv =
  case StaticVal
sv of
    StaticVal
_ | forall k a. Map k a -> Bool
M.null Map VName SizeSubst
orig_substs -> StaticVal
sv
    LambdaSV Pat
param (RetType [VName]
t_dims StructType
t) Exp
e Env
closure_env ->
      let substs :: Map VName SizeSubst
substs =
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Map k a
M.delete) Map VName SizeSubst
orig_substs forall a b. (a -> b) -> a -> b
$
              forall a. Ord a => [a] -> Set a
S.fromList (forall k a. Map k a -> [k]
M.keys Env
closure_env)
       in Pat -> StructRetType -> Exp -> Env -> StaticVal
LambdaSV
            (forall x. ASTMappable x => Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs Pat
param)
            (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
t_dims (forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs StructType
t))
            (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e)
            (forall {k}.
Ord k =>
Map VName SizeSubst -> Map k Binding -> Map k Binding
onEnv Map VName SizeSubst
orig_substs Env
closure_env) -- intentional
    Dynamic PatType
t ->
      PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
orig_substs PatType
t
    RecordSV [(Name, StaticVal)]
fs ->
      [(Name, StaticVal)] -> StaticVal
RecordSV forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs)) [(Name, StaticVal)]
fs
    SumSV Name
c [StaticVal]
svs [(Name, [PatType])]
ts ->
      Name -> [StaticVal] -> [(Name, [PatType])] -> StaticVal
SumSV Name
c (forall a b. (a -> b) -> [a] -> [b]
map (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs) [StaticVal]
svs) forall a b. (a -> b) -> a -> b
$
        forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a -> b) -> a -> b
$ forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
orig_substs) [(Name, [PatType])]
ts
    DynamicFun (Exp
e, StaticVal
sv1) StaticVal
sv2 ->
      (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
orig_substs Exp
e, Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv1) forall a b. (a -> b) -> a -> b
$
        Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
orig_substs StaticVal
sv2
    StaticVal
IntrinsicSV ->
      StaticVal
IntrinsicSV
    HoleSV PatType
t SrcLoc
loc ->
      PatType -> SrcLoc -> StaticVal
HoleSV PatType
t SrcLoc
loc
  where
    tv :: Map VName SizeSubst -> ASTMapper m
tv Map VName SizeSubst
substs =
      forall (m :: * -> *). Monad m => ASTMapper m
identityMapper
        { mapOnPatType :: PatType -> m PatType
mapOnPatType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs,
          mapOnStructType :: StructType -> m StructType
mapOnStructType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs,
          mapOnExp :: Exp -> m Exp
mapOnExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs,
          mapOnName :: VName -> m VName
mapOnName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs
        }

    onName :: Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs VName
v =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> forall vn. QualName vn -> vn
qualLeaf QualName VName
v'
        Maybe SizeSubst
_ -> VName
v

    onExp :: Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs (Var QualName VName
v Info PatType
t SrcLoc
loc) =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') ->
          forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
v' Info PatType
t SrcLoc
loc
        Just (SubstConst Int64
d) ->
          forall (f :: * -> *) vn. PrimValue -> SrcLoc -> ExpBase f vn
Literal (IntValue -> PrimValue
SignedValue (Int64 -> IntValue
Int64Value (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
d))) SrcLoc
loc
        Maybe SizeSubst
Nothing ->
          forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
v (forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Info PatType
t) SrcLoc
loc
    onExp Map VName SizeSubst
substs (AppExp (Coerce Exp
e TypeExp Info VName
te SrcLoc
loc) (Info (AppRes PatType
t [VName]
ext))) =
      forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn -> TypeExp f vn -> SrcLoc -> AppExpBase f vn
Coerce (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e) TypeExp Info VName
te' SrcLoc
loc) (forall a. a -> Info a
Info (PatType -> [VName] -> AppRes
AppRes (forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs PatType
t) [VName]
ext))
      where
        te' :: TypeExp Info VName
te' = Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
te
    onExp Map VName SizeSubst
substs (Lambda [Pat]
params Exp
e Maybe (TypeExp Info VName)
ret (Info (Set Alias
als, RetType [VName]
t_dims StructType
t)) SrcLoc
loc) =
      forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f (Set Alias, StructRetType)
-> SrcLoc
-> ExpBase f vn
Lambda
        (forall a b. (a -> b) -> [a] -> [b]
map (forall x. ASTMappable x => Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs) [Pat]
params)
        (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e)
        Maybe (TypeExp Info VName)
ret
        (forall a. a -> Info a
Info (Set Alias
als, forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
t_dims (forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs StructType
t)))
        SrcLoc
loc
    onExp Map VName SizeSubst
substs Exp
e = forall x. ASTMappable x => Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs Exp
e

    onTypeExpDim :: Map VName SizeSubst -> SizeExp Info VName -> SizeExp Info VName
onTypeExpDim Map VName SizeSubst
substs (SizeExp Exp
e SrcLoc
loc) = forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> SizeExp f vn
SizeExp (Map VName SizeSubst -> Exp -> Exp
onExp Map VName SizeSubst
substs Exp
e) SrcLoc
loc
    onTypeExpDim Map VName SizeSubst
_ (SizeExpAny SrcLoc
loc) = forall (f :: * -> *) vn. SrcLoc -> SizeExp f vn
SizeExpAny SrcLoc
loc

    onTypeArgExp :: Map VName SizeSubst
-> TypeArgExp Info VName -> TypeArgExp Info VName
onTypeArgExp Map VName SizeSubst
substs (TypeArgExpSize SizeExp Info VName
d) =
      forall (f :: * -> *) vn. SizeExp f vn -> TypeArgExp f vn
TypeArgExpSize (Map VName SizeSubst -> SizeExp Info VName -> SizeExp Info VName
onTypeExpDim Map VName SizeSubst
substs SizeExp Info VName
d)
    onTypeArgExp Map VName SizeSubst
substs (TypeArgExpType TypeExp Info VName
te) =
      forall (f :: * -> *) vn. TypeExp f vn -> TypeArgExp f vn
TypeArgExpType (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
te)

    onTypeExp :: Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs (TEArray SizeExp Info VName
d TypeExp Info VName
te SrcLoc
loc) =
      forall (f :: * -> *) vn.
SizeExp f vn -> TypeExp f vn -> SrcLoc -> TypeExp f vn
TEArray (Map VName SizeSubst -> SizeExp Info VName -> SizeExp Info VName
onTypeExpDim Map VName SizeSubst
substs SizeExp Info VName
d) (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
te) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEUnique TypeExp Info VName
t SrcLoc
loc) =
      forall (f :: * -> *) vn. TypeExp f vn -> SrcLoc -> TypeExp f vn
TEUnique (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
t) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEApply TypeExp Info VName
t1 TypeArgExp Info VName
t2 SrcLoc
loc) =
      forall (f :: * -> *) vn.
TypeExp f vn -> TypeArgExp f vn -> SrcLoc -> TypeExp f vn
TEApply (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
t1) (Map VName SizeSubst
-> TypeArgExp Info VName -> TypeArgExp Info VName
onTypeArgExp Map VName SizeSubst
substs TypeArgExp Info VName
t2) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEArrow Maybe VName
p TypeExp Info VName
t1 TypeExp Info VName
t2 SrcLoc
loc) =
      forall (f :: * -> *) vn.
Maybe vn -> TypeExp f vn -> TypeExp f vn -> SrcLoc -> TypeExp f vn
TEArrow Maybe VName
p (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
t1) (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
t2) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TETuple [TypeExp Info VName]
ts SrcLoc
loc) =
      forall (f :: * -> *) vn. [TypeExp f vn] -> SrcLoc -> TypeExp f vn
TETuple (forall a b. (a -> b) -> [a] -> [b]
map (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs) [TypeExp Info VName]
ts) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TERecord [(Name, TypeExp Info VName)]
ts SrcLoc
loc) =
      forall (f :: * -> *) vn.
[(Name, TypeExp f vn)] -> SrcLoc -> TypeExp f vn
TERecord (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs) [(Name, TypeExp Info VName)]
ts) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TESum [(Name, [TypeExp Info VName])]
ts SrcLoc
loc) =
      forall (f :: * -> *) vn.
[(Name, [TypeExp f vn])] -> SrcLoc -> TypeExp f vn
TESum (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a -> b) -> a -> b
$ Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs) [(Name, [TypeExp Info VName])]
ts) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEDim [VName]
dims TypeExp Info VName
t SrcLoc
loc) =
      forall (f :: * -> *) vn.
[vn] -> TypeExp f vn -> SrcLoc -> TypeExp f vn
TEDim [VName]
dims (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
t) SrcLoc
loc
    onTypeExp Map VName SizeSubst
substs (TEParens TypeExp Info VName
te SrcLoc
loc) =
      forall (f :: * -> *) vn. TypeExp f vn -> SrcLoc -> TypeExp f vn
TEParens (Map VName SizeSubst -> TypeExp Info VName -> TypeExp Info VName
onTypeExp Map VName SizeSubst
substs TypeExp Info VName
te) SrcLoc
loc
    onTypeExp Map VName SizeSubst
_ (TEVar QualName VName
v SrcLoc
loc) =
      forall (f :: * -> *) vn. QualName vn -> SrcLoc -> TypeExp f vn
TEVar QualName VName
v SrcLoc
loc

    onEnv :: Map VName SizeSubst -> Map k Binding -> Map k Binding
onEnv Map VName SizeSubst
substs =
      forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Map VName SizeSubst -> Binding -> Binding
onBinding Map VName SizeSubst
substs))
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList

    onBinding :: Map VName SizeSubst -> Binding -> Binding
onBinding Map VName SizeSubst
substs (Binding Maybe ([VName], StructType)
t StaticVal
bsv) =
      Maybe ([VName], StructType) -> StaticVal -> Binding
Binding
        (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
substs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ([VName], StructType)
t)
        (Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
substs StaticVal
bsv)

    onAST :: ASTMappable x => M.Map VName SizeSubst -> x -> x
    onAST :: forall x. ASTMappable x => Map VName SizeSubst -> x -> x
onAST Map VName SizeSubst
substs = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap (forall {m :: * -> *}. Monad m => Map VName SizeSubst -> ASTMapper m
tv Map VName SizeSubst
substs)

-- | Returns the defunctionalization environment restricted
-- to the given set of variable names and types.
restrictEnvTo :: FV -> DefM Env
restrictEnvTo :: FV -> DefM Env
restrictEnvTo (FV Map VName StructType
m) = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Env
restrict
  where
    restrict :: (Set VName, Env) -> Env
restrict (Set VName
globals, Env
env) = forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
M.mapMaybeWithKey VName -> Binding -> Maybe Binding
keep Env
env
      where
        keep :: VName -> Binding -> Maybe Binding
keep VName
k (Binding Maybe ([VName], StructType)
t StaticVal
sv) = do
          forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName
k forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals
          Uniqueness
u <- forall shape as. TypeBase shape as -> Uniqueness
uniqueness forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k Map VName StructType
m
          forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
t forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv
    restrict' :: Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
Nonunique (Dynamic PatType
t) =
      PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ PatType
t forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
    restrict' Uniqueness
_ (Dynamic PatType
t) =
      PatType -> StaticVal
Dynamic PatType
t
    restrict' Uniqueness
u (LambdaSV Pat
pat StructRetType
t Exp
e Env
env) =
      Pat -> StructRetType -> Exp -> Env -> StaticVal
LambdaSV Pat
pat StructRetType
t Exp
e forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Uniqueness -> Binding -> Binding
restrict'' Uniqueness
u) Env
env
    restrict' Uniqueness
u (RecordSV [(Name, StaticVal)]
fields) =
      [(Name, StaticVal)] -> StaticVal
RecordSV forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u) [(Name, StaticVal)]
fields
    restrict' Uniqueness
u (SumSV Name
c [StaticVal]
svs [(Name, [PatType])]
fields) =
      Name -> [StaticVal] -> [(Name, [PatType])] -> StaticVal
SumSV Name
c (forall a b. (a -> b) -> [a] -> [b]
map (Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u) [StaticVal]
svs) [(Name, [PatType])]
fields
    restrict' Uniqueness
u (DynamicFun (Exp
e, StaticVal
sv1) StaticVal
sv2) =
      (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp
e, Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv1) forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv2
    restrict' Uniqueness
_ StaticVal
IntrinsicSV = StaticVal
IntrinsicSV
    restrict' Uniqueness
_ (HoleSV PatType
t SrcLoc
loc) = PatType -> SrcLoc -> StaticVal
HoleSV PatType
t SrcLoc
loc
    restrict'' :: Uniqueness -> Binding -> Binding
restrict'' Uniqueness
u (Binding Maybe ([VName], StructType)
t StaticVal
sv) = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding Maybe ([VName], StructType)
t forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv

-- | Defunctionalization monad.  The Reader environment tracks both
-- the current Env as well as the set of globally defined dynamic
-- functions.  This is used to avoid unnecessarily large closure
-- environments.
newtype DefM a
  = DefM (ReaderT (S.Set VName, Env) (State ([ValBind], VNameSource)) a)
  deriving
    ( forall a b. a -> DefM b -> DefM a
forall a b. (a -> b) -> DefM a -> DefM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DefM b -> DefM a
$c<$ :: forall a b. a -> DefM b -> DefM a
fmap :: forall a b. (a -> b) -> DefM a -> DefM b
$cfmap :: forall a b. (a -> b) -> DefM a -> DefM b
Functor,
      Functor DefM
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM (a -> b) -> DefM a -> DefM b
forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. DefM a -> DefM b -> DefM a
$c<* :: forall a b. DefM a -> DefM b -> DefM a
*> :: forall a b. DefM a -> DefM b -> DefM b
$c*> :: forall a b. DefM a -> DefM b -> DefM b
liftA2 :: forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
$cliftA2 :: forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
<*> :: forall a b. DefM (a -> b) -> DefM a -> DefM b
$c<*> :: forall a b. DefM (a -> b) -> DefM a -> DefM b
pure :: forall a. a -> DefM a
$cpure :: forall a. a -> DefM a
Applicative,
      Applicative DefM
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM a -> (a -> DefM b) -> DefM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DefM a
$creturn :: forall a. a -> DefM a
>> :: forall a b. DefM a -> DefM b -> DefM b
$c>> :: forall a b. DefM a -> DefM b -> DefM b
>>= :: forall a b. DefM a -> (a -> DefM b) -> DefM b
$c>>= :: forall a b. DefM a -> (a -> DefM b) -> DefM b
Monad,
      MonadReader (S.Set VName, Env),
      MonadState ([ValBind], VNameSource)
    )

instance MonadFreshNames DefM where
  putNameSource :: VNameSource -> DefM ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \([ValBind]
x, VNameSource
_) -> ([ValBind]
x, VNameSource
src)
  getNameSource :: DefM VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a, b) -> b
snd

-- | Run a computation in the defunctionalization monad. Returns the result of
-- the computation, a new name source, and a list of lifted function declations.
runDefM :: VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM :: forall a. VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM VNameSource
src (DefM ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
m) =
  let (a
x, ([ValBind]
vbs, VNameSource
src')) = forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Set VName, Env) (State ([ValBind], VNameSource)) a
m forall a. Monoid a => a
mempty) (forall a. Monoid a => a
mempty, VNameSource
src)
   in (a
x, VNameSource
src', forall a. [a] -> [a]
reverse [ValBind]
vbs)

addValBind :: ValBind -> DefM ()
addValBind :: ValBind -> DefM ()
addValBind ValBind
vb = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (ValBind
vb :)

-- | Looks up the associated static value for a given name in the environment.
lookupVar :: StructType -> VName -> DefM StaticVal
lookupVar :: StructType -> VName -> DefM StaticVal
lookupVar StructType
t VName
x = do
  Env
env <- DefM Env
askEnv
  case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Env
env of
    Just (Binding (Just ([VName]
dims, StructType
sv_t)) StaticVal
sv) -> do
      Set VName
globals <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a, b) -> a
fst
      forall (m :: * -> *).
MonadFreshNames m =>
Set VName
-> [VName] -> StructType -> StructType -> StaticVal -> m StaticVal
instStaticVal Set VName
globals [VName]
dims StructType
t StructType
sv_t StaticVal
sv
    Just (Binding Maybe ([VName], StructType)
Nothing StaticVal
sv) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure StaticVal
sv
    Maybe Binding
Nothing -- If the variable is unknown, it may refer to the 'intrinsics'
    -- module, which we will have to treat specially.
      | VName -> Int
baseTag VName
x forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag -> forall (f :: * -> *) a. Applicative f => a -> f a
pure StaticVal
IntrinsicSV
      | Bool
otherwise ->
          -- Anything not in scope is going to be an existential size.
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. PrimType -> ScalarTypeBase dim as
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

-- Like freeInPat, but ignores sizes that are only found in
-- funtion types.
arraySizes :: StructType -> S.Set VName
arraySizes :: StructType -> Set VName
arraySizes (Scalar Arrow {}) = forall a. Monoid a => a
mempty
arraySizes (Scalar (Record Map Name StructType
fields)) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes Map Name StructType
fields
arraySizes (Scalar (Sum Map Name [StructType]
cs)) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes) Map Name [StructType]
cs
arraySizes (Scalar (TypeVar ()
_ Uniqueness
_ QualName VName
_ [TypeArg Size]
targs)) =
  forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map TypeArg Size -> Set VName
f [TypeArg Size]
targs
  where
    f :: TypeArg Size -> Set VName
f (TypeArgDim (NamedSize QualName VName
d) SrcLoc
_) = forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall vn. QualName vn -> vn
qualLeaf QualName VName
d
    f TypeArgDim {} = forall a. Monoid a => a
mempty
    f (TypeArgType StructType
t SrcLoc
_) = StructType -> Set VName
arraySizes StructType
t
arraySizes (Scalar Prim {}) = forall a. Monoid a => a
mempty
arraySizes (Array ()
_ Uniqueness
_ Shape Size
shape ScalarTypeBase Size ()
t) =
  StructType -> Set VName
arraySizes (forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar ScalarTypeBase Size ()
t) forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Size -> Set VName
dimName (forall dim. Shape dim -> [dim]
shapeDims Shape Size
shape)
  where
    dimName :: Size -> S.Set VName
    dimName :: Size -> Set VName
dimName (NamedSize QualName VName
qn) = forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ forall vn. QualName vn -> vn
qualLeaf QualName VName
qn
    dimName Size
_ = forall a. Monoid a => a
mempty

patternArraySizes :: Pat -> S.Set VName
patternArraySizes :: Pat -> Set VName
patternArraySizes = StructType -> Set VName
arraySizes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat -> StructType
patternStructType

data SizeSubst
  = SubstNamed (QualName VName)
  | SubstConst Int64
  deriving (SizeSubst -> SizeSubst -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SizeSubst -> SizeSubst -> Bool
$c/= :: SizeSubst -> SizeSubst -> Bool
== :: SizeSubst -> SizeSubst -> Bool
$c== :: SizeSubst -> SizeSubst -> Bool
Eq, Eq SizeSubst
SizeSubst -> SizeSubst -> Bool
SizeSubst -> SizeSubst -> Ordering
SizeSubst -> SizeSubst -> SizeSubst
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SizeSubst -> SizeSubst -> SizeSubst
$cmin :: SizeSubst -> SizeSubst -> SizeSubst
max :: SizeSubst -> SizeSubst -> SizeSubst
$cmax :: SizeSubst -> SizeSubst -> SizeSubst
>= :: SizeSubst -> SizeSubst -> Bool
$c>= :: SizeSubst -> SizeSubst -> Bool
> :: SizeSubst -> SizeSubst -> Bool
$c> :: SizeSubst -> SizeSubst -> Bool
<= :: SizeSubst -> SizeSubst -> Bool
$c<= :: SizeSubst -> SizeSubst -> Bool
< :: SizeSubst -> SizeSubst -> Bool
$c< :: SizeSubst -> SizeSubst -> Bool
compare :: SizeSubst -> SizeSubst -> Ordering
$ccompare :: SizeSubst -> SizeSubst -> Ordering
Ord, Int -> SizeSubst -> ShowS
[SizeSubst] -> ShowS
SizeSubst -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [SizeSubst] -> ShowS
$cshowList :: [SizeSubst] -> ShowS
show :: SizeSubst -> [Char]
$cshow :: SizeSubst -> [Char]
showsPrec :: Int -> SizeSubst -> ShowS
$cshowsPrec :: Int -> SizeSubst -> ShowS
Show)

dimMapping ::
  Monoid a =>
  TypeBase Size a ->
  TypeBase Size a ->
  M.Map VName SizeSubst
dimMapping :: forall a.
Monoid a =>
TypeBase Size a -> TypeBase Size a -> Map VName SizeSubst
dimMapping TypeBase Size a
t1 TypeBase Size a
t2 = forall s a. State s a -> s -> s
execState (forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
([VName] -> d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims forall {t :: * -> *} {f :: * -> *}.
(Foldable t, MonadState (Map VName SizeSubst) f) =>
t VName -> Size -> Size -> f Size
f TypeBase Size a
t1 TypeBase Size a
t2) forall a. Monoid a => a
mempty
  where
    f :: t VName -> Size -> Size -> f Size
f t VName
bound Size
d1 (NamedSize QualName VName
d2)
      | forall vn. QualName vn -> vn
qualLeaf QualName VName
d2 forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` t VName
bound = forall (f :: * -> *) a. Applicative f => a -> f a
pure Size
d1
    f t VName
_ (NamedSize QualName VName
d1) (NamedSize QualName VName
d2) = do
      forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall vn. QualName vn -> vn
qualLeaf QualName VName
d1) forall a b. (a -> b) -> a -> b
$ QualName VName -> SizeSubst
SubstNamed QualName VName
d2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> Size
NamedSize QualName VName
d1
    f t VName
_ (NamedSize QualName VName
d1) (ConstSize Int64
d2) = do
      forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall vn. QualName vn -> vn
qualLeaf QualName VName
d1) forall a b. (a -> b) -> a -> b
$ Int64 -> SizeSubst
SubstConst Int64
d2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> Size
NamedSize QualName VName
d1
    f t VName
_ Size
d Size
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure Size
d

dimMapping' ::
  Monoid a =>
  TypeBase Size a ->
  TypeBase Size a ->
  M.Map VName VName
dimMapping' :: forall a.
Monoid a =>
TypeBase Size a -> TypeBase Size a -> Map VName VName
dimMapping' TypeBase Size a
t1 TypeBase Size a
t2 = forall a b k. (a -> Maybe b) -> Map k a -> Map k b
M.mapMaybe SizeSubst -> Maybe VName
f forall a b. (a -> b) -> a -> b
$ forall a.
Monoid a =>
TypeBase Size a -> TypeBase Size a -> Map VName SizeSubst
dimMapping TypeBase Size a
t1 TypeBase Size a
t2
  where
    f :: SizeSubst -> Maybe VName
f (SubstNamed QualName VName
d) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall vn. QualName vn -> vn
qualLeaf QualName VName
d
    f SizeSubst
_ = forall a. Maybe a
Nothing

sizesToRename :: StaticVal -> S.Set VName
sizesToRename :: StaticVal -> Set VName
sizesToRename (DynamicFun (Exp
_, StaticVal
sv1) StaticVal
sv2) =
  StaticVal -> Set VName
sizesToRename StaticVal
sv1 forall a. Semigroup a => a -> a -> a
<> StaticVal -> Set VName
sizesToRename StaticVal
sv2
sizesToRename StaticVal
IntrinsicSV =
  forall a. Monoid a => a
mempty
sizesToRename HoleSV {} =
  forall a. Monoid a => a
mempty
sizesToRename Dynamic {} =
  forall a. Monoid a => a
mempty
sizesToRename (RecordSV [(Name, StaticVal)]
fs) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StaticVal -> Set VName
sizesToRename forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, StaticVal)]
fs
sizesToRename (SumSV Name
_ [StaticVal]
svs [(Name, [PatType])]
_) =
  forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StaticVal -> Set VName
sizesToRename [StaticVal]
svs
sizesToRename (LambdaSV Pat
param StructRetType
_ Exp
_ Env
_) =
  Pat -> Set VName
freeInPat Pat
param
    forall a. Semigroup a => a -> a -> a
<> forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map forall (f :: * -> *) vn. IdentBase f vn -> vn
identName (forall a. (a -> Bool) -> Set a -> Set a
S.filter forall {vn}. IdentBase Info vn -> Bool
couldBeSize forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents Pat
param)
  where
    couldBeSize :: IdentBase Info vn -> Bool
couldBeSize IdentBase Info vn
ident =
      forall a. Info a -> a
unInfo (forall (f :: * -> *) vn. IdentBase f vn -> f PatType
identType IdentBase Info vn
ident) forall a. Eq a => a -> a -> Bool
== forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (forall dim as. PrimType -> ScalarTypeBase dim as
Prim (IntType -> PrimType
Signed IntType
Int64))

-- When we instantiate a polymorphic StaticVal, we rename all the
-- sizes to avoid name conflicts later on.  This is a bit of a hack...
instStaticVal ::
  MonadFreshNames m =>
  S.Set VName ->
  [VName] ->
  StructType ->
  StructType ->
  StaticVal ->
  m StaticVal
instStaticVal :: forall (m :: * -> *).
MonadFreshNames m =>
Set VName
-> [VName] -> StructType -> StructType -> StaticVal -> m StaticVal
instStaticVal Set VName
globals [VName]
dims StructType
t StructType
sv_t StaticVal
sv = do
  Map VName SizeSubst
fresh_substs <-
    forall {f :: * -> *}.
MonadFreshNames f =>
[VName] -> f (Map VName SizeSubst)
mkSubsts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
globals) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$
      forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims forall a. Semigroup a => a -> a -> a
<> StaticVal -> Set VName
sizesToRename StaticVal
sv

  let dims' :: [VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map (Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
fresh_substs) [VName]
dims
      isDim :: VName -> p -> Bool
isDim VName
k p
_ = VName
k forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
dims'
      dim_substs :: Map VName SizeSubst
dim_substs =
        forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey forall {p}. VName -> p -> Bool
isDim forall a b. (a -> b) -> a -> b
$ forall a.
Monoid a =>
TypeBase Size a -> TypeBase Size a -> Map VName SizeSubst
dimMapping (forall als.
Map VName SizeSubst -> TypeBase Size als -> TypeBase Size als
replaceTypeSizes Map VName SizeSubst
fresh_substs StructType
sv_t) StructType
t
      replace :: SizeSubst -> SizeSubst
replace (SubstNamed QualName VName
k) = forall a. a -> Maybe a -> a
fromMaybe (QualName VName -> SizeSubst
SubstNamed QualName VName
k) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall vn. QualName vn -> vn
qualLeaf QualName VName
k) Map VName SizeSubst
dim_substs
      replace SizeSubst
k = SizeSubst
k
      substs :: Map VName SizeSubst
substs = forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeSubst -> SizeSubst
replace Map VName SizeSubst
fresh_substs forall a. Semigroup a => a -> a -> a
<> Map VName SizeSubst
dim_substs

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Set VName -> Map VName SizeSubst -> StaticVal -> StaticVal
replaceStaticValSizes Set VName
globals Map VName SizeSubst
substs StaticVal
sv
  where
    mkSubsts :: [VName] -> f (Map VName SizeSubst)
mkSubsts [VName]
names =
      forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (QualName VName -> SizeSubst
SubstNamed forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. v -> QualName v
qualName)
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName [VName]
names

    onName :: Map VName SizeSubst -> VName -> VName
onName Map VName SizeSubst
substs VName
v =
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SizeSubst
substs of
        Just (SubstNamed QualName VName
v') -> forall vn. QualName vn -> vn
qualLeaf QualName VName
v'
        Maybe SizeSubst
_ -> VName
v

defuncFun ::
  [VName] ->
  [Pat] ->
  Exp ->
  StructRetType ->
  SrcLoc ->
  DefM (Exp, StaticVal)
defuncFun :: [VName]
-> [Pat] -> Exp -> StructRetType -> SrcLoc -> DefM (Exp, StaticVal)
defuncFun [VName]
tparams [Pat]
pats Exp
e0 StructRetType
ret SrcLoc
loc = do
  -- Extract the first parameter of the lambda and "push" the
  -- remaining ones (if there are any) into the body of the lambda.
  let (Pat
pat, StructRetType
ret', Exp
e0') = case [Pat]
pats of
        [] -> forall a. HasCallStack => [Char] -> a
error [Char]
"Received a lambda with no parameters."
        [Pat
pat'] -> (Pat
pat', StructRetType
ret, Exp
e0)
        (Pat
pat' : [Pat]
pats') ->
          ( Pat
pat',
            forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] forall a b. (a -> b) -> a -> b
$ [Pat] -> StructRetType -> StructType
funType [Pat]
pats' StructRetType
ret,
            forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f (Set Alias, StructRetType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats' Exp
e0 forall a. Maybe a
Nothing (forall a. a -> Info a
Info (forall a. Monoid a => a
mempty, StructRetType
ret)) SrcLoc
loc
          )

  -- Construct a record literal that closes over the environment of
  -- the lambda.  Closed-over 'DynamicFun's are converted to their
  -- closure representation.
  let used :: FV
used =
        Exp -> FV
freeInExp (forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f (Set Alias, StructRetType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
e0 forall a. Maybe a
Nothing (forall a. a -> Info a
Info (forall a. Monoid a => a
mempty, StructRetType
ret)) SrcLoc
loc)
          FV -> Set VName -> FV
`freeWithout` forall a. Ord a => [a] -> Set a
S.fromList [VName]
tparams
  Env
used_env <- FV -> DefM Env
restrictEnvTo FV
used

  -- The closure parts that are sizes are proactively turned into size
  -- parameters.
  let sizes_of_arrays :: Set VName
sizes_of_arrays =
        forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StructType -> Set VName
arraySizes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct forall b c a. (b -> c) -> (a -> b) -> a -> c
. StaticVal -> PatType
typeFromSV forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> StaticVal
bindingSV) Env
used_env
          forall a. Semigroup a => a -> a -> a
<> Pat -> Set VName
patternArraySizes Pat
pat
      notSize :: VName -> Bool
notSize = Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes_of_arrays)
      ([FieldBase Info VName]
fields, Env
env) =
        forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [(a, b)] -> ([a], [b])
unzip
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (VName, Binding) -> (FieldBase Info VName, (VName, Binding))
closureFromDynamicFun
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Bool
notSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
          forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Env
used_env

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fields SrcLoc
loc,
      Pat -> StructRetType -> Exp -> Env -> StaticVal
LambdaSV Pat
pat StructRetType
ret' Exp
e0' Env
env
    )
  where
    closureFromDynamicFun :: (VName, Binding) -> (FieldBase Info VName, (VName, Binding))
closureFromDynamicFun (VName
vn, Binding Maybe ([VName], StructType)
_ (DynamicFun (Exp
clsr_env, StaticVal
sv) StaticVal
_)) =
      let name :: Name
name = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString VName
vn
       in ( forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
name Exp
clsr_env forall a. Monoid a => a
mempty,
            (VName
vn, Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing StaticVal
sv)
          )
    closureFromDynamicFun (VName
vn, Binding Maybe ([VName], StructType)
_ StaticVal
sv) =
      let name :: Name
name = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString VName
vn
          tp' :: PatType
tp' = StaticVal -> PatType
typeFromSV StaticVal
sv
       in ( forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit
              Name
name
              (forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
vn) (forall a. a -> Info a
Info PatType
tp') forall a. Monoid a => a
mempty)
              forall a. Monoid a => a
mempty,
            (VName
vn, Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing StaticVal
sv)
          )

-- | Defunctionalization of an expression. Returns the residual expression and
-- the associated static value in the defunctionalization monad.
defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp e :: Exp
e@Literal {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp e :: Exp
e@IntLit {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp e :: Exp
e@FloatLit {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp e :: Exp
e@StringLit {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e)
defuncExp (Parens Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens Exp
e' SrcLoc
loc, StaticVal
sv)
defuncExp (QualParens (QualName VName, SrcLoc)
qn Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
(QualName vn, SrcLoc) -> ExpBase f vn -> SrcLoc -> ExpBase f vn
QualParens (QualName VName, SrcLoc)
qn Exp
e' SrcLoc
loc, StaticVal
sv)
defuncExp (TupLit [Exp]
es SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit [Exp]
es' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
tupleFieldNames [StaticVal]
svs)
defuncExp (RecordLit [FieldBase Info VName]
fs SrcLoc
loc) = do
  ([FieldBase Info VName]
fs', [(Name, StaticVal)]
names_svs) <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField [FieldBase Info VName]
fs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fs' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV [(Name, StaticVal)]
names_svs)
  where
    defuncField :: FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField (RecordFieldExplicit Name
vn Exp
e SrcLoc
loc') = do
      (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn Exp
e' SrcLoc
loc', (Name
vn, StaticVal
sv))
    defuncField (RecordFieldImplicit VName
vn (Info PatType
t) SrcLoc
loc') = do
      StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t) VName
vn
      case StaticVal
sv of
        -- If the implicit field refers to a dynamic function, we
        -- convert it to an explicit field with a record closing over
        -- the environment and bind the corresponding static value.
        DynamicFun (Exp
e, StaticVal
sv') StaticVal
_ ->
          let vn' :: Name
vn' = VName -> Name
baseName VName
vn
           in forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn' Exp
e SrcLoc
loc',
                  (Name
vn', StaticVal
sv')
                )
        -- The field may refer to a functional expression, so we get the
        -- type from the static value and not the one from the AST.
        StaticVal
_ ->
          let tp :: Info PatType
tp = forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv
           in forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
vn -> f PatType -> SrcLoc -> FieldBase f vn
RecordFieldImplicit VName
vn Info PatType
tp SrcLoc
loc', (VName -> Name
baseName VName
vn, StaticVal
sv))
defuncExp (ArrayLit [Exp]
es t :: Info PatType
t@(Info PatType
t') SrcLoc
loc) = do
  [Exp]
es' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' [Exp]
es
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
[ExpBase f vn] -> f PatType -> SrcLoc -> ExpBase f vn
ArrayLit [Exp]
es' Info PatType
t SrcLoc
loc, PatType -> StaticVal
Dynamic PatType
t')
defuncExp (AppExp (Range Exp
e1 Maybe Exp
me Inclusiveness Exp
incl SrcLoc
loc) Info AppRes
res) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  Maybe Exp
me' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' Maybe Exp
me
  Inclusiveness Exp
incl' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' Inclusiveness Exp
incl
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn
-> Maybe (ExpBase f vn)
-> Inclusiveness (ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Range Exp
e1' Maybe Exp
me' Inclusiveness Exp
incl' SrcLoc
loc) Info AppRes
res,
      PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ AppRes -> PatType
appResType forall a b. (a -> b) -> a -> b
$ forall a. Info a -> a
unInfo Info AppRes
res
    )
defuncExp e :: Exp
e@(Var QualName VName
qn (Info PatType
t) SrcLoc
loc) = do
  StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t) (forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
  case StaticVal
sv of
    -- If the variable refers to a dynamic function, we return its closure
    -- representation (i.e., a record expression capturing the free variables
    -- and a 'LambdaSV' static value) instead of the variable itself.
    DynamicFun (Exp, StaticVal)
closure StaticVal
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp, StaticVal)
closure
    -- Intrinsic functions used as variables are eta-expanded, so we
    -- can get rid of them.
    StaticVal
IntrinsicSV -> do
      ([Pat]
pats, Exp
body, StructRetType
tp) <- PatRetType -> Exp -> DefM ([Pat], Exp, StructRetType)
etaExpand (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (Exp -> PatType
typeOf Exp
e)) Exp
e
      Exp -> DefM (Exp, StaticVal)
defuncExp forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f (Set Alias, StructRetType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
body forall a. Maybe a
Nothing (forall a. a -> Info a
Info (forall a. Monoid a => a
mempty, StructRetType
tp)) forall a. Monoid a => a
mempty
    HoleSV PatType
_ SrcLoc
hole_loc ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. f PatType -> SrcLoc -> ExpBase f vn
Hole (forall a. a -> Info a
Info PatType
t) SrcLoc
hole_loc, StaticVal
sv)
    StaticVal
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (forall a. a -> Info a
Info (StaticVal -> PatType
typeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)
defuncExp (Hole (Info PatType
t) SrcLoc
loc) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. f PatType -> SrcLoc -> ExpBase f vn
Hole (forall a. a -> Info a
Info PatType
t) SrcLoc
loc, PatType -> SrcLoc -> StaticVal
HoleSV PatType
t SrcLoc
loc)
defuncExp (Ascript Exp
e0 TypeExp Info VName
tydecl SrcLoc
loc)
  | forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> PatType
typeOf Exp
e0) = do
      (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
ExpBase f vn -> TypeExp f vn -> SrcLoc -> ExpBase f vn
Ascript Exp
e0' TypeExp Info VName
tydecl SrcLoc
loc, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
defuncExp (AppExp (Coerce Exp
e0 TypeExp Info VName
tydecl SrcLoc
loc) Info AppRes
res)
  | forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> PatType
typeOf Exp
e0) = do
      (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn -> TypeExp f vn -> SrcLoc -> AppExpBase f vn
Coerce Exp
e0' TypeExp Info VName
tydecl SrcLoc
loc) Info AppRes
res, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
defuncExp (AppExp (LetPat [SizeBinder VName]
sizes Pat
pat Exp
e1 Exp
e2 SrcLoc
loc) (Info (AppRes PatType
t [VName]
retext))) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env :: Env
env = Pat -> StaticVal -> Env
alwaysMatchPatSV Pat
pat StaticVal
sv1
      pat' :: Pat
pat' = Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv1
  (Exp
e2', StaticVal
sv2) <- forall a. Env -> DefM a -> DefM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  -- To maintain any sizes going out of scope, we need to compute the
  -- old size substitution induced by retext and also apply it to the
  -- newly computed body type.
  let mapping :: Map VName VName
mapping = forall a.
Monoid a =>
TypeBase Size a -> TypeBase Size a -> Map VName VName
dimMapping' (Exp -> PatType
typeOf Exp
e2) PatType
t
      subst :: VName -> VName
subst VName
v = forall a. a -> Maybe a -> a
fromMaybe VName
v forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
mapping
      mapper :: ASTMapper Identity
mapper = forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnName :: VName -> Identity VName
mapOnName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> VName
subst}
      t' :: PatType
t' = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper Identity
mapper) forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
e2'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
[SizeBinder vn]
-> PatBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetPat [SizeBinder VName]
sizes Pat
pat' Exp
e1' Exp
e2' SrcLoc
loc) (forall a. a -> Info a
Info (PatType -> [VName] -> AppRes
AppRes PatType
t' [VName]
retext)), StaticVal
sv2)
defuncExp (AppExp (LetFun VName
vn ([TypeParamBase VName], [Pat], Maybe (TypeExp Info VName),
 Info StructRetType, Exp)
_ Exp
_ SrcLoc
_) Info AppRes
_) =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"defuncExp: Unexpected LetFun: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show VName
vn
defuncExp (AppExp (If Exp
e1 Exp
e2 Exp
e3 SrcLoc
loc) Info AppRes
res) = do
  (Exp
e1', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp
e3', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn -> ExpBase f vn -> SrcLoc -> AppExpBase f vn
If Exp
e1' Exp
e2' Exp
e3' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
defuncExp (AppExp (Apply Exp
f NonEmpty (Info (Diet, Maybe VName), Exp)
args SrcLoc
loc) (Info AppRes
appres)) =
  Exp
-> NonEmpty ((Diet, Maybe VName), Exp)
-> AppRes
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncApply Exp
f (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a. Info a -> a
unInfo) NonEmpty (Info (Diet, Maybe VName), Exp)
args) AppRes
appres SrcLoc
loc
defuncExp (Negate Exp
e0 SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Negate Exp
e0' SrcLoc
loc, StaticVal
sv)
defuncExp (Not Exp
e0 SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Not Exp
e0' SrcLoc
loc, StaticVal
sv)
defuncExp (Lambda [Pat]
pats Exp
e0 Maybe (TypeExp Info VName)
_ (Info (Set Alias
_, StructRetType
ret)) SrcLoc
loc) =
  [VName]
-> [Pat] -> Exp -> StructRetType -> SrcLoc -> DefM (Exp, StaticVal)
defuncFun [] [Pat]
pats Exp
e0 StructRetType
ret SrcLoc
loc
-- Operator sections are expected to be converted to lambda-expressions
-- by the monomorphizer, so they should no longer occur at this point.
defuncExp OpSection {} = forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected operator section."
defuncExp OpSectionLeft {} = forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected operator section."
defuncExp OpSectionRight {} = forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected operator section."
defuncExp ProjectSection {} = forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected projection section."
defuncExp IndexSection {} = forall a. HasCallStack => [Char] -> a
error [Char]
"defuncExp: unexpected projection section."
defuncExp (AppExp (DoLoop [VName]
sparams Pat
pat Exp
e1 LoopFormBase Info VName
form Exp
e3 SrcLoc
loc) Info AppRes
res) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env1 :: Env
env1 = Pat -> StaticVal -> Env
alwaysMatchPatSV Pat
pat StaticVal
sv1
  (LoopFormBase Info VName
form', Env
env2) <- case LoopFormBase Info VName
form of
    For IdentBase Info VName
v Exp
e2 -> do
      Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
IdentBase f vn -> ExpBase f vn -> LoopFormBase f vn
For IdentBase Info VName
v Exp
e2', forall {k}. IdentBase Info k -> Map k Binding
envFromIdent IdentBase Info VName
v)
    ForIn Pat
pat2 Exp
e2 -> do
      Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
PatBase f vn -> ExpBase f vn -> LoopFormBase f vn
ForIn Pat
pat2 Exp
e2', Pat -> Env
envFromPat Pat
pat2)
    While Exp
e2 -> do
      Exp
e2' <- forall a. Env -> DefM a -> DefM a
localEnv Env
env1 forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
e2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While Exp
e2', forall a. Monoid a => a
mempty)
  (Exp
e3', StaticVal
sv) <- forall a. Env -> DefM a -> DefM a
localEnv (Env
env1 forall a. Semigroup a => a -> a -> a
<> Env
env2) forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
[VName]
-> PatBase f vn
-> ExpBase f vn
-> LoopFormBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
DoLoop [VName]
sparams Pat
pat Exp
e1' LoopFormBase Info VName
form' Exp
e3' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
  where
    envFromIdent :: IdentBase Info k -> Map k Binding
envFromIdent (Ident k
vn (Info PatType
tp) SrcLoc
_) =
      forall k a. k -> a -> Map k a
M.singleton k
vn forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic PatType
tp
defuncExp e :: Exp
e@(AppExp BinOp {} Info AppRes
_) =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"defuncExp: unexpected binary operator: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Exp
e
defuncExp (Project Name
vn Exp
e0 tp :: Info PatType
tp@(Info PatType
tp') SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv0) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  case StaticVal
sv0 of
    RecordSV [(Name, StaticVal)]
svs -> case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
vn [(Name, StaticVal)]
svs of
      Just StaticVal
sv -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) SrcLoc
loc, StaticVal
sv)
      Maybe StaticVal
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid record projection."
    Dynamic PatType
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' Info PatType
tp SrcLoc
loc, PatType -> StaticVal
Dynamic PatType
tp')
    HoleSV PatType
_ SrcLoc
hloc -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' Info PatType
tp SrcLoc
loc, PatType -> SrcLoc -> StaticVal
HoleSV PatType
tp' SrcLoc
hloc)
    StaticVal
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Projection of an expression with static value " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show StaticVal
sv0
defuncExp (AppExp (LetWith IdentBase Info VName
id1 IdentBase Info VName
id2 SliceBase Info VName
idxs Exp
e1 Exp
body SrcLoc
loc) Info AppRes
res) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  SliceBase Info VName
idxs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  let id1_binding :: Binding
id1_binding = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ forall a. Info a -> a
unInfo forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. IdentBase f vn -> f PatType
identType IdentBase Info VName
id1
  (Exp
body', StaticVal
sv) <-
    forall a. Env -> DefM a -> DefM a
localEnv (forall k a. k -> a -> Map k a
M.singleton (forall (f :: * -> *) vn. IdentBase f vn -> vn
identName IdentBase Info VName
id1) Binding
id1_binding) forall a b. (a -> b) -> a -> b
$
      Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
IdentBase f vn
-> IdentBase f vn
-> SliceBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> SrcLoc
-> AppExpBase f vn
LetWith IdentBase Info VName
id1 IdentBase Info VName
id2 SliceBase Info VName
idxs' Exp
e1' Exp
body' SrcLoc
loc) Info AppRes
res, StaticVal
sv)
defuncExp expr :: Exp
expr@(AppExp (Index Exp
e0 SliceBase Info VName
idxs SrcLoc
loc) Info AppRes
res) = do
  Exp
e0' <- Exp -> DefM Exp
defuncExp' Exp
e0
  SliceBase Info VName
idxs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn -> SliceBase f vn -> SrcLoc -> AppExpBase f vn
Index Exp
e0' SliceBase Info VName
idxs' SrcLoc
loc) Info AppRes
res,
      PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
expr
    )
defuncExp (Update Exp
e1 SliceBase Info VName
idxs Exp
e2 SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  SliceBase Info VName
idxs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex SliceBase Info VName
idxs
  Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
ExpBase f vn
-> SliceBase f vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Update Exp
e1' SliceBase Info VName
idxs' Exp
e2' SrcLoc
loc, StaticVal
sv)

-- Note that we might change the type of the record field here.  This
-- is not permitted in the type checker due to problems with type
-- inference, but it actually works fine.
defuncExp (RecordUpdate Exp
e1 [Name]
fs Exp
e2 Info PatType
_ SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv2) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  let sv :: StaticVal
sv = StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv1 StaticVal
sv2 [Name]
fs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall (f :: * -> *) vn.
ExpBase f vn
-> [Name] -> ExpBase f vn -> f PatType -> SrcLoc -> ExpBase f vn
RecordUpdate Exp
e1' [Name]
fs Exp
e2' (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv1) SrcLoc
loc,
      StaticVal
sv
    )
  where
    staticField :: StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (RecordSV [(Name, StaticVal)]
svs) StaticVal
sv2 (Name
f : [Name]
fs') =
      case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
f [(Name, StaticVal)]
svs of
        Just StaticVal
sv ->
          [(Name, StaticVal)] -> StaticVal
RecordSV forall a b. (a -> b) -> a -> b
$
            (Name
f, StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv StaticVal
sv2 [Name]
fs') forall a. a -> [a] -> [a]
: forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= Name
f) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Name, StaticVal)]
svs
        Maybe StaticVal
Nothing -> forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid record projection."
    staticField (Dynamic t :: PatType
t@(Scalar Record {})) StaticVal
sv2 fs' :: [Name]
fs'@(Name
_ : [Name]
_) =
      StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (PatType -> StaticVal
svFromType PatType
t) StaticVal
sv2 [Name]
fs'
    staticField StaticVal
_ StaticVal
sv2 [Name]
_ = StaticVal
sv2
defuncExp (Assert Exp
e1 Exp
e2 Info Text
desc SrcLoc
loc) = do
  (Exp
e1', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
ExpBase f vn -> ExpBase f vn -> f Text -> SrcLoc -> ExpBase f vn
Assert Exp
e1' Exp
e2' Info Text
desc SrcLoc
loc, StaticVal
sv)
defuncExp (Constr Name
name [Exp]
es (Info sum_t :: PatType
sum_t@(Scalar (Sum Map Name [PatType]
all_fs))) SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  let sv :: StaticVal
sv =
        Name -> [StaticVal] -> [(Name, [PatType])] -> StaticVal
SumSV Name
name [StaticVal]
svs forall a b. (a -> b) -> a -> b
$
          forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$
            Name
name forall k a. Ord k => k -> Map k a -> Map k a
`M.delete` forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall a b. (a -> b) -> [a] -> [b]
map forall als. Monoid als => TypeBase Size als -> TypeBase Size als
defuncType) Map Name [PatType]
all_fs
      sum_t' :: PatType
sum_t' = forall as.
Monoid as =>
TypeBase Size as -> TypeBase Size as -> TypeBase Size as
combineTypeShapes PatType
sum_t (StaticVal -> PatType
typeFromSV StaticVal
sv)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
Name -> [ExpBase f vn] -> f PatType -> SrcLoc -> ExpBase f vn
Constr Name
name [Exp]
es' (forall a. a -> Info a
Info PatType
sum_t') SrcLoc
loc, StaticVal
sv)
  where
    defuncType ::
      Monoid als =>
      TypeBase Size als ->
      TypeBase Size als
    defuncType :: forall als. Monoid als => TypeBase Size als -> TypeBase Size als
defuncType (Array als
as Uniqueness
u Shape Size
shape ScalarTypeBase Size ()
t) = forall dim as.
as
-> Uniqueness
-> Shape dim
-> ScalarTypeBase dim ()
-> TypeBase dim as
Array als
as Uniqueness
u Shape Size
shape (forall als.
Monoid als =>
ScalarTypeBase Size als -> ScalarTypeBase Size als
defuncScalar ScalarTypeBase Size ()
t)
    defuncType (Scalar ScalarTypeBase Size als
t) = forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall als.
Monoid als =>
ScalarTypeBase Size als -> ScalarTypeBase Size als
defuncScalar ScalarTypeBase Size als
t

    defuncScalar ::
      Monoid als =>
      ScalarTypeBase Size als ->
      ScalarTypeBase Size als
    defuncScalar :: forall als.
Monoid als =>
ScalarTypeBase Size als -> ScalarTypeBase Size als
defuncScalar (Record Map Name (TypeBase Size als)
fs) = forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall als. Monoid als => TypeBase Size als -> TypeBase Size als
defuncType Map Name (TypeBase Size als)
fs
    defuncScalar Arrow {} = forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record forall a. Monoid a => a
mempty
    defuncScalar (Sum Map Name [TypeBase Size als]
fs) = forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall a b. (a -> b) -> [a] -> [b]
map forall als. Monoid als => TypeBase Size als -> TypeBase Size als
defuncType) Map Name [TypeBase Size als]
fs
    defuncScalar (Prim PrimType
t) = forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
t
    defuncScalar (TypeVar als
as Uniqueness
u QualName VName
tn [TypeArg Size]
targs) = forall dim as.
as
-> Uniqueness
-> QualName VName
-> [TypeArg dim]
-> ScalarTypeBase dim as
TypeVar als
as Uniqueness
u QualName VName
tn [TypeArg Size]
targs
defuncExp (Constr Name
name [Exp]
_ (Info PatType
t) SrcLoc
loc) =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
"Constructor "
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Name
name
      forall a. [a] -> [a] -> [a]
++ [Char]
" given type "
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PatType
t
      forall a. [a] -> [a] -> [a]
++ [Char]
" at "
      forall a. [a] -> [a] -> [a]
++ forall a. Located a => a -> [Char]
locStr SrcLoc
loc
defuncExp (AppExp (Match Exp
e NonEmpty (CaseBase Info VName)
cs SrcLoc
loc) Info AppRes
res) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  let bad :: a
bad = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"No case matches StaticVal\n" forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show StaticVal
sv
  NonEmpty (CaseBase Info VName, StaticVal)
csPairs <-
    forall a. a -> Maybe a -> a
fromMaybe forall {a}. a
bad forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> Maybe (NonEmpty a)
NE.nonEmpty forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (StaticVal
-> CaseBase Info VName
-> DefM (Maybe (CaseBase Info VName, StaticVal))
defuncCase StaticVal
sv) (forall a. NonEmpty a -> [a]
NE.toList NonEmpty (CaseBase Info VName)
cs)
  let cs' :: NonEmpty (CaseBase Info VName)
cs' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst NonEmpty (CaseBase Info VName, StaticVal)
csPairs
      sv' :: StaticVal
sv' = forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. NonEmpty a -> a
NE.head NonEmpty (CaseBase Info VName, StaticVal)
csPairs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn) -> SrcLoc -> AppExpBase f vn
Match Exp
e' NonEmpty (CaseBase Info VName)
cs' SrcLoc
loc) Info AppRes
res, StaticVal
sv')
defuncExp (Attr AttrInfo VName
info Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
AttrInfo vn -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Attr AttrInfo VName
info Exp
e' SrcLoc
loc, StaticVal
sv)

-- | Same as 'defuncExp', except it ignores the static value.
defuncExp' :: Exp -> DefM Exp
defuncExp' :: Exp -> DefM Exp
defuncExp' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> DefM (Exp, StaticVal)
defuncExp

defuncCase :: StaticVal -> Case -> DefM (Maybe (Case, StaticVal))
defuncCase :: StaticVal
-> CaseBase Info VName
-> DefM (Maybe (CaseBase Info VName, StaticVal))
defuncCase StaticVal
sv (CasePat Pat
p Exp
e SrcLoc
loc) = do
  let p' :: Pat
p' = Pat -> StaticVal -> Pat
updatePat Pat
p StaticVal
sv
  case Pat -> StaticVal -> Maybe Env
matchPatSV Pat
p StaticVal
sv of
    Just Env
env -> do
      (Exp
e', StaticVal
sv') <- forall a. Env -> DefM a -> DefM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (forall (f :: * -> *) vn.
PatBase f vn -> ExpBase f vn -> SrcLoc -> CaseBase f vn
CasePat Pat
p' Exp
e' SrcLoc
loc, StaticVal
sv')
    Maybe Env
Nothing ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

-- | Defunctionalize the function argument to a SOAC by eta-expanding if
-- necessary and then defunctionalizing the body of the introduced lambda.
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp e :: Exp
e@OpSection {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp e :: Exp
e@OpSectionLeft {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp e :: Exp
e@OpSectionRight {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp e :: Exp
e@ProjectSection {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
defuncSoacExp (Parens Exp
e SrcLoc
loc) =
  forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM Exp
defuncSoacExp Exp
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc
defuncSoacExp (Lambda [Pat]
params Exp
e0 Maybe (TypeExp Info VName)
decl Info (Set Alias, StructRetType)
tp SrcLoc
loc) = do
  let env :: Env
env = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
params
  Exp
e0' <- forall a. Env -> DefM a -> DefM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncSoacExp Exp
e0
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f (Set Alias, StructRetType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
params Exp
e0' Maybe (TypeExp Info VName)
decl Info (Set Alias, StructRetType)
tp SrcLoc
loc
defuncSoacExp Exp
e
  | Scalar Arrow {} <- Exp -> PatType
typeOf Exp
e = do
      ([Pat]
pats, Exp
body, StructRetType
tp) <- PatRetType -> Exp -> DefM ([Pat], Exp, StructRetType)
etaExpand (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (Exp -> PatType
typeOf Exp
e)) Exp
e
      let env :: Env
env = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
pats
      Exp
body' <- forall a. Env -> DefM a -> DefM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
body
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f (Set Alias, StructRetType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
body' forall a. Maybe a
Nothing (forall a. a -> Info a
Info (forall a. Monoid a => a
mempty, StructRetType
tp)) forall a. Monoid a => a
mempty
  | Bool
otherwise = Exp -> DefM Exp
defuncExp' Exp
e

etaExpand :: PatRetType -> Exp -> DefM ([Pat], Exp, StructRetType)
etaExpand :: PatRetType -> Exp -> DefM ([Pat], Exp, StructRetType)
etaExpand PatRetType
e_t Exp
e = do
  let ([(PName, (Diet, StructType))]
ps, PatRetType
ret) = forall {dim} {as}.
RetTypeBase dim as
-> ([(PName, (Diet, TypeBase dim ()))], RetTypeBase dim as)
getType PatRetType
e_t
  -- Some careful hackery to avoid duplicate names.
  ([VName]
_, ([Pat]
pats, [Exp]
vars)) <- forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM forall {m :: * -> *} {as}.
MonadFreshNames m =>
[VName]
-> (PName, (Diet, TypeBase Size as)) -> m ([VName], (Pat, Exp))
f [] [(PName, (Diet, StructType))]
ps
  -- Important that we synthesize new existential names and substitute
  -- them into the (body) return type.
  [VName]
ext' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName forall a b. (a -> b) -> a -> b
$ forall dim as. RetTypeBase dim as -> [VName]
retDims PatRetType
ret
  let extsubst :: Map VName (Subst t)
extsubst =
        forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip (forall dim as. RetTypeBase dim as -> [VName]
retDims PatRetType
ret) forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map (forall t. Size -> Subst t
SizeSubst forall b c a. (b -> c) -> (a -> b) -> a -> c
. QualName VName -> Size
NamedSize forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. v -> QualName v
qualName) [VName]
ext'
      ret' :: PatRetType
ret' = forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` forall {t}. Map VName (Subst t)
extsubst) PatRetType
ret
      e' :: Exp
e' =
        forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply
          Exp
e
          (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (forall a b. (a -> b) -> [a] -> [b]
map (forall shape as. TypeBase shape as -> Diet
diet forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(PName, (Diet, StructType))]
ps) (forall a. a -> [a]
repeat forall a. Maybe a
Nothing) [Exp]
vars)
          (PatType -> [VName] -> AppRes
AppRes (forall dim as. RetTypeBase dim as -> TypeBase dim as
retType PatRetType
ret') [VName]
ext')
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Pat]
pats, Exp
e', forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall a b. a -> b -> a
const ()) PatRetType
ret)
  where
    getType :: RetTypeBase dim as
-> ([(PName, (Diet, TypeBase dim ()))], RetTypeBase dim as)
getType (RetType [VName]
_ (Scalar (Arrow as
_ PName
p Diet
d TypeBase dim ()
t1 RetTypeBase dim as
t2))) =
      let ([(PName, (Diet, TypeBase dim ()))]
ps, RetTypeBase dim as
r) = RetTypeBase dim as
-> ([(PName, (Diet, TypeBase dim ()))], RetTypeBase dim as)
getType RetTypeBase dim as
t2
       in ((PName
p, (Diet
d, TypeBase dim ()
t1)) forall a. a -> [a] -> [a]
: [(PName, (Diet, TypeBase dim ()))]
ps, RetTypeBase dim as
r)
    getType RetTypeBase dim as
t = ([], RetTypeBase dim as
t)

    f :: [VName]
-> (PName, (Diet, TypeBase Size as)) -> m ([VName], (Pat, Exp))
f [VName]
prev (PName
p, (Diet
d, TypeBase Size as
t)) = do
      let t' :: PatType
t' =
            forall dim as. TypeBase dim as -> TypeBase dim (Set Alias)
fromStruct TypeBase Size as
t
              forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` case Diet
d of
                Diet
Consume -> Uniqueness
Unique
                Diet
Observe -> Uniqueness
Nonunique
      VName
x <- case PName
p of
        Named VName
x | VName
x forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
prev -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
x
        PName
_ -> forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString [Char]
"x"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( VName
x forall a. a -> [a] -> [a]
: [VName]
prev,
          ( forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
x (forall a. a -> Info a
Info PatType
t') forall a. Monoid a => a
mempty,
            forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
x) (forall a. a -> Info a
Info PatType
t') forall a. Monoid a => a
mempty
          )
        )

-- | Defunctionalize an indexing of a single array dimension.
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex (DimFix Exp
e1) = forall (f :: * -> *) vn. ExpBase f vn -> DimIndexBase f vn
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
defuncDimIndex (DimSlice Maybe Exp
me1 Maybe Exp
me2 Maybe Exp
me3) =
  forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me2 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me3
  where
    defunc' :: Maybe Exp -> DefM (Maybe Exp)
defunc' = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp'

-- | Defunctionalize a let-bound function, while preserving parameters
-- that have order 0 types (i.e., non-functional).
defuncLet ::
  [VName] ->
  [Pat] ->
  Exp ->
  StructRetType ->
  DefM ([VName], [Pat], Exp, StaticVal)
defuncLet :: [VName]
-> [Pat]
-> Exp
-> StructRetType
-> DefM ([VName], [Pat], Exp, StaticVal)
defuncLet [VName]
dims ps :: [Pat]
ps@(Pat
pat : [Pat]
pats) Exp
body (RetType [VName]
ret_dims StructType
rettype)
  | forall vn. PatBase Info vn -> Bool
patternOrderZero Pat
pat = do
      let bound_by_pat :: VName -> Bool
bound_by_pat = (forall a. Ord a => a -> Set a -> Bool
`S.member` Pat -> Set VName
freeInPat Pat
pat)
          -- Take care to not include more size parameters than necessary.
          ([VName]
pat_dims, [VName]
rest_dims) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition VName -> Bool
bound_by_pat [VName]
dims
          env :: Env
env = Pat -> Env
envFromPat Pat
pat forall a. Semigroup a => a -> a -> a
<> [VName] -> Env
envFromDimNames [VName]
pat_dims
      ([VName]
rest_dims', [Pat]
pats', Exp
body', StaticVal
sv) <-
        forall a. Env -> DefM a -> DefM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ [VName]
-> [Pat]
-> Exp
-> StructRetType
-> DefM ([VName], [Pat], Exp, StaticVal)
defuncLet [VName]
rest_dims [Pat]
pats Exp
body forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype
      (Exp, StaticVal)
closure <- [VName]
-> [Pat] -> Exp -> StructRetType -> SrcLoc -> DefM (Exp, StaticVal)
defuncFun [VName]
dims [Pat]
ps Exp
body (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype) forall a. Monoid a => a
mempty
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( [VName]
pat_dims forall a. [a] -> [a] -> [a]
++ [VName]
rest_dims',
          Pat
pat forall a. a -> [a] -> [a]
: [Pat]
pats',
          Exp
body',
          (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
closure StaticVal
sv
        )
  | Bool
otherwise = do
      (Exp
e, StaticVal
sv) <- [VName]
-> [Pat] -> Exp -> StructRetType -> SrcLoc -> DefM (Exp, StaticVal)
defuncFun [VName]
dims [Pat]
ps Exp
body (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype) forall a. Monoid a => a
mempty
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [], Exp
e, StaticVal
sv)
defuncLet [VName]
_ [] Exp
body (RetType [VName]
_ StructType
rettype) = do
  (Exp
body', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([], [], Exp
body', forall {as}. StaticVal -> TypeBase Size as -> StaticVal
imposeType StaticVal
sv StructType
rettype)
  where
    imposeType :: StaticVal -> TypeBase Size as -> StaticVal
imposeType Dynamic {} TypeBase Size as
t =
      PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim (Set Alias)
fromStruct TypeBase Size as
t
    imposeType (RecordSV [(Name, StaticVal)]
fs1) (Scalar (Record Map Name (TypeBase Size as)
fs2)) =
      [(Name, StaticVal)] -> StaticVal
RecordSV forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith StaticVal -> TypeBase Size as -> StaticVal
imposeType (forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, StaticVal)]
fs1) Map Name (TypeBase Size as)
fs2
    imposeType StaticVal
sv TypeBase Size as
_ = StaticVal
sv

sizesForAll :: MonadFreshNames m => S.Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll :: forall (m :: * -> *).
MonadFreshNames m =>
Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
params = do
  ([Pat]
params', Set VName
sizes) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper (StateT (Set VName) m)
tv) [Pat]
params) forall a. Monoid a => a
mempty
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Set a -> [a]
S.toList Set VName
sizes, [Pat]
params')
  where
    bound :: Set VName
bound = Set VName
bound_sizes forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames [Pat]
params
    tv :: ASTMapper (StateT (Set VName) m)
tv = forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnPatType :: PatType -> StateT (Set VName) m PatType
mapOnPatType = forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse forall {t :: (* -> *) -> * -> *} {m :: * -> *}.
(MonadState (Set VName) (t m), MonadTrans t, MonadFreshNames m) =>
Size -> t m Size
onDim forall (f :: * -> *) a. Applicative f => a -> f a
pure}
    onDim :: Size -> t m Size
onDim (AnySize (Just VName
v)) = do
      forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
S.insert VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> Size
NamedSize forall a b. (a -> b) -> a -> b
$ forall v. v -> QualName v
qualName VName
v
    onDim (AnySize Maybe VName
Nothing) = do
      VName
v <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"size"
      forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
S.insert VName
v
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> Size
NamedSize forall a b. (a -> b) -> a -> b
$ forall v. v -> QualName v
qualName VName
v
    onDim (NamedSize QualName VName
d) = do
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall vn. QualName vn -> vn
qualLeaf QualName VName
d forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound) forall a b. (a -> b) -> a -> b
$
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$
          forall a. Ord a => a -> Set a -> Set a
S.insert forall a b. (a -> b) -> a -> b
$
            forall vn. QualName vn -> vn
qualLeaf QualName VName
d
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> Size
NamedSize QualName VName
d
    onDim Size
d = forall (f :: * -> *) a. Applicative f => a -> f a
pure Size
d

unRetType :: StructRetType -> StructType
unRetType :: StructRetType -> StructType
unRetType (RetType [] StructType
t) = StructType
t
unRetType (RetType [VName]
ext StructType
t) = forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Size -> Size
onDim StructType
t
  where
    onDim :: Size -> Size
onDim (NamedSize QualName VName
d) | forall vn. QualName vn -> vn
qualLeaf QualName VName
d forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
ext = Maybe VName -> Size
AnySize forall a. Maybe a
Nothing
    onDim Size
d = Size
d

defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction :: Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction e :: Exp
e@(Var QualName VName
qn (Info PatType
t) SrcLoc
loc) Int
num_args = do
  let ([(Diet, StructType)]
argtypes, StructType
_) = forall dim as.
TypeBase dim as -> ([(Diet, TypeBase dim ())], TypeBase dim ())
unfoldFunType PatType
t
  StaticVal
sv <- StructType -> VName -> DefM StaticVal
lookupVar (forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
t) (forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)

  case StaticVal
sv of
    DynamicFun (Exp, StaticVal)
_ StaticVal
_
      | StaticVal -> Int -> Bool
fullyApplied StaticVal
sv Int
num_args -> do
          -- We still need to update the types in case the dynamic
          -- function returns a higher-order term.
          let ([(Diet, PatType)]
argtypes', PatType
rettype) = StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType)
dynamicFunType StaticVal
sv [(Diet, StructType)]
argtypes
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (forall a. a -> Info a
Info (forall as dim pas.
Monoid as =>
[(Diet, TypeBase dim pas)] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [(Diet, PatType)]
argtypes' forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
rettype)) SrcLoc
loc, StaticVal
sv)
      | Bool
otherwise -> do
          VName
fname <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ [Char]
"dyn_" forall a. Semigroup a => a -> a -> a
<> VName -> [Char]
baseString (forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
          let ([Pat]
pats, Exp
e0, StaticVal
sv') = [Char] -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun (forall a. Pretty a => a -> [Char]
prettyString QualName VName
qn) StaticVal
sv Int
num_args
              ([(Diet, PatType)]
argtypes', PatType
rettype) = StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType)
dynamicFunType StaticVal
sv' [(Diet, StructType)]
argtypes
              dims' :: [VName]
dims' = forall a. Monoid a => a
mempty

          -- Ensure that no parameter sizes are AnySize.  The internaliser
          -- expects this.  This is easy, because they are all
          -- first-order.
          Set VName
globals <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a, b) -> a
fst
          let bound_sizes :: Set VName
bound_sizes = forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims' forall a. Semigroup a => a -> a -> a
<> Set VName
globals
          ([VName]
missing_dims, [Pat]
pats') <- forall (m :: * -> *).
MonadFreshNames m =>
Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
pats

          VName -> StructRetType -> [VName] -> [Pat] -> Exp -> DefM ()
liftValDec VName
fname (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
rettype) ([VName]
dims' forall a. [a] -> [a] -> [a]
++ [VName]
missing_dims) [Pat]
pats' Exp
e0
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var
                (forall v. v -> QualName v
qualName VName
fname)
                (forall a. a -> Info a
Info (forall as dim pas.
Monoid as =>
[(Diet, TypeBase dim pas)] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [(Diet, PatType)]
argtypes' forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim (Set Alias)
fromStruct PatType
rettype))
                SrcLoc
loc,
              StaticVal
sv'
            )
    StaticVal
IntrinsicSV -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e, StaticVal
IntrinsicSV)
    StaticVal
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (forall a. a -> Info a
Info (StaticVal -> PatType
typeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)
defuncApplyFunction Exp
e Int
_ = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e

-- Embed some information about the original function
-- into the name of the lifted function, to make the
-- result slightly more human-readable.
liftedName :: Int -> Exp -> String
liftedName :: Int -> Exp -> [Char]
liftedName Int
i (Var QualName VName
f Info PatType
_ SrcLoc
_) =
  [Char]
"defunc_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
i forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ VName -> [Char]
baseString (forall vn. QualName vn -> vn
qualLeaf QualName VName
f)
liftedName Int
i (AppExp (Apply Exp
f NonEmpty (Info (Diet, Maybe VName), Exp)
_ SrcLoc
_) Info AppRes
_) =
  Int -> Exp -> [Char]
liftedName (Int
i forall a. Num a => a -> a -> a
+ Int
1) Exp
f
liftedName Int
_ Exp
_ = [Char]
"defunc"

defuncApplyArg ::
  String ->
  (Exp, StaticVal) ->
  (((Diet, Maybe VName), Exp), [(Diet, StructType)]) ->
  DefM (Exp, StaticVal)
defuncApplyArg :: [Char]
-> (Exp, StaticVal)
-> (((Diet, Maybe VName), Exp), [(Diet, StructType)])
-> DefM (Exp, StaticVal)
defuncApplyArg [Char]
fname_s (Exp
f', f_sv :: StaticVal
f_sv@(LambdaSV Pat
pat StructRetType
lam_e_t Exp
lam_e Env
closure_env)) (((Diet
d, Maybe VName
argext), Exp
arg), [(Diet, StructType)]
_) = do
  (Exp
arg', StaticVal
arg_sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
arg
  let env' :: Env
env' = Pat -> StaticVal -> Env
alwaysMatchPatSV Pat
pat StaticVal
arg_sv
      dims :: [VName]
dims = forall a. Monoid a => a
mempty
  (Exp
lam_e', StaticVal
sv) <-
    forall a. Env -> DefM a -> DefM a
localNewEnv (Env
env' forall a. Semigroup a => a -> a -> a
<> Env
closure_env) forall a b. (a -> b) -> a -> b
$
      Exp -> DefM (Exp, StaticVal)
defuncExp Exp
lam_e

  let closure_pat :: Pat
closure_pat = [VName] -> Env -> Pat
buildEnvPat [VName]
dims Env
closure_env
      pat' :: Pat
pat' = Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
arg_sv

  Set VName
globals <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a, b) -> a
fst

  -- Lift lambda to top-level function definition.  We put in
  -- a lot of effort to try to infer the uniqueness attributes
  -- of the lifted function, but this is ultimately all a sham
  -- and a hack.  There is some piece we're missing.
  let params :: [Pat]
params = [Pat
closure_pat, Pat
pat']
      params_for_rettype :: [Pat]
params_for_rettype = [Pat]
params forall a. [a] -> [a] -> [a]
++ StaticVal -> [Pat]
svParams StaticVal
f_sv forall a. [a] -> [a] -> [a]
++ StaticVal -> [Pat]
svParams StaticVal
arg_sv
      svParams :: StaticVal -> [Pat]
svParams (LambdaSV Pat
sv_pat StructRetType
_ Exp
_ Env
_) = [Pat
sv_pat]
      svParams StaticVal
_ = []
      lifted_rettype :: PatType
lifted_rettype = Env -> [Pat] -> StructType -> PatType -> PatType
buildRetType Env
closure_env [Pat]
params_for_rettype (StructRetType -> StructType
unRetType StructRetType
lam_e_t) forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
lam_e'

      already_bound :: Set VName
already_bound =
        Set VName
globals
          forall a. Semigroup a => a -> a -> a
<> forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims
          forall a. Semigroup a => a -> a -> a
<> forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map forall (f :: * -> *) vn. IdentBase f vn -> vn
identName (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
params)

      more_dims :: [VName]
more_dims =
        forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$
          forall a. (a -> Bool) -> Set a -> Set a
S.filter (forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
already_bound) forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set VName
patternArraySizes [Pat]
params

  -- Ensure that no parameter sizes are AnySize.  The internaliser
  -- expects this.  This is easy, because they are all
  -- first-order.
  let bound_sizes :: Set VName
bound_sizes = forall a. Ord a => [a] -> Set a
S.fromList ([VName]
dims forall a. Semigroup a => a -> a -> a
<> [VName]
more_dims) forall a. Semigroup a => a -> a -> a
<> Set VName
globals
  ([VName]
missing_dims, [Pat]
params') <- forall (m :: * -> *).
MonadFreshNames m =>
Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
params

  VName
fname <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newNameFromString [Char]
fname_s
  VName -> StructRetType -> [VName] -> [Pat] -> Exp -> DefM ()
liftValDec
    VName
fname
    (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
lifted_rettype)
    ([VName]
dims forall a. [a] -> [a] -> [a]
++ [VName]
more_dims forall a. [a] -> [a] -> [a]
++ [VName]
missing_dims)
    [Pat]
params'
    Exp
lam_e'

  let f_t :: StructType
f_t = forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
f'
      arg_t :: StructType
arg_t = forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
arg'
      d1 :: Diet
d1 = Diet
Observe
      fname_t :: PatType
fname_t = forall as dim pas.
Monoid as =>
[(Diet, TypeBase dim pas)] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [(Diet
d1, StructType
f_t), (Diet
d, StructType
arg_t)] forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
lifted_rettype
      fname' :: Exp
fname' = forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
fname) (forall a. a -> Info a
Info PatType
fname_t) (forall a. Located a => a -> SrcLoc
srclocOf Exp
arg)
      callret :: AppRes
callret = PatType -> [VName] -> AppRes
AppRes PatType
lifted_rettype []

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
fname' [(Diet
Observe, forall a. Maybe a
Nothing, Exp
f'), (Diet
Observe, Maybe VName
argext, Exp
arg')] AppRes
callret,
      StaticVal
sv
    )
-- If 'f' is a dynamic function, we just leave the application in
-- place, but we update the types since it may be partially
-- applied or return a higher-order value.
defuncApplyArg [Char]
_ (Exp
f', DynamicFun (Exp, StaticVal)
_ StaticVal
sv) (((Diet
d, Maybe VName
argext), Exp
arg), [(Diet, StructType)]
argtypes) = do
  (Exp
arg', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
arg
  let ([(Diet, PatType)]
argtypes', PatType
rettype) = StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType)
dynamicFunType StaticVal
sv [(Diet, StructType)]
argtypes
      restype :: PatType
restype = forall as dim pas.
Monoid as =>
[(Diet, TypeBase dim pas)] -> RetTypeBase dim as -> TypeBase dim as
foldFunType [(Diet, PatType)]
argtypes' (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] PatType
rettype)
      callret :: AppRes
callret = PatType -> [VName] -> AppRes
AppRes PatType
restype []
      apply_e :: Exp
apply_e = forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
f' [(Diet
d, Maybe VName
argext, Exp
arg')] AppRes
callret
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
apply_e, StaticVal
sv)
--
defuncApplyArg [Char]
fname_s (Exp
_, StaticVal
sv) (((Diet, Maybe VName), Exp), [(Diet, StructType)])
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
"defuncApplyArg: cannot apply StaticVal\n"
      forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show StaticVal
sv
      forall a. Semigroup a => a -> a -> a
<> [Char]
"\nFunction name: "
      forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> [Char]
prettyString [Char]
fname_s

updateReturn :: AppRes -> Exp -> Exp
updateReturn :: AppRes -> Exp -> Exp
updateReturn (AppRes PatType
ret1 [VName]
ext1) (AppExp AppExpBase Info VName
apply (Info (AppRes PatType
ret2 [VName]
ext2))) =
  forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp AppExpBase Info VName
apply forall a b. (a -> b) -> a -> b
$ forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ PatType -> [VName] -> AppRes
AppRes (forall as.
Monoid as =>
TypeBase Size as -> TypeBase Size as -> TypeBase Size as
combineTypeShapes PatType
ret1 PatType
ret2) ([VName]
ext1 forall a. Semigroup a => a -> a -> a
<> [VName]
ext2)
updateReturn AppRes
_ Exp
e = Exp
e

defuncApply :: Exp -> NE.NonEmpty ((Diet, Maybe VName), Exp) -> AppRes -> SrcLoc -> DefM (Exp, StaticVal)
defuncApply :: Exp
-> NonEmpty ((Diet, Maybe VName), Exp)
-> AppRes
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncApply Exp
f NonEmpty ((Diet, Maybe VName), Exp)
args AppRes
appres SrcLoc
loc = do
  (Exp
f', StaticVal
f_sv) <- Exp -> Int -> DefM (Exp, StaticVal)
defuncApplyFunction Exp
f (forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty ((Diet, Maybe VName), Exp)
args)
  case StaticVal
f_sv of
    StaticVal
IntrinsicSV -> do
      NonEmpty (Info (Diet, Maybe VName), Exp)
args' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a. a -> Info a
Info) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Exp -> DefM Exp
defuncSoacExp) NonEmpty ((Diet, Maybe VName), Exp)
args
      let e' :: Exp
e' = forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (f (Diet, Maybe VName), ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
f' NonEmpty (Info (Diet, Maybe VName), Exp)
args' SrcLoc
loc) (forall a. a -> Info a
Info AppRes
appres)
      Exp -> DefM (Exp, StaticVal)
intrinsicOrHole Exp
e'
    HoleSV {} -> do
      NonEmpty (Info (Diet, Maybe VName), Exp)
args' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a. a -> Info a
Info) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> DefM (Exp, StaticVal)
defuncExp) NonEmpty ((Diet, Maybe VName), Exp)
args
      let e' :: Exp
e' = forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (f (Diet, Maybe VName), ExpBase f vn)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
f' NonEmpty (Info (Diet, Maybe VName), Exp)
args' SrcLoc
loc) (forall a. a -> Info a
Info AppRes
appres)
      Exp -> DefM (Exp, StaticVal)
intrinsicOrHole Exp
e'
    StaticVal
_ -> do
      let fname :: [Char]
fname = Int -> Exp -> [Char]
liftedName Int
0 Exp
f
          ([(Diet, StructType)]
argtypes, StructType
_) = forall dim as.
TypeBase dim as -> ([(Diet, TypeBase dim ())], TypeBase dim ())
unfoldFunType forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
f
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a b. (a -> b) -> a -> b
$ AppRes -> Exp -> Exp
updateReturn AppRes
appres) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Char]
-> (Exp, StaticVal)
-> (((Diet, Maybe VName), Exp), [(Diet, StructType)])
-> DefM (Exp, StaticVal)
defuncApplyArg [Char]
fname) (Exp
f', StaticVal
f_sv) forall a b. (a -> b) -> a -> b
$
          forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip NonEmpty ((Diet, Maybe VName), Exp)
args forall a b. (a -> b) -> a -> b
$
            forall (f :: * -> *) a. Foldable f => f a -> NonEmpty [a]
NE.tails [(Diet, StructType)]
argtypes
  where
    intrinsicOrHole :: Exp -> DefM (Exp, StaticVal)
intrinsicOrHole Exp
e' = do
      -- If the intrinsic is fully applied, then we are done.
      -- Otherwise we need to eta-expand it and recursively
      -- defunctionalise. XXX: might it be better to simply eta-expand
      -- immediately any time we encounter a non-fully-applied
      -- intrinsic?
      if forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall dim as.
TypeBase dim as -> ([(Diet, TypeBase dim ())], TypeBase dim ())
unfoldFunType forall a b. (a -> b) -> a -> b
$ AppRes -> PatType
appResType AppRes
appres
        then forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp
e', PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ AppRes -> PatType
appResType AppRes
appres)
        else do
          ([Pat]
pats, Exp
body, StructRetType
tp) <- PatRetType -> Exp -> DefM ([Pat], Exp, StructRetType)
etaExpand (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] (Exp -> PatType
typeOf Exp
e')) Exp
e'
          Exp -> DefM (Exp, StaticVal)
defuncExp forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn.
[PatBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp f vn)
-> f (Set Alias, StructRetType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pat]
pats Exp
body forall a. Maybe a
Nothing (forall a. a -> Info a
Info (forall a. Monoid a => a
mempty, StructRetType
tp)) forall a. Monoid a => a
mempty

-- | Check if a 'StaticVal' and a given application depth corresponds
-- to a fully applied dynamic function.
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) Int
depth
  | Int
depth forall a. Eq a => a -> a -> Bool
== Int
0 = Bool
False
  | Int
depth forall a. Ord a => a -> a -> Bool
> Int
0 = StaticVal -> Int -> Bool
fullyApplied StaticVal
sv (Int
depth forall a. Num a => a -> a -> a
- Int
1)
fullyApplied StaticVal
_ Int
_ = Bool
True

-- | Converts a dynamic function 'StaticVal' into a list of
-- dimensions, a list of parameters, a function body, and the
-- appropriate static value for applying the function at the given
-- depth of partial application.
liftDynFun :: String -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun :: [Char] -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun [Char]
_ (DynamicFun (Exp
e, StaticVal
sv) StaticVal
_) Int
0 = ([], Exp
e, StaticVal
sv)
liftDynFun [Char]
s (DynamicFun clsr :: (Exp, StaticVal)
clsr@(Exp
_, LambdaSV Pat
pat StructRetType
_ Exp
_ Env
_) StaticVal
sv) Int
d
  | Int
d forall a. Ord a => a -> a -> Bool
> Int
0 =
      let ([Pat]
pats, Exp
e', StaticVal
sv') = [Char] -> StaticVal -> Int -> ([Pat], Exp, StaticVal)
liftDynFun [Char]
s StaticVal
sv (Int
d forall a. Num a => a -> a -> a
- Int
1)
       in (Pat
pat forall a. a -> [a] -> [a]
: [Pat]
pats, Exp
e', (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
clsr StaticVal
sv')
liftDynFun [Char]
s StaticVal
sv Int
d =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
s
      forall a. [a] -> [a] -> [a]
++ [Char]
" Tried to lift a StaticVal "
      forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take Int
100 (forall a. Show a => a -> [Char]
show StaticVal
sv)
      forall a. [a] -> [a] -> [a]
++ [Char]
", but expected a dynamic function.\n"
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Int
d

-- | Converts a pattern to an environment that binds the individual names of the
-- pattern to their corresponding types wrapped in a 'Dynamic' static value.
envFromPat :: Pat -> Env
envFromPat :: Pat -> Env
envFromPat Pat
pat = case Pat
pat of
  TuplePat [Pat]
ps SrcLoc
_ -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
ps
  RecordPat [(Name, Pat)]
fs SrcLoc
_ -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pat -> Env
envFromPat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, Pat)]
fs
  PatParens Pat
p SrcLoc
_ -> Pat -> Env
envFromPat Pat
p
  PatAttr AttrInfo VName
_ Pat
p SrcLoc
_ -> Pat -> Env
envFromPat Pat
p
  Id VName
vn (Info PatType
t) SrcLoc
_ -> forall k a. k -> a -> Map k a
M.singleton VName
vn forall a b. (a -> b) -> a -> b
$ Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic PatType
t
  Wildcard Info PatType
_ SrcLoc
_ -> forall a. Monoid a => a
mempty
  PatAscription Pat
p TypeExp Info VName
_ SrcLoc
_ -> Pat -> Env
envFromPat Pat
p
  PatLit {} -> forall a. Monoid a => a
mempty
  PatConstr Name
_ Info PatType
_ [Pat]
ps SrcLoc
_ -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Env
envFromPat [Pat]
ps

envFromDimNames :: [VName] -> Env
envFromDimNames :: [VName] -> Env
envFromDimNames = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. a -> [a]
repeat Binding
d)
  where
    d :: Binding
d = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. PrimType -> ScalarTypeBase dim as
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

-- | Create a new top-level value declaration with the given function name,
-- return type, list of parameters, and body expression.
liftValDec :: VName -> StructRetType -> [VName] -> [Pat] -> Exp -> DefM ()
liftValDec :: VName -> StructRetType -> [VName] -> [Pat] -> Exp -> DefM ()
liftValDec VName
fname (RetType [VName]
ret_dims StructType
ret) [VName]
dims [Pat]
pats Exp
body = ValBind -> DefM ()
addValBind ValBind
dec
  where
    dims' :: [TypeParamBase VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map (forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` forall a. Monoid a => a
mempty) [VName]
dims
    -- FIXME: this pass is still not correctly size-preserving, so
    -- forget those return sizes that we forgot to propagate along
    -- the way.  Hopefully the internaliser is conservative and
    -- will insert reshapes...
    bound_here :: Set VName
bound_here = forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims forall a. Semigroup a => a -> a -> a
<> forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map forall (f :: * -> *) vn. IdentBase f vn -> vn
identName (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
pats)
    mkExt :: VName -> Maybe VName
mkExt VName
v
      | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ VName
v forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound_here = forall a. a -> Maybe a
Just VName
v
    mkExt VName
_ = forall a. Maybe a
Nothing
    rettype_st :: StructRetType
rettype_st = forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe VName -> Maybe VName
mkExt (forall a. Set a -> [a]
S.toList (forall as. TypeBase Size as -> Set VName
freeInType StructType
ret)) forall a. [a] -> [a] -> [a]
++ [VName]
ret_dims) StructType
ret

    dec :: ValBind
dec =
      ValBind
        { valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = forall a. Maybe a
Nothing,
          valBindName :: VName
valBindName = VName
fname,
          valBindRetDecl :: Maybe (TypeExp Info VName)
valBindRetDecl = forall a. Maybe a
Nothing,
          valBindRetType :: Info StructRetType
valBindRetType = forall a. a -> Info a
Info StructRetType
rettype_st,
          valBindTypeParams :: [TypeParamBase VName]
valBindTypeParams = [TypeParamBase VName]
dims',
          valBindParams :: [Pat]
valBindParams = [Pat]
pats,
          valBindBody :: Exp
valBindBody = Exp
body,
          valBindDoc :: Maybe DocComment
valBindDoc = forall a. Maybe a
Nothing,
          valBindAttrs :: [AttrInfo VName]
valBindAttrs = forall a. Monoid a => a
mempty,
          valBindLocation :: SrcLoc
valBindLocation = forall a. Monoid a => a
mempty
        }

-- | Given a closure environment, construct a record pattern that
-- binds the closed over variables.  Insert wildcard for any patterns
-- that would otherwise clash with size parameters.
buildEnvPat :: [VName] -> Env -> Pat
buildEnvPat :: [VName] -> Env -> Pat
buildEnvPat [VName]
sizes Env
env = forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat (forall a b. (a -> b) -> [a] -> [b]
map (VName, Binding) -> (Name, Pat)
buildField forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Env
env) forall a. Monoid a => a
mempty
  where
    buildField :: (VName, Binding) -> (Name, Pat)
buildField (VName
vn, Binding Maybe ([VName], StructType)
_ StaticVal
sv) =
      ( [Char] -> Name
nameFromString (forall a. Pretty a => a -> [Char]
prettyString VName
vn),
        if VName
vn forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
sizes
          then forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) forall a. Monoid a => a
mempty
          else forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
vn (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) forall a. Monoid a => a
mempty
      )

-- | Given a closure environment pattern and the type of a term,
-- construct the type of that term, where uniqueness is set to
-- `Nonunique` for those arrays that are bound in the environment or
-- pattern (except if they are unique there).  This ensures that a
-- lifted function can create unique arrays as long as they do not
-- alias any of its parameters.  XXX: it is not clear that this is a
-- sufficient property, unfortunately.
buildRetType :: Env -> [Pat] -> StructType -> PatType -> PatType
buildRetType :: Env -> [Pat] -> StructType -> PatType -> PatType
buildRetType Env
env [Pat]
pats = forall {t :: * -> *} {shape} {as}.
(Foldable t, Monoid (t Alias)) =>
TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb
  where
    bound :: Set VName
bound =
      forall a. Ord a => [a] -> Set a
S.fromList (forall k a. Map k a -> [k]
M.keys Env
env) forall a. Semigroup a => a -> a -> a
<> forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map forall (f :: * -> *) vn. IdentBase f vn -> vn
identName (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
pats)
    boundAsUnique :: VName -> Bool
boundAsUnique VName
v =
      forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (forall dim as. TypeBase dim as -> Bool
unique forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Info a -> a
unInfo forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) vn. IdentBase f vn -> f PatType
identType) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) vn. IdentBase f vn -> vn
identName) forall a b. (a -> b) -> a -> b
$
          forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set (IdentBase f vn)
patIdents [Pat]
pats
    problematic :: VName -> Bool
problematic VName
v = (VName
v forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound) Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Bool
boundAsUnique VName
v)
    comb :: TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb (Scalar (Record Map Name (TypeBase shape as)
fs_annot)) (Scalar (Record Map Name (TypeBase shape (t Alias))
fs_got)) =
      forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record forall a b. (a -> b) -> a -> b
$ forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb Map Name (TypeBase shape as)
fs_annot Map Name (TypeBase shape (t Alias))
fs_got
    comb (Scalar (Sum Map Name [TypeBase shape as]
cs_annot)) (Scalar (Sum Map Name [TypeBase shape (t Alias)]
cs_got)) =
      forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum forall a b. (a -> b) -> a -> b
$ forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb) Map Name [TypeBase shape as]
cs_annot Map Name [TypeBase shape (t Alias)]
cs_got
    comb (Scalar Arrow {}) TypeBase shape (t Alias)
t =
      forall {t :: * -> *} {dim}.
(Foldable t, Monoid (t Alias)) =>
TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend TypeBase shape (t Alias)
t
    comb TypeBase shape as
got TypeBase shape (t Alias)
et =
      forall {t :: * -> *} {dim}.
(Foldable t, Monoid (t Alias)) =>
TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim (Set Alias)
fromStruct TypeBase shape as
got forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase shape (t Alias)
et

    descend :: TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend t :: TypeBase dim (t Alias)
t@Array {}
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Bool
problematic forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alias -> VName
aliasVar) (forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim (t Alias)
t) = TypeBase dim (t Alias)
t forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
    descend (Scalar (Record Map Name (TypeBase dim (t Alias))
t)) = forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend Map Name (TypeBase dim (t Alias))
t
    descend TypeBase dim (t Alias)
t = TypeBase dim (t Alias)
t

-- | Compute the corresponding type for the *representation* of a
-- given static value (not the original possibly higher-order value).
typeFromSV :: StaticVal -> PatType
typeFromSV :: StaticVal -> PatType
typeFromSV (Dynamic PatType
tp) =
  PatType
tp
typeFromSV (LambdaSV Pat
_ StructRetType
_ Exp
_ Env
env) =
  forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
    forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ([Char] -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString) (StaticVal -> PatType
typeFromSV forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binding -> StaticVal
bindingSV)) forall a b. (a -> b) -> a -> b
$
      forall k a. Map k a -> [(k, a)]
M.toList Env
env
typeFromSV (RecordSV [(Name, StaticVal)]
ls) =
  let ts :: [(Name, PatType)]
ts = forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StaticVal -> PatType
typeFromSV) [(Name, StaticVal)]
ls
   in forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, PatType)]
ts
typeFromSV (DynamicFun (Exp
_, StaticVal
sv) StaticVal
_) =
  StaticVal -> PatType
typeFromSV StaticVal
sv
typeFromSV (SumSV Name
name [StaticVal]
svs [(Name, [PatType])]
fields) =
  let svs' :: [PatType]
svs' = forall a b. (a -> b) -> [a] -> [b]
map StaticVal -> PatType
typeFromSV [StaticVal]
svs
   in forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name [PatType]
svs' forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, [PatType])]
fields
typeFromSV (HoleSV PatType
t SrcLoc
_) =
  PatType
t
typeFromSV StaticVal
IntrinsicSV =
  forall a. HasCallStack => [Char] -> a
error [Char]
"Tried to get the type from the static value of an intrinsic."

-- | Construct the type for a fully-applied dynamic function from its
-- static value and the original types of its arguments.
dynamicFunType :: StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType)
dynamicFunType :: StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType)
dynamicFunType (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) ((Diet, StructType)
p : [(Diet, StructType)]
ps) =
  let ([(Diet, PatType)]
ps', PatType
ret) = StaticVal -> [(Diet, StructType)] -> ([(Diet, PatType)], PatType)
dynamicFunType StaticVal
sv [(Diet, StructType)]
ps
   in (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall dim as. TypeBase dim as -> TypeBase dim (Set Alias)
fromStruct (Diet, StructType)
p forall a. a -> [a] -> [a]
: [(Diet, PatType)]
ps', PatType
ret)
dynamicFunType StaticVal
sv [(Diet, StructType)]
_ = ([], StaticVal -> PatType
typeFromSV StaticVal
sv)

-- | Match a pattern with its static value. Returns an environment
-- with the identifier components of the pattern mapped to the
-- corresponding subcomponents of the static value.  If this function
-- returns 'Nothing', then it corresponds to an unmatchable case.
-- These should only occur for 'Match' expressions.
matchPatSV :: Pat -> StaticVal -> Maybe Env
matchPatSV :: Pat -> StaticVal -> Maybe Env
matchPatSV (TuplePat [Pat]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls) =
  forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\Pat
p (Name
_, StaticVal
sv) -> Pat -> StaticVal -> Maybe Env
matchPatSV Pat
p StaticVal
sv) [Pat]
ps [(Name, StaticVal)]
ls
matchPatSV (RecordPat [(Name, Pat)]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls)
  | [(Name, Pat)]
ps' <- forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst [(Name, Pat)]
ps,
    [(Name, StaticVal)]
ls' <- forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls,
    forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Name, Pat)]
ps' forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls' =
      forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\(Name
_, Pat
p) (Name
_, StaticVal
sv) -> Pat -> StaticVal -> Maybe Env
matchPatSV Pat
p StaticVal
sv) [(Name, Pat)]
ps' [(Name, StaticVal)]
ls'
matchPatSV (PatParens Pat
pat SrcLoc
_) StaticVal
sv = Pat -> StaticVal -> Maybe Env
matchPatSV Pat
pat StaticVal
sv
matchPatSV (PatAttr AttrInfo VName
_ Pat
pat SrcLoc
_) StaticVal
sv = Pat -> StaticVal -> Maybe Env
matchPatSV Pat
pat StaticVal
sv
matchPatSV (Id VName
vn (Info PatType
t) SrcLoc
_) StaticVal
sv =
  -- When matching a zero-order pattern with a StaticVal, the type of
  -- the pattern wins out.  This is important for propagating sizes
  -- (but probably reveals a flaw in our bookkeeping).
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    if forall dim as. TypeBase dim as -> Bool
orderZero PatType
t
      then Env
dim_env forall a. Semigroup a => a -> a -> a
<> forall k a. k -> a -> Map k a
M.singleton VName
vn (Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic PatType
t)
      else Env
dim_env forall a. Semigroup a => a -> a -> a
<> forall k a. k -> a -> Map k a
M.singleton VName
vn (Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing StaticVal
sv)
  where
    dim_env :: Env
dim_env =
      forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (,Binding
i64) forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList forall a b. (a -> b) -> a -> b
$ forall as. TypeBase Size as -> Set VName
freeInType PatType
t
    i64 :: Binding
i64 = Maybe ([VName], StructType) -> StaticVal -> Binding
Binding forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
Dynamic forall a b. (a -> b) -> a -> b
$ forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. PrimType -> ScalarTypeBase dim as
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64
matchPatSV (Wildcard Info PatType
_ SrcLoc
_) StaticVal
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
matchPatSV (PatAscription Pat
pat TypeExp Info VName
_ SrcLoc
_) StaticVal
sv = Pat -> StaticVal -> Maybe Env
matchPatSV Pat
pat StaticVal
sv
matchPatSV PatLit {} StaticVal
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
matchPatSV (PatConstr Name
c1 Info PatType
_ [Pat]
ps SrcLoc
_) (SumSV Name
c2 [StaticVal]
ls [(Name, [PatType])]
fs)
  | Name
c1 forall a. Eq a => a -> a -> Bool
== Name
c2 =
      forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Pat -> StaticVal -> Maybe Env
matchPatSV [Pat]
ps [StaticVal]
ls
  | Just [PatType]
_ <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
c1 [(Name, [PatType])]
fs =
      forall a. Maybe a
Nothing
  | Bool
otherwise =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"matchPatSV: missing constructor in type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Name
c1
matchPatSV (PatConstr Name
c1 Info PatType
_ [Pat]
ps SrcLoc
_) (Dynamic (Scalar (Sum Map Name [PatType]
fs)))
  | Just [PatType]
ts <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c1 Map Name [PatType]
fs =
      -- A higher-order pattern can only match an appropriate SumSV.
      if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall dim as. TypeBase dim as -> Bool
orderZero [PatType]
ts
        then forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Pat -> StaticVal -> Maybe Env
matchPatSV [Pat]
ps (forall a b. (a -> b) -> [a] -> [b]
map PatType -> StaticVal
svFromType [PatType]
ts)
        else forall a. Maybe a
Nothing
  | Bool
otherwise =
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"matchPatSV: missing constructor in type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Name
c1
matchPatSV Pat
pat (Dynamic PatType
t) = Pat -> StaticVal -> Maybe Env
matchPatSV Pat
pat forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
svFromType PatType
t
matchPatSV Pat
pat (HoleSV PatType
t SrcLoc
_) = Pat -> StaticVal -> Maybe Env
matchPatSV Pat
pat forall a b. (a -> b) -> a -> b
$ PatType -> StaticVal
svFromType PatType
t
matchPatSV Pat
pat StaticVal
sv =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
"Tried to match pattern\n"
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat
pat
      forall a. [a] -> [a] -> [a]
++ [Char]
"\n with static value\n"
      forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show StaticVal
sv

alwaysMatchPatSV :: Pat -> StaticVal -> Env
alwaysMatchPatSV :: Pat -> StaticVal -> Env
alwaysMatchPatSV Pat
pat StaticVal
sv = forall a. a -> Maybe a -> a
fromMaybe forall {a}. a
bad forall a b. (a -> b) -> a -> b
$ Pat -> StaticVal -> Maybe Env
matchPatSV Pat
pat StaticVal
sv
  where
    bad :: a
bad = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines [forall a. Pretty a => a -> [Char]
prettyString Pat
pat, [Char]
"cannot match StaticVal", forall a. Show a => a -> [Char]
show StaticVal
sv]

-- | Given a pattern and the static value for the defunctionalized argument,
-- update the pattern to reflect the changes in the types.
updatePat :: Pat -> StaticVal -> Pat
updatePat :: Pat -> StaticVal -> Pat
updatePat (TuplePat [Pat]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs) =
  forall (f :: * -> *) vn. [PatBase f vn] -> SrcLoc -> PatBase f vn
TuplePat (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat -> StaticVal -> Pat
updatePat [Pat]
ps forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Name, StaticVal)]
svs) SrcLoc
loc
updatePat (RecordPat [(Name, Pat)]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs)
  | [(Name, Pat)]
ps' <- forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst [(Name, Pat)]
ps,
    [(Name, StaticVal)]
svs' <- forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst [(Name, StaticVal)]
svs =
      forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat
        (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(Name
n, Pat
p) (Name
_, StaticVal
sv) -> (Name
n, Pat -> StaticVal -> Pat
updatePat Pat
p StaticVal
sv)) [(Name, Pat)]
ps' [(Name, StaticVal)]
svs')
        SrcLoc
loc
updatePat (PatParens Pat
pat SrcLoc
loc) StaticVal
sv =
  forall (f :: * -> *) vn. PatBase f vn -> SrcLoc -> PatBase f vn
PatParens (Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv) SrcLoc
loc
updatePat (PatAttr AttrInfo VName
attr Pat
pat SrcLoc
loc) StaticVal
sv =
  forall (f :: * -> *) vn.
AttrInfo vn -> PatBase f vn -> SrcLoc -> PatBase f vn
PatAttr AttrInfo VName
attr (Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv) SrcLoc
loc
updatePat (Id VName
vn (Info PatType
tp) SrcLoc
loc) StaticVal
sv =
  forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
vn (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ forall {dim} {as}.
TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb PatType
tp (StaticVal -> PatType
typeFromSV StaticVal
sv forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique)) SrcLoc
loc
  where
    -- Preserve any original zeroth-order types.
    comb :: TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb (Scalar Arrow {}) TypeBase dim as
t2 = TypeBase dim as
t2
    comb (Scalar (Record Map Name (TypeBase dim as)
m1)) (Scalar (Record Map Name (TypeBase dim as)
m2)) =
      forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record forall a b. (a -> b) -> a -> b
$ forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb Map Name (TypeBase dim as)
m1 Map Name (TypeBase dim as)
m2
    comb (Scalar (Sum Map Name [TypeBase dim as]
m1)) (Scalar (Sum Map Name [TypeBase dim as]
m2)) =
      forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum forall a b. (a -> b) -> a -> b
$ forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase dim as -> TypeBase dim as -> TypeBase dim as
comb) Map Name [TypeBase dim as]
m1 Map Name [TypeBase dim as]
m2
    comb TypeBase dim as
t1 TypeBase dim as
_ = TypeBase dim as
t1 -- t1 must be array or prim.
updatePat pat :: Pat
pat@(Wildcard (Info PatType
tp) SrcLoc
loc) StaticVal
sv
  | forall dim as. TypeBase dim as -> Bool
orderZero PatType
tp = Pat
pat
  | Bool
otherwise = forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$ StaticVal -> PatType
typeFromSV StaticVal
sv) SrcLoc
loc
updatePat (PatAscription Pat
pat TypeExp Info VName
_ SrcLoc
_) StaticVal
sv =
  Pat -> StaticVal -> Pat
updatePat Pat
pat StaticVal
sv
updatePat p :: Pat
p@PatLit {} StaticVal
_ = Pat
p
updatePat pat :: Pat
pat@(PatConstr Name
c1 (Info PatType
t) [Pat]
ps SrcLoc
loc) sv :: StaticVal
sv@(SumSV Name
_ [StaticVal]
svs [(Name, [PatType])]
_)
  | forall dim as. TypeBase dim as -> Bool
orderZero PatType
t = Pat
pat
  | Bool
otherwise = forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
c1 (forall a. a -> Info a
Info PatType
t') [Pat]
ps' SrcLoc
loc
  where
    t' :: PatType
t' = StaticVal -> PatType
typeFromSV StaticVal
sv forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
    ps' :: [Pat]
ps' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pat -> StaticVal -> Pat
updatePat [Pat]
ps [StaticVal]
svs
updatePat (PatConstr Name
c1 Info PatType
_ [Pat]
ps SrcLoc
loc) (Dynamic PatType
t) =
  forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
c1 (forall a. a -> Info a
Info PatType
t) [Pat]
ps SrcLoc
loc
updatePat Pat
pat (Dynamic PatType
t) = Pat -> StaticVal -> Pat
updatePat Pat
pat (PatType -> StaticVal
svFromType PatType
t)
updatePat Pat
pat (HoleSV PatType
t SrcLoc
_) = Pat -> StaticVal -> Pat
updatePat Pat
pat (PatType -> StaticVal
svFromType PatType
t)
updatePat Pat
pat StaticVal
sv =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
"Tried to update pattern\n"
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat
pat
      forall a. [a] -> [a] -> [a]
++ [Char]
"\nto reflect the static value\n"
      forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show StaticVal
sv

-- | Convert a record (or tuple) type to a record static value. This
-- is used for "unwrapping" tuples and records that are nested in
-- 'Dynamic' static values.
svFromType :: PatType -> StaticVal
svFromType :: PatType -> StaticVal
svFromType (Scalar (Record Map Name PatType
fs)) = [(Name, StaticVal)] -> StaticVal
RecordSV forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map PatType -> StaticVal
svFromType Map Name PatType
fs
svFromType PatType
t = PatType -> StaticVal
Dynamic PatType
t

-- | Defunctionalize a top-level value binding. Returns the
-- transformed result as well as an environment that binds the name of
-- the value binding to the static value of the transformed body.  The
-- boolean is true if the function is a 'DynamicFun'.
defuncValBind :: ValBind -> DefM (ValBind, Env)
-- Eta-expand entry points with a functional return type.
defuncValBind :: ValBind -> DefM (ValBind, Env)
defuncValBind (ValBind Maybe (Info EntryPoint)
entry VName
name Maybe (TypeExp Info VName)
_ (Info StructRetType
rettype) [TypeParamBase VName]
tparams [Pat]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
attrs SrcLoc
loc)
  | Scalar Arrow {} <- forall dim as. RetTypeBase dim as -> TypeBase dim as
retType StructRetType
rettype = do
      ([Pat]
body_pats, Exp
body', StructRetType
rettype') <- PatRetType -> Exp -> DefM ([Pat], Exp, StructRetType)
etaExpand (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) StructRetType
rettype) Exp
body
      ValBind -> DefM (ValBind, Env)
defuncValBind forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) vn.
Maybe (f EntryPoint)
-> vn
-> Maybe (TypeExp f vn)
-> f StructRetType
-> [TypeParamBase vn]
-> [PatBase f vn]
-> ExpBase f vn
-> Maybe DocComment
-> [AttrInfo vn]
-> SrcLoc
-> ValBindBase f vn
ValBind
          Maybe (Info EntryPoint)
entry
          VName
name
          forall a. Maybe a
Nothing
          (forall a. a -> Info a
Info StructRetType
rettype')
          [TypeParamBase VName]
tparams
          ([Pat]
params forall a. Semigroup a => a -> a -> a
<> [Pat]
body_pats)
          Exp
body'
          forall a. Maybe a
Nothing
          [AttrInfo VName]
attrs
          SrcLoc
loc
defuncValBind valbind :: ValBind
valbind@(ValBind Maybe (Info EntryPoint)
_ VName
name Maybe (TypeExp Info VName)
retdecl (Info (RetType [VName]
ret_dims StructType
rettype)) [TypeParamBase VName]
tparams [Pat]
params Exp
body Maybe DocComment
_ [AttrInfo VName]
_ SrcLoc
_) = do
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall vn. TypeParamBase vn -> Bool
isTypeParam [TypeParamBase VName]
tparams) forall a b. (a -> b) -> a -> b
$
    forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
      forall a. Show a => a -> [Char]
show VName
name
        forall a. [a] -> [a] -> [a]
++ [Char]
" has type parameters, "
        forall a. [a] -> [a] -> [a]
++ [Char]
"but the defunctionaliser expects a monomorphic input program."
  ([VName]
tparams', [Pat]
params', Exp
body', StaticVal
sv) <-
    [VName]
-> [Pat]
-> Exp
-> StructRetType
-> DefM ([VName], [Pat], Exp, StaticVal)
defuncLet (forall a b. (a -> b) -> [a] -> [b]
map forall vn. TypeParamBase vn -> vn
typeParamName [TypeParamBase VName]
tparams) [Pat]
params Exp
body forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims StructType
rettype
  Set VName
globals <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a, b) -> a
fst
  let bound_sizes :: Set VName
bound_sizes = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames [Pat]
params' forall a. Semigroup a => a -> a -> a
<> forall a. Ord a => [a] -> Set a
S.fromList [VName]
tparams' forall a. Semigroup a => a -> a -> a
<> Set VName
globals
      rettype' :: StructType
rettype' =
        -- FIXME: dubious that we cannot assume that all sizes in the
        -- body are in scope.  This is because when we insert
        -- applications of lifted functions, we don't properly update
        -- the types in the return type annotation.
        forall as.
Monoid as =>
TypeBase Size as -> TypeBase Size as -> TypeBase Size as
combineTypeShapes StructType
rettype forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Set VName -> Size -> Size
anyDimIfNotBound Set VName
bound_sizes) forall a b. (a -> b) -> a -> b
$ forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct forall a b. (a -> b) -> a -> b
$ Exp -> PatType
typeOf Exp
body'
      ret_dims' :: [VName]
ret_dims' = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> Set a -> Bool
`S.member` forall as. TypeBase Size as -> Set VName
freeInType StructType
rettype') [VName]
ret_dims
  ([VName]
missing_dims, [Pat]
params'') <- forall (m :: * -> *).
MonadFreshNames m =>
Set VName -> [Pat] -> m ([VName], [Pat])
sizesForAll Set VName
bound_sizes [Pat]
params'

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( ValBind
valbind
        { valBindRetDecl :: Maybe (TypeExp Info VName)
valBindRetDecl = Maybe (TypeExp Info VName)
retdecl,
          valBindRetType :: Info StructRetType
valBindRetType =
            forall a. a -> Info a
Info forall a b. (a -> b) -> a -> b
$
              if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Pat]
params'
                then forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims' forall a b. (a -> b) -> a -> b
$ StructType
rettype' forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
                else forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
ret_dims' StructType
rettype',
          valBindTypeParams :: [TypeParamBase VName]
valBindTypeParams =
            forall a b. (a -> b) -> [a] -> [b]
map (forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` forall a. Monoid a => a
mempty) forall a b. (a -> b) -> a -> b
$ [VName]
tparams' forall a. [a] -> [a] -> [a]
++ [VName]
missing_dims,
          valBindParams :: [Pat]
valBindParams = [Pat]
params'',
          valBindBody :: Exp
valBindBody = Exp
body'
        },
      forall k a. k -> a -> Map k a
M.singleton VName
name forall a b. (a -> b) -> a -> b
$
        Maybe ([VName], StructType) -> StaticVal -> Binding
Binding
          (forall a. a -> Maybe a
Just (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a b. (a -> b) -> [a] -> [b]
map forall vn. TypeParamBase vn -> vn
typeParamName) (ValBind -> ([TypeParamBase VName], StructType)
valBindTypeScheme ValBind
valbind)))
          StaticVal
sv
    )
  where
    anyDimIfNotBound :: Set VName -> Size -> Size
anyDimIfNotBound Set VName
bound_sizes (NamedSize QualName VName
v)
      | forall vn. QualName vn -> vn
qualLeaf QualName VName
v forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
bound_sizes = Maybe VName -> Size
AnySize forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall vn. QualName vn -> vn
qualLeaf QualName VName
v
    anyDimIfNotBound Set VName
_ Size
d = Size
d

-- | Defunctionalize a list of top-level declarations.
defuncVals :: [ValBind] -> DefM ()
defuncVals :: [ValBind] -> DefM ()
defuncVals [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defuncVals (ValBind
valbind : [ValBind]
ds) = do
  (ValBind
valbind', Env
env) <- ValBind -> DefM (ValBind, Env)
defuncValBind ValBind
valbind
  ValBind -> DefM ()
addValBind ValBind
valbind'
  let globals :: [VName]
globals = ValBind -> [VName]
valBindBound ValBind
valbind'
  forall a. Env -> DefM a -> DefM a
localEnv Env
env forall a b. (a -> b) -> a -> b
$ forall a. [VName] -> DefM a -> DefM a
areGlobal [VName]
globals forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM ()
defuncVals [ValBind]
ds

{-# NOINLINE transformProg #-}

-- | Transform a list of top-level value bindings. May produce new
-- lifted function definitions, which are placed in front of the
-- resulting list of declarations.
transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg :: forall (m :: * -> *). MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg [ValBind]
decs = forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
  let ((), VNameSource
namesrc', [ValBind]
decs') = forall a. VNameSource -> DefM a -> (a, VNameSource, [ValBind])
runDefM VNameSource
namesrc forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM ()
defuncVals [ValBind]
decs
   in ([ValBind]
decs', VNameSource
namesrc')