{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE Trustworthy #-}
module Futhark.Internalise.LiftLambdas (transformProg) where
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Foldable
import Data.List (partition)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.IR.Pretty ()
import Futhark.MonadFreshNames
import Language.Futhark
import Language.Futhark.Traversals
newtype Env = Env {Env -> Map VName Exp
envReplace :: M.Map VName Exp}
initialEnv :: Env
initialEnv :: Env
initialEnv = Map VName Exp -> Env
Env Map VName Exp
forall a. Monoid a => a
mempty
data LiftState = State
{ LiftState -> VNameSource
stateNameSource :: VNameSource,
LiftState -> [ValBind]
stateValBinds :: [ValBind],
LiftState -> Set VName
stateGlobal :: S.Set VName
}
initialState :: VNameSource -> LiftState
initialState :: VNameSource -> LiftState
initialState VNameSource
src = VNameSource -> [ValBind] -> Set VName -> LiftState
State VNameSource
src [ValBind]
forall a. Monoid a => a
mempty (Set VName -> LiftState) -> Set VName -> LiftState
forall a b. (a -> b) -> a -> b
$ [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ Map VName Intrinsic -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName Intrinsic
intrinsics
newtype LiftM a = LiftM (ReaderT Env (State LiftState) a)
deriving (a -> LiftM b -> LiftM a
(a -> b) -> LiftM a -> LiftM b
(forall a b. (a -> b) -> LiftM a -> LiftM b)
-> (forall a b. a -> LiftM b -> LiftM a) -> Functor LiftM
forall a b. a -> LiftM b -> LiftM a
forall a b. (a -> b) -> LiftM a -> LiftM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> LiftM b -> LiftM a
$c<$ :: forall a b. a -> LiftM b -> LiftM a
fmap :: (a -> b) -> LiftM a -> LiftM b
$cfmap :: forall a b. (a -> b) -> LiftM a -> LiftM b
Functor, Functor LiftM
a -> LiftM a
Functor LiftM
-> (forall a. a -> LiftM a)
-> (forall a b. LiftM (a -> b) -> LiftM a -> LiftM b)
-> (forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM a)
-> Applicative LiftM
LiftM a -> LiftM b -> LiftM b
LiftM a -> LiftM b -> LiftM a
LiftM (a -> b) -> LiftM a -> LiftM b
(a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: LiftM a -> LiftM b -> LiftM a
$c<* :: forall a b. LiftM a -> LiftM b -> LiftM a
*> :: LiftM a -> LiftM b -> LiftM b
$c*> :: forall a b. LiftM a -> LiftM b -> LiftM b
liftA2 :: (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftM a -> LiftM b -> LiftM c
<*> :: LiftM (a -> b) -> LiftM a -> LiftM b
$c<*> :: forall a b. LiftM (a -> b) -> LiftM a -> LiftM b
pure :: a -> LiftM a
$cpure :: forall a. a -> LiftM a
$cp1Applicative :: Functor LiftM
Applicative, Applicative LiftM
a -> LiftM a
Applicative LiftM
-> (forall a b. LiftM a -> (a -> LiftM b) -> LiftM b)
-> (forall a b. LiftM a -> LiftM b -> LiftM b)
-> (forall a. a -> LiftM a)
-> Monad LiftM
LiftM a -> (a -> LiftM b) -> LiftM b
LiftM a -> LiftM b -> LiftM b
forall a. a -> LiftM a
forall a b. LiftM a -> LiftM b -> LiftM b
forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> LiftM a
$creturn :: forall a. a -> LiftM a
>> :: LiftM a -> LiftM b -> LiftM b
$c>> :: forall a b. LiftM a -> LiftM b -> LiftM b
>>= :: LiftM a -> (a -> LiftM b) -> LiftM b
$c>>= :: forall a b. LiftM a -> (a -> LiftM b) -> LiftM b
$cp1Monad :: Applicative LiftM
Monad, MonadReader Env, MonadState LiftState)
instance MonadFreshNames LiftM where
putNameSource :: VNameSource -> LiftM ()
putNameSource VNameSource
src = (LiftState -> LiftState) -> LiftM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftState -> LiftState) -> LiftM ())
-> (LiftState -> LiftState) -> LiftM ()
forall a b. (a -> b) -> a -> b
$ \LiftState
s -> LiftState
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
getNameSource :: LiftM VNameSource
getNameSource = (LiftState -> VNameSource) -> LiftM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftState -> VNameSource
stateNameSource
runLiftM :: VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM :: VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM VNameSource
src (LiftM ReaderT Env (State LiftState) ()
m) =
let s :: LiftState
s = State LiftState () -> LiftState -> LiftState
forall s a. State s a -> s -> s
execState (ReaderT Env (State LiftState) () -> Env -> State LiftState ()
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Env (State LiftState) ()
m Env
initialEnv) (VNameSource -> LiftState
initialState VNameSource
src)
in ([ValBind] -> [ValBind]
forall a. [a] -> [a]
reverse (LiftState -> [ValBind]
stateValBinds LiftState
s), LiftState -> VNameSource
stateNameSource LiftState
s)
addValBind :: ValBind -> LiftM ()
addValBind :: ValBind -> LiftM ()
addValBind ValBind
vb = (LiftState -> LiftState) -> LiftM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftState -> LiftState) -> LiftM ())
-> (LiftState -> LiftState) -> LiftM ()
forall a b. (a -> b) -> a -> b
$ \LiftState
s ->
LiftState
s
{ stateValBinds :: [ValBind]
stateValBinds = ValBind
vb ValBind -> [ValBind] -> [ValBind]
forall a. a -> [a] -> [a]
: LiftState -> [ValBind]
stateValBinds LiftState
s,
stateGlobal :: Set VName
stateGlobal = (Set VName -> VName -> Set VName)
-> Set VName -> [VName] -> Set VName
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> Set VName -> Set VName)
-> Set VName -> VName -> Set VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert) (LiftState -> Set VName
stateGlobal LiftState
s) (ValBind -> [VName]
valBindBound ValBind
vb)
}
replacing :: VName -> Exp -> LiftM a -> LiftM a
replacing :: VName -> Exp -> LiftM a -> LiftM a
replacing VName
v Exp
e = (Env -> Env) -> LiftM a -> LiftM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> LiftM a -> LiftM a)
-> (Env -> Env) -> LiftM a -> LiftM a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
Env
env {envReplace :: Map VName Exp
envReplace = VName -> Exp -> Map VName Exp -> Map VName Exp
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Exp
e (Map VName Exp -> Map VName Exp) -> Map VName Exp -> Map VName Exp
forall a b. (a -> b) -> a -> b
$ Env -> Map VName Exp
envReplace Env
env}
existentials :: Exp -> S.Set VName
existentials :: Exp -> Set VName
existentials Exp
e =
let here :: Set VName
here = case Exp
e of
AppExp (Apply Exp
_ Exp
_ (Info (Diet
_, Maybe VName
pdim)) SrcLoc
_) (Info AppRes
res) ->
[VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList (Maybe VName -> [VName]
forall a. Maybe a -> [a]
maybeToList Maybe VName
pdim [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ AppRes -> [VName]
appResExt AppRes
res)
AppExp AppExpBase Info VName
_ (Info AppRes
res) ->
[VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList (AppRes -> [VName]
appResExt AppRes
res)
Exp
_ ->
Set VName
forall a. Monoid a => a
mempty
m :: ASTMapper (StateT (Set VName) Identity)
m = ASTMapper (StateT (Set VName) Identity)
forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> StateT (Set VName) Identity Exp
mapOnExp = \Exp
e' -> (Set VName -> Set VName) -> StateT (Set VName) Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Exp -> Set VName
existentials Exp
e') StateT (Set VName) Identity ()
-> StateT (Set VName) Identity Exp
-> StateT (Set VName) Identity Exp
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp -> StateT (Set VName) Identity Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e'}
in StateT (Set VName) Identity Exp -> Set VName -> Set VName
forall s a. State s a -> s -> s
execState (ASTMapper (StateT (Set VName) Identity)
-> Exp -> StateT (Set VName) Identity Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper (StateT (Set VName) Identity)
m Exp
e) Set VName
here
freeSizes :: S.Set VName -> FV
freeSizes :: Set VName -> FV
freeSizes Set VName
vs =
Map VName StructType -> FV
FV (Map VName StructType -> FV) -> Map VName StructType -> FV
forall a b. (a -> b) -> a -> b
$ [(VName, StructType)] -> Map VName StructType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, StructType)] -> Map VName StructType)
-> [(VName, StructType)] -> Map VName StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> [StructType] -> [(VName, StructType)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Set VName -> [VName]
forall a. Set a -> [a]
S.toList Set VName
vs) ([StructType] -> [(VName, StructType)])
-> [StructType] -> [(VName, StructType)]
forall a b. (a -> b) -> a -> b
$ StructType -> [StructType]
forall a. a -> [a]
repeat (StructType -> [StructType]) -> StructType -> [StructType]
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase Size () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase Size () -> StructType)
-> ScalarTypeBase Size () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase Size ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim (PrimType -> ScalarTypeBase Size ())
-> PrimType -> ScalarTypeBase Size ()
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
Signed IntType
Int64
liftFunction :: VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction :: VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction VName
fname [TypeParam]
tparams [Pat]
params (RetType [VName]
dims StructType
ret) Exp
funbody = do
Set VName
global <- (LiftState -> Set VName) -> LiftM (Set VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftState -> Set VName
stateGlobal
let bound :: Set VName
bound =
Set VName
global
Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (Pat -> Set VName) -> [Pat] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set VName
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames [Pat]
params
Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ((TypeParam -> VName) -> [TypeParam] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TypeParam -> VName
forall vn. TypeParamBase vn -> vn
typeParamName [TypeParam]
tparams)
Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
dims
free :: [(VName, StructType)]
free =
let immediate_free :: FV
immediate_free = Exp -> FV
freeInExp Exp
funbody FV -> Set VName -> FV
`freeWithout` (Set VName
bound Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Exp -> Set VName
existentials Exp
funbody)
sizes_in_free :: Set VName
sizes_in_free =
(StructType -> Set VName) -> [StructType] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
forall as. TypeBase Size as -> Set VName
freeInType ([StructType] -> Set VName) -> [StructType] -> Set VName
forall a b. (a -> b) -> a -> b
$ Map VName StructType -> [StructType]
forall k a. Map k a -> [a]
M.elems (Map VName StructType -> [StructType])
-> Map VName StructType -> [StructType]
forall a b. (a -> b) -> a -> b
$ FV -> Map VName StructType
unFV FV
immediate_free
sizes :: FV
sizes =
Set VName -> FV
freeSizes (Set VName -> FV) -> Set VName -> FV
forall a b. (a -> b) -> a -> b
$
Set VName
sizes_in_free
Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (Pat -> Set VName) -> [Pat] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Pat -> Set VName
freeInPat [Pat]
params
Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> StructType -> Set VName
forall as. TypeBase Size as -> Set VName
freeInType StructType
ret
in Map VName StructType -> [(VName, StructType)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName StructType -> [(VName, StructType)])
-> Map VName StructType -> [(VName, StructType)]
forall a b. (a -> b) -> a -> b
$ FV -> Map VName StructType
unFV (FV -> Map VName StructType) -> FV -> Map VName StructType
forall a b. (a -> b) -> a -> b
$ FV
immediate_free FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> (FV
sizes FV -> Set VName -> FV
`freeWithout` Set VName
bound)
sizes_in_types :: Set VName
sizes_in_types =
(StructType -> Set VName) -> [StructType] -> Set VName
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap StructType -> Set VName
forall as. TypeBase Size as -> Set VName
freeInType (StructType
ret StructType -> [StructType] -> [StructType]
forall a. a -> [a] -> [a]
: ((VName, StructType) -> StructType)
-> [(VName, StructType)] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map (VName, StructType) -> StructType
forall a b. (a, b) -> b
snd [(VName, StructType)]
free [StructType] -> [StructType] -> [StructType]
forall a. [a] -> [a] -> [a]
++ (Pat -> StructType) -> [Pat] -> [StructType]
forall a b. (a -> b) -> [a] -> [b]
map Pat -> StructType
patternStructType [Pat]
params)
isSize :: (VName, b) -> Bool
isSize (VName
v, b
_) = VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes_in_types
([(VName, StructType)]
free_dims, [(VName, StructType)]
free_nondims) = ((VName, StructType) -> Bool)
-> [(VName, StructType)]
-> ([(VName, StructType)], [(VName, StructType)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (VName, StructType) -> Bool
forall b. (VName, b) -> Bool
isSize [(VName, StructType)]
free
free_params :: [Pat]
free_params =
((VName, StructType) -> Pat) -> [(VName, StructType)] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, StructType) -> Pat
forall vn as. (vn, TypeBase Size as) -> PatBase Info vn
mkParam ((VName, StructType) -> Pat)
-> ((VName, StructType) -> (VName, StructType))
-> (VName, StructType)
-> Pat
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StructType -> StructType)
-> (VName, StructType) -> (VName, StructType)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (StructType -> Uniqueness -> StructType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique)) ([(VName, StructType)] -> [Pat]) -> [(VName, StructType)] -> [Pat]
forall a b. (a -> b) -> a -> b
$
[(VName, StructType)]
free_dims [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims
ValBind -> LiftM ()
addValBind (ValBind -> LiftM ()) -> ValBind -> LiftM ()
forall a b. (a -> b) -> a -> b
$
ValBind :: forall (f :: * -> *) vn.
Maybe (f EntryPoint)
-> vn
-> Maybe (TypeExp vn)
-> f StructRetType
-> [TypeParamBase vn]
-> [PatBase f vn]
-> ExpBase f vn
-> Maybe DocComment
-> [AttrInfo vn]
-> SrcLoc
-> ValBindBase f vn
ValBind
{ valBindName :: VName
valBindName = VName
fname,
valBindTypeParams :: [TypeParam]
valBindTypeParams = [TypeParam]
tparams,
valBindParams :: [Pat]
valBindParams = [Pat]
free_params [Pat] -> [Pat] -> [Pat]
forall a. [a] -> [a] -> [a]
++ [Pat]
params,
valBindRetDecl :: Maybe (TypeExp VName)
valBindRetDecl = Maybe (TypeExp VName)
forall a. Maybe a
Nothing,
valBindRetType :: Info StructRetType
valBindRetType = StructRetType -> Info StructRetType
forall a. a -> Info a
Info ([VName] -> StructType -> StructRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims StructType
ret),
valBindBody :: Exp
valBindBody = Exp
funbody,
valBindDoc :: Maybe DocComment
valBindDoc = Maybe DocComment
forall a. Maybe a
Nothing,
valBindAttrs :: [AttrInfo VName]
valBindAttrs = [AttrInfo VName]
forall a. Monoid a => a
mempty,
valBindLocation :: SrcLoc
valBindLocation = SrcLoc
forall a. Monoid a => a
mempty,
valBindEntryPoint :: Maybe (Info EntryPoint)
valBindEntryPoint = Maybe (Info EntryPoint)
forall a. Maybe a
Nothing
}
Exp -> LiftM Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(Exp -> LiftM Exp) -> Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> [(VName, StructType)] -> Exp
apply
(QualName VName -> Info PatType -> SrcLoc -> Exp
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (VName -> QualName VName
forall v. v -> QualName v
qualName VName
fname) (PatType -> Info PatType
forall a. a -> Info a
Info ([(VName, StructType)] -> PatType
forall as. [(VName, TypeBase Size as)] -> PatType
augType ([(VName, StructType)] -> PatType)
-> [(VName, StructType)] -> PatType
forall a b. (a -> b) -> a -> b
$ [(VName, StructType)]
free_dims [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims)) SrcLoc
forall a. Monoid a => a
mempty)
([(VName, StructType)] -> Exp) -> [(VName, StructType)] -> Exp
forall a b. (a -> b) -> a -> b
$ [(VName, StructType)]
free_dims [(VName, StructType)]
-> [(VName, StructType)] -> [(VName, StructType)]
forall a. [a] -> [a] -> [a]
++ [(VName, StructType)]
free_nondims
where
orig_type :: StructType
orig_type = [Pat] -> StructRetType -> StructType
funType [Pat]
params (StructRetType -> StructType) -> StructRetType -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> StructType -> StructRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [VName]
dims StructType
ret
mkParam :: (vn, TypeBase Size as) -> PatBase Info vn
mkParam (vn
v, TypeBase Size as
t) = vn -> Info PatType -> SrcLoc -> PatBase Info vn
forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id vn
v (PatType -> Info PatType
forall a. a -> Info a
Info (TypeBase Size as -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase Size as
t)) SrcLoc
forall a. Monoid a => a
mempty
freeVar :: (vn, TypeBase Size as) -> ExpBase Info vn
freeVar (vn
v, TypeBase Size as
t) = QualName vn -> Info PatType -> SrcLoc -> ExpBase Info vn
forall (f :: * -> *) vn.
QualName vn -> f PatType -> SrcLoc -> ExpBase f vn
Var (vn -> QualName vn
forall v. v -> QualName v
qualName vn
v) (PatType -> Info PatType
forall a. a -> Info a
Info (TypeBase Size as -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct TypeBase Size as
t)) SrcLoc
forall a. Monoid a => a
mempty
augType :: [(VName, TypeBase Size as)] -> PatType
augType [(VName, TypeBase Size as)]
rem_free = StructType -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct (StructType -> PatType) -> StructType -> PatType
forall a b. (a -> b) -> a -> b
$ [Pat] -> StructRetType -> StructType
funType (((VName, TypeBase Size as) -> Pat)
-> [(VName, TypeBase Size as)] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TypeBase Size as) -> Pat
forall vn as. (vn, TypeBase Size as) -> PatBase Info vn
mkParam [(VName, TypeBase Size as)]
rem_free) (StructRetType -> StructType) -> StructRetType -> StructType
forall a b. (a -> b) -> a -> b
$ [VName] -> StructType -> StructRetType
forall dim as. [VName] -> TypeBase dim as -> RetTypeBase dim as
RetType [] StructType
orig_type
apply :: Exp -> [(VName, StructType)] -> Exp
apply :: Exp -> [(VName, StructType)] -> Exp
apply Exp
f [] = Exp
f
apply Exp
f ((VName, StructType)
p : [(VName, StructType)]
rem_ps) =
let inner_ret :: AppRes
inner_ret = PatType -> [VName] -> AppRes
AppRes (PatType -> PatType
forall dim as. TypeBase dim as -> TypeBase dim Aliasing
fromStruct ([(VName, StructType)] -> PatType
forall as. [(VName, TypeBase Size as)] -> PatType
augType [(VName, StructType)]
rem_ps)) [VName]
forall a. Monoid a => a
mempty
inner :: Exp
inner = AppExpBase Info VName -> Info AppRes -> Exp
forall (f :: * -> *) vn.
AppExpBase f vn -> f AppRes -> ExpBase f vn
AppExp (Exp
-> Exp
-> Info (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase Info VName
forall (f :: * -> *) vn.
ExpBase f vn
-> ExpBase f vn
-> f (Diet, Maybe VName)
-> SrcLoc
-> AppExpBase f vn
Apply Exp
f ((VName, StructType) -> Exp
forall vn as. (vn, TypeBase Size as) -> ExpBase Info vn
freeVar (VName, StructType)
p) ((Diet, Maybe VName) -> Info (Diet, Maybe VName)
forall a. a -> Info a
Info (Diet
Observe, Maybe VName
forall a. Maybe a
Nothing)) SrcLoc
forall a. Monoid a => a
mempty) (AppRes -> Info AppRes
forall a. a -> Info a
Info AppRes
inner_ret)
in Exp -> [(VName, StructType)] -> Exp
apply Exp
inner [(VName, StructType)]
rem_ps
transformExp :: Exp -> LiftM Exp
transformExp :: Exp -> LiftM Exp
transformExp (AppExp (LetFun VName
fname ([TypeParam]
tparams, [Pat]
params, Maybe (TypeExp VName)
_, Info StructRetType
ret, Exp
funbody) Exp
body SrcLoc
_) Info AppRes
_) = do
Exp
funbody' <- Exp -> LiftM Exp
transformExp Exp
funbody
VName
fname' <- String -> LiftM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> LiftM VName) -> String -> LiftM VName
forall a b. (a -> b) -> a -> b
$ String
"lifted_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
baseString VName
fname
Exp
lifted_call <- VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction VName
fname' [TypeParam]
tparams [Pat]
params StructRetType
ret Exp
funbody'
VName -> Exp -> LiftM Exp -> LiftM Exp
forall a. VName -> Exp -> LiftM a -> LiftM a
replacing VName
fname Exp
lifted_call (LiftM Exp -> LiftM Exp) -> LiftM Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ Exp -> LiftM Exp
transformExp Exp
body
transformExp (Lambda [Pat]
params Exp
body Maybe (TypeExp VName)
_ (Info (Aliasing
_, StructRetType
ret)) SrcLoc
_) = do
Exp
body' <- Exp -> LiftM Exp
transformExp Exp
body
VName
fname <- String -> LiftM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"lifted_lambda"
VName -> [TypeParam] -> [Pat] -> StructRetType -> Exp -> LiftM Exp
liftFunction VName
fname [] [Pat]
params StructRetType
ret Exp
body'
transformExp e :: Exp
e@(Var QualName VName
v Info PatType
_ SrcLoc
_) =
(Env -> Exp) -> LiftM Exp
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Exp -> Maybe Exp -> Exp
forall a. a -> Maybe a -> a
fromMaybe Exp
e (Maybe Exp -> Exp) -> (Env -> Maybe Exp) -> Env -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName Exp -> Maybe Exp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) (Map VName Exp -> Maybe Exp)
-> (Env -> Map VName Exp) -> Env -> Maybe Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Map VName Exp
envReplace)
transformExp Exp
e =
ASTMapper LiftM -> Exp -> LiftM Exp
forall x (m :: * -> *).
(ASTMappable x, Monad m) =>
ASTMapper m -> x -> m x
astMap ASTMapper LiftM
m Exp
e
where
m :: ASTMapper LiftM
m = ASTMapper LiftM
forall (m :: * -> *). Monad m => ASTMapper m
identityMapper {mapOnExp :: Exp -> LiftM Exp
mapOnExp = Exp -> LiftM Exp
transformExp}
transformValBind :: ValBind -> LiftM ()
transformValBind :: ValBind -> LiftM ()
transformValBind ValBind
vb = do
Exp
e <- Exp -> LiftM Exp
transformExp (Exp -> LiftM Exp) -> Exp -> LiftM Exp
forall a b. (a -> b) -> a -> b
$ ValBind -> Exp
forall (f :: * -> *) vn. ValBindBase f vn -> ExpBase f vn
valBindBody ValBind
vb
ValBind -> LiftM ()
addValBind (ValBind -> LiftM ()) -> ValBind -> LiftM ()
forall a b. (a -> b) -> a -> b
$ ValBind
vb {valBindBody :: Exp
valBindBody = Exp
e}
{-# NOINLINE transformProg #-}
transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg :: [ValBind] -> m [ValBind]
transformProg [ValBind]
vbinds =
(VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ([ValBind], VNameSource)) -> m [ValBind])
-> (VNameSource -> ([ValBind], VNameSource)) -> m [ValBind]
forall a b. (a -> b) -> a -> b
$ \VNameSource
namesrc ->
VNameSource -> LiftM () -> ([ValBind], VNameSource)
runLiftM VNameSource
namesrc (LiftM () -> ([ValBind], VNameSource))
-> LiftM () -> ([ValBind], VNameSource)
forall a b. (a -> b) -> a -> b
$ (ValBind -> LiftM ()) -> [ValBind] -> LiftM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ValBind -> LiftM ()
transformValBind [ValBind]
vbinds