{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE Trustworthy #-}
-- | Defunctionalization of typed, monomorphic Futhark programs without modules.
module Futhark.Internalise.Defunctionalise
  ( transformProg ) where

import qualified Control.Arrow as Arrow
import           Control.Monad.State
import           Control.Monad.RWS hiding (Sum)
import           Data.Bifunctor
import           Data.Foldable
import           Data.List (sortOn, nub, partition, tails)
import qualified Data.List.NonEmpty as NE
import           Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.Sequence as Seq

import           Futhark.MonadFreshNames
import           Language.Futhark
import           Futhark.IR.Pretty ()

-- | An expression or an extended 'Lambda' (with size parameters,
-- which AST lambdas do not support).
data ExtExp = ExtLambda [TypeParam] [Pattern] Exp (Aliasing, StructType) SrcLoc
            | ExtExp Exp
  deriving (Int -> ExtExp -> ShowS
[ExtExp] -> ShowS
ExtExp -> String
(Int -> ExtExp -> ShowS)
-> (ExtExp -> String) -> ([ExtExp] -> ShowS) -> Show ExtExp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ExtExp] -> ShowS
$cshowList :: [ExtExp] -> ShowS
show :: ExtExp -> String
$cshow :: ExtExp -> String
showsPrec :: Int -> ExtExp -> ShowS
$cshowsPrec :: Int -> ExtExp -> ShowS
Show)

-- | A static value stores additional information about the result of
-- defunctionalization of an expression, aside from the residual expression.
data StaticVal = Dynamic PatternType
               | LambdaSV [VName] Pattern StructType ExtExp Env
                 -- ^ The 'VName's are shape parameters that are bound
                 -- by the 'Pattern'.
               | RecordSV [(Name, StaticVal)]
               | SumSV Name [StaticVal] [(Name, [PatternType])]
                 -- ^ The constructor that is actually present, plus
                 -- the others that are not.
               | DynamicFun (Exp, StaticVal) StaticVal
               | IntrinsicSV
  deriving (Int -> StaticVal -> ShowS
[StaticVal] -> ShowS
StaticVal -> String
(Int -> StaticVal -> ShowS)
-> (StaticVal -> String)
-> ([StaticVal] -> ShowS)
-> Show StaticVal
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StaticVal] -> ShowS
$cshowList :: [StaticVal] -> ShowS
show :: StaticVal -> String
$cshow :: StaticVal -> String
showsPrec :: Int -> StaticVal -> ShowS
$cshowsPrec :: Int -> StaticVal -> ShowS
Show)

-- | Environment mapping variable names to their associated static value.
type Env = M.Map VName StaticVal

localEnv :: Env -> DefM a -> DefM a
localEnv :: Env -> DefM a -> DefM a
localEnv Env
env = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ (Env -> Env) -> (Set VName, Env) -> (Set VName, Env)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
Arrow.second (Env
envEnv -> Env -> Env
forall a. Semigroup a => a -> a -> a
<>)

-- Even when using a "new" environment (for evaluating closures) we
-- still ram the global environment of DynamicFuns in there.
localNewEnv :: Env -> DefM a -> DefM a
localNewEnv :: Env -> DefM a -> DefM a
localNewEnv Env
env = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ \(Set VName
globals, Env
old_env) ->
  (Set VName
globals, (VName -> StaticVal -> Bool) -> Env -> Env
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\VName
k StaticVal
_ -> VName
k VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals) Env
old_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
env)

extendEnv :: VName -> StaticVal -> DefM a -> DefM a
extendEnv :: VName -> StaticVal -> DefM a -> DefM a
extendEnv VName
vn StaticVal
sv = Env -> DefM a -> DefM a
forall a. Env -> DefM a -> DefM a
localEnv (VName -> StaticVal -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn StaticVal
sv)

askEnv :: DefM Env
askEnv :: DefM Env
askEnv = ((Set VName, Env) -> Env) -> DefM Env
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Set VName, Env) -> Env
forall a b. (a, b) -> b
snd

isGlobal :: VName -> DefM a -> DefM a
isGlobal :: VName -> DefM a -> DefM a
isGlobal VName
v = ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a)
-> ((Set VName, Env) -> (Set VName, Env)) -> DefM a -> DefM a
forall a b. (a -> b) -> a -> b
$ (Set VName -> Set VName) -> (Set VName, Env) -> (Set VName, Env)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
Arrow.first (VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert VName
v)

-- | Returns the defunctionalization environment restricted
-- to the given set of variable names and types.
restrictEnvTo :: NameSet -> DefM Env
restrictEnvTo :: NameSet -> DefM Env
restrictEnvTo (NameSet Map VName Uniqueness
m) = (Set VName, Env) -> Env
restrict ((Set VName, Env) -> Env) -> DefM (Set VName, Env) -> DefM Env
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DefM (Set VName, Env)
forall r (m :: * -> *). MonadReader r m => m r
ask
  where restrict :: (Set VName, Env) -> Env
restrict (Set VName
globals, Env
env) = (VName -> StaticVal -> Maybe StaticVal) -> Env -> Env
forall k a b. (k -> a -> Maybe b) -> Map k a -> Map k b
M.mapMaybeWithKey VName -> StaticVal -> Maybe StaticVal
keep Env
env
          where keep :: VName -> StaticVal -> Maybe StaticVal
keep VName
k StaticVal
sv = do Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
k VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
globals
                               Uniqueness
u <- VName -> Map VName Uniqueness -> Maybe Uniqueness
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
k Map VName Uniqueness
m
                               StaticVal -> Maybe StaticVal
forall a. a -> Maybe a
Just (StaticVal -> Maybe StaticVal) -> StaticVal -> Maybe StaticVal
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv
        restrict' :: Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
Nonunique (Dynamic PatternType
t) =
          PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ PatternType
t PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness`  Uniqueness
Nonunique
        restrict' Uniqueness
_ (Dynamic PatternType
t) =
          PatternType -> StaticVal
Dynamic PatternType
t
        restrict' Uniqueness
u (LambdaSV [VName]
dims Pattern
pat StructType
t ExtExp
e Env
env) =
          [VName] -> Pattern -> StructType -> ExtExp -> Env -> StaticVal
LambdaSV [VName]
dims Pattern
pat StructType
t ExtExp
e (Env -> StaticVal) -> Env -> StaticVal
forall a b. (a -> b) -> a -> b
$ (StaticVal -> StaticVal) -> Env -> Env
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u) Env
env
        restrict' Uniqueness
u (RecordSV [(Name, StaticVal)]
fields) =
          [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> (Name, StaticVal))
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> StaticVal) -> (Name, StaticVal) -> (Name, StaticVal)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((StaticVal -> StaticVal)
 -> (Name, StaticVal) -> (Name, StaticVal))
-> (StaticVal -> StaticVal)
-> (Name, StaticVal)
-> (Name, StaticVal)
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u) [(Name, StaticVal)]
fields
        restrict' Uniqueness
u (SumSV Name
c [StaticVal]
svs [(Name, [PatternType])]
fields) =
          Name -> [StaticVal] -> [(Name, [PatternType])] -> StaticVal
SumSV Name
c ((StaticVal -> StaticVal) -> [StaticVal] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map (Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u) [StaticVal]
svs) [(Name, [PatternType])]
fields
        restrict' Uniqueness
u (DynamicFun (Exp
e, StaticVal
sv1) StaticVal
sv2) =
          (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp
e, Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv1) (StaticVal -> StaticVal) -> StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$ Uniqueness -> StaticVal -> StaticVal
restrict' Uniqueness
u StaticVal
sv2
        restrict' Uniqueness
_ StaticVal
IntrinsicSV = StaticVal
IntrinsicSV

-- | Defunctionalization monad.  The Reader environment tracks both
-- the current Env as well as the set of globally defined dynamic
-- functions.  This is used to avoid unnecessarily large closure
-- environments.
newtype DefM a = DefM (RWS (S.Set VName, Env) (Seq.Seq ValBind) VNameSource a)
  deriving (a -> DefM b -> DefM a
(a -> b) -> DefM a -> DefM b
(forall a b. (a -> b) -> DefM a -> DefM b)
-> (forall a b. a -> DefM b -> DefM a) -> Functor DefM
forall a b. a -> DefM b -> DefM a
forall a b. (a -> b) -> DefM a -> DefM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> DefM b -> DefM a
$c<$ :: forall a b. a -> DefM b -> DefM a
fmap :: (a -> b) -> DefM a -> DefM b
$cfmap :: forall a b. (a -> b) -> DefM a -> DefM b
Functor, Functor DefM
a -> DefM a
Functor DefM
-> (forall a. a -> DefM a)
-> (forall a b. DefM (a -> b) -> DefM a -> DefM b)
-> (forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c)
-> (forall a b. DefM a -> DefM b -> DefM b)
-> (forall a b. DefM a -> DefM b -> DefM a)
-> Applicative DefM
DefM a -> DefM b -> DefM b
DefM a -> DefM b -> DefM a
DefM (a -> b) -> DefM a -> DefM b
(a -> b -> c) -> DefM a -> DefM b -> DefM c
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM (a -> b) -> DefM a -> DefM b
forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: DefM a -> DefM b -> DefM a
$c<* :: forall a b. DefM a -> DefM b -> DefM a
*> :: DefM a -> DefM b -> DefM b
$c*> :: forall a b. DefM a -> DefM b -> DefM b
liftA2 :: (a -> b -> c) -> DefM a -> DefM b -> DefM c
$cliftA2 :: forall a b c. (a -> b -> c) -> DefM a -> DefM b -> DefM c
<*> :: DefM (a -> b) -> DefM a -> DefM b
$c<*> :: forall a b. DefM (a -> b) -> DefM a -> DefM b
pure :: a -> DefM a
$cpure :: forall a. a -> DefM a
$cp1Applicative :: Functor DefM
Applicative, Applicative DefM
a -> DefM a
Applicative DefM
-> (forall a b. DefM a -> (a -> DefM b) -> DefM b)
-> (forall a b. DefM a -> DefM b -> DefM b)
-> (forall a. a -> DefM a)
-> Monad DefM
DefM a -> (a -> DefM b) -> DefM b
DefM a -> DefM b -> DefM b
forall a. a -> DefM a
forall a b. DefM a -> DefM b -> DefM b
forall a b. DefM a -> (a -> DefM b) -> DefM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> DefM a
$creturn :: forall a. a -> DefM a
>> :: DefM a -> DefM b -> DefM b
$c>> :: forall a b. DefM a -> DefM b -> DefM b
>>= :: DefM a -> (a -> DefM b) -> DefM b
$c>>= :: forall a b. DefM a -> (a -> DefM b) -> DefM b
$cp1Monad :: Applicative DefM
Monad,
            MonadReader (S.Set VName, Env),
            MonadWriter (Seq.Seq ValBind),
            Monad DefM
Applicative DefM
DefM VNameSource
Applicative DefM
-> Monad DefM
-> DefM VNameSource
-> (VNameSource -> DefM ())
-> MonadFreshNames DefM
VNameSource -> DefM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> DefM ()
$cputNameSource :: VNameSource -> DefM ()
getNameSource :: DefM VNameSource
$cgetNameSource :: DefM VNameSource
$cp2MonadFreshNames :: Monad DefM
$cp1MonadFreshNames :: Applicative DefM
MonadFreshNames)

-- | Run a computation in the defunctionalization monad. Returns the result of
-- the computation, a new name source, and a list of lifted function declations.
runDefM :: VNameSource -> DefM a -> (a, VNameSource, Seq.Seq ValBind)
runDefM :: VNameSource -> DefM a -> (a, VNameSource, Seq ValBind)
runDefM VNameSource
src (DefM RWS (Set VName, Env) (Seq ValBind) VNameSource a
m) = RWS (Set VName, Env) (Seq ValBind) VNameSource a
-> (Set VName, Env) -> VNameSource -> (a, VNameSource, Seq ValBind)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Set VName, Env) (Seq ValBind) VNameSource a
m (Set VName, Env)
forall a. Monoid a => a
mempty VNameSource
src

collectFuns :: DefM a -> DefM (a, Seq.Seq ValBind)
collectFuns :: DefM a -> DefM (a, Seq ValBind)
collectFuns DefM a
m = DefM ((a, Seq ValBind), Seq ValBind -> Seq ValBind)
-> DefM (a, Seq ValBind)
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (DefM ((a, Seq ValBind), Seq ValBind -> Seq ValBind)
 -> DefM (a, Seq ValBind))
-> DefM ((a, Seq ValBind), Seq ValBind -> Seq ValBind)
-> DefM (a, Seq ValBind)
forall a b. (a -> b) -> a -> b
$ do
  (a
x, Seq ValBind
decs) <- DefM a -> DefM (a, Seq ValBind)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen DefM a
m
  ((a, Seq ValBind), Seq ValBind -> Seq ValBind)
-> DefM ((a, Seq ValBind), Seq ValBind -> Seq ValBind)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, Seq ValBind
decs), Seq ValBind -> Seq ValBind -> Seq ValBind
forall a b. a -> b -> a
const Seq ValBind
forall a. Monoid a => a
mempty)

-- | Looks up the associated static value for a given name in the environment.
lookupVar :: SrcLoc -> VName -> DefM StaticVal
lookupVar :: SrcLoc -> VName -> DefM StaticVal
lookupVar SrcLoc
loc VName
x = do
  Env
env <- DefM Env
askEnv
  case VName -> Env -> Maybe StaticVal
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Env
env of
    Just StaticVal
sv -> StaticVal -> DefM StaticVal
forall (m :: * -> *) a. Monad m => a -> m a
return StaticVal
sv
    Maybe StaticVal
Nothing -- If the variable is unknown, it may refer to the 'intrinsics'
            -- module, which we will have to treat specially.
      | VName -> Int
baseTag VName
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag -> StaticVal -> DefM StaticVal
forall (m :: * -> *) a. Monad m => a -> m a
return StaticVal
IntrinsicSV
      | Bool
otherwise -> -- Anything not in scope is going to be an
                     -- existential size.
          StaticVal -> DefM StaticVal
forall (m :: * -> *) a. Monad m => a -> m a
return (StaticVal -> DefM StaticVal) -> StaticVal -> DefM StaticVal
forall a b. (a -> b) -> a -> b
$ PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32
      | Bool
otherwise ->  String -> DefM StaticVal
forall a. HasCallStack => String -> a
error (String -> DefM StaticVal) -> String -> DefM StaticVal
forall a b. (a -> b) -> a -> b
$ String
"Variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" at "
                          String -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is out of scope."

-- Like patternDimNames, but ignores sizes that are only found in
-- funtion types.
arraySizes :: StructType -> S.Set VName
arraySizes :: StructType -> Set VName
arraySizes (Scalar Arrow{}) = Set VName
forall a. Monoid a => a
mempty
arraySizes (Scalar (Record Map Name StructType
fields)) = (StructType -> Set VName) -> Map Name StructType -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes Map Name StructType
fields
arraySizes (Scalar (Sum Map Name [StructType]
cs)) = ([StructType] -> Set VName) -> Map Name [StructType] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((StructType -> Set VName) -> [StructType] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
arraySizes) Map Name [StructType]
cs
arraySizes (Scalar TypeVar{}) = Set VName
forall a. Monoid a => a
mempty
arraySizes (Scalar Prim{}) = Set VName
forall a. Monoid a => a
mempty
arraySizes (Array ()
_ Uniqueness
_ ScalarTypeBase (DimDecl VName) ()
t ShapeDecl (DimDecl VName)
shape) =
  StructType -> Set VName
arraySizes (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar ScalarTypeBase (DimDecl VName) ()
t) Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (DimDecl VName -> Set VName) -> [DimDecl VName] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DimDecl VName -> Set VName
dimName (ShapeDecl (DimDecl VName) -> [DimDecl VName]
forall dim. ShapeDecl dim -> [dim]
shapeDims ShapeDecl (DimDecl VName)
shape)
  where dimName :: DimDecl VName -> S.Set VName
        dimName :: DimDecl VName -> Set VName
dimName (NamedDim QualName VName
qn) = VName -> Set VName
forall a. a -> Set a
S.singleton (VName -> Set VName) -> VName -> Set VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn
        dimName DimDecl VName
_             = Set VName
forall a. Monoid a => a
mempty

patternArraySizes :: Pattern -> S.Set VName
patternArraySizes :: Pattern -> Set VName
patternArraySizes = StructType -> Set VName
arraySizes (StructType -> Set VName)
-> (Pattern -> StructType) -> Pattern -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> StructType
patternStructType

dimMapping :: Monoid a =>
              TypeBase (DimDecl VName) a
           -> TypeBase (DimDecl VName) a
           -> M.Map VName VName
dimMapping :: TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName VName
dimMapping TypeBase (DimDecl VName) a
t1 TypeBase (DimDecl VName) a
t2 = State (Map VName VName) (TypeBase (DimDecl VName) a)
-> Map VName VName -> Map VName VName
forall s a. State s a -> s -> s
execState ((DimDecl VName
 -> DimDecl VName
 -> StateT (Map VName VName) Identity (DimDecl VName))
-> TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a
-> State (Map VName VName) (TypeBase (DimDecl VName) a)
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
(d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims DimDecl VName
-> DimDecl VName
-> StateT (Map VName VName) Identity (DimDecl VName)
forall (m :: * -> *) vn a.
(MonadState (Map vn a) m, Ord vn) =>
DimDecl vn -> DimDecl a -> m (DimDecl vn)
f TypeBase (DimDecl VName) a
t1 TypeBase (DimDecl VName) a
t2) Map VName VName
forall a. Monoid a => a
mempty
  where f :: DimDecl vn -> DimDecl a -> m (DimDecl vn)
f (NamedDim QualName vn
d1) (NamedDim QualName a
d2) = do
          (Map vn a -> Map vn a) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map vn a -> Map vn a) -> m ()) -> (Map vn a -> Map vn a) -> m ()
forall a b. (a -> b) -> a -> b
$ vn -> a -> Map vn a -> Map vn a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName vn -> vn
forall vn. QualName vn -> vn
qualLeaf QualName vn
d1) (QualName a -> a
forall vn. QualName vn -> vn
qualLeaf QualName a
d2)
          DimDecl vn -> m (DimDecl vn)
forall (m :: * -> *) a. Monad m => a -> m a
return (DimDecl vn -> m (DimDecl vn)) -> DimDecl vn -> m (DimDecl vn)
forall a b. (a -> b) -> a -> b
$ QualName vn -> DimDecl vn
forall vn. QualName vn -> DimDecl vn
NamedDim QualName vn
d1
        f DimDecl vn
d DimDecl a
_ = DimDecl vn -> m (DimDecl vn)
forall (m :: * -> *) a. Monad m => a -> m a
return DimDecl vn
d

defuncFun :: [TypeParam] -> [Pattern] -> Exp -> (Aliasing, StructType) -> SrcLoc
          -> DefM (Exp, StaticVal)
defuncFun :: [TypeParam]
-> [Pattern]
-> Exp
-> (Aliasing, StructType)
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [TypeParam]
tparams [Pattern]
pats Exp
e0 (Aliasing
closure, StructType
ret) SrcLoc
loc = do
  Bool -> DefM () -> DefM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((TypeParam -> Bool) -> [TypeParam] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any TypeParam -> Bool
forall vn. TypeParamBase vn -> Bool
isTypeParam [TypeParam]
tparams) (DefM () -> DefM ()) -> DefM () -> DefM ()
forall a b. (a -> b) -> a -> b
$
    String -> DefM ()
forall a. HasCallStack => String -> a
error (String -> DefM ()) -> String -> DefM ()
forall a b. (a -> b) -> a -> b
$ String
"Received a lambda with type parameters at " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc
         String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but the defunctionalizer expects a monomorphic input program."
  -- Extract the first parameter of the lambda and "push" the
  -- remaining ones (if there are any) into the body of the lambda.
  let ([VName]
dims, Pattern
pat, StructType
ret', ExtExp
e0') = case [Pattern]
pats of
        [] -> String -> ([VName], Pattern, StructType, ExtExp)
forall a. HasCallStack => String -> a
error String
"Received a lambda with no parameters."
        [Pattern
pat'] -> ((TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams, Pattern
pat', StructType
ret, Exp -> ExtExp
ExtExp Exp
e0)
        (Pattern
pat' : [Pattern]
pats') ->
          -- Split shape parameters into those that are determined by
          -- the first pattern, and those that are determined by later
          -- patterns.
          let bound_by_pat :: TypeParam -> Bool
bound_by_pat = (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Pattern -> Set VName
patternArraySizes Pattern
pat') (VName -> Bool) -> (TypeParam -> VName) -> TypeParam -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName
              ([TypeParam]
pat_dims, [TypeParam]
rest_dims) = (TypeParam -> Bool) -> [TypeParam] -> ([TypeParam], [TypeParam])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition TypeParam -> Bool
bound_by_pat [TypeParam]
tparams
          in ((TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
pat_dims, Pattern
pat',
              [StructType] -> StructType -> StructType
forall as dim.
Monoid as =>
[TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType ((Pattern -> StructType) -> [Pattern] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map (PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType)
-> (Pattern -> PatternType) -> Pattern -> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> PatternType
patternType) [Pattern]
pats') StructType
ret,
              [TypeParam]
-> [Pattern] -> Exp -> (Aliasing, StructType) -> SrcLoc -> ExtExp
ExtLambda [TypeParam]
rest_dims [Pattern]
pats' Exp
e0 (Aliasing
closure, StructType
ret) SrcLoc
loc)

  -- Construct a record literal that closes over the environment of
  -- the lambda.  Closed-over 'DynamicFun's are converted to their
  -- closure representation.
  let used :: NameSet
used = Exp -> NameSet
freeVars ([Pattern]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, StructType)
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatternBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, StructType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pattern]
pats Exp
e0 Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, StructType) -> Info (Aliasing, StructType)
forall a. a -> Info a
Info (Aliasing
closure, StructType
ret)) SrcLoc
loc)
             NameSet -> NameSet -> NameSet
`without` [NameSet] -> NameSet
forall a. Monoid a => [a] -> a
mconcat ((VName -> NameSet) -> [VName] -> [NameSet]
forall a b. (a -> b) -> [a] -> [b]
map VName -> NameSet
oneName [VName]
dims)
  Env
used_env <- NameSet -> DefM Env
restrictEnvTo NameSet
used

  -- The closure parts that are sizes are proactively turned into size
  -- parameters.
  let sizes_of_arrays :: Set VName
sizes_of_arrays = (StaticVal -> Set VName) -> Env -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (StructType -> Set VName
arraySizes (StructType -> Set VName)
-> (StaticVal -> StructType) -> StaticVal -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType)
-> (StaticVal -> PatternType) -> StaticVal -> StructType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StaticVal -> PatternType
typeFromSV) Env
used_env Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<>
                        Pattern -> Set VName
patternArraySizes Pattern
pat
      notSize :: VName -> Bool
notSize = Bool -> Bool
not (Bool -> Bool) -> (VName -> Bool) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes_of_arrays)
      ([FieldBase Info VName]
fields, [(VName, StaticVal)]
env) = [(FieldBase Info VName, (VName, StaticVal))]
-> ([FieldBase Info VName], [(VName, StaticVal)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(FieldBase Info VName, (VName, StaticVal))]
 -> ([FieldBase Info VName], [(VName, StaticVal)]))
-> [(FieldBase Info VName, (VName, StaticVal))]
-> ([FieldBase Info VName], [(VName, StaticVal)])
forall a b. (a -> b) -> a -> b
$ ((VName, StaticVal) -> (FieldBase Info VName, (VName, StaticVal)))
-> [(VName, StaticVal)]
-> [(FieldBase Info VName, (VName, StaticVal))]
forall a b. (a -> b) -> [a] -> [b]
map (VName, StaticVal) -> (FieldBase Info VName, (VName, StaticVal))
closureFromDynamicFun ([(VName, StaticVal)]
 -> [(FieldBase Info VName, (VName, StaticVal))])
-> [(VName, StaticVal)]
-> [(FieldBase Info VName, (VName, StaticVal))]
forall a b. (a -> b) -> a -> b
$
                      ((VName, StaticVal) -> Bool)
-> [(VName, StaticVal)] -> [(VName, StaticVal)]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Bool
notSize (VName -> Bool)
-> ((VName, StaticVal) -> VName) -> (VName, StaticVal) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, StaticVal) -> VName
forall a b. (a, b) -> a
fst) ([(VName, StaticVal)] -> [(VName, StaticVal)])
-> [(VName, StaticVal)] -> [(VName, StaticVal)]
forall a b. (a -> b) -> a -> b
$ Env -> [(VName, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList Env
used_env
      env' :: Env
env' = [(VName, StaticVal)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, StaticVal)]
env
      closure_dims :: [VName]
closure_dims = Set VName -> [VName]
forall a. Set a -> [a]
S.toList Set VName
sizes_of_arrays

  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([FieldBase Info VName] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fields SrcLoc
loc,
          [VName] -> Pattern -> StructType -> ExtExp -> Env -> StaticVal
LambdaSV ([VName] -> [VName]
forall a. Eq a => [a] -> [a]
nub ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ [VName]
dims[VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<>[VName]
closure_dims) Pattern
pat StructType
ret' ExtExp
e0' Env
env')

  where closureFromDynamicFun :: (VName, StaticVal) -> (FieldBase Info VName, (VName, StaticVal))
closureFromDynamicFun (VName
vn, DynamicFun (Exp
clsr_env, StaticVal
sv) StaticVal
_) =
          let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
vn
          in (Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
name Exp
clsr_env SrcLoc
forall a. Monoid a => a
mempty, (VName
vn, StaticVal
sv))

        closureFromDynamicFun (VName
vn, StaticVal
sv) =
          let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
vn
              tp' :: PatternType
tp' = StaticVal -> PatternType
typeFromSV StaticVal
sv
          in (Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
name
               (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
vn) (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
tp') SrcLoc
forall a. Monoid a => a
mempty) SrcLoc
forall a. Monoid a => a
mempty, (VName
vn, StaticVal
sv))

-- | Defunctionalization of an expression. Returns the residual expression and
-- the associated static value in the defunctionalization monad.
defuncExp :: Exp -> DefM (Exp, StaticVal)

defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp e :: Exp
e@Literal{} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e)

defuncExp e :: Exp
e@IntLit{} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e)

defuncExp e :: Exp
e@FloatLit{} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e)

defuncExp e :: Exp
e@StringLit{} =
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e)

defuncExp (Parens Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens Exp
e' SrcLoc
loc, StaticVal
sv)

defuncExp (QualParens (QualName VName, SrcLoc)
qn Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ((QualName VName, SrcLoc) -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
(QualName vn, SrcLoc) -> ExpBase f vn -> SrcLoc -> ExpBase f vn
QualParens (QualName VName, SrcLoc)
qn Exp
e' SrcLoc
loc, StaticVal
sv)

defuncExp (TupLit [Exp]
es SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- [(Exp, StaticVal)] -> ([Exp], [StaticVal])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Exp, StaticVal)] -> ([Exp], [StaticVal]))
-> DefM [(Exp, StaticVal)] -> DefM ([Exp], [StaticVal])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> DefM (Exp, StaticVal)) -> [Exp] -> DefM [(Exp, StaticVal)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Exp] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit [Exp]
es' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ [Name] -> [StaticVal] -> [(Name, StaticVal)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
tupleFieldNames [StaticVal]
svs)

defuncExp (RecordLit [FieldBase Info VName]
fs SrcLoc
loc) = do
  ([FieldBase Info VName]
fs', [(Name, StaticVal)]
names_svs) <- [(FieldBase Info VName, (Name, StaticVal))]
-> ([FieldBase Info VName], [(Name, StaticVal)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(FieldBase Info VName, (Name, StaticVal))]
 -> ([FieldBase Info VName], [(Name, StaticVal)]))
-> DefM [(FieldBase Info VName, (Name, StaticVal))]
-> DefM ([FieldBase Info VName], [(Name, StaticVal)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FieldBase Info VName
 -> DefM (FieldBase Info VName, (Name, StaticVal)))
-> [FieldBase Info VName]
-> DefM [(FieldBase Info VName, (Name, StaticVal))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField [FieldBase Info VName]
fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([FieldBase Info VName] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [FieldBase f vn] -> SrcLoc -> ExpBase f vn
RecordLit [FieldBase Info VName]
fs' SrcLoc
loc, [(Name, StaticVal)] -> StaticVal
RecordSV [(Name, StaticVal)]
names_svs)

  where defuncField :: FieldBase Info VName
-> DefM (FieldBase Info VName, (Name, StaticVal))
defuncField (RecordFieldExplicit Name
vn Exp
e SrcLoc
loc') = do
          (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
          (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn Exp
e' SrcLoc
loc', (Name
vn, StaticVal
sv))
        defuncField (RecordFieldImplicit VName
vn Info PatternType
_ SrcLoc
loc') = do
          StaticVal
sv <- SrcLoc -> VName -> DefM StaticVal
lookupVar SrcLoc
loc' VName
vn
          case StaticVal
sv of
            -- If the implicit field refers to a dynamic function, we
            -- convert it to an explicit field with a record closing over
            -- the environment and bind the corresponding static value.
            DynamicFun (Exp
e, StaticVal
sv') StaticVal
_ -> let vn' :: Name
vn' = VName -> Name
baseName VName
vn
                                     in (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Exp -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> SrcLoc -> FieldBase f vn
RecordFieldExplicit Name
vn' Exp
e SrcLoc
loc',
                                                (Name
vn', StaticVal
sv'))
            -- The field may refer to a functional expression, so we get the
            -- type from the static value and not the one from the AST.
            StaticVal
_ -> let tp :: Info PatternType
tp = PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatternType
typeFromSV StaticVal
sv
                 in (FieldBase Info VName, (Name, StaticVal))
-> DefM (FieldBase Info VName, (Name, StaticVal))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Info PatternType -> SrcLoc -> FieldBase Info VName
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> FieldBase f vn
RecordFieldImplicit VName
vn Info PatternType
tp SrcLoc
loc', (VName -> Name
baseName VName
vn, StaticVal
sv))

defuncExp (ArrayLit [Exp]
es t :: Info PatternType
t@(Info PatternType
t') SrcLoc
loc) = do
  [Exp]
es' <- (Exp -> DefM Exp) -> [Exp] -> DefM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' [Exp]
es
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Exp] -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
[ExpBase f vn] -> f PatternType -> SrcLoc -> ExpBase f vn
ArrayLit [Exp]
es' Info PatternType
t SrcLoc
loc, PatternType -> StaticVal
Dynamic PatternType
t')

defuncExp (Range Exp
e1 Maybe Exp
me Inclusiveness Exp
incl t :: (Info PatternType, Info [VName])
t@(Info PatternType
t', Info [VName]
_) SrcLoc
loc) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  Maybe Exp
me' <- (Exp -> DefM Exp) -> Maybe Exp -> DefM (Maybe Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' Maybe Exp
me
  Inclusiveness Exp
incl' <- (Exp -> DefM Exp) -> Inclusiveness Exp -> DefM (Inclusiveness Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp' Inclusiveness Exp
incl
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> Maybe Exp
-> Inclusiveness Exp
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> Maybe (ExpBase f vn)
-> Inclusiveness (ExpBase f vn)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Range Exp
e1' Maybe Exp
me' Inclusiveness Exp
incl' (Info PatternType, Info [VName])
t SrcLoc
loc, PatternType -> StaticVal
Dynamic PatternType
t')

defuncExp e :: Exp
e@(Var QualName VName
qn Info PatternType
_ SrcLoc
loc) = do
  StaticVal
sv <- SrcLoc -> VName -> DefM StaticVal
lookupVar SrcLoc
loc (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
  case StaticVal
sv of
    -- If the variable refers to a dynamic function, we return its closure
    -- representation (i.e., a record expression capturing the free variables
    -- and a 'LambdaSV' static value) instead of the variable itself.
    DynamicFun (Exp, StaticVal)
closure StaticVal
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp, StaticVal)
closure
    -- Intrinsic functions used as variables are eta-expanded, so we
    -- can get rid of them.
    StaticVal
IntrinsicSV -> do
      ([Pattern]
pats, Exp
body, StructType
tp) <- PatternType -> Exp -> DefM ([Pattern], Exp, StructType)
etaExpand (Exp -> PatternType
typeOf Exp
e) Exp
e
      Exp -> DefM (Exp, StaticVal)
defuncExp (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Pattern]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, StructType)
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatternBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, StructType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pattern]
pats Exp
body Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, StructType) -> Info (Aliasing, StructType)
forall a. a -> Info a
Info (Aliasing
forall a. Monoid a => a
mempty, StructType
tp)) SrcLoc
forall a. Monoid a => a
mempty
    StaticVal
_ -> let tp :: PatternType
tp = StaticVal -> PatternType
typeFromSV StaticVal
sv
         in (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
tp) SrcLoc
loc, StaticVal
sv)

defuncExp (Ascript Exp
e0 TypeDeclBase Info VName
tydecl SrcLoc
loc)
  | PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> PatternType
typeOf Exp
e0) = do (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
                               (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> TypeDeclBase Info VName -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> TypeDeclBase f vn -> SrcLoc -> ExpBase f vn
Ascript Exp
e0' TypeDeclBase Info VName
tydecl SrcLoc
loc, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0

defuncExp (Coerce Exp
e0 TypeDeclBase Info VName
tydecl (Info PatternType, Info [VName])
t SrcLoc
loc)
  | PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (Exp -> PatternType
typeOf Exp
e0) = do (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
                               (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> TypeDeclBase Info VName
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> TypeDeclBase f vn
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Coerce Exp
e0' TypeDeclBase Info VName
tydecl (Info PatternType, Info [VName])
t SrcLoc
loc, StaticVal
sv)
  | Bool
otherwise = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0

defuncExp (LetPat Pattern
pat Exp
e1 Exp
e2 (Info PatternType
t, Info [VName]
retext) SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env :: Env
env  = Pattern -> StaticVal -> Env
matchPatternSV Pattern
pat StaticVal
sv1
      pat' :: Pattern
pat' = Pattern -> StaticVal -> Pattern
updatePattern Pattern
pat StaticVal
sv1
  (Exp
e2', StaticVal
sv2) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  -- To maintain any sizes going out of scope, we need to compute the
  -- old size substitution induced by retext and also apply it to the
  -- newly computed body type.
  let mapping :: Map VName VName
mapping = PatternType -> PatternType -> Map VName VName
forall a.
Monoid a =>
TypeBase (DimDecl VName) a
-> TypeBase (DimDecl VName) a -> Map VName VName
dimMapping (Exp -> PatternType
typeOf Exp
e2) PatternType
t
      subst :: VName -> VName
subst VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName VName
mapping
      t' :: PatternType
t' = (DimDecl VName -> DimDecl VName) -> PatternType -> PatternType
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((VName -> VName) -> DimDecl VName -> DimDecl VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> VName
subst) (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e2'
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern
-> Exp -> Exp -> (Info PatternType, Info [VName]) -> SrcLoc -> Exp
forall (f :: * -> *) vn.
PatternBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
LetPat Pattern
pat' Exp
e1' Exp
e2' (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t', Info [VName]
retext) SrcLoc
loc, StaticVal
sv2)

-- Local functions are handled by rewriting them to lambdas, so that
-- the same machinery can be re-used.
defuncExp (LetFun VName
vn ([TypeParam]
dims, [Pattern]
pats, Maybe (TypeExp VName)
_, Info StructType
ret, Exp
e1) Exp
e2 Info PatternType
_ SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv1) <- [TypeParam]
-> [Pattern]
-> Exp
-> (Aliasing, StructType)
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [TypeParam]
dims [Pattern]
pats Exp
e1 (Aliasing
forall a. Monoid a => a
mempty, StructType
ret) SrcLoc
loc
  (Exp
e2', StaticVal
sv2) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv (VName -> StaticVal -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn StaticVal
sv1) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern
-> Exp -> Exp -> (Info PatternType, Info [VName]) -> SrcLoc -> Exp
forall (f :: * -> *) vn.
PatternBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
LetPat (VName -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id VName
vn (PatternType -> Info PatternType
forall a. a -> Info a
Info (Exp -> PatternType
typeOf Exp
e1')) SrcLoc
loc) Exp
e1' Exp
e2' (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e2', [VName] -> Info [VName]
forall a. a -> Info a
Info []) SrcLoc
loc,
          StaticVal
sv2)

defuncExp (If Exp
e1 Exp
e2 Exp
e3 (Info PatternType, Info [VName])
tp SrcLoc
loc) = do
  (Exp
e1', StaticVal
_ ) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp
e3', StaticVal
_ ) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> Exp -> Exp -> (Info PatternType, Info [VName]) -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> ExpBase f vn
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
If Exp
e1' Exp
e2' Exp
e3' (Info PatternType, Info [VName])
tp SrcLoc
loc, StaticVal
sv)

defuncExp e :: Exp
e@(Apply f :: Exp
f@(Var QualName VName
f' Info PatternType
_ SrcLoc
_) Exp
arg Info (Diet, Maybe VName)
d (Info PatternType
t, Info [VName]
ext) SrcLoc
loc)
  | VName -> Int
baseTag (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
f') Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
maxIntrinsicTag,
    TupLit [Exp]
es SrcLoc
tuploc <- Exp
arg = do
      -- defuncSoacExp also works fine for non-SOACs.
      [Exp]
es' <- (Exp -> DefM Exp) -> [Exp] -> DefM [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncSoacExp [Exp]
es
      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply Exp
f ([Exp] -> SrcLoc -> Exp
forall (f :: * -> *) vn. [ExpBase f vn] -> SrcLoc -> ExpBase f vn
TupLit [Exp]
es' SrcLoc
tuploc) Info (Diet, Maybe VName)
d (Info PatternType
t, Info [VName]
ext) SrcLoc
loc,
              PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e)

defuncExp e :: Exp
e@Apply{} = Int -> Exp -> DefM (Exp, StaticVal)
defuncApply Int
0 Exp
e

defuncExp (Negate Exp
e0 SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Negate Exp
e0' SrcLoc
loc, StaticVal
sv)

defuncExp (Lambda [Pattern]
pats Exp
e0 Maybe (TypeExp VName)
_ (Info (Aliasing
closure, StructType
ret)) SrcLoc
loc) =
  [TypeParam]
-> [Pattern]
-> Exp
-> (Aliasing, StructType)
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [] [Pattern]
pats Exp
e0 (Aliasing
closure, StructType
ret) SrcLoc
loc

-- Operator sections are expected to be converted to lambda-expressions
-- by the monomorphizer, so they should no longer occur at this point.
defuncExp OpSection{}      = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected operator section."
defuncExp OpSectionLeft{}  = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected operator section."
defuncExp OpSectionRight{} = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected operator section."
defuncExp ProjectSection{} = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected projection section."
defuncExp IndexSection{}   = String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"defuncExp: unexpected projection section."

defuncExp (DoLoop [VName]
sparams Pattern
pat Exp
e1 LoopFormBase Info VName
form Exp
e3 Info (PatternType, [VName])
ret SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  let env1 :: Env
env1 = Pattern -> StaticVal -> Env
matchPatternSV Pattern
pat StaticVal
sv1
  (LoopFormBase Info VName
form', Env
env2) <- case LoopFormBase Info VName
form of
    For IdentBase Info VName
v Exp
e2      -> do Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
                        (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall (m :: * -> *) a. Monad m => a -> m a
return (IdentBase Info VName -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
IdentBase f vn -> ExpBase f vn -> LoopFormBase f vn
For IdentBase Info VName
v Exp
e2', IdentBase Info VName -> Env
forall k. IdentBase Info k -> Map k StaticVal
envFromIdent IdentBase Info VName
v)
    ForIn Pattern
pat2 Exp
e2 -> do Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
                        (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
PatternBase f vn -> ExpBase f vn -> LoopFormBase f vn
ForIn Pattern
pat2 Exp
e2', Pattern -> Env
envFromPattern Pattern
pat2)
    While Exp
e2      -> do Exp
e2' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env1 (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
e2
                        (LoopFormBase Info VName, Env)
-> DefM (LoopFormBase Info VName, Env)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While Exp
e2', Env
forall a. Monoid a => a
mempty)
  (Exp
e3', StaticVal
sv) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv (Env
env1 Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
env2) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e3
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
-> Pattern
-> Exp
-> LoopFormBase Info VName
-> Exp
-> Info (PatternType, [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[VName]
-> PatternBase f vn
-> ExpBase f vn
-> LoopFormBase f vn
-> ExpBase f vn
-> f (PatternType, [VName])
-> SrcLoc
-> ExpBase f vn
DoLoop [VName]
sparams Pattern
pat Exp
e1' LoopFormBase Info VName
form' Exp
e3' Info (PatternType, [VName])
ret SrcLoc
loc, StaticVal
sv)
  where envFromIdent :: IdentBase Info k -> Map k StaticVal
envFromIdent (Ident k
vn (Info PatternType
tp) SrcLoc
_) =
          k -> StaticVal -> Map k StaticVal
forall k a. k -> a -> Map k a
M.singleton k
vn (StaticVal -> Map k StaticVal) -> StaticVal -> Map k StaticVal
forall a b. (a -> b) -> a -> b
$ PatternType -> StaticVal
Dynamic PatternType
tp

-- We handle BinOps by turning them into ordinary function applications.
defuncExp (BinOp (QualName VName
qn, SrcLoc
qnloc) (Info PatternType
t)
           (Exp
e1, Info (StructType
pt1, Maybe VName
ext1)) (Exp
e2, Info (StructType
pt2, Maybe VName
ext2))
           (Info PatternType
ret) (Info [VName]
retext) SrcLoc
loc) =
  Exp -> DefM (Exp, StaticVal)
defuncExp (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
qnloc)
                     Exp
e1 ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (StructType -> Diet
forall shape as. TypeBase shape as -> Diet
diet StructType
pt1, Maybe VName
ext1))
                     (PatternType -> Info PatternType
forall a. a -> Info a
Info (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed (StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
pt2) PatternType
ret), [VName] -> Info [VName]
forall a. a -> Info a
Info []) SrcLoc
loc)
                    Exp
e2 ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (StructType -> Diet
forall shape as. TypeBase shape as -> Diet
diet StructType
pt2, Maybe VName
ext2)) (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ret, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
retext) SrcLoc
loc

defuncExp (Project Name
vn Exp
e0 tp :: Info PatternType
tp@(Info PatternType
tp') SrcLoc
loc) = do
  (Exp
e0', StaticVal
sv0) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e0
  case StaticVal
sv0 of
    RecordSV [(Name, StaticVal)]
svs -> case Name -> [(Name, StaticVal)] -> Maybe StaticVal
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
vn [(Name, StaticVal)]
svs of
      Just StaticVal
sv -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Exp -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatternType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatternType
typeFromSV StaticVal
sv) SrcLoc
loc, StaticVal
sv)
      Maybe StaticVal
Nothing -> String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error String
"Invalid record projection."
    Dynamic PatternType
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> Exp -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> ExpBase f vn -> f PatternType -> SrcLoc -> ExpBase f vn
Project Name
vn Exp
e0' Info PatternType
tp SrcLoc
loc, PatternType -> StaticVal
Dynamic PatternType
tp')
    StaticVal
_ -> String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ String
"Projection of an expression with static value " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv0

defuncExp (LetWith IdentBase Info VName
id1 IdentBase Info VName
id2 [DimIndexBase Info VName]
idxs Exp
e1 Exp
body Info PatternType
t SrcLoc
loc) = do
  Exp
e1' <- Exp -> DefM Exp
defuncExp' Exp
e1
  StaticVal
sv1 <- SrcLoc -> VName -> DefM StaticVal
lookupVar (IdentBase Info VName -> SrcLoc
forall (f :: * -> *) vn. IdentBase f vn -> SrcLoc
identSrcLoc IdentBase Info VName
id2) (VName -> DefM StaticVal) -> VName -> DefM StaticVal
forall a b. (a -> b) -> a -> b
$ IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName IdentBase Info VName
id2
  [DimIndexBase Info VName]
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> [DimIndexBase Info VName] -> DefM [DimIndexBase Info VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex [DimIndexBase Info VName]
idxs
  (Exp
body', StaticVal
sv) <- VName
-> StaticVal -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. VName -> StaticVal -> DefM a -> DefM a
extendEnv (IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName IdentBase Info VName
id1) StaticVal
sv1 (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (IdentBase Info VName
-> IdentBase Info VName
-> [DimIndexBase Info VName]
-> Exp
-> Exp
-> Info PatternType
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
IdentBase f vn
-> IdentBase f vn
-> [DimIndexBase f vn]
-> ExpBase f vn
-> ExpBase f vn
-> f PatternType
-> SrcLoc
-> ExpBase f vn
LetWith IdentBase Info VName
id1 IdentBase Info VName
id2 [DimIndexBase Info VName]
idxs' Exp
e1' Exp
body' Info PatternType
t SrcLoc
loc, StaticVal
sv)

defuncExp expr :: Exp
expr@(Index Exp
e0 [DimIndexBase Info VName]
idxs (Info PatternType, Info [VName])
info SrcLoc
loc) = do
  Exp
e0' <- Exp -> DefM Exp
defuncExp' Exp
e0
  [DimIndexBase Info VName]
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> [DimIndexBase Info VName] -> DefM [DimIndexBase Info VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex [DimIndexBase Info VName]
idxs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> [DimIndexBase Info VName]
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [DimIndexBase f vn]
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Index Exp
e0' [DimIndexBase Info VName]
idxs' (Info PatternType, Info [VName])
info SrcLoc
loc, PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
expr)

defuncExp (Update Exp
e1 [DimIndexBase Info VName]
idxs Exp
e2 SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  [DimIndexBase Info VName]
idxs' <- (DimIndexBase Info VName -> DefM (DimIndexBase Info VName))
-> [DimIndexBase Info VName] -> DefM [DimIndexBase Info VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex [DimIndexBase Info VName]
idxs
  Exp
e2' <- Exp -> DefM Exp
defuncExp' Exp
e2
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> [DimIndexBase Info VName] -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [DimIndexBase f vn] -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Update Exp
e1' [DimIndexBase Info VName]
idxs' Exp
e2' SrcLoc
loc, StaticVal
sv)

-- Note that we might change the type of the record field here.  This
-- is not permitted in the type checker due to problems with type
-- inference, but it actually works fine.
defuncExp (RecordUpdate Exp
e1 [Name]
fs Exp
e2 Info PatternType
_ SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv1) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv2) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  let sv :: StaticVal
sv = StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv1 StaticVal
sv2 [Name]
fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> [Name] -> Exp -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> [Name]
-> ExpBase f vn
-> f PatternType
-> SrcLoc
-> ExpBase f vn
RecordUpdate Exp
e1' [Name]
fs Exp
e2' (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatternType
typeFromSV StaticVal
sv1) SrcLoc
loc,
          StaticVal
sv)
  where staticField :: StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (RecordSV [(Name, StaticVal)]
svs) StaticVal
sv2 (Name
f:[Name]
fs') =
          case Name -> [(Name, StaticVal)] -> Maybe StaticVal
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
f [(Name, StaticVal)]
svs of
            Just StaticVal
sv -> [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$
                       (Name
f, StaticVal -> StaticVal -> [Name] -> StaticVal
staticField StaticVal
sv StaticVal
sv2 [Name]
fs') (Name, StaticVal) -> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a. a -> [a] -> [a]
: ((Name, StaticVal) -> Bool)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/=Name
f) (Name -> Bool)
-> ((Name, StaticVal) -> Name) -> (Name, StaticVal) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst) [(Name, StaticVal)]
svs
            Maybe StaticVal
Nothing -> String -> StaticVal
forall a. HasCallStack => String -> a
error String
"Invalid record projection."
        staticField (Dynamic t :: PatternType
t@(Scalar Record{})) StaticVal
sv2 fs' :: [Name]
fs'@(Name
_:[Name]
_) =
          StaticVal -> StaticVal -> [Name] -> StaticVal
staticField (PatternType -> StaticVal
svFromType PatternType
t) StaticVal
sv2 [Name]
fs'
        staticField StaticVal
_ StaticVal
sv2 [Name]
_ = StaticVal
sv2

defuncExp (Unsafe Exp
e1 SrcLoc
loc) = do
  (Exp
e1', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Unsafe Exp
e1' SrcLoc
loc, StaticVal
sv)

defuncExp (Assert Exp
e1 Exp
e2 Info String
desc SrcLoc
loc) = do
  (Exp
e1', StaticVal
_) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
  (Exp
e2', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> Exp -> Info String -> SrcLoc -> Exp
forall (f :: * -> *) vn.
ExpBase f vn -> ExpBase f vn -> f String -> SrcLoc -> ExpBase f vn
Assert Exp
e1' Exp
e2' Info String
desc SrcLoc
loc, StaticVal
sv)

defuncExp (Constr Name
name [Exp]
es (Info (Scalar (Sum Map Name [PatternType]
all_fs))) SrcLoc
loc) = do
  ([Exp]
es', [StaticVal]
svs) <- [(Exp, StaticVal)] -> ([Exp], [StaticVal])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Exp, StaticVal)] -> ([Exp], [StaticVal]))
-> DefM [(Exp, StaticVal)] -> DefM ([Exp], [StaticVal])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> DefM (Exp, StaticVal)) -> [Exp] -> DefM [(Exp, StaticVal)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM (Exp, StaticVal)
defuncExp [Exp]
es
  let sv :: StaticVal
sv = Name -> [StaticVal] -> [(Name, [PatternType])] -> StaticVal
SumSV Name
name [StaticVal]
svs ([(Name, [PatternType])] -> StaticVal)
-> [(Name, [PatternType])] -> StaticVal
forall a b. (a -> b) -> a -> b
$ Map Name [PatternType] -> [(Name, [PatternType])]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name [PatternType] -> [(Name, [PatternType])])
-> Map Name [PatternType] -> [(Name, [PatternType])]
forall a b. (a -> b) -> a -> b
$
           Name
name Name -> Map Name [PatternType] -> Map Name [PatternType]
forall k a. Ord k => k -> Map k a -> Map k a
`M.delete` ([PatternType] -> [PatternType])
-> Map Name [PatternType] -> Map Name [PatternType]
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((PatternType -> PatternType) -> [PatternType] -> [PatternType]
forall a b. (a -> b) -> [a] -> [b]
map PatternType -> PatternType
forall als.
Monoid als =>
TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType) Map Name [PatternType]
all_fs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> [Exp] -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
Name -> [ExpBase f vn] -> f PatternType -> SrcLoc -> ExpBase f vn
Constr Name
name [Exp]
es' (PatternType -> Info PatternType
forall a. a -> Info a
Info (StaticVal -> PatternType
typeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)
  where defuncType :: Monoid als =>
                      TypeBase (DimDecl VName) als
                   -> TypeBase (DimDecl VName) als
        defuncType :: TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType (Array als
as Uniqueness
u ScalarTypeBase (DimDecl VName) ()
t ShapeDecl (DimDecl VName)
shape) = als
-> Uniqueness
-> ScalarTypeBase (DimDecl VName) ()
-> ShapeDecl (DimDecl VName)
-> TypeBase (DimDecl VName) als
forall dim as.
as
-> Uniqueness
-> ScalarTypeBase dim ()
-> ShapeDecl dim
-> TypeBase dim as
Array als
as Uniqueness
u (ScalarTypeBase (DimDecl VName) ()
-> ScalarTypeBase (DimDecl VName) ()
forall als.
Monoid als =>
ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
defuncScalar ScalarTypeBase (DimDecl VName) ()
t) ShapeDecl (DimDecl VName)
shape
        defuncType (Scalar ScalarTypeBase (DimDecl VName) als
t) = ScalarTypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) als
 -> TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
-> TypeBase (DimDecl VName) als
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
forall als.
Monoid als =>
ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
defuncScalar ScalarTypeBase (DimDecl VName) als
t

        defuncScalar :: Monoid als =>
                        ScalarTypeBase (DimDecl VName) als
                     -> ScalarTypeBase (DimDecl VName) als
        defuncScalar :: ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
defuncScalar (Record Map Name (TypeBase (DimDecl VName) als)
fs) = Map Name (TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase (DimDecl VName) als)
 -> ScalarTypeBase (DimDecl VName) als)
-> Map Name (TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
forall a b. (a -> b) -> a -> b
$ (TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als)
-> Map Name (TypeBase (DimDecl VName) als)
-> Map Name (TypeBase (DimDecl VName) als)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
forall als.
Monoid als =>
TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType Map Name (TypeBase (DimDecl VName) als)
fs
        defuncScalar Arrow{} = Map Name (TypeBase (DimDecl VName) als)
-> ScalarTypeBase (DimDecl VName) als
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record Map Name (TypeBase (DimDecl VName) als)
forall a. Monoid a => a
mempty
        defuncScalar (Sum Map Name [TypeBase (DimDecl VName) als]
fs) = Map Name [TypeBase (DimDecl VName) als]
-> ScalarTypeBase (DimDecl VName) als
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [TypeBase (DimDecl VName) als]
 -> ScalarTypeBase (DimDecl VName) als)
-> Map Name [TypeBase (DimDecl VName) als]
-> ScalarTypeBase (DimDecl VName) als
forall a b. (a -> b) -> a -> b
$ ([TypeBase (DimDecl VName) als] -> [TypeBase (DimDecl VName) als])
-> Map Name [TypeBase (DimDecl VName) als]
-> Map Name [TypeBase (DimDecl VName) als]
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als)
-> [TypeBase (DimDecl VName) als] -> [TypeBase (DimDecl VName) als]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
forall als.
Monoid als =>
TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
defuncType) Map Name [TypeBase (DimDecl VName) als]
fs
        defuncScalar (Prim PrimType
t) = PrimType -> ScalarTypeBase (DimDecl VName) als
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
t
        defuncScalar (TypeVar als
as Uniqueness
u TypeName
tn [TypeArg (DimDecl VName)]
targs) = als
-> Uniqueness
-> TypeName
-> [TypeArg (DimDecl VName)]
-> ScalarTypeBase (DimDecl VName) als
forall dim as.
as
-> Uniqueness -> TypeName -> [TypeArg dim] -> ScalarTypeBase dim as
TypeVar als
as Uniqueness
u TypeName
tn [TypeArg (DimDecl VName)]
targs

defuncExp (Constr Name
name [Exp]
_ (Info PatternType
t) SrcLoc
loc) =
  String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ String
"Constructor " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
pretty Name
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" given type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
  PatternType -> String
forall a. Pretty a => a -> String
pretty PatternType
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" at " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> String
forall a. Located a => a -> String
locStr SrcLoc
loc

defuncExp (Match Exp
e NonEmpty (CaseBase Info VName)
cs (Info PatternType, Info [VName])
t SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  NonEmpty (CaseBase Info VName, StaticVal)
csPairs  <- (CaseBase Info VName -> DefM (CaseBase Info VName, StaticVal))
-> NonEmpty (CaseBase Info VName)
-> DefM (NonEmpty (CaseBase Info VName, StaticVal))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (StaticVal
-> CaseBase Info VName -> DefM (CaseBase Info VName, StaticVal)
defuncCase StaticVal
sv) NonEmpty (CaseBase Info VName)
cs
  let cs' :: NonEmpty (CaseBase Info VName)
cs' = ((CaseBase Info VName, StaticVal) -> CaseBase Info VName)
-> NonEmpty (CaseBase Info VName, StaticVal)
-> NonEmpty (CaseBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CaseBase Info VName, StaticVal) -> CaseBase Info VName
forall a b. (a, b) -> a
fst NonEmpty (CaseBase Info VName, StaticVal)
csPairs
      sv' :: StaticVal
sv' = (CaseBase Info VName, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd ((CaseBase Info VName, StaticVal) -> StaticVal)
-> (CaseBase Info VName, StaticVal) -> StaticVal
forall a b. (a -> b) -> a -> b
$ NonEmpty (CaseBase Info VName, StaticVal)
-> (CaseBase Info VName, StaticVal)
forall a. NonEmpty a -> a
NE.head NonEmpty (CaseBase Info VName, StaticVal)
csPairs
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> NonEmpty (CaseBase Info VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> NonEmpty (CaseBase f vn)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Match Exp
e' NonEmpty (CaseBase Info VName)
cs' (Info PatternType, Info [VName])
t SrcLoc
loc, StaticVal
sv')

defuncExp (Attr AttrInfo
info Exp
e SrcLoc
loc) = do
  (Exp
e', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (AttrInfo -> Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn.
AttrInfo -> ExpBase f vn -> SrcLoc -> ExpBase f vn
Attr AttrInfo
info Exp
e' SrcLoc
loc, StaticVal
sv)

-- | Same as 'defuncExp', except it ignores the static value.
defuncExp' :: Exp -> DefM Exp
defuncExp' :: Exp -> DefM Exp
defuncExp' = ((Exp, StaticVal) -> Exp) -> DefM (Exp, StaticVal) -> DefM Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Exp, StaticVal) -> Exp
forall a b. (a, b) -> a
fst (DefM (Exp, StaticVal) -> DefM Exp)
-> (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> DefM (Exp, StaticVal)
defuncExp

defuncExtExp :: ExtExp -> DefM (Exp, StaticVal)
defuncExtExp :: ExtExp -> DefM (Exp, StaticVal)
defuncExtExp (ExtExp Exp
e) = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
defuncExtExp (ExtLambda [TypeParam]
tparams [Pattern]
pats Exp
e0 (Aliasing
closure, StructType
ret) SrcLoc
loc) =
  [TypeParam]
-> [Pattern]
-> Exp
-> (Aliasing, StructType)
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [TypeParam]
tparams [Pattern]
pats Exp
e0 (Aliasing
closure, StructType
ret) SrcLoc
loc

defuncCase :: StaticVal -> Case -> DefM (Case, StaticVal)
defuncCase :: StaticVal
-> CaseBase Info VName -> DefM (CaseBase Info VName, StaticVal)
defuncCase StaticVal
sv (CasePat Pattern
p Exp
e SrcLoc
loc) = do
  let p' :: Pattern
p'  = Pattern -> StaticVal -> Pattern
updatePattern Pattern
p StaticVal
sv
      env :: Env
env = Pattern -> StaticVal -> Env
matchPatternSV Pattern
p StaticVal
sv
  (Exp
e', StaticVal
sv') <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e
  (CaseBase Info VName, StaticVal)
-> DefM (CaseBase Info VName, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Pattern -> Exp -> SrcLoc -> CaseBase Info VName
forall (f :: * -> *) vn.
PatternBase f vn -> ExpBase f vn -> SrcLoc -> CaseBase f vn
CasePat Pattern
p' Exp
e' SrcLoc
loc, StaticVal
sv')

-- | Defunctionalize the function argument to a SOAC by eta-expanding if
-- necessary and then defunctionalizing the body of the introduced lambda.
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp e :: Exp
e@OpSection{}      = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
defuncSoacExp e :: Exp
e@OpSectionLeft{}  = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
defuncSoacExp e :: Exp
e@OpSectionRight{} = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e
defuncSoacExp e :: Exp
e@ProjectSection{} = Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return Exp
e

defuncSoacExp (Parens Exp
e SrcLoc
loc) =
  Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens (Exp -> SrcLoc -> Exp) -> DefM Exp -> DefM (SrcLoc -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM Exp
defuncSoacExp Exp
e DefM (SrcLoc -> Exp) -> DefM SrcLoc -> DefM Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> DefM SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
loc

defuncSoacExp (Lambda [Pattern]
params Exp
e0 Maybe (TypeExp VName)
decl Info (Aliasing, StructType)
tp SrcLoc
loc) = do
  let env :: Env
env = (Pattern -> Env) -> [Pattern] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Env
envFromPattern [Pattern]
params
  Exp
e0' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncSoacExp Exp
e0
  Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> DefM Exp) -> Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ [Pattern]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, StructType)
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatternBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, StructType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pattern]
params Exp
e0' Maybe (TypeExp VName)
decl Info (Aliasing, StructType)
tp SrcLoc
loc

defuncSoacExp Exp
e
  | Scalar Arrow{} <- Exp -> PatternType
typeOf Exp
e = do
      ([Pattern]
pats, Exp
body, StructType
tp) <- PatternType -> Exp -> DefM ([Pattern], Exp, StructType)
etaExpand (Exp -> PatternType
typeOf Exp
e) Exp
e
      let env :: Env
env = (Pattern -> Env) -> [Pattern] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Env
envFromPattern [Pattern]
pats
      Exp
body' <- Env -> DefM Exp -> DefM Exp
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM Exp -> DefM Exp) -> DefM Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> DefM Exp
defuncExp' Exp
body
      Exp -> DefM Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> DefM Exp) -> Exp -> DefM Exp
forall a b. (a -> b) -> a -> b
$ [Pattern]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, StructType)
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatternBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, StructType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pattern]
pats Exp
body' Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, StructType) -> Info (Aliasing, StructType)
forall a. a -> Info a
Info (Aliasing
forall a. Monoid a => a
mempty, StructType
tp)) SrcLoc
forall a. Monoid a => a
mempty
  | Bool
otherwise = Exp -> DefM Exp
defuncExp' Exp
e

etaExpand :: PatternType -> Exp -> DefM ([Pattern], Exp, StructType)
etaExpand :: PatternType -> Exp -> DefM ([Pattern], Exp, StructType)
etaExpand PatternType
e_t Exp
e = do
  let ([(PName, PatternType)]
ps, PatternType
ret) = PatternType -> ([(PName, PatternType)], PatternType)
forall dim as.
TypeBase dim as -> ([(PName, TypeBase dim as)], TypeBase dim as)
getType PatternType
e_t
  ([Pattern]
pats, [Exp]
vars) <- ([(Pattern, Exp)] -> ([Pattern], [Exp]))
-> DefM [(Pattern, Exp)] -> DefM ([Pattern], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Pattern, Exp)] -> ([Pattern], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip (DefM [(Pattern, Exp)] -> DefM ([Pattern], [Exp]))
-> (((PName, PatternType) -> DefM (Pattern, Exp))
    -> DefM [(Pattern, Exp)])
-> ((PName, PatternType) -> DefM (Pattern, Exp))
-> DefM ([Pattern], [Exp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(PName, PatternType)]
-> ((PName, PatternType) -> DefM (Pattern, Exp))
-> DefM [(Pattern, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(PName, PatternType)]
ps (((PName, PatternType) -> DefM (Pattern, Exp))
 -> DefM ([Pattern], [Exp]))
-> ((PName, PatternType) -> DefM (Pattern, Exp))
-> DefM ([Pattern], [Exp])
forall a b. (a -> b) -> a -> b
$ \(PName
p, PatternType
t) -> do
    VName
x <- case PName
p of Named VName
x -> VName -> DefM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
x
                   PName
Unnamed -> String -> DefM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString String
"x"
    (Pattern, Exp) -> DefM (Pattern, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id VName
x (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
forall a. Monoid a => a
mempty,
            QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
x) (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) SrcLoc
forall a. Monoid a => a
mempty)
  let e' :: Exp
e' = (Exp -> (Exp, PatternType, [PatternType]) -> Exp)
-> Exp -> [(Exp, PatternType, [PatternType])] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Exp
e1 (Exp
e2, PatternType
t2, [PatternType]
argtypes) ->
                     Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply Exp
e1 Exp
e2 ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (PatternType -> Diet
forall shape as. TypeBase shape as -> Diet
diet PatternType
t2, Maybe VName
forall a. Maybe a
Nothing))
                     (PatternType -> Info PatternType
forall a. a -> Info a
Info ([PatternType] -> PatternType -> PatternType
forall as dim.
Monoid as =>
[TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType [PatternType]
argtypes PatternType
ret), [VName] -> Info [VName]
forall a. a -> Info a
Info []) SrcLoc
forall a. Monoid a => a
mempty)
           Exp
e ([(Exp, PatternType, [PatternType])] -> Exp)
-> [(Exp, PatternType, [PatternType])] -> Exp
forall a b. (a -> b) -> a -> b
$ [Exp]
-> [PatternType]
-> [[PatternType]]
-> [(Exp, PatternType, [PatternType])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Exp]
vars (((PName, PatternType) -> PatternType)
-> [(PName, PatternType)] -> [PatternType]
forall a b. (a -> b) -> [a] -> [b]
map (PName, PatternType) -> PatternType
forall a b. (a, b) -> b
snd [(PName, PatternType)]
ps) (Int -> [[PatternType]] -> [[PatternType]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[PatternType]] -> [[PatternType]])
-> [[PatternType]] -> [[PatternType]]
forall a b. (a -> b) -> a -> b
$ [PatternType] -> [[PatternType]]
forall a. [a] -> [[a]]
tails ([PatternType] -> [[PatternType]])
-> [PatternType] -> [[PatternType]]
forall a b. (a -> b) -> a -> b
$ ((PName, PatternType) -> PatternType)
-> [(PName, PatternType)] -> [PatternType]
forall a b. (a -> b) -> [a] -> [b]
map (PName, PatternType) -> PatternType
forall a b. (a, b) -> b
snd [(PName, PatternType)]
ps)
  ([Pattern], Exp, StructType) -> DefM ([Pattern], Exp, StructType)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Pattern]
pats, Exp
e', PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
ret)

  where getType :: TypeBase dim as -> ([(PName, TypeBase dim as)], TypeBase dim as)
getType (Scalar (Arrow as
_ PName
p TypeBase dim as
t1 TypeBase dim as
t2)) =
          let ([(PName, TypeBase dim as)]
ps, TypeBase dim as
r) = TypeBase dim as -> ([(PName, TypeBase dim as)], TypeBase dim as)
getType TypeBase dim as
t2 in ((PName
p,TypeBase dim as
t1) (PName, TypeBase dim as)
-> [(PName, TypeBase dim as)] -> [(PName, TypeBase dim as)]
forall a. a -> [a] -> [a]
: [(PName, TypeBase dim as)]
ps, TypeBase dim as
r)
        getType TypeBase dim as
t = ([], TypeBase dim as
t)

-- | Defunctionalize an indexing of a single array dimension.
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex (DimFix Exp
e1) = Exp -> DimIndexBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> DimIndexBase f vn
DimFix (Exp -> DimIndexBase Info VName)
-> ((Exp, StaticVal) -> Exp)
-> (Exp, StaticVal)
-> DimIndexBase Info VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp, StaticVal) -> Exp
forall a b. (a, b) -> a
fst ((Exp, StaticVal) -> DimIndexBase Info VName)
-> DefM (Exp, StaticVal) -> DefM (DimIndexBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e1
defuncDimIndex (DimSlice Maybe Exp
me1 Maybe Exp
me2 Maybe Exp
me3) =
  Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndexBase Info VName
forall (f :: * -> *) vn.
Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> Maybe (ExpBase f vn)
-> DimIndexBase f vn
DimSlice (Maybe Exp -> Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp)
-> DefM (Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me1 DefM (Maybe Exp -> Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp) -> DefM (Maybe Exp -> DimIndexBase Info VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me2 DefM (Maybe Exp -> DimIndexBase Info VName)
-> DefM (Maybe Exp) -> DefM (DimIndexBase Info VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe Exp -> DefM (Maybe Exp)
defunc' Maybe Exp
me3
  where defunc' :: Maybe Exp -> DefM (Maybe Exp)
defunc' = (Exp -> DefM Exp) -> Maybe Exp -> DefM (Maybe Exp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Exp -> DefM Exp
defuncExp'

-- | Defunctionalize a let-bound function, while preserving parameters
-- that have order 0 types (i.e., non-functional).
defuncLet :: [TypeParam] -> [Pattern] -> Exp -> StructType
          -> DefM ([TypeParam], [Pattern], Exp, StaticVal)
defuncLet :: [TypeParam]
-> [Pattern]
-> Exp
-> StructType
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
defuncLet [TypeParam]
dims ps :: [Pattern]
ps@(Pattern
pat:[Pattern]
pats) Exp
body StructType
rettype
  | Pattern -> Bool
forall vn. PatternBase Info vn -> Bool
patternOrderZero Pattern
pat = do

      let bound_by_pat :: TypeParam -> Bool
bound_by_pat = (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Pattern -> Set VName
patternDimNames Pattern
pat) (VName -> Bool) -> (TypeParam -> VName) -> TypeParam -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName
          -- Take care to not include more size parameters than necessary.
          ([TypeParam]
pat_dims, [TypeParam]
rest_dims) = (TypeParam -> Bool) -> [TypeParam] -> ([TypeParam], [TypeParam])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition TypeParam -> Bool
bound_by_pat [TypeParam]
dims
          env :: Env
env = Pattern -> Env
envFromPattern Pattern
pat Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> [TypeParam] -> Env
envFromShapeParams [TypeParam]
pat_dims
      ([TypeParam]
rest_dims', [Pattern]
pats', Exp
body', StaticVal
sv) <- Env
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM ([TypeParam], [Pattern], Exp, StaticVal)
 -> DefM ([TypeParam], [Pattern], Exp, StaticVal))
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [TypeParam]
-> [Pattern]
-> Exp
-> StructType
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
defuncLet [TypeParam]
rest_dims [Pattern]
pats Exp
body StructType
rettype
      (Exp, StaticVal)
closure <- [TypeParam]
-> [Pattern]
-> Exp
-> (Aliasing, StructType)
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [TypeParam]
dims [Pattern]
ps Exp
body (Aliasing
forall a. Monoid a => a
mempty, StructType
rettype) SrcLoc
forall a. Monoid a => a
mempty
      ([TypeParam], [Pattern], Exp, StaticVal)
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeParam]
pat_dims [TypeParam] -> [TypeParam] -> [TypeParam]
forall a. [a] -> [a] -> [a]
++ [TypeParam]
rest_dims', Pattern
pat Pattern -> [Pattern] -> [Pattern]
forall a. a -> [a] -> [a]
: [Pattern]
pats', Exp
body', (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
closure StaticVal
sv)
  | Bool
otherwise = do
      (Exp
e, StaticVal
sv) <- [TypeParam]
-> [Pattern]
-> Exp
-> (Aliasing, StructType)
-> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun [TypeParam]
dims [Pattern]
ps Exp
body (Aliasing
forall a. Monoid a => a
mempty, StructType
rettype) SrcLoc
forall a. Monoid a => a
mempty
      ([TypeParam], [Pattern], Exp, StaticVal)
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [], Exp
e, StaticVal
sv)

defuncLet [TypeParam]
_ [] Exp
body StructType
rettype = do
  (Exp
body', StaticVal
sv) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
body
  ([TypeParam], [Pattern], Exp, StaticVal)
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [], Exp
body', StaticVal -> StructType -> StaticVal
forall as. StaticVal -> TypeBase (DimDecl VName) as -> StaticVal
imposeType StaticVal
sv StructType
rettype)
  where imposeType :: StaticVal -> TypeBase (DimDecl VName) as -> StaticVal
imposeType Dynamic{} TypeBase (DimDecl VName) as
t =
          PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ TypeBase (DimDecl VName) as -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase (DimDecl VName) as
t
        imposeType (RecordSV [(Name, StaticVal)]
fs1) (Scalar (Record Map Name (TypeBase (DimDecl VName) as)
fs2)) =
          [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> [(Name, StaticVal)] -> StaticVal
forall a b. (a -> b) -> a -> b
$ Map Name StaticVal -> [(Name, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name StaticVal -> [(Name, StaticVal)])
-> Map Name StaticVal -> [(Name, StaticVal)]
forall a b. (a -> b) -> a -> b
$ (StaticVal -> TypeBase (DimDecl VName) as -> StaticVal)
-> Map Name StaticVal
-> Map Name (TypeBase (DimDecl VName) as)
-> Map Name StaticVal
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith StaticVal -> TypeBase (DimDecl VName) as -> StaticVal
imposeType ([(Name, StaticVal)] -> Map Name StaticVal
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, StaticVal)]
fs1) Map Name (TypeBase (DimDecl VName) as)
fs2
        imposeType StaticVal
sv TypeBase (DimDecl VName) as
_ = StaticVal
sv

-- | Defunctionalize an application expression at a given depth of application.
-- Calls to dynamic (first-order) functions are preserved at much as possible,
-- but a new lifted function is created if a dynamic function is only partially
-- applied.
defuncApply :: Int -> Exp -> DefM (Exp, StaticVal)
defuncApply :: Int -> Exp -> DefM (Exp, StaticVal)
defuncApply Int
depth e :: Exp
e@(Apply Exp
e1 Exp
e2 Info (Diet, Maybe VName)
d t :: (Info PatternType, Info [VName])
t@(Info PatternType
ret, Info [VName]
ext) SrcLoc
loc) = do
  let ([PatternType]
argtypes, PatternType
_) = PatternType -> ([PatternType], PatternType)
forall dim as.
TypeBase dim as -> ([TypeBase dim as], TypeBase dim as)
unfoldFunType PatternType
ret
  (Exp
e1', StaticVal
sv1) <- Int -> Exp -> DefM (Exp, StaticVal)
defuncApply (Int
depthInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Exp
e1
  (Exp
e2', StaticVal
sv2) <- Exp -> DefM (Exp, StaticVal)
defuncExp Exp
e2
  let e' :: Exp
e' = Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply Exp
e1' Exp
e2' Info (Diet, Maybe VName)
d (Info PatternType, Info [VName])
t SrcLoc
loc
  case StaticVal
sv1 of
    LambdaSV [VName]
dims Pattern
pat StructType
e0_t ExtExp
e0 Env
closure_env -> do
      let env' :: Env
env' = Pattern -> StaticVal -> Env
matchPatternSV Pattern
pat StaticVal
sv2
          env_dim :: Env
env_dim = [VName] -> Env
envFromDimNames [VName]
dims
      (Exp
e0', StaticVal
sv) <- Env -> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a. Env -> DefM a -> DefM a
localNewEnv (Env
env' Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
closure_env Env -> Env -> Env
forall a. Semigroup a => a -> a -> a
<> Env
env_dim) (DefM (Exp, StaticVal) -> DefM (Exp, StaticVal))
-> DefM (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ ExtExp -> DefM (Exp, StaticVal)
defuncExtExp ExtExp
e0

      let closure_pat :: Pattern
closure_pat = Env -> Pattern
buildEnvPattern Env
closure_env
          pat' :: Pattern
pat' = Pattern -> StaticVal -> Pattern
updatePattern Pattern
pat StaticVal
sv2

      -- Lift lambda to top-level function definition.  We put in
      -- a lot of effort to try to infer the uniqueness attributes
      -- of the lifted function, but this is ultimately all a sham
      -- and a hack.  There is some piece we're missing.
      let params :: [Pattern]
params = [Pattern
closure_pat, Pattern
pat']
          params_for_rettype :: [Pattern]
params_for_rettype = [Pattern]
params [Pattern] -> [Pattern] -> [Pattern]
forall a. [a] -> [a] -> [a]
++ StaticVal -> [Pattern]
svParams StaticVal
sv1 [Pattern] -> [Pattern] -> [Pattern]
forall a. [a] -> [a] -> [a]
++ StaticVal -> [Pattern]
svParams StaticVal
sv2
          svParams :: StaticVal -> [Pattern]
svParams (LambdaSV [VName]
_ Pattern
sv_pat StructType
_ ExtExp
_ Env
_) = [Pattern
sv_pat]
          svParams StaticVal
_                         = []
          rettype :: PatternType
rettype = Env -> [Pattern] -> StructType -> PatternType -> PatternType
buildRetType Env
closure_env [Pattern]
params_for_rettype StructType
e0_t (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e0'

          already_bound :: Set VName
already_bound = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<>
                          (IdentBase Info VName -> VName)
-> Set (IdentBase Info VName) -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName ((Pattern -> Set (IdentBase Info VName))
-> [Pattern] -> Set (IdentBase Info VName)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents [Pattern]
params)
          more_dims :: [VName]
more_dims = Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$
                      (VName -> Bool) -> Set VName -> Set VName
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
already_bound) (Set VName -> Set VName) -> Set VName -> Set VName
forall a b. (a -> b) -> a -> b
$
                      (Pattern -> Set VName) -> [Pattern] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Set VName
patternArraySizes [Pattern]
params

          -- Embed some information about the original function
          -- into the name of the lifted function, to make the
          -- result slightly more human-readable.
          liftedName :: t -> ExpBase f VName -> String
liftedName t
i (Var QualName VName
f f PatternType
_ SrcLoc
_) =
            String
"lifted_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ t -> String
forall a. Show a => a -> String
show t
i String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
baseString (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
f)
          liftedName t
i (Apply ExpBase f VName
f ExpBase f VName
_ f (Diet, Maybe VName)
_ (f PatternType, f [VName])
_ SrcLoc
_) =
            t -> ExpBase f VName -> String
liftedName (t
it -> t -> t
forall a. Num a => a -> a -> a
+t
1) ExpBase f VName
f
          liftedName t
_ ExpBase f VName
_ = String
"lifted"

      VName
fname <- String -> DefM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newNameFromString (String -> DefM VName) -> String -> DefM VName
forall a b. (a -> b) -> a -> b
$ Int -> Exp -> String
forall t (f :: * -> *).
(Show t, Num t) =>
t -> ExpBase f VName -> String
liftedName (Int
0::Int) Exp
e1
      VName -> PatternType -> [VName] -> [Pattern] -> Exp -> DefM ()
liftValDec VName
fname PatternType
rettype ([VName]
dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
more_dims) [Pattern]
params Exp
e0'

      let t1 :: StructType
t1 = PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> PatternType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e1'
          t2 :: StructType
t2 = PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> PatternType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e2'
          fname' :: QualName VName
fname' = VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname
          fname'' :: Exp
fname'' = QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var QualName VName
fname' (PatternType -> Info PatternType
forall a. a -> Info a
Info (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed (StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t1) (PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$
                                      ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed (StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t2) PatternType
rettype))
                    SrcLoc
loc

          -- FIXME: what if this application returns both a function
          -- and a value?
          callret :: (Info PatternType, Info [VName])
callret | PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatternType
ret = (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ret, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
ext)
                  | Bool
otherwise     = (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
rettype, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
ext)

      (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> SrcLoc -> Exp
forall (f :: * -> *) vn. ExpBase f vn -> SrcLoc -> ExpBase f vn
Parens (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply Exp
fname'' Exp
e1'
                              ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (Diet
Observe, Maybe VName
forall a. Maybe a
Nothing))
                              (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Aliasing
-> PName
-> PatternType
-> PatternType
-> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as.
as
-> PName
-> TypeBase dim as
-> TypeBase dim as
-> ScalarTypeBase dim as
Arrow Aliasing
forall a. Monoid a => a
mempty PName
Unnamed (StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
t2) PatternType
rettype,
                               [VName] -> Info [VName]
forall a. a -> Info a
Info [])
                              SrcLoc
loc)
                      Exp
e2' Info (Diet, Maybe VName)
d (Info PatternType, Info [VName])
callret SrcLoc
loc) SrcLoc
forall a. Monoid a => a
mempty, StaticVal
sv)

    -- If e1 is a dynamic function, we just leave the application in place,
    -- but we update the types since it may be partially applied or return
    -- a higher-order term.
    DynamicFun (Exp, StaticVal)
_ StaticVal
sv ->
      let ([PatternType]
argtypes', PatternType
rettype) = StaticVal -> [PatternType] -> ([PatternType], PatternType)
dynamicFunType StaticVal
sv [PatternType]
argtypes
          restype :: PatternType
restype = [PatternType] -> PatternType -> PatternType
forall as dim.
Monoid as =>
[TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType [PatternType]
argtypes' PatternType
rettype PatternType -> Aliasing -> PatternType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` PatternType -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases PatternType
ret
          -- FIXME: what if this application returns both a function
          -- and a value?
          callret :: (Info PatternType, Info [VName])
