-- | Lambda-lifting of typed, monomorphic Futhark programs without
-- modules.  After this pass, the program will no longer contain any
-- 'LetFun's or 'Lambda's.
module Futhark.Internalise.LiftLambdas (transformProg) where

import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Foldable
import Data.List (partition)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.IR.Pretty ()
import Futhark.MonadFreshNames
import Language.Futhark
import Language.Futhark.Traversals

newtype Env = Env {Env -> Map VName Exp
envReplace :: M.Map VName Exp}

initialEnv :: Env
initialEnv :: Env
initialEnv = Map VName Exp -> Env
Env forall a. Monoid a => a
mempty

data LiftState = State
  { LiftState -> VNameSource
stateNameSource :: VNameSource,
    LiftState -> [ValBind]
stateValBinds :: [ValBind],
    LiftState -> Set VName
stateGlobal :: S.Set VName
  }

initialState :: VNameSource -> LiftState
initialState :: VNameSource -> LiftState
initialState VNameSource
src = VNameSource -> [ValBind] -> Set VName -> LiftState
State VNameSource
src forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys Map VName Intrinsic
intrinsics

newtype LiftM a = LiftM (ReaderT Env (State LiftState) a)
  deriving (forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LiftM b -> LiftM a
$c<$ :: forall a b. a -> LiftM b -> LiftM a
fmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
Functor, Functor LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. LiftM a -> LiftM b -> LiftM a
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
*> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
liftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
pure :: forall a. a -> LiftM a
$cpure :: forall a. a -> LiftM a
Applicative, Applicative LiftM
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> LiftM a
$creturn :: forall a. a -> LiftM a
>> :: forall a b. LiftM a -> LiftM b -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
Monad, MonadReader Env, MonadState LiftState)

instance MonadFreshNames LiftM where
  putNameSource :: VNameSource -> LiftM ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \LiftState
s -> LiftState
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
  getNameSource :: LiftM VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftState -> VNameSource
stateNameSource

runLiftM :: VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM :: VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM VNameSource
src (LiftM ReaderT Env (State LiftState) ()
m) =
  let s :: LiftState
s = forall s a. State s a -> s -> s
execState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Env (State LiftState) ()
m Env
initialEnv) (VNameSource -> LiftState
initialState VNameSource
src)
   in (forall a. [a] -> [a]
reverse (LiftState -> [ValBind]
stateValBinds LiftState
s), LiftState -> VNameSource
stateNameSource LiftState
s)

addValBind :: ValBind -> LiftM ()
addValBind :: ValBind -> LiftM ()
addValBind ValBind
vb = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \LiftState
s ->
  LiftState
s
    { stateValBinds :: [ValBind]
stateValBinds = ValBind
vb forall a. a -> [a] -> [a]
: LiftState -> [ValBind]
stateValBinds LiftState
s,
      stateGlobal :: Set VName
stateGlobal = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Ord a => a -> Set a -> Set a
S.insert) (LiftState -> Set VName
stateGlobal LiftState
s) (ValBind -> [VName]
valBindBound ValBind
vb)
    }

replacing :: VName -> Exp -> LiftM a -> LiftM a
replacing :: forall a. VName -> Exp -> LiftM a -> LiftM a
replacing VName
v Exp
e = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env
env ->
  Env
env {envReplace :: Map VName Exp
envReplace = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Exp
e forall a b. (a -> b) -> a -> b
$ Env -> Map VName Exp
envReplace Env
env}

existentials :: Exp -> S.Set VName
existentials :: Exp -> Set VName
existentials Exp
e =
  let onArg :: (Info (a, Maybe a), b) -> [a]
onArg (Info (a
_, Maybe a
pdim), b
_) =
        forall a. Maybe a -> [a]
maybeToList Maybe a
pdim
      here :: Set VName
here = case Exp
e of
        AppExp (Apply Exp
_ NonEmpty (Info (Diet, Maybe VName), Exp)
args SrcLoc
_) (Info AppRes
res) ->
          forall a. Ord a => [a] -> Set a
S.fromList (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {a} {a} {b}. (Info (a, Maybe a), b) -> [a]
onArg NonEmpty (Info (Diet, Maybe VName), Exp)
args forall a. Semigroup a => a -> a -> a
<> AppRes -> [VName]
appResExt AppRes
res)
        AppExp AppExpBase Info VName
_ (Info AppRes
res) ->
          forall a. Ord a => [a] -> Set a
S.fromList (AppRes -> [VName]
appResExt AppRes
res)
        Exp
_ ->
          forall a. Monoid a => a
mempty

      m :: ASTMapper (StateT (Set VName) Identity)
m = forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> StateT (Set VName) Identity Exp
mapOnExp = \Exp
e' -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> Exp -> Set VName
existentials Exp
e') forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e'}
   in forall s a. State s a -> s -> s
