{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Map (vjpMap) where
import Control.Monad
import Data.Bifunctor (first)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitAt3)
data AdjVar
=
FreeAcc VName
|
FreeArr VName Shape PrimType
|
FreeNonAcc VName
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars = (VName -> ADM AdjVar) -> [VName] -> ADM [AdjVar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM AdjVar
f
where
f :: VName -> ADM AdjVar
f VName
v = do
VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
case Type
v_adj_t of
Array PrimType
pt Shape
shape NoUniqueness
_ ->
AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> PrimType -> AdjVar
FreeArr VName
v Shape
shape PrimType
pt
Acc {} ->
AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeAcc VName
v
Type
_ ->
AdjVar -> ADM AdjVar
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AdjVar -> ADM AdjVar) -> AdjVar -> ADM AdjVar
forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeNonAcc VName
v
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [] = ([], [], [])
partitionAdjVars (AdjVar
fv : [AdjVar]
fvs) =
case AdjVar
fv of
FreeArr VName
v Shape
shape PrimType
t -> ((VName
v, (Shape
shape, PrimType
t)) (VName, (Shape, PrimType))
-> [(VName, (Shape, PrimType))] -> [(VName, (Shape, PrimType))]
forall a. a -> [a] -> [a]
: [(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs)
FreeAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, [VName]
zs)
FreeNonAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, [VName]
ys, VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
zs)
where
([(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs) = [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [AdjVar]
fvs
buildRenamedBody ::
(MonadBuilder m) =>
m (Result, a) ->
m (Body (Rep m), a)
buildRenamedBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody m (Result, a)
m = do
(Body (Rep m)
body, a
x) <- m (Result, a) -> m (Body (Rep m), a)
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m
Body (Rep m)
body' <- Body (Rep m) -> m (Body (Rep m))
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body (Rep m)
body
(Body (Rep m), a) -> m (Body (Rep m), a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body', a
x)
withAcc ::
[(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] ->
([VName] -> ADM Result) ->
ADM [VName]
withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [] [VName] -> ADM Result
m =
(SubExpRes -> ADM VName) -> Result -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"withacc_res" (Exp SOACS -> ADM VName)
-> (SubExpRes -> Exp SOACS) -> SubExpRes -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> ADM [VName]) -> ADM Result -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> ADM Result
m []
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs [VName] -> ADM Result
m = do
([Param Type]
cert_params, [Param Type]
acc_params) <- ([(Param Type, Param Type)] -> ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Param Type)] -> ([Param Type], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip (ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type]))
-> ADM [(Param Type, Param Type)]
-> ADM ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$
[(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
-> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs (((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
-> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)])
-> ((Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
-> ADM (Param Type, Param Type))
-> ADM [(Param Type, Param Type)]
forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
_) -> do
Param Type
cert_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
[Type]
ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Type -> Type) -> ADM Type -> ADM Type
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape)) (ADM Type -> ADM Type) -> (VName -> ADM Type) -> VName -> ADM Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
arrs
Param Type
acc_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_p" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
cert_param) Shape
shape [Type]
ts NoUniqueness
NoUniqueness
(Param Type, Param Type) -> ADM (Param Type, Param Type)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
cert_param, Param Type
acc_param)
Lambda SOACS
acc_lam <-
ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a. ADM a -> ADM a
subAD (ADM (Lambda SOACS) -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
acc_params) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName] -> ADM Result
m ([VName] -> ADM Result) -> [VName] -> ADM Result
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_params
String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"withhacc_res" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> Lambda SOACS -> Exp SOACS
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam
vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap :: VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
res_adjs StmAux ()
_ SubExp
w Lambda SOACS
map_lam [VName]
as
| Just [[(InBounds, SubExp, SubExp)]]
res_ivs <- (Adj -> Maybe [(InBounds, SubExp, SubExp)])
-> [Adj] -> Maybe [[(InBounds, SubExp, SubExp)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse [Adj]
res_adjs = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
[VName]
free <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
as
[Type]
free_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free
let adjs_for :: [VName]
adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
adjs_ts :: [Type]
adjs_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
free_ts
let oneHot :: Int -> Adj -> [Adj]
oneHot Int
res_i Adj
adj_v = (Int -> Type -> Adj) -> [Int] -> [Type] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Type -> Adj
f [Int
0 :: Int ..] ([Type] -> [Adj]) -> [Type] -> [Adj]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
where
f :: Int -> Type -> Adj
f Int
j Type
t
| Int
res_i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Adj
adj_v
| Bool
otherwise = Shape -> PrimType -> Adj
AdjZero (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
ooBounds :: SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body (Rep ADM), [SubExp] -> [Adj])
ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
[(VName, Type)] -> ((VName, Type) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [Type]
adjs_ts) (((VName, Type) -> ADM ()) -> ADM ())
-> ((VName, Type) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, Type
t) -> do
SubExp
scratch <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"oo_scratch" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
OutOfBounds, SubExp
adj_i) SubExp
scratch
[VName] -> (VName -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
free ((VName -> ADM ()) -> ADM ()) -> (VName -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> VName -> ADM ()
insAdj VName
v (VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal (Adj -> ADM VName) -> ADM Adj -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
([SubExp] -> Result)
-> ([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes (([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj]))
-> ([Adj] -> ([SubExp], [SubExp] -> [Adj]))
-> [Adj]
-> (Result, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps ([Adj] -> (Result, [SubExp] -> [Adj]))
-> ADM [Adj] -> ADM (Result, [SubExp] -> [Adj])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName]
as [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free)
inBounds :: Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v = ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a. ADM a -> ADM a
subAD (ADM (Body SOACS, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj]))
-> (ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM (Result, [SubExp] -> [Adj])
-> ADM (Body (Rep ADM), [SubExp] -> [Adj])
ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody (ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj]))
-> ADM (Result, [SubExp] -> [Adj])
-> ADM (Body SOACS, [SubExp] -> [Adj])
forall a b. (a -> b) -> a -> b
$ do
[(Param Type, VName)] -> ((Param Type, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName]
as) (((Param Type, VName) -> ADM ()) -> ADM ())
-> ((Param Type, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
a) -> do
Type
a_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
[VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp SOACS -> ADM ())
-> (Slice SubExp -> Exp SOACS) -> Slice SubExp -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
a (Slice SubExp -> ADM ()) -> Slice SubExp -> ADM ()
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
adj_i]
[SubExp]
adj_elems <-
(Result -> [SubExp]) -> ADM Result -> ADM [SubExp]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp) (ADM Result -> ADM [SubExp])
-> (Lambda SOACS -> ADM Result) -> Lambda SOACS -> ADM [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body (Rep ADM) -> ADM Result
Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
(Lambda SOACS -> ADM [SubExp])
-> ADM (Lambda SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (Int -> Adj -> [Adj]
oneHot Int
res_i (SubExp -> Adj
AdjVal SubExp
adj_v)) [VName]
adjs_for Lambda SOACS
map_lam
let ([SubExp]
as_adj_elems, [SubExp]
free_adj_elems) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as) [SubExp]
adj_elems
[(VName, SubExp)] -> ((VName, SubExp) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
as_adj_elems) (((VName, SubExp) -> ADM ()) -> ADM ())
-> ((VName, SubExp) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, SubExp
a_adj_elem) ->
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
AssumeBounds, SubExp
adj_i) SubExp
a_adj_elem
[(VName, SubExp)] -> ((VName, SubExp) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
free [SubExp]
free_adj_elems) (((VName, SubExp) -> ADM ()) -> ADM ())
-> ((VName, SubExp) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
adj_se) -> do
VName
adj_se_v <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"adj_v" (BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
adj_se)
VName -> VName -> ADM ()
insAdj VName
v VName
adj_se_v
([SubExp] -> Result)
-> ([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes (([SubExp], [SubExp] -> [Adj]) -> (Result, [SubExp] -> [Adj]))
-> ([Adj] -> ([SubExp], [SubExp] -> [Adj]))
-> [Adj]
-> (Result, [SubExp] -> [Adj])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps ([Adj] -> (Result, [SubExp] -> [Adj]))
-> ADM [Adj] -> ADM (Result, [SubExp] -> [Adj])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj ([VName]
as [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free)
forPos :: Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i (InBounds
check, SubExp
adj_i, SubExp
adj_v) = do
[Adj]
adjs <-
case InBounds
check of
CheckBounds Maybe SubExp
b -> do
(Body SOACS
obbranch, [SubExp] -> [Adj]
mkadjs) <- SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i
(Body SOACS
ibbranch, [SubExp] -> [Adj]
_) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
([SubExp] -> [Adj]) -> ADM [SubExp] -> ADM [Adj]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [Adj]
mkadjs (ADM [SubExp] -> ADM [Adj])
-> (Exp SOACS -> ADM [SubExp]) -> Exp SOACS -> ADM [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"map_adj_elem"
(Exp SOACS -> ADM [Adj]) -> ADM (Exp SOACS) -> ADM [Adj]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(ADM (Exp SOACS)
-> (SubExp -> ADM (Exp SOACS)) -> Maybe SubExp -> ADM (Exp SOACS)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
adj_i)) SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp Maybe SubExp
b)
(Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
ibbranch)
(Body SOACS -> ADM (Body SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
obbranch)
InBounds
AssumeBounds -> do
(Body SOACS
body, [SubExp] -> [Adj]
mkadjs) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
[SubExp] -> [Adj]
mkadjs ([SubExp] -> [Adj]) -> (Result -> [SubExp]) -> Result -> [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Result -> [Adj]) -> ADM Result -> ADM [Adj]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep ADM) -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep ADM)
Body SOACS
body
InBounds
OutOfBounds ->
(VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Adj
lookupAdj [VName]
as
(VName -> Adj -> ADM ()) -> [VName] -> [Adj] -> ADM [()]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Adj -> ADM ()
setAdj ([VName]
as [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
free) [Adj]
adjs
forRes :: Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes Int
res_i = ((InBounds, SubExp, SubExp) -> ADM [()])
-> [(InBounds, SubExp, SubExp)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i)
(Int -> [(InBounds, SubExp, SubExp)] -> ADM ())
-> [Int] -> [[(InBounds, SubExp, SubExp)]] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes [Int
0 ..] [[(InBounds, SubExp, SubExp)]]
res_ivs
where
isSparse :: Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse (AdjSparse (Sparse Shape
shape PrimType
_ [(InBounds, SubExp, SubExp)]
ivs)) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp
w]
[(InBounds, SubExp, SubExp)] -> Maybe [(InBounds, SubExp, SubExp)]
forall a. a -> Maybe a
Just [(InBounds, SubExp, SubExp)]
ivs
isSparse Adj
_ =
Maybe [(InBounds, SubExp, SubExp)]
forall a. Maybe a
Nothing
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
map_lam [VName]
as = ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
[VName]
pat_adj_vals <- [(Adj, Type)] -> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Adj] -> [Type] -> [(Adj, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Adj]
pat_adj (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)) (((Adj, Type) -> ADM VName) -> ADM [VName])
-> ((Adj, Type) -> ADM VName) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \(Adj
adj, Type
t) ->
case Type
t of
Acc {} -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_adj_rep" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> ADM VName) -> ADM VName -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
adj
Type
_ -> Adj -> ADM VName
adjVal Adj
adj
[Param Type]
pat_adj_params <-
(VName -> ADM (Param Type)) -> [VName] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"map_adj_p" (Type -> ADM (Param Type))
-> (Type -> Type) -> Type -> ADM (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> ADM (Param Type))
-> (VName -> ADM Type) -> VName -> ADM (Param Type)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
pat_adj_vals
Lambda SOACS
map_lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
map_lam
[VName]
free <- (VName -> ADM Bool) -> [VName] -> ADM [VName]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam'
[VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free (([VName] -> Names -> ADM ()) -> ADM ())
-> ([VName] -> Names -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \[VName]
free_with_adjs Names
free_without_adjs -> do
[VName]
free_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
free_with_adjs
[Type]
free_adjs_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free_adjs
[Param Type]
free_adjs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"free_adj_p") [Type]
free_adjs_ts
let lam_rev_params :: [Param Type]
lam_rev_params =
Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam' [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
pat_adj_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
free_adjs_params
adjs_for :: [VName]
adjs_for = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam') [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free
Lambda SOACS
lam_rev <-
[LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
lam_rev_params (ADM Result -> ADM (Lambda SOACS))
-> (ADM Result -> ADM Result) -> ADM Result -> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ADM Result -> ADM Result
forall a. ADM a -> ADM a
subAD (ADM Result -> ADM Result)
-> (ADM Result -> ADM Result) -> ADM Result -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> ADM Result -> ADM Result
forall a. Names -> ADM a -> ADM a
noAdjsFor Names
free_without_adjs (ADM Result -> ADM (Lambda SOACS))
-> ADM Result -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
free_with_adjs ([VName] -> ADM ()) -> [VName] -> ADM ()
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
free_adjs_params
Body (Rep ADM) -> ADM Result
Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result)
-> (Lambda SOACS -> Body SOACS) -> Lambda SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody
(Lambda SOACS -> ADM Result) -> ADM (Lambda SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((Param Type -> Adj) -> [Param Type] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Adj
forall t. Param t -> Adj
adjFromParam [Param Type]
pat_adj_params) [VName]
adjs_for Lambda SOACS
map_lam'
([VName]
param_contribs, [VName]
free_contribs) <-
([VName] -> ([VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam'))) (ADM [VName] -> ADM ([VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName])
forall a b. (a -> b) -> a -> b
$
StmAux () -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM [VName] -> ADM [VName])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_adjs" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
pat_adj_vals [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
free_adjs) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_rev)
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
freeContrib [VName]
free [VName]
free_contribs
let param_ts :: [Type]
param_ts = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam')
[(Type, VName, VName)]
-> ((Type, VName, VName) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [VName] -> [(Type, VName, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
param_ts [VName]
as [VName]
param_contribs) (((Type, VName, VName) -> ADM ()) -> ADM ())
-> ((Type, VName, VName) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Type
param_t, VName
a, VName
param_contrib) ->
case Type
param_t of
Acc {} -> VName -> VName -> ADM ()
freeContrib VName
a VName
param_contrib
Type
_ -> VName -> VName -> ADM ()
updateAdj VName
a VName
param_contrib
where
addIdxParams :: Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n Lambda rep
lam = do
[Param (TypeBase shape u)]
idxs <- Int -> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)])
-> m (Param (TypeBase shape u)) -> m [Param (TypeBase shape u)]
forall a b. (a -> b) -> a -> b
$ String -> TypeBase shape u -> m (Param (TypeBase shape u))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" (TypeBase shape u -> m (Param (TypeBase shape u)))
-> TypeBase shape u -> m (Param (TypeBase shape u))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda rep -> m (Lambda rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep -> m (Lambda rep)) -> Lambda rep -> m (Lambda rep)
forall a b. (a -> b) -> a -> b
$ Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = [Param (TypeBase shape u)]
idxs [Param (TypeBase shape u)]
-> [Param (TypeBase shape u)] -> [Param (TypeBase shape u)]
forall a. [a] -> [a] -> [a]
++ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}
accAddLambda :: Int -> Type -> ADM (Lambda SOACS)
accAddLambda Int
n Type
t = Int -> Lambda SOACS -> ADM (Lambda SOACS)
forall {rep} {shape} {u} {m :: * -> *}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n (Lambda SOACS -> ADM (Lambda SOACS))
-> ADM (Lambda SOACS) -> ADM (Lambda SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Lambda SOACS)
addLambda Type
t
withAccInput :: (VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput (VName
v, (a
shape, PrimType
pt)) = do
VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
Lambda SOACS
add_lam <- Int -> Type -> ADM (Lambda SOACS)
accAddLambda (a -> Int
forall a. ArrayShape a => a -> Int
shapeRank a
shape) (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
SubExp
zero <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep ADM)
forall rep. Type -> Exp rep
zeroExp (Type -> Exp (Rep ADM)) -> Type -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
(a, [VName], Maybe (Lambda SOACS, [SubExp]))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, [VName
v_adj], (Lambda SOACS, [SubExp]) -> Maybe (Lambda SOACS, [SubExp])
forall a. a -> Maybe a
Just (Lambda SOACS
add_lam, [SubExp
zero]))
accAdjoints :: [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free [VName] -> Names -> ADM ()
m = do
([(VName, (Shape, PrimType))]
arr_free, [VName]
acc_free, [VName]
nonacc_free) <-
[AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars ([AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName]))
-> ADM [AdjVar]
-> ADM ([(VName, (Shape, PrimType))], [VName], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> ADM [AdjVar]
classifyAdjVars [VName]
free
[(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' <- ((VName, (Shape, PrimType))
-> ADM (Shape, [VName], Maybe (Lambda SOACS, [SubExp])))
-> [(VName, (Shape, PrimType))]
-> ADM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName, (Shape, PrimType))
-> ADM (Shape, [VName], Maybe (Lambda SOACS, [SubExp]))
forall {a}.
ArrayShape a =>
(VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput [(VName, (Shape, PrimType))]
arr_free
let as_nonfree :: [VName]
as_nonfree = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
free) [VName]
as
([VName]
arr_adjs, [VName]
acc_adjs, [VName]
rest_adjs) <-
([VName] -> ([VName], [VName], [VName]))
-> ADM [VName] -> ADM ([VName], [VName], [VName])
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([(VName, (Shape, PrimType))] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
acc_free)) (ADM [VName] -> ADM ([VName], [VName], [VName]))
-> (([VName] -> ADM Result) -> ADM [VName])
-> ([VName] -> ADM Result)
-> ADM ([VName], [VName], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' (([VName] -> ADM Result) -> ADM ([VName], [VName], [VName]))
-> ([VName] -> ADM Result) -> ADM ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ \[VName]
accs -> do
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
accs
() <- [VName] -> Names -> ADM ()
m ([VName]
acc_free [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ ((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Names
namesFromList [VName]
nonacc_free)
[VName]
acc_free_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
acc_free
[VName]
arr_free_adj <- ((VName, (Shape, PrimType)) -> ADM VName)
-> [(VName, (Shape, PrimType))] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName -> ADM VName
lookupAdjVal (VName -> ADM VName)
-> ((VName, (Shape, PrimType)) -> VName)
-> (VName, (Shape, PrimType))
-> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst) [(VName, (Shape, PrimType))]
arr_free
[VName]
nonacc_free_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
nonacc_free
[VName]
as_nonfree_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
as_nonfree
Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [VName]
arr_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
acc_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
nonacc_free_adj [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
as_nonfree_adj
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
acc_free [VName]
acc_adjs
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (((VName, (Shape, PrimType)) -> VName)
-> [(VName, (Shape, PrimType))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Shape, PrimType)) -> VName
forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
arr_adjs
let ([VName]
nonacc_adjs, [VName]
as_nonfree_adjs) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
nonacc_free) [VName]
rest_adjs
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
nonacc_free [VName]
nonacc_adjs
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
as_nonfree [VName]
as_nonfree_adjs
freeContrib :: VName -> VName -> ADM ()
freeContrib VName
v VName
contribs = do
Type
contribs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
contribs
case Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
contribs_t of
Acc {} -> ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
contribs
Type
t -> do
Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
t
SubExp
zero <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
t
ScremaForm SOACS
reduce <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
zero]]
VName
contrib_sum <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_contrib_sum") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
contribs] ScremaForm SOACS
reduce
ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
contrib_sum