callret | PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatternType
ret = (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
ret, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
ext)
                  | Bool
otherwise     = (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
restype, [VName] -> Info [VName]
forall a. a -> Info a
Info [VName]
ext)
          apply_e :: Exp
apply_e = Exp
-> Exp
-> Info (Diet, Maybe VName)
-> (Info PatternType, Info [VName])
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> (f PatternType, f [VName])
-> SrcLoc
-> ExpBase f vn
Apply Exp
e1' Exp
e2' Info (Diet, Maybe VName)
d (Info PatternType, Info [VName])
callret SrcLoc
loc
      in (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
apply_e, StaticVal
sv)

    -- Propagate the 'IntrinsicsSV' until we reach the outermost application,
    -- where we construct a dynamic static value with the appropriate type.
    StaticVal
IntrinsicSV
      | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
          -- If the intrinsic is fully applied, then we are done.
          -- Otherwise we need to eta-expand it and recursively
          -- defunctionalise. XXX: might it be better to simply
          -- eta-expand immediately any time we encounter a
          -- non-fully-applied intrinsic?
          if [PatternType] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [PatternType]
argtypes
            then (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e', PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
e)
            else do ([Pattern]
pats, Exp
body, StructType
tp) <- PatternType -> Exp -> DefM ([Pattern], Exp, StructType)
etaExpand (Exp -> PatternType
typeOf Exp
e') Exp
e'
                    Exp -> DefM (Exp, StaticVal)
defuncExp (Exp -> DefM (Exp, StaticVal)) -> Exp -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ [Pattern]
-> Exp
-> Maybe (TypeExp VName)
-> Info (Aliasing, StructType)
-> SrcLoc
-> Exp
forall (f :: * -> *) vn.
[PatternBase f vn]
-> ExpBase f vn
-> Maybe (TypeExp vn)
-> f (Aliasing, StructType)
-> SrcLoc
-> ExpBase f vn
Lambda [Pattern]
pats Exp
body Maybe (TypeExp VName)
forall a. Maybe a
Nothing ((Aliasing, StructType) -> Info (Aliasing, StructType)
forall a. a -> Info a
Info (Aliasing
forall a. Monoid a => a
mempty, StructType
tp)) SrcLoc
forall a. Monoid a => a
mempty
      | Bool
otherwise -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e', StaticVal
IntrinsicSV)

    StaticVal