execState (forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper (StateT (Set VName) Identity)
m Exp
e) Set VName
here

freeSizes :: S.Set VName -> FV
freeSizes :: Set VName -> FV
freeSizes Set VName
vs =
  Map VName StructType -> FV
FV forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Set a -> [a]
S.toList Set VName
vs) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar forall a b. (a -> b) -> a -> b
$ forall dim as. PrimType -> ScalarTypeBase dim as
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64

liftFunction :: VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction :: VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction VName
fname [TypeParam]
tparams [Pat]
params (RetType [VName]
dims StructType
ret) Exp
funbody = do
  -- Find free variables
  Set VName
global <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftState -> Set VName
stateGlobal
  let bound :: Set VName
bound =
        Set VName
global
          forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames [Pat]
params
          forall a. Semigroup a => a -> a -> a
<> forall a. Ord a => [a] -> Set a
S.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams)
          forall a. Semigroup a => a -> a -> a
<> forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims

      free :: [(VName, StructType)]
free =
        let immediate_free :: FV
immediate_free = Exp -> FV
freeInExp Exp
funbody FV -> Set VName -> FV
`freeWithout` (Set VName
bound forall a. Semigroup a => a -> a -> a
<> Exp -> Set VName
existentials Exp
funbody)
            sizes_in_free :: Set VName
sizes_in_free =
              forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall as. TypeBase Size as -> Set VName
freeInType forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems forall a b. (a -> b) -> a -> b
$ FV -> Map VName StructType
unFV FV
immediate_free
            sizes :: FV
sizes =
              Set VName -> FV
freeSizes forall a b. (a -> b) -> a -> b
$
                Set VName
sizes_in_free
                  forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set VName
freeInPat [Pat]
params
                  forall a. Semigroup a => a -> a -> a
<> forall as. TypeBase Size as -> Set VName
freeInType StructType
ret
         in forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ FV -> Map VName StructType
unFV forall a b. (a -> b) -> a -> b
$ FV
immediate_free forall a. Semigroup a => a -> a -> a
<> (FV
sizes FV -> Set VName -> FV
`freeWithout` Set VName
bound)

      -- Those parameters that correspond to sizes must come first.
      sizes_in_types :: Set VName
sizes_in_types =
        forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall as. TypeBase Size as -> Set VName
freeInType (StructType
ret forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, StructType)]
free forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map Pat -> StructType
patternStructType [Pat]
params)
      isSize :: (VName, b) -> Bool
isSize (VName
v, b
_) = VName
v forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes_in_types
      ([(VName, StructType)]
free_dims, [(VName, StructType)]
free_nondims) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition forall {b}. (VName, b) -> Bool
isSize [(VName, StructType)]
free

      free_params :: [Pat]
free_params =
        forall a b. (a -> b) -> [a] -> [b]
map (forall {vn} {as}. (vn, TypeBase Size as) -> PatBase Info vn
mkParam forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique)) forall a b. (a -> b) -> a -> b
$
          [(VName, StructType)]
free_dims forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims

  ValBind -> LiftM ()
addValBind forall a b. (a -> b) -> a -> b
$
    ValBind
      { valBindName :: VName
valBindName = VName
fname,
        valBindTypeParams :: [TypeParam]
valBindTypeParams = [TypeParam]
tparams,
        valBindParams :: [Pat]
valBindParams = [Pat]
free_params forall a. [a] -> [a] -> [a]
++ [Pat]
params,
        valBindRetDecl :: Maybe (TypeExp Info VName)
valBindRetDecl = forall a. Maybe a
Nothing,
        valBindRetType :: Info StructRetType
valBindRetType = forall a. a -> Info a
Info (forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims StructType
ret),
        valBindBody :: Exp
valBindBody = Exp
funbody,
        valBindDoc :: Maybe DocComment
valBindDoc = forall a. Maybe a
Nothing,
        valBindAttrs :: [AttrInfo VName]
valBindAttrs = forall a. Monoid a => a
mempty,
        valBindLocation :: SrcLoc
valBindLocation = forall a. Monoid a => a
mempty,
        valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = forall a. Maybe a
Nothing
      }

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    forall a b. (a -> b) -> a -> b
$ Exp -> [(VName, StructType)] -> Exp
apply
      (forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName VName
fname) (forall a. a -> Info a
Info (forall {as}. [(VName, TypeBase Size as)] -> PatType
augType forall a b. (a -> b) -> a -> b
$ [(VName, StructType)]
free_dims forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims)) forall a. Monoid a => a
mempty)
    forall a b. (a -> b) -> a -> b
$ [(VName, StructType)]
free_dims forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims
  where
    orig_type :: StructType
orig_type = [Pat] -> StructRetType -> StructType
funType [Pat]
params forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims StructType
ret
    mkParam :: (vn, TypeBase Size as) -> PatBase Info vn