_ -> String -> DefM (Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> DefM (Exp, StaticVal))
-> String -> DefM (Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ String
"Application of an expression that is neither a static lambda "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"nor a dynamic function, but has static value: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv1

defuncApply Int
depth e :: Exp
e@(Var QualName VName
qn (Info PatternType
t) SrcLoc
loc) = do
    let ([PatternType]
argtypes, PatternType
_) = PatternType -> ([PatternType], PatternType)
forall dim as.
TypeBase dim as -> ([TypeBase dim as], TypeBase dim as)
unfoldFunType PatternType
t
    StaticVal
sv <- SrcLoc -> VName -> DefM StaticVal
lookupVar SrcLoc
loc (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn)
    case StaticVal
sv of
      DynamicFun (Exp, StaticVal)
_ StaticVal
_
        | StaticVal -> Int -> Bool
fullyApplied StaticVal
sv Int
depth ->
            -- We still need to update the types in case the dynamic
            -- function returns a higher-order term.
            let ([PatternType]
argtypes', PatternType
rettype) = StaticVal -> [PatternType] -> ([PatternType], PatternType)
dynamicFunType StaticVal
sv [PatternType]
argtypes
            in (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (PatternType -> Info PatternType
forall a. a -> Info a
Info ([PatternType] -> PatternType -> PatternType
forall as dim.
Monoid as =>
[TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType [PatternType]
argtypes' PatternType
rettype)) SrcLoc
loc, StaticVal
sv)

        | Bool
otherwise -> do
            VName
fname <- VName -> DefM VName
forall (m :: * -> *). MonadFreshNames m => VName -> m VName
newName (VName -> DefM VName) -> VName -> DefM VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn
            let ([VName]
dims, [Pattern]
pats, Exp
e0, StaticVal
sv') = StaticVal -> Int -> ([VName], [Pattern], Exp, StaticVal)
liftDynFun StaticVal
sv Int
depth
                pats_names :: Set VName
pats_names = (IdentBase Info VName -> VName)
-> Set (IdentBase Info VName) -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName (Set (IdentBase Info VName) -> Set VName)
-> Set (IdentBase Info VName) -> Set VName
forall a b. (a -> b) -> a -> b
$ [Set (IdentBase Info VName)] -> Set (IdentBase Info VName)
forall a. Monoid a => [a] -> a
mconcat ([Set (IdentBase Info VName)] -> Set (IdentBase Info VName))
-> [Set (IdentBase Info VName)] -> Set (IdentBase Info VName)
forall a b. (a -> b) -> a -> b
$ (Pattern -> Set (IdentBase Info VName))
-> [Pattern] -> [Set (IdentBase Info VName)]
forall a b. (a -> b) -> [a] -> [b]
map Pattern -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents [Pattern]
pats
                notInPats :: VName -> Bool
notInPats = (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set VName
pats_names)
                dims' :: [VName]
dims' = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter VName -> Bool
notInPats [VName]
dims
                ([PatternType]
argtypes', PatternType
rettype) = StaticVal -> [PatternType] -> ([PatternType], PatternType)
dynamicFunType StaticVal
sv' [PatternType]
argtypes
            VName -> PatternType -> [VName] -> [Pattern] -> Exp -> DefM ()
liftValDec VName
fname (PatternType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct PatternType
rettype) [VName]
dims' [Pattern]
pats Exp
e0
            (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname)
                    (PatternType -> Info PatternType
forall a. a -> Info a
Info ([PatternType] -> PatternType -> PatternType
forall as dim.
Monoid as =>
[TypeBase dim as] -> TypeBase dim as -> TypeBase dim as
foldFunType [PatternType]
argtypes' (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ PatternType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct PatternType
rettype)) SrcLoc
loc, StaticVal
sv')

      StaticVal
IntrinsicSV -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
e, StaticVal
IntrinsicSV)

      StaticVal
_ -> (Exp, StaticVal) -> DefM (Exp, StaticVal)
forall (m :: * -> *) a. Monad m => a -> m a
return (QualName VName -> Info PatternType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatternType -> SrcLoc -> ExpBase f vn
Var QualName VName
qn (PatternType -> Info PatternType
forall a. a -> Info a
Info (StaticVal -> PatternType
typeFromSV StaticVal
sv)) SrcLoc
loc, StaticVal
sv)

defuncApply Int
depth (Parens Exp
e SrcLoc
_) = Int -> Exp -> DefM (Exp, StaticVal)
defuncApply Int
depth Exp
e

defuncApply Int
_ Exp
expr = Exp -> DefM (Exp, StaticVal)
defuncExp Exp
expr

-- | Check if a 'StaticVal' and a given application depth corresponds
-- to a fully applied dynamic function.
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) Int
depth
  | Int
depth Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0   = Bool
False
  | Int
depth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>  Int
0   = StaticVal -> Int -> Bool
fullyApplied StaticVal
sv (Int
depthInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
fullyApplied StaticVal
_ Int
_ = Bool
True

-- | Converts a dynamic function 'StaticVal' into a list of
-- dimensions, a list of parameters, a function body, and the
-- appropriate static value for applying the function at the given
-- depth of partial application.
liftDynFun :: StaticVal -> Int -> ([VName], [Pattern], Exp, StaticVal)
liftDynFun :: StaticVal -> Int -> ([VName], [Pattern], Exp, StaticVal)
liftDynFun (DynamicFun (Exp
e, StaticVal
sv) StaticVal
_) Int
0 = ([], [], Exp
e, StaticVal
sv)
liftDynFun (DynamicFun clsr :: (Exp, StaticVal)
clsr@(Exp
_, LambdaSV [VName]
dims Pattern
pat StructType
_ ExtExp
_ Env
_) StaticVal
sv) Int
d
  | Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 =  let ([VName]
dims', [Pattern]
pats, Exp
e', StaticVal
sv') = StaticVal -> Int -> ([VName], [Pattern], Exp, StaticVal)
liftDynFun StaticVal
sv (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
             in ([VName] -> [VName]
forall a. Eq a => [a] -> [a]
nub ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ [VName]
dims [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
dims', Pattern
pat Pattern -> [Pattern] -> [Pattern]
forall a. a -> [a] -> [a]
: [Pattern]
pats, Exp
e', (Exp, StaticVal) -> StaticVal -> StaticVal
DynamicFun (Exp, StaticVal)
clsr StaticVal
sv')
liftDynFun StaticVal
sv Int
_ = String -> ([VName], [Pattern], Exp, StaticVal)
forall a. HasCallStack => String -> a
error (String -> ([VName], [Pattern], Exp, StaticVal))
-> String -> ([VName], [Pattern], Exp, StaticVal)
forall a b. (a -> b) -> a -> b
$ String
"Tried to lift a StaticVal " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv
                       String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but expected a dynamic function."

-- | Converts a pattern to an environment that binds the individual names of the
-- pattern to their corresponding types wrapped in a 'Dynamic' static value.
envFromPattern :: Pattern -> Env
envFromPattern :: Pattern -> Env
envFromPattern Pattern
pat = case Pattern
pat of
  TuplePattern [Pattern]
ps SrcLoc
_       -> (Pattern -> Env) -> [Pattern] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Env
envFromPattern [Pattern]
ps
  RecordPattern [(Name, Pattern)]
fs SrcLoc
_      -> ((Name, Pattern) -> Env) -> [(Name, Pattern)] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Pattern -> Env
envFromPattern (Pattern -> Env)
-> ((Name, Pattern) -> Pattern) -> (Name, Pattern) -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Pattern) -> Pattern
forall a b. (a, b) -> b
snd) [(Name, Pattern)]
fs
  PatternParens Pattern
p SrcLoc
_       -> Pattern -> Env
envFromPattern Pattern
p
  Id VName
vn (Info PatternType
t) SrcLoc
_        -> VName -> StaticVal -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (StaticVal -> Env) -> StaticVal -> Env
forall a b. (a -> b) -> a -> b
$ PatternType -> StaticVal
Dynamic PatternType
t
  Wildcard Info PatternType
_ SrcLoc
_            -> Env
forall a. Monoid a => a
mempty
  PatternAscription Pattern
p TypeDeclBase Info VName
_ SrcLoc
_ -> Pattern -> Env
envFromPattern Pattern
p
  PatternLit{}            -> Env
forall a. Monoid a => a
mempty
  PatternConstr Name
_ Info PatternType
_ [Pattern]
ps SrcLoc
_  -> (Pattern -> Env) -> [Pattern] -> Env
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Env
envFromPattern [Pattern]
ps

-- | Create an environment that binds the shape parameters.
envFromShapeParams :: [TypeParamBase VName] -> Env
envFromShapeParams :: [TypeParam] -> Env
envFromShapeParams = [VName] -> Env
envFromDimNames ([VName] -> Env) -> ([TypeParam] -> [VName]) -> [TypeParam] -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall p. (Eq p, IsName p) => TypeParamBase p -> p
dim
  where dim :: TypeParamBase p -> p
dim (TypeParamDim p
vn SrcLoc
_) = p
vn
        dim TypeParamBase p
tparam = String -> p
forall a. HasCallStack => String -> a
error (String -> p) -> String -> p
forall a b. (a -> b) -> a -> b
$
          String
"The defunctionalizer expects a monomorphic input program,\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++
          String
"but it received a type parameter " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeParamBase p -> String
forall a. Pretty a => a -> String
pretty TypeParamBase p
tparam String -> ShowS
forall a. [a] -> [a] -> [a]
++
          String
" at " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SrcLoc -> String
forall a. Located a => a -> String
locStr (TypeParamBase p -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf TypeParamBase p
tparam) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."

envFromDimNames :: [VName] -> Env
envFromDimNames :: [VName] -> Env
envFromDimNames = [(VName, StaticVal)] -> Env
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, StaticVal)] -> Env)
-> ([VName] -> [(VName, StaticVal)]) -> [VName] -> Env
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> [StaticVal] -> [(VName, StaticVal)])
-> [StaticVal] -> [VName] -> [(VName, StaticVal)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [VName] -> [StaticVal] -> [(VName, StaticVal)]
forall a b. [a] -> [b] -> [(a, b)]
zip (StaticVal -> [StaticVal]
forall a. a -> [a]
repeat (StaticVal -> [StaticVal]) -> StaticVal -> [StaticVal]
forall a b. (a -> b) -> a -> b
$ PatternType -> StaticVal
Dynamic (PatternType -> StaticVal) -> PatternType -> StaticVal
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int32)

-- | Create a new top-level value declaration with the given function name,
-- return type, list of parameters, and body expression.
liftValDec :: VName -> PatternType -> [VName] -> [Pattern] -> Exp -> DefM ()
liftValDec :: VName -> PatternType -> [VName] -> [Pattern] -> Exp -> DefM ()
liftValDec VName
fname PatternType
rettype [VName]
dims [Pattern]
pats Exp
body = Seq ValBind -> DefM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Seq ValBind -> DefM ()) -> Seq ValBind -> DefM ()
forall a b. (a -> b) -> a -> b
$ ValBind -> Seq ValBind
forall a. a -> Seq a
Seq.singleton ValBind
dec
  where dims' :: [TypeParam]
dims' = (VName -> TypeParam) -> [VName] -> [TypeParam]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SrcLoc -> TypeParam
forall vn. vn -> SrcLoc -> TypeParamBase vn
`TypeParamDim` SrcLoc
forall a. Monoid a => a
mempty) [VName]
dims
        -- FIXME: this pass is still not correctly size-preserving, so
        -- forget those return sizes that we forgot to propagate along
        -- the way.  Hopefully the internaliser is conservative and
        -- will insert reshapes...
        bound_here :: Set VName
bound_here = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (IdentBase Info VName -> VName)
-> Set (IdentBase Info VName) -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName ((Pattern -> Set (IdentBase Info VName))
-> [Pattern] -> Set (IdentBase Info VName)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents [Pattern]
pats)
        anyDimIfNotBound :: DimDecl VName -> DimDecl VName
anyDimIfNotBound (NamedDim QualName VName
v)
          | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
bound_here = QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim QualName VName
v
          | Bool
otherwise = DimDecl VName
forall vn. DimDecl vn
AnyDim
        anyDimIfNotBound DimDecl VName
d = DimDecl VName
d
        rettype_st :: StructType
rettype_st = (DimDecl VName -> DimDecl VName) -> StructType -> StructType
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DimDecl VName -> DimDecl VName
anyDimIfNotBound (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatternType
rettype

        dec :: ValBind
dec = ValBind :: forall (f :: * -> *) vn.
Maybe (f EntryPoint)
-> vn
-> Maybe (TypeExp vn)
-> f (StructType, [VName])
-> [TypeParamBase vn]
-> [PatternBase f vn]
-> ExpBase f vn
-> Maybe DocComment
-> SrcLoc
-> ValBindBase f vn
ValBind
          { valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = Maybe (Info EntryPoint)
forall a. Maybe a
Nothing
          , valBindName :: VName
valBindName       = VName
fname
          , valBindRetDecl :: Maybe (TypeExp VName)
valBindRetDecl    = Maybe (TypeExp VName)
forall a. Maybe a
Nothing
          , valBindRetType :: Info (StructType, [VName])
valBindRetType    = (StructType, [VName]) -> Info (StructType, [VName])
forall a. a -> Info a
Info (StructType
rettype_st, [])
          , valBindTypeParams :: [TypeParam]
valBindTypeParams = [TypeParam]
dims'
          , valBindParams :: [Pattern]
valBindParams     = [Pattern]
pats
          , valBindBody :: Exp
valBindBody       = Exp
body
          , valBindDoc :: Maybe DocComment
valBindDoc        = Maybe DocComment
forall a. Maybe a
Nothing
          , valBindLocation :: SrcLoc
valBindLocation   = SrcLoc
forall a. Monoid a => a
mempty
          }

-- | Given a closure environment, construct a record pattern that
-- binds the closed over variables.
buildEnvPattern :: Env -> Pattern
buildEnvPattern :: Env -> Pattern
buildEnvPattern Env
env = [(Name, Pattern)] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[(Name, PatternBase f vn)] -> SrcLoc -> PatternBase f vn
RecordPattern (((VName, StaticVal) -> (Name, Pattern))
-> [(VName, StaticVal)] -> [(Name, Pattern)]
forall a b. (a -> b) -> [a] -> [b]
map (VName, StaticVal) -> (Name, Pattern)
forall vn.
Pretty vn =>
(vn, StaticVal) -> (Name, PatternBase Info vn)
buildField ([(VName, StaticVal)] -> [(Name, Pattern)])
-> [(VName, StaticVal)] -> [(Name, Pattern)]
forall a b. (a -> b) -> a -> b
$ Env -> [(VName, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList Env
env) SrcLoc
forall a. Monoid a => a
mempty
  where buildField :: (vn, StaticVal) -> (Name, PatternBase Info vn)
buildField (vn
vn, StaticVal
sv) =
          (String -> Name
nameFromString (vn -> String
forall a. Pretty a => a -> String
pretty vn
vn),
           vn -> Info PatternType -> SrcLoc -> PatternBase Info vn
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id vn
vn (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatternType
typeFromSV StaticVal
sv) SrcLoc
forall a. Monoid a => a
mempty)

-- | Given a closure environment pattern and the type of a term,
-- construct the type of that term, where uniqueness is set to
-- `Nonunique` for those arrays that are bound in the environment or
-- pattern (except if they are unique there).  This ensures that a
-- lifted function can create unique arrays as long as they do not
-- alias any of its parameters.  XXX: it is not clear that this is a
-- sufficient property, unfortunately.
buildRetType :: Env -> [Pattern] -> StructType -> PatternType -> PatternType
buildRetType :: Env -> [Pattern] -> StructType -> PatternType -> PatternType
buildRetType Env
env [Pattern]
pats = StructType -> PatternType -> PatternType
forall (t :: * -> *) shape as.
(Foldable t, Monoid (t Alias)) =>
TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb
  where bound :: NameSet
bound = (VName -> NameSet) -> [VName] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> NameSet
oneName (Env -> [VName]
forall k a. Map k a -> [k]
M.keys Env
env) NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> (Pattern -> NameSet) -> [Pattern] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> NameSet
patternVars [Pattern]
pats
        boundAsUnique :: VName -> Bool
boundAsUnique VName
v =
          Bool
-> (IdentBase Info VName -> Bool)
-> Maybe (IdentBase Info VName)
-> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
unique (PatternType -> Bool)
-> (IdentBase Info VName -> PatternType)
-> IdentBase Info VName
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Info PatternType -> PatternType
forall a. Info a -> a
unInfo (Info PatternType -> PatternType)
-> (IdentBase Info VName -> Info PatternType)
-> IdentBase Info VName
-> PatternType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdentBase Info VName -> Info PatternType
forall (f :: * -> *) vn. IdentBase f vn -> f PatternType
identType) (Maybe (IdentBase Info VName) -> Bool)
-> Maybe (IdentBase Info VName) -> Bool
forall a b. (a -> b) -> a -> b
$
          (IdentBase Info VName -> Bool)
-> [IdentBase Info VName] -> Maybe (IdentBase Info VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
v) (VName -> Bool)
-> (IdentBase Info VName -> VName) -> IdentBase Info VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName) ([IdentBase Info VName] -> Maybe (IdentBase Info VName))
-> [IdentBase Info VName] -> Maybe (IdentBase Info VName)
forall a b. (a -> b) -> a -> b
$ Set (IdentBase Info VName) -> [IdentBase Info VName]
forall a. Set a -> [a]
S.toList (Set (IdentBase Info VName) -> [IdentBase Info VName])
-> Set (IdentBase Info VName) -> [IdentBase Info VName]
forall a b. (a -> b) -> a -> b
$ (Pattern -> Set (IdentBase Info VName))
-> [Pattern] -> Set (IdentBase Info VName)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents [Pattern]
pats
        problematic :: VName -> Bool
problematic VName
v = (VName
v VName -> NameSet -> Bool
`member` NameSet
bound) Bool -> Bool -> Bool
&& Bool -> Bool
not (VName -> Bool
boundAsUnique VName
v)
        comb :: TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb (Scalar (Record Map Name (TypeBase shape as)
fs_annot)) (Scalar (Record Map Name (TypeBase shape (t Alias))
fs_got)) =
          ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias)
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase shape (t Alias))
 -> ScalarTypeBase shape (t Alias))