mkParam (vn
v, TypeBase Size as
t) = forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id vn
v (forall a. a -> Info a
Info (forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase Size as
t)) forall a. Monoid a => a
mempty
    freeVar :: (vn, TypeBase Size as) -> ExpBase Info vn
freeVar (vn
v, TypeBase Size as
t) = forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (forall v. v -> QualName v
qualName vn
v) (forall a. a -> Info a
Info (forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase Size as
t)) forall a. Monoid a => a
mempty
    augType :: [(VName, TypeBase Size as)] -> PatType
augType [(VName, TypeBase Size as)]
rem_free = forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct forall a b. (a -> b) -> a -> b
$ [Pat] -> StructRetType -> StructType
funType (forall a b. (a -> b) -> [a] -> [b]
map forall {vn} {as}. (vn, TypeBase Size as) -> PatBase Info vn
mkParam [(VName, TypeBase Size as)]
rem_free) forall a b. (a -> b) -> a -> b
$ forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] StructType
orig_type

    apply :: Exp -> [(VName, StructType)] -> Exp
    apply :: Exp -> [(VName, StructType)] -> Exp
apply Exp
f [] = Exp
f
    apply Exp
f ((VName, StructType)
p : [(VName, StructType)]
rem_ps) =
      let inner_ret :: AppRes
inner_ret = PatType -> [VName] -> AppRes
AppRes (forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct (forall {as}. [(VName, TypeBase Size as)] -> PatType
augType [(VName, StructType)]
rem_ps)) forall a. Monoid a => a
mempty
          inner :: Exp
inner = forall vn.
ExpBase Info vn
-> [(Diet, Maybe VName, ExpBase Info vn)]
-> AppRes
-> ExpBase Info vn
mkApply Exp
f [(Diet
Observe, forall a. Maybe a
Nothing, forall {vn} {as}. (vn, TypeBase Size as) -> ExpBase Info vn
freeVar (VName, StructType)
p)] AppRes
inner_ret
       in Exp -> [(VName, StructType)] -> Exp
apply Exp
inner [(VName, StructType)]
rem_ps

transformExp :: Exp -> LiftM Exp
transformExp :: Exp -> LiftM Exp
transformExp (AppExp (LetFun VName
fname ([TypeParam]
tparams, [Pat]
params, Maybe (TypeExp Info VName)
_, Info StructRetType
ret, Exp
funbody) Exp
body SrcLoc
_) Info AppRes
_) = do
  Exp
funbody' <- Exp -> LiftM Exp
transformExp Exp
funbody
  VName
fname' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ [Char]
"lifted_" forall a. [a] -> [a] -> [a]
++ VName -> [Char]
baseString VName
fname
  Exp
lifted_call <- VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction VName
fname' [TypeParam]
tparams [Pat]
params StructRetType
ret Exp
funbody'
  forall a. VName -> Exp -> LiftM a -> LiftM a
replacing VName
fname Exp
lifted_call forall a b. (a -> b) -> a -> b
$ Exp -> LiftM Exp
transformExp Exp
body
transformExp (Lambda [Pat]
params Exp
body Maybe (TypeExp Info VName)
_ (Info (Aliasing
_, StructRetType
ret)) SrcLoc
_) = do
  Exp
body' <- Exp -> LiftM Exp
transformExp Exp
body
  VName
fname <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"lifted_lambda"
  VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction VName
fname [] [Pat]
params StructRetType
ret Exp
body'
transformExp e :: Exp
e@(Var QualName VName
v Info PatType
_ SrcLoc
_) =
  -- Note that function-typed variables can only occur in expressions,
  -- not in other places where VNames/QualNames can occur.
  forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (forall a. a -> Maybe a -> a
fromMaybe Exp
e forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall vn. QualName vn -> vn
qualLeaf QualName VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map VName Exp
envReplace)
transformExp Exp
e =
  forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper LiftM
m Exp
e
  where
    m :: ASTMapper LiftM
m = forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> LiftM Exp
mapOnExp = Exp -> LiftM Exp
transformExp}

transformValBind :: ValBind -> LiftM ()
transformValBind :: ValBind -> LiftM ()
transformValBind ValBind
vb = do
  Exp
e <- Exp -> LiftM Exp
transformExp forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) vn. ValBindBase f vn -> ExpBase f vn
valBindBody ValBind
vb
  ValBind -> LiftM ()
addValBind forall a b. (a -> b) -> a -> b
$ ValBind
vb {valBindBody :: Exp
valBindBody = Exp
e}

{-# NOINLINE transformProg #-}

-- | Perform the transformation.
transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg :: forall (m :: * -> *). MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg [ValBind]
vbinds =
  forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
    VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM VNameSource
namesrc forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ValBind -> LiftM ()
transformValBind [ValBind]
vbinds