-> Map Name (TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ (TypeBase shape as
 -> TypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> Map Name (TypeBase shape as)
-> Map Name (TypeBase shape (t Alias))
-> Map Name (TypeBase shape (t Alias))
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb Map Name (TypeBase shape as)
fs_annot Map Name (TypeBase shape (t Alias))
fs_got
        comb (Scalar (Sum Map Name [TypeBase shape as]
cs_annot)) (Scalar (Sum Map Name [TypeBase shape (t Alias)]
cs_got)) =
          ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> ScalarTypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ Map Name [TypeBase shape (t Alias)]
-> ScalarTypeBase shape (t Alias)
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [TypeBase shape (t Alias)]
 -> ScalarTypeBase shape (t Alias))
-> Map Name [TypeBase shape (t Alias)]
-> ScalarTypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ ([TypeBase shape as]
 -> [TypeBase shape (t Alias)] -> [TypeBase shape (t Alias)])
-> Map Name [TypeBase shape as]
-> Map Name [TypeBase shape (t Alias)]
-> Map Name [TypeBase shape (t Alias)]
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith ((TypeBase shape as
 -> TypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> [TypeBase shape as]
-> [TypeBase shape (t Alias)]
-> [TypeBase shape (t Alias)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase shape as
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
comb) Map Name [TypeBase shape as]
cs_annot Map Name [TypeBase shape (t Alias)]
cs_got
        comb (Scalar Arrow{}) TypeBase shape (t Alias)
t =
          TypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall (t :: * -> *) dim.
(Foldable t, Monoid (t Alias)) =>
TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend TypeBase shape (t Alias)
t
        comb TypeBase shape as
got TypeBase shape (t Alias)
et =
          TypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall (t :: * -> *) dim.
(Foldable t, Monoid (t Alias)) =>
TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend (TypeBase shape (t Alias) -> TypeBase shape (t Alias))
-> TypeBase shape (t Alias) -> TypeBase shape (t Alias)
forall a b. (a -> b) -> a -> b
$ TypeBase shape as -> TypeBase shape Aliasing
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase shape as
got TypeBase shape Aliasing -> t Alias -> TypeBase shape (t Alias)
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` TypeBase shape (t Alias) -> t Alias
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase shape (t Alias)
et

        descend :: TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend t :: TypeBase dim (t Alias)
t@Array{}
          | (Alias -> Bool) -> t Alias -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Bool
problematic (VName -> Bool) -> (Alias -> VName) -> Alias -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alias -> VName
aliasVar) (TypeBase dim (t Alias) -> t Alias
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim (t Alias)
t) = TypeBase dim (t Alias)
t TypeBase dim (t Alias) -> Uniqueness -> TypeBase dim (t Alias)
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
        descend (Scalar (Record Map Name (TypeBase dim (t Alias))
t)) = ScalarTypeBase dim (t Alias) -> TypeBase dim (t Alias)
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase dim (t Alias) -> TypeBase dim (t Alias))
-> ScalarTypeBase dim (t Alias) -> TypeBase dim (t Alias)
forall a b. (a -> b) -> a -> b
$ Map Name (TypeBase dim (t Alias)) -> ScalarTypeBase dim (t Alias)
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name (TypeBase dim (t Alias)) -> ScalarTypeBase dim (t Alias))
-> Map Name (TypeBase dim (t Alias))
-> ScalarTypeBase dim (t Alias)
forall a b. (a -> b) -> a -> b
$ (TypeBase dim (t Alias) -> TypeBase dim (t Alias))
-> Map Name (TypeBase dim (t Alias))
-> Map Name (TypeBase dim (t Alias))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase dim (t Alias) -> TypeBase dim (t Alias)
descend Map Name (TypeBase dim (t Alias))
t
        descend TypeBase dim (t Alias)
t = TypeBase dim (t Alias)
t

-- | Compute the corresponding type for a given static value.
typeFromSV :: StaticVal -> PatternType
typeFromSV :: StaticVal -> PatternType
typeFromSV (Dynamic PatternType
tp) = PatternType
tp
typeFromSV (LambdaSV [VName]
sizes Pattern
_ StructType
_ ExtExp
_ Env
env) =
  Set VName -> PatternType -> PatternType
unscopeType ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
sizes) (PatternType -> PatternType) -> PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ Env -> PatternType
typeFromEnv Env
env
typeFromSV (RecordSV [(Name, StaticVal)]
ls) =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ [(Name, PatternType)] -> Map Name PatternType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, PatternType)] -> Map Name PatternType)
-> [(Name, PatternType)] -> Map Name PatternType
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> (Name, PatternType))
-> [(Name, StaticVal)] -> [(Name, PatternType)]
forall a b. (a -> b) -> [a] -> [b]
map ((StaticVal -> PatternType)
-> (Name, StaticVal) -> (Name, PatternType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StaticVal -> PatternType
typeFromSV) [(Name, StaticVal)]
ls
typeFromSV (DynamicFun (Exp
_, StaticVal
sv) StaticVal
_) =
  StaticVal -> PatternType
typeFromSV StaticVal
sv
typeFromSV (SumSV Name
name [StaticVal]
svs [(Name, [PatternType])]
fields) =
  ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall a b. (a -> b) -> a -> b
$ Map Name [PatternType] -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name [TypeBase dim as] -> ScalarTypeBase dim as
Sum (Map Name [PatternType] -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Map Name [PatternType]
-> ScalarTypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ Name
-> [PatternType]
-> Map Name [PatternType]
-> Map Name [PatternType]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name ((StaticVal -> PatternType) -> [StaticVal] -> [PatternType]
forall a b. (a -> b) -> [a] -> [b]
map StaticVal -> PatternType
typeFromSV [StaticVal]
svs) (Map Name [PatternType] -> Map Name [PatternType])
-> Map Name [PatternType] -> Map Name [PatternType]
forall a b. (a -> b) -> a -> b
$ [(Name, [PatternType])] -> Map Name [PatternType]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, [PatternType])]
fields
typeFromSV StaticVal
IntrinsicSV =
  String -> PatternType
forall a. HasCallStack => String -> a
error String
"Tried to get the type from the static value of an intrinsic."

typeFromEnv :: Env -> PatternType
typeFromEnv :: Env -> PatternType
typeFromEnv = ScalarTypeBase (DimDecl VName) Aliasing -> PatternType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) Aliasing -> PatternType)
-> (Env -> ScalarTypeBase (DimDecl VName) Aliasing)
-> Env
-> PatternType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. Map Name (TypeBase dim as) -> ScalarTypeBase dim as
Record (Map Name PatternType -> ScalarTypeBase (DimDecl VName) Aliasing)
-> (Env -> Map Name PatternType)
-> Env
-> ScalarTypeBase (DimDecl VName) Aliasing
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Name, PatternType)] -> Map Name PatternType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, PatternType)] -> Map Name PatternType)
-> (Env -> [(Name, PatternType)]) -> Env -> Map Name PatternType
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
              ((VName, StaticVal) -> (Name, PatternType))
-> [(VName, StaticVal)] -> [(Name, PatternType)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Name)
-> (StaticVal -> PatternType)
-> (VName, StaticVal)
-> (Name, PatternType)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty) StaticVal -> PatternType
typeFromSV) ([(VName, StaticVal)] -> [(Name, PatternType)])
-> (Env -> [(VName, StaticVal)]) -> Env -> [(Name, PatternType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> [(VName, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList

-- | Construct the type for a fully-applied dynamic function from its
-- static value and the original types of its arguments.
dynamicFunType :: StaticVal -> [PatternType] -> ([PatternType], PatternType)
dynamicFunType :: StaticVal -> [PatternType] -> ([PatternType], PatternType)
dynamicFunType (DynamicFun (Exp, StaticVal)
_ StaticVal
sv) (PatternType
p:[PatternType]
ps) =
  let ([PatternType]
ps', PatternType
ret) = StaticVal -> [PatternType] -> ([PatternType], PatternType)
dynamicFunType StaticVal
sv [PatternType]
ps in (PatternType
p PatternType -> [PatternType] -> [PatternType]
forall a. a -> [a] -> [a]
: [PatternType]
ps', PatternType
ret)
dynamicFunType StaticVal
sv [PatternType]
_ = ([], StaticVal -> PatternType
typeFromSV StaticVal
sv)

-- | Match a pattern with its static value. Returns an environment with
-- the identifier components of the pattern mapped to the corresponding
-- subcomponents of the static value.
matchPatternSV :: PatternBase Info VName -> StaticVal -> Env
matchPatternSV :: Pattern -> StaticVal -> Env
matchPatternSV (TuplePattern [Pattern]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls) =
  [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pattern -> (Name, StaticVal) -> Env)
-> [Pattern] -> [(Name, StaticVal)] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Pattern
p (Name
_, StaticVal
sv) -> Pattern -> StaticVal -> Env
matchPatternSV Pattern
p StaticVal
sv) [Pattern]
ps [(Name, StaticVal)]
ls
matchPatternSV (RecordPattern [(Name, Pattern)]
ps SrcLoc
_) (RecordSV [(Name, StaticVal)]
ls)
  | [(Name, Pattern)]
ps' <- ((Name, Pattern) -> Name) -> [(Name, Pattern)] -> [(Name, Pattern)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, Pattern) -> Name
forall a b. (a, b) -> a
fst [(Name, Pattern)]
ps, [(Name, StaticVal)]
ls' <- ((Name, StaticVal) -> Name)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls,
    ((Name, Pattern) -> Name) -> [(Name, Pattern)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Pattern) -> Name
forall a b. (a, b) -> a
fst [(Name, Pattern)]
ps' [Name] -> [Name] -> Bool
forall a. Eq a => a -> a -> Bool
== ((Name, StaticVal) -> Name) -> [(Name, StaticVal)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
ls' =
      [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ ((Name, Pattern) -> (Name, StaticVal) -> Env)
-> [(Name, Pattern)] -> [(Name, StaticVal)] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(Name
_, Pattern
p) (Name
_, StaticVal
sv) -> Pattern -> StaticVal -> Env
matchPatternSV Pattern
p StaticVal
sv) [(Name, Pattern)]
ps' [(Name, StaticVal)]
ls'
matchPatternSV (PatternParens Pattern
pat SrcLoc
_) StaticVal
sv = Pattern -> StaticVal -> Env
matchPatternSV Pattern
pat StaticVal
sv
matchPatternSV (Id VName
vn (Info PatternType
t) SrcLoc
_) StaticVal
sv =
  -- When matching a pattern with a zero-order STaticVal, the type of
  -- the pattern wins out.  This is important when matching a
  -- nonunique pattern with a unique value.
  if StaticVal -> Bool
orderZeroSV StaticVal
sv
  then VName -> StaticVal -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn (StaticVal -> Env) -> StaticVal -> Env
forall a b. (a -> b) -> a -> b
$ PatternType -> StaticVal
Dynamic PatternType
t
  else VName -> StaticVal -> Env
forall k a. k -> a -> Map k a
M.singleton VName
vn StaticVal
sv
matchPatternSV (Wildcard Info PatternType
_ SrcLoc
_) StaticVal
_ = Env
forall a. Monoid a => a
mempty
matchPatternSV (PatternAscription Pattern
pat TypeDeclBase Info VName
_ SrcLoc
_) StaticVal
sv = Pattern -> StaticVal -> Env
matchPatternSV Pattern
pat StaticVal
sv
matchPatternSV PatternLit{} StaticVal
_ = Env
forall a. Monoid a => a
mempty
matchPatternSV (PatternConstr Name
c1 Info PatternType
_ [Pattern]
ps SrcLoc
_) (SumSV Name
c2 [StaticVal]
ls [(Name, [PatternType])]
fs)
  | Name
c1 Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
c2 =
      [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pattern -> StaticVal -> Env) -> [Pattern] -> [StaticVal] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pattern -> StaticVal -> Env
matchPatternSV [Pattern]
ps [StaticVal]
ls
  | Just [PatternType]
ts <- Name -> [(Name, [PatternType])] -> Maybe [PatternType]
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
c1 [(Name, [PatternType])]
fs =
      [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pattern -> StaticVal -> Env) -> [Pattern] -> [StaticVal] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pattern -> StaticVal -> Env
matchPatternSV [Pattern]
ps ([StaticVal] -> [Env]) -> [StaticVal] -> [Env]
forall a b. (a -> b) -> a -> b
$ (PatternType -> StaticVal) -> [PatternType] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map PatternType -> StaticVal
svFromType [PatternType]
ts
  | Bool
otherwise =
      String -> Env
forall a. HasCallStack => String -> a
error (String -> Env) -> String -> Env
forall a b. (a -> b) -> a -> b
$ String
"matchPatternSV: missing constructor in type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
pretty Name
c1
matchPatternSV (PatternConstr Name
c1 Info PatternType
_ [Pattern]
ps SrcLoc
_) (Dynamic (Scalar (Sum Map Name [PatternType]
fs)))
  | Just [PatternType]
ts <- Name -> Map Name [PatternType] -> Maybe [PatternType]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
c1 Map Name [PatternType]
fs =
      [Env] -> Env
forall a. Monoid a => [a] -> a
mconcat ([Env] -> Env) -> [Env] -> Env
forall a b. (a -> b) -> a -> b
$ (Pattern -> StaticVal -> Env) -> [Pattern] -> [StaticVal] -> [Env]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pattern -> StaticVal -> Env
matchPatternSV [Pattern]
ps ([StaticVal] -> [Env]) -> [StaticVal] -> [Env]
forall a b. (a -> b) -> a -> b
$ (PatternType -> StaticVal) -> [PatternType] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map PatternType -> StaticVal
svFromType [PatternType]
ts
  | Bool
otherwise =
      String -> Env
forall a. HasCallStack => String -> a
error (String -> Env) -> String -> Env
forall a b. (a -> b) -> a -> b
$ String
"matchPatternSV: missing constructor in type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Pretty a => a -> String
pretty Name
c1
matchPatternSV Pattern
pat (Dynamic PatternType
t) = Pattern -> StaticVal -> Env
matchPatternSV Pattern
pat (StaticVal -> Env) -> StaticVal -> Env
forall a b. (a -> b) -> a -> b
$ PatternType -> StaticVal
svFromType PatternType
t
matchPatternSV Pattern
pat StaticVal
sv = String -> Env
forall a. HasCallStack => String -> a
error (String -> Env) -> String -> Env
forall a b. (a -> b) -> a -> b
$ String
"Tried to match pattern " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Pattern -> String
forall a. Pretty a => a -> String
pretty Pattern
pat
                             String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" with static value " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."

orderZeroSV :: StaticVal -> Bool
orderZeroSV :: StaticVal -> Bool
orderZeroSV Dynamic{} = Bool
True
orderZeroSV (RecordSV [(Name, StaticVal)]
fields) = ((Name, StaticVal) -> Bool) -> [(Name, StaticVal)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (StaticVal -> Bool
orderZeroSV (StaticVal -> Bool)
-> ((Name, StaticVal) -> StaticVal) -> (Name, StaticVal) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd) [(Name, StaticVal)]
fields
orderZeroSV StaticVal
_ = Bool
False

-- | Given a pattern and the static value for the defunctionalized argument,
-- update the pattern to reflect the changes in the types.
updatePattern :: Pattern -> StaticVal -> Pattern
updatePattern :: Pattern -> StaticVal -> Pattern
updatePattern (TuplePattern [Pattern]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs) =
  [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
TuplePattern ((Pattern -> StaticVal -> Pattern)
-> [Pattern] -> [StaticVal] -> [Pattern]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pattern -> StaticVal -> Pattern
updatePattern [Pattern]
ps ([StaticVal] -> [Pattern]) -> [StaticVal] -> [Pattern]
forall a b. (a -> b) -> a -> b
$ ((Name, StaticVal) -> StaticVal)
-> [(Name, StaticVal)] -> [StaticVal]
forall a b. (a -> b) -> [a] -> [b]
map (Name, StaticVal) -> StaticVal
forall a b. (a, b) -> b
snd [(Name, StaticVal)]
svs) SrcLoc
loc
updatePattern (RecordPattern [(Name, Pattern)]
ps SrcLoc
loc) (RecordSV [(Name, StaticVal)]
svs)
  | [(Name, Pattern)]
ps' <- ((Name, Pattern) -> Name) -> [(Name, Pattern)] -> [(Name, Pattern)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, Pattern) -> Name
forall a b. (a, b) -> a
fst [(Name, Pattern)]
ps, [(Name, StaticVal)]
svs' <- ((Name, StaticVal) -> Name)
-> [(Name, StaticVal)] -> [(Name, StaticVal)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Name, StaticVal) -> Name
forall a b. (a, b) -> a
fst [(Name, StaticVal)]
svs =
      [(Name, Pattern)] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[(Name, PatternBase f vn)] -> SrcLoc -> PatternBase f vn
RecordPattern (((Name, Pattern) -> (Name, StaticVal) -> (Name, Pattern))
-> [(Name, Pattern)] -> [(Name, StaticVal)] -> [(Name, Pattern)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(Name
n, Pattern
p) (Name
_, StaticVal
sv) ->
                                (Name
n, Pattern -> StaticVal -> Pattern
updatePattern Pattern
p StaticVal
sv)) [(Name, Pattern)]
ps' [(Name, StaticVal)]
svs') SrcLoc
loc
updatePattern (PatternParens Pattern
pat SrcLoc
loc) StaticVal
sv =
  Pattern -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> SrcLoc -> PatternBase f vn
PatternParens (Pattern -> StaticVal -> Pattern
updatePattern Pattern
pat StaticVal
sv) SrcLoc
loc
updatePattern pat :: Pattern
pat@(Id VName
vn (Info PatternType
tp) SrcLoc
loc) StaticVal
sv
  | PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatternType
tp = Pattern
pat
  | Bool
otherwise = VName -> Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> PatternBase f vn
Id VName
vn (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatternType
typeFromSV StaticVal
sv PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique) SrcLoc
loc
updatePattern pat :: Pattern
pat@(Wildcard (Info PatternType
tp) SrcLoc
loc) StaticVal
sv
  | PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatternType
tp = Pattern
pat
  | Bool
otherwise = Info PatternType -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f PatternType -> SrcLoc -> PatternBase f vn
Wildcard (PatternType -> Info PatternType
forall a. a -> Info a
Info (PatternType -> Info PatternType)
-> PatternType -> Info PatternType
forall a b. (a -> b) -> a -> b
$ StaticVal -> PatternType
typeFromSV StaticVal
sv) SrcLoc
loc
updatePattern (PatternAscription Pattern
pat TypeDeclBase Info VName
tydecl SrcLoc
loc) StaticVal
sv
  | StructType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero (StructType -> Bool)
-> (Info StructType -> StructType) -> Info StructType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> Bool) -> Info StructType -> Bool
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
tydecl =
      Pattern -> TypeDeclBase Info VName -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
PatternBase f vn -> TypeDeclBase f vn -> SrcLoc -> PatternBase f vn
PatternAscription (Pattern -> StaticVal -> Pattern
updatePattern Pattern
pat StaticVal
sv) TypeDeclBase Info VName
tydecl SrcLoc
loc
  | Bool
otherwise = Pattern -> StaticVal -> Pattern
updatePattern Pattern
pat StaticVal
sv
updatePattern p :: Pattern
p@PatternLit{} StaticVal
_ = Pattern
p
updatePattern pat :: Pattern
pat@(PatternConstr Name
c1 (Info PatternType
t) [Pattern]
ps SrcLoc
loc) sv :: StaticVal
sv@(SumSV Name
_ [StaticVal]
svs [(Name, [PatternType])]
_)
  | PatternType -> Bool
forall dim as. TypeBase dim as -> Bool
orderZero PatternType
t = Pattern
pat
  | Bool
otherwise = Name -> Info PatternType -> [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
c1 (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t') [Pattern]
ps' SrcLoc
loc
  where t' :: PatternType
t' = StaticVal -> PatternType
typeFromSV StaticVal
sv PatternType -> Uniqueness -> PatternType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
        ps' :: [Pattern]
ps' = (Pattern -> StaticVal -> Pattern)
-> [Pattern] -> [StaticVal] -> [Pattern]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Pattern -> StaticVal -> Pattern
updatePattern [Pattern]
ps [StaticVal]
svs
updatePattern (PatternConstr Name
c1 Info PatternType
_ [Pattern]
ps SrcLoc
loc) (Dynamic PatternType
t) =
  Name -> Info PatternType -> [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
Name
-> f PatternType
-> [PatternBase f vn]
-> SrcLoc
-> PatternBase f vn
PatternConstr Name
c1 (PatternType -> Info PatternType
forall a. a -> Info a
Info PatternType
t) [Pattern]
ps SrcLoc
loc
updatePattern Pattern
pat (Dynamic PatternType
t) = Pattern -> StaticVal -> Pattern
updatePattern Pattern
pat (PatternType -> StaticVal
svFromType PatternType
t)
updatePattern Pattern
pat StaticVal
sv =
  String -> Pattern
forall a. HasCallStack => String -> a
error (String -> Pattern) -> String -> Pattern
forall a b. (a -> b) -> a -> b
$ String
"Tried to update pattern " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Pattern -> String
forall a. Pretty a => a -> String
pretty Pattern
pat
       String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"to reflect the static value " String -> ShowS
forall a. [a] -> [a] -> [a]
++ StaticVal -> String
forall a. Show a => a -> String
show StaticVal
sv

-- | Convert a record (or tuple) type to a record static value. This is used for
-- "unwrapping" tuples and records that are nested in 'Dynamic' static values.
svFromType :: PatternType -> StaticVal
svFromType :: PatternType -> StaticVal
svFromType (Scalar (Record Map Name PatternType
fs)) = [(Name, StaticVal)] -> StaticVal
RecordSV ([(Name, StaticVal)] -> StaticVal)
-> (Map Name StaticVal -> [(Name, StaticVal)])
-> Map Name StaticVal
-> StaticVal
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name StaticVal -> [(Name, StaticVal)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name StaticVal -> StaticVal)
-> Map Name StaticVal -> StaticVal
forall a b. (a -> b) -> a -> b
$ (PatternType -> StaticVal)
-> Map Name PatternType -> Map Name StaticVal
forall a b k. (a -> b) -> Map k a -> Map k b
M.map PatternType -> StaticVal
svFromType Map Name PatternType
fs
svFromType PatternType
t                    = PatternType -> StaticVal
Dynamic PatternType
t

-- A set of names where we also track uniqueness.
newtype NameSet = NameSet (M.Map VName Uniqueness) deriving (Int -> NameSet -> ShowS
[NameSet] -> ShowS
NameSet -> String
(Int -> NameSet -> ShowS)
-> (NameSet -> String) -> ([NameSet] -> ShowS) -> Show NameSet
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NameSet] -> ShowS
$cshowList :: [NameSet] -> ShowS
show :: NameSet -> String
$cshow :: NameSet -> String
showsPrec :: Int -> NameSet -> ShowS
$cshowsPrec :: Int -> NameSet -> ShowS
Show)

instance Semigroup NameSet where
  NameSet Map VName Uniqueness
x <> :: NameSet -> NameSet -> NameSet
<> NameSet Map VName Uniqueness
y = Map VName Uniqueness -> NameSet
NameSet (Map VName Uniqueness -> NameSet)
-> Map VName Uniqueness -> NameSet
forall a b. (a -> b) -> a -> b
$ (Uniqueness -> Uniqueness -> Uniqueness)
-> Map VName Uniqueness
-> Map VName Uniqueness
-> Map VName Uniqueness
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Uniqueness -> Uniqueness -> Uniqueness
forall a. Ord a => a -> a -> a
max Map VName Uniqueness
x Map VName Uniqueness
y

instance Monoid NameSet where
  mempty :: NameSet
mempty = Map VName Uniqueness -> NameSet
NameSet Map VName Uniqueness
forall a. Monoid a => a
mempty

without :: NameSet -> NameSet -> NameSet
without :: NameSet -> NameSet -> NameSet
without (NameSet Map VName Uniqueness
x) (NameSet Map VName Uniqueness
y) = Map VName Uniqueness -> NameSet
NameSet (Map VName Uniqueness -> NameSet)
-> Map VName Uniqueness -> NameSet
forall a b. (a -> b) -> a -> b
$ Map VName Uniqueness
x Map VName Uniqueness
-> Map VName Uniqueness -> Map VName Uniqueness
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Map VName Uniqueness
y

member :: VName -> NameSet -> Bool
member :: VName -> NameSet -> Bool
member VName
v (NameSet Map VName Uniqueness
m) = VName
v VName -> Map VName Uniqueness -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map VName Uniqueness
m

ident :: Ident -> NameSet
ident :: IdentBase Info VName -> NameSet
ident IdentBase Info VName
v = Map VName Uniqueness -> NameSet
NameSet (Map VName Uniqueness -> NameSet)
-> Map VName Uniqueness -> NameSet
forall a b. (a -> b) -> a -> b
$ VName -> Uniqueness -> Map VName Uniqueness
forall k a. k -> a -> Map k a
M.singleton (IdentBase Info VName -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName IdentBase Info VName
v) (PatternType -> Uniqueness
forall shape as. TypeBase shape as -> Uniqueness
uniqueness (PatternType -> Uniqueness) -> PatternType -> Uniqueness
forall a b. (a -> b) -> a -> b
$ Info PatternType -> PatternType
forall a. Info a -> a
unInfo (Info PatternType -> PatternType)
-> Info PatternType -> PatternType
forall a b. (a -> b) -> a -> b
$ IdentBase Info VName -> Info PatternType
forall (f :: * -> *) vn. IdentBase f vn -> f PatternType
identType IdentBase Info VName
v)

oneName :: VName -> NameSet
oneName :: VName -> NameSet
oneName VName
v = Map VName Uniqueness -> NameSet
NameSet (Map VName Uniqueness -> NameSet)
-> Map VName Uniqueness -> NameSet
forall a b. (a -> b) -> a -> b
$ VName -> Uniqueness -> Map VName Uniqueness
forall k a. k -> a -> Map k a
M.singleton VName
v Uniqueness
Nonunique

names :: S.Set VName -> NameSet
names :: Set VName -> NameSet
names = (VName -> NameSet) -> Set VName -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> NameSet
oneName

-- | Compute the set of free variables of an expression.
freeVars :: Exp -> NameSet
freeVars :: Exp -> NameSet
freeVars Exp
expr = case Exp
expr of
  Literal{}            -> NameSet
forall a. Monoid a => a
mempty
  IntLit{}             -> NameSet
forall a. Monoid a => a
mempty
  FloatLit{}           -> NameSet
forall a. Monoid a => a
mempty
  StringLit{}          -> NameSet
forall a. Monoid a => a
mempty
  Parens Exp
e SrcLoc
_           -> Exp -> NameSet
freeVars Exp
e
  QualParens (QualName VName, SrcLoc)
_ Exp
e SrcLoc
_     -> Exp -> NameSet
freeVars Exp
e
  TupLit [Exp]
es SrcLoc
_          -> (Exp -> NameSet) -> [Exp] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> NameSet
freeVars [Exp]
es

  RecordLit [FieldBase Info VName]
fs SrcLoc
_       -> (FieldBase Info VName -> NameSet)
-> [FieldBase Info VName] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap FieldBase Info VName -> NameSet
freeVarsField [FieldBase Info VName]
fs
    where freeVarsField :: FieldBase Info VName -> NameSet
freeVarsField (RecordFieldExplicit Name
_ Exp
e SrcLoc
_)  = Exp -> NameSet
freeVars Exp
e
          freeVarsField (RecordFieldImplicit VName
vn Info PatternType
t SrcLoc
_) = IdentBase Info VName -> NameSet
ident (IdentBase Info VName -> NameSet)
-> IdentBase Info VName -> NameSet
forall a b. (a -> b) -> a -> b
$ VName -> Info PatternType -> SrcLoc -> IdentBase Info VName
forall (f :: * -> *) vn.
vn -> f PatternType -> SrcLoc -> IdentBase f vn
Ident VName
vn Info PatternType
t SrcLoc
forall a. Monoid a => a
mempty

  ArrayLit [Exp]
es Info PatternType
t SrcLoc
_      -> (Exp -> NameSet) -> [Exp] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> NameSet
freeVars [Exp]
es NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<>
                          Set VName -> NameSet
names (PatternType -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames (PatternType -> Set VName) -> PatternType -> Set VName
forall a b. (a -> b) -> a -> b
$ Info PatternType -> PatternType
forall a. Info a -> a
unInfo Info PatternType
t)
  Range Exp
e Maybe Exp
me Inclusiveness Exp
incl (Info PatternType, Info [VName])
_ SrcLoc
_  -> Exp -> NameSet
freeVars Exp
e NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> (Exp -> NameSet) -> Maybe Exp -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> NameSet
freeVars Maybe Exp
me NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<>
                          (Exp -> NameSet) -> Inclusiveness Exp -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> NameSet
freeVars Inclusiveness Exp
incl
  Var QualName VName
qn (Info PatternType
t) SrcLoc
_    -> Map VName Uniqueness -> NameSet
NameSet (Map VName Uniqueness -> NameSet)
-> Map VName Uniqueness -> NameSet
forall a b. (a -> b) -> a -> b
$ VName -> Uniqueness -> Map VName Uniqueness
forall k a. k -> a -> Map k a
M.singleton (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn) (Uniqueness -> Map VName Uniqueness)
-> Uniqueness -> Map VName Uniqueness
forall a b. (a -> b) -> a -> b
$ PatternType -> Uniqueness
forall shape as. TypeBase shape as -> Uniqueness
uniqueness PatternType
t
  Ascript Exp
e TypeDeclBase Info VName
t SrcLoc
_        -> Exp -> NameSet
freeVars Exp
e NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Set VName -> NameSet
names (StructType -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames (StructType -> Set VName) -> StructType -> Set VName
forall a b. (a -> b) -> a -> b
$ Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
t)
  Coerce Exp
e TypeDeclBase Info VName
t (Info PatternType, Info [VName])
_ SrcLoc
_       -> Exp -> NameSet
freeVars Exp
e NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Set VName -> NameSet
names (StructType -> Set VName
forall als. TypeBase (DimDecl VName) als -> Set VName
typeDimNames (StructType -> Set VName) -> StructType -> Set VName
forall a b. (a -> b) -> a -> b
$ Info StructType -> StructType
forall a. Info a -> a
unInfo (Info StructType -> StructType) -> Info StructType -> StructType
forall a b. (a -> b) -> a -> b
$ TypeDeclBase Info VName -> Info StructType
forall (f :: * -> *) vn. TypeDeclBase f vn -> f StructType
expandedType TypeDeclBase Info VName
t)
  LetPat Pattern
pat Exp
e1 Exp
e2 (Info PatternType, Info [VName])
_ SrcLoc
_ -> Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> ((Set VName -> NameSet
names (Pattern -> Set VName
patternDimNames Pattern
pat) NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e2)
                                          NameSet -> NameSet -> NameSet
`without` Pattern -> NameSet
patternVars Pattern
pat)

  LetFun VName
vn ([TypeParam]
_, [Pattern]
pats, Maybe (TypeExp VName)
_, Info StructType
_, Exp
e1) Exp
e2 Info PatternType
_ SrcLoc
_ ->
    ((Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Set VName -> NameSet
names ((Pattern -> Set VName) -> [Pattern] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Set VName
patternDimNames [Pattern]
pats))
      NameSet -> NameSet -> NameSet
`without` (Pattern -> NameSet) -> [Pattern] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> NameSet
patternVars [Pattern]
pats) NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<>
    (Exp -> NameSet
freeVars Exp
e2 NameSet -> NameSet -> NameSet
`without` VName -> NameSet
oneName VName
vn)

  If Exp
e1 Exp
e2 Exp
e3 (Info PatternType, Info [VName])
_ SrcLoc
_           -> Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e2 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e3
  Apply Exp
e1 Exp
e2 Info (Diet, Maybe VName)
_ (Info PatternType, Info [VName])
_ SrcLoc
_         -> Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e2
  Negate Exp
e SrcLoc
_                -> Exp -> NameSet
freeVars Exp
e
  Lambda [Pattern]
pats Exp
e0 Maybe (TypeExp VName)
_ Info (Aliasing, StructType)
_ SrcLoc
_      -> (Set VName -> NameSet
names ((Pattern -> Set VName) -> [Pattern] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> Set VName
patternDimNames [Pattern]
pats) NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e0)
                               NameSet -> NameSet -> NameSet
`without` (Pattern -> NameSet) -> [Pattern] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pattern -> NameSet
patternVars [Pattern]
pats
  OpSection{}                 -> NameSet
forall a. Monoid a => a
mempty
  OpSectionLeft QualName VName
_  Info PatternType
_ Exp
e (Info (StructType, Maybe VName), Info StructType)
_ (Info PatternType, Info [VName])
_ SrcLoc
_  -> Exp -> NameSet
freeVars Exp
e
  OpSectionRight QualName VName
_ Info PatternType
_ Exp
e (Info StructType, Info (StructType, Maybe VName))
_ Info PatternType
_ SrcLoc
_  -> Exp -> NameSet
freeVars Exp
e
  ProjectSection{}            -> NameSet
forall a. Monoid a => a
mempty
  IndexSection [DimIndexBase Info VName]
idxs Info PatternType
_ SrcLoc
_       -> (DimIndexBase Info VName -> NameSet)
-> [DimIndexBase Info VName] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DimIndexBase Info VName -> NameSet
freeDimIndex [DimIndexBase Info VName]
idxs

  DoLoop [VName]
sparams Pattern
pat Exp
e1 LoopFormBase Info VName
form Exp
e3 Info (PatternType, [VName])
_ SrcLoc
_ ->
    let (NameSet
e2fv, NameSet
e2ident) = LoopFormBase Info VName -> (NameSet, NameSet)
formVars LoopFormBase Info VName
form
    in Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> NameSet
e2fv NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<>
       (Exp -> NameSet
freeVars Exp
e3 NameSet -> NameSet -> NameSet
`without`
        (Set VName -> NameSet
names ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
sparams) NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Pattern -> NameSet
patternVars Pattern
pat NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> NameSet
e2ident))
    where formVars :: LoopFormBase Info VName -> (NameSet, NameSet)
formVars (For IdentBase Info VName
v Exp
e2) = (Exp -> NameSet
freeVars Exp
e2, IdentBase Info VName -> NameSet
ident IdentBase Info VName
v)
          formVars (ForIn Pattern
p Exp
e2)   = (Exp -> NameSet
freeVars Exp
e2, Pattern -> NameSet
patternVars Pattern
p)
          formVars (While Exp
e2)     = (Exp -> NameSet
freeVars Exp
e2, NameSet
forall a. Monoid a => a
mempty)

  BinOp (QualName VName
qn, SrcLoc
_) Info PatternType
_ (Exp
e1, Info (StructType, Maybe VName)
_) (Exp
e2, Info (StructType, Maybe VName)
_) Info PatternType
_ Info [VName]
_ SrcLoc
_ -> VName -> NameSet
oneName (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
qn) NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<>
                                           Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e2
  Project Name
_ Exp
e Info PatternType
_ SrcLoc
_                -> Exp -> NameSet
freeVars Exp
e

  LetWith IdentBase Info VName
id1 IdentBase Info VName
id2 [DimIndexBase Info VName]
idxs Exp
e1 Exp
e2 Info PatternType
_ SrcLoc
_ ->
    IdentBase Info VName -> NameSet
ident IdentBase Info VName
id2 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> (DimIndexBase Info VName -> NameSet)
-> [DimIndexBase Info VName] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DimIndexBase Info VName -> NameSet
freeDimIndex [DimIndexBase Info VName]
idxs NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<>
    (Exp -> NameSet
freeVars Exp
e2 NameSet -> NameSet -> NameSet
`without` IdentBase Info VName -> NameSet
ident IdentBase Info VName
id1)

  Index Exp
e [DimIndexBase Info VName]
idxs (Info PatternType, Info [VName])
_ SrcLoc
_    -> Exp -> NameSet
freeVars Exp
e  NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> (DimIndexBase Info VName -> NameSet)
-> [DimIndexBase Info VName] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DimIndexBase Info VName -> NameSet
freeDimIndex [DimIndexBase Info VName]
idxs
  Update Exp
e1 [DimIndexBase Info VName]
idxs Exp
e2 SrcLoc
_ -> Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> (DimIndexBase Info VName -> NameSet)
-> [DimIndexBase Info VName] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DimIndexBase Info VName -> NameSet
freeDimIndex [DimIndexBase Info VName]
idxs NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e2
  RecordUpdate Exp
e1 [Name]
_ Exp
e2 Info PatternType
_ SrcLoc
_ -> Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e2

  Unsafe Exp
e SrcLoc
_          -> Exp -> NameSet
freeVars Exp
e
  Assert Exp
e1 Exp
e2 Info String
_ SrcLoc
_    -> Exp -> NameSet
freeVars Exp
e1 NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
e2
  Constr Name
_ [Exp]
es Info PatternType
_ SrcLoc
_     -> (Exp -> NameSet) -> [Exp] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> NameSet
freeVars [Exp]
es
  Attr AttrInfo
_ Exp
e SrcLoc
_          -> Exp -> NameSet
freeVars Exp
e
  Match Exp
e NonEmpty (CaseBase Info VName)
cs (Info PatternType, Info [VName])
_ SrcLoc
_      -> Exp -> NameSet
freeVars Exp
e NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> (CaseBase Info VName -> NameSet)
-> NonEmpty (CaseBase Info VName) -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap CaseBase Info VName -> NameSet
caseFV NonEmpty (CaseBase Info VName)
cs
    where caseFV :: CaseBase Info VName -> NameSet
caseFV (CasePat Pattern
p Exp
eCase SrcLoc
_) = (Set VName -> NameSet
names (Pattern -> Set VName
patternDimNames Pattern
p) NameSet -> NameSet -> NameSet
forall a. Semigroup a => a -> a -> a
<> Exp -> NameSet
freeVars Exp
eCase)
                                       NameSet -> NameSet -> NameSet
`without` Pattern -> NameSet
patternVars Pattern
p

freeDimIndex :: DimIndexBase Info VName -> NameSet
freeDimIndex :: DimIndexBase Info VName -> NameSet
freeDimIndex (DimFix Exp
e) = Exp -> NameSet
freeVars Exp
e
freeDimIndex (DimSlice Maybe Exp
me1 Maybe Exp
me2 Maybe Exp
me3) =
  (Maybe Exp -> NameSet) -> [Maybe Exp] -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((Exp -> NameSet) -> Maybe Exp -> NameSet
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> NameSet
freeVars) [Maybe Exp
me1, Maybe Exp
me2, Maybe Exp
me3]

-- | Extract all the variable names bound in a pattern.
patternVars :: Pattern -> NameSet
patternVars :: Pattern -> NameSet
patternVars = [NameSet] -> NameSet
forall a. Monoid a => [a] -> a
mconcat ([NameSet] -> NameSet)
-> (Pattern -> [NameSet]) -> Pattern -> NameSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IdentBase Info VName -> NameSet)
-> [IdentBase Info VName] -> [NameSet]
forall a b. (a -> b) -> [a] -> [b]
map IdentBase Info VName -> NameSet
ident ([IdentBase Info VName] -> [NameSet])
-> (Pattern -> [IdentBase Info VName]) -> Pattern -> [NameSet]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set (IdentBase Info VName) -> [IdentBase Info VName]
forall a. Set a -> [a]
S.toList (Set (IdentBase Info VName) -> [IdentBase Info VName])
-> (Pattern -> Set (IdentBase Info VName))
-> Pattern
-> [IdentBase Info VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> Set (IdentBase Info VName)
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatternBase f vn -> Set (IdentBase f vn)
patternIdents

-- | Defunctionalize a top-level value binding. Returns the
-- transformed result as well as an environment that binds the name of
-- the value binding to the static value of the transformed body.  The
-- boolean is true if the function is a 'DynamicFun'.
defuncValBind :: ValBind -> DefM (ValBind, Env, Bool)

-- Eta-expand entry points with a functional return type.
defuncValBind :: ValBind -> DefM (ValBind, Env, Bool)
defuncValBind (ValBind Maybe (Info EntryPoint)
entry VName
name Maybe (TypeExp VName)
_ (Info (StructType
rettype, [VName]
retext)) [TypeParam]
tparams [Pattern]
params Exp
body Maybe DocComment
_ SrcLoc
loc)
  | Scalar Arrow{} <- StructType
rettype = do
      ([Pattern]
body_pats, Exp
body', StructType
rettype') <- PatternType -> Exp -> DefM ([Pattern], Exp, StructType)
etaExpand (StructType -> PatternType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct StructType
rettype) Exp
body
      ValBind -> DefM (ValBind, Env, Bool)
defuncValBind (ValBind -> DefM (ValBind, Env, Bool))
-> ValBind -> DefM (ValBind, Env, Bool)
forall a b. (a -> b) -> a -> b
$ Maybe (Info EntryPoint)
-> VName
-> Maybe (TypeExp VName)
-> Info (StructType, [VName])
-> [TypeParam]
-> [Pattern]
-> Exp
-> Maybe DocComment
-> SrcLoc
-> ValBind
forall (f :: * -> *) vn.
Maybe (f EntryPoint)
-> vn
-> Maybe (TypeExp vn)
-> f (StructType, [VName])
-> [TypeParamBase vn]
-> [PatternBase f vn]
-> ExpBase f vn
-> Maybe DocComment
-> SrcLoc
-> ValBindBase f vn
ValBind Maybe (Info EntryPoint)
entry VName
name Maybe (TypeExp VName)
forall a. Maybe a
Nothing
        ((StructType, [VName]) -> Info (StructType, [VName])
forall a. a -> Info a
Info (StructType
rettype', [VName]
retext))
        [TypeParam]
tparams ([Pattern]
params [Pattern] -> [Pattern] -> [Pattern]
forall a. Semigroup a => a -> a -> a
<> [Pattern]
body_pats) Exp
body' Maybe DocComment
forall a. Maybe a
Nothing SrcLoc
loc

defuncValBind valbind :: ValBind
valbind@(ValBind Maybe (Info EntryPoint)
_ VName
name Maybe (TypeExp VName)
retdecl (Info (StructType
rettype, [VName]
retext)) [TypeParam]
tparams [Pattern]
params Exp
body Maybe DocComment
_ SrcLoc
_) = do
  ([TypeParam]
tparams', [Pattern]
params', Exp
body', StaticVal
sv) <- [TypeParam]
-> [Pattern]
-> Exp
-> StructType
-> DefM ([TypeParam], [Pattern], Exp, StaticVal)
defuncLet [TypeParam]
tparams [Pattern]
params Exp
body StructType
rettype
  let rettype' :: StructType
rettype' = StructType -> StructType -> StructType
forall as dim.
(Monoid as, ArrayDim dim) =>
TypeBase dim as -> TypeBase dim as -> TypeBase dim as
combineTypeShapes StructType
rettype (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ StructType -> StructType
forall vn as. TypeBase (DimDecl vn) as -> TypeBase (DimDecl vn) as
anySizes (StructType -> StructType) -> StructType -> StructType
forall a b. (a -> b) -> a -> b
$ PatternType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatternType -> StructType) -> PatternType -> StructType
forall a b. (a -> b) -> a -> b
$ Exp -> PatternType
typeOf Exp
body'
  (ValBind, Env, Bool) -> DefM (ValBind, Env, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValBind
valbind { valBindRetDecl :: Maybe (TypeExp VName)
valBindRetDecl    = Maybe (TypeExp VName)
retdecl
                   , valBindRetType :: Info (StructType, [VName])
valBindRetType    = (StructType, [VName]) -> Info (StructType, [VName])
forall a. a -> Info a
Info (if [Pattern] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Pattern]
params'
                                               then StructType
rettype' StructType -> Uniqueness -> StructType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
                                               else StructType
rettype',
                                               [VName]
retext)
                   , valBindTypeParams :: [TypeParam]
valBindTypeParams = [TypeParam]
tparams'
                   , valBindParams :: [Pattern]
valBindParams     = [Pattern]
params'
                   , valBindBody :: Exp
valBindBody       = Exp
body'
                   }
         , VName -> StaticVal -> Env
forall k a. k -> a -> Map k a
M.singleton VName
name StaticVal
sv
         , case StaticVal
sv of DynamicFun{} -> Bool
True
                      StaticVal
_            -> Bool
False)

-- | Defunctionalize a list of top-level declarations.
defuncVals :: [ValBind] -> DefM (Seq.Seq ValBind)
defuncVals :: [ValBind] -> DefM (Seq ValBind)
defuncVals [] = Seq ValBind -> DefM (Seq ValBind)
forall (m :: * -> *) a. Monad m => a -> m a
return Seq ValBind
forall a. Monoid a => a
mempty
defuncVals (ValBind
valbind : [ValBind]
ds) = do
  ((ValBind
valbind', Env
env, Bool
dyn), Seq ValBind
defs) <- DefM (ValBind, Env, Bool)
-> DefM ((ValBind, Env, Bool), Seq ValBind)
forall a. DefM a -> DefM (a, Seq ValBind)
collectFuns (DefM (ValBind, Env, Bool)
 -> DefM ((ValBind, Env, Bool), Seq ValBind))
-> DefM (ValBind, Env, Bool)
-> DefM ((ValBind, Env, Bool), Seq ValBind)
forall a b. (a -> b) -> a -> b
$ ValBind -> DefM (ValBind, Env, Bool)
defuncValBind ValBind
valbind
  Seq ValBind
ds' <- Env -> DefM (Seq ValBind) -> DefM (Seq ValBind)
forall a. Env -> DefM a -> DefM a
localEnv Env
env (DefM (Seq ValBind) -> DefM (Seq ValBind))
-> DefM (Seq ValBind) -> DefM (Seq ValBind)
forall a b. (a -> b) -> a -> b
$ if Bool
dyn
                        then VName -> DefM (Seq ValBind) -> DefM (Seq ValBind)
forall a. VName -> DefM a -> DefM a
isGlobal (ValBind -> VName
forall (f :: * -> *) vn. ValBindBase f vn -> vn
valBindName ValBind
valbind') (DefM (Seq ValBind) -> DefM (Seq ValBind))
-> DefM (Seq ValBind) -> DefM (Seq ValBind)
forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM (Seq ValBind)
defuncVals [ValBind]
ds
                        else [ValBind] -> DefM (Seq ValBind)
defuncVals [ValBind]
ds
  Seq ValBind -> DefM (Seq ValBind)
forall (m :: * -> *) a. Monad m => a -> m a
return (Seq ValBind -> DefM (Seq ValBind))
-> Seq ValBind -> DefM (Seq ValBind)
forall a b. (a -> b) -> a -> b
$ Seq ValBind
defs Seq ValBind -> Seq ValBind -> Seq ValBind
forall a. Semigroup a => a -> a -> a
<> ValBind -> Seq ValBind
forall a. a -> Seq a
Seq.singleton ValBind
valbind' Seq ValBind -> Seq ValBind -> Seq ValBind
forall a. Semigroup a => a -> a -> a
<> Seq ValBind
ds'

-- | Transform a list of top-level value bindings. May produce new
-- lifted function definitions, which are placed in front of the
-- resulting list of declarations.
transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg :: [ValBind] -> m [ValBind]
transformProg [ValBind]
decs = (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ([ValBind], VNameSource)) -> m [ValBind])
-> (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
  let (Seq ValBind
decs', VNameSource
namesrc', Seq ValBind
liftedDecs) = VNameSource
-> DefM (Seq ValBind) -> (Seq ValBind, VNameSource, Seq ValBind)
forall a. VNameSource -> DefM a -> (a, VNameSource, Seq ValBind)
runDefM VNameSource
namesrc (DefM (Seq ValBind) -> (Seq ValBind, VNameSource, Seq ValBind))
-> DefM (Seq ValBind) -> (Seq ValBind, VNameSource, Seq ValBind)
forall a b. (a -> b) -> a -> b
$ [ValBind] -> DefM (Seq ValBind)
defuncVals [ValBind]
decs
  in (Seq ValBind -> [ValBind]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Seq ValBind -> [ValBind]) -> Seq ValBind -> [ValBind]
forall a b. (a -> b) -> a -> b
$ Seq ValBind
liftedDecs Seq ValBind -> Seq ValBind -> Seq ValBind
forall a. Semigroup a => a -> a -> a
<> Seq ValBind
decs', VNameSource
namesrc')