{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Hist
( diffMinMaxHist,
diffMulHist,
diffAddHist,
diffHist,
)
where
import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus (IntType IntType
x) = IntType -> Overflow -> BinOp
Add IntType
x Overflow
OverflowUndef
getBinOpPlus (FloatType FloatType
f) = FloatType -> BinOp
FAdd FloatType
f
getBinOpPlus PrimType
_ = forall a. HasCallStack => String -> a
error String
"In getBinOpMul, Hist.hs: input not supported"
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getBinOpDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getBinOpDiv PrimType
_ = forall a. HasCallStack => String -> a
error String
"In getBinOpDiv, Hist.hs: input not supported"
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis
elseIf :: PrimType -> [(ADM (Exp SOACS), ADM (Exp SOACS))] -> [ADM (Body SOACS)] -> ADM (Exp SOACS)
elseIf :: PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2)] [ADM (Body SOACS)
bt, ADM (Body SOACS)
bf] =
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp SOACS)
c1 ADM (Exp SOACS)
c2)
ADM (Body SOACS)
bt
ADM (Body SOACS)
bf
elseIf PrimType
t ((ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2) : [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs) (ADM (Body SOACS)
bt : [ADM (Body SOACS)]
bs) =
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp SOACS)
c1 ADM (Exp SOACS)
c2)
ADM (Body SOACS)
bt
forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure
forall a b. (a -> b) -> a -> b
$ PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs [ADM (Body SOACS)]
bs
elseIf PrimType
_ [(ADM (Exp SOACS), ADM (Exp SOACS))]
_ [ADM (Body SOACS)]
_ = forall a. HasCallStack => String -> a
error String
"In elseIf, Hist.hs: input not supported"
bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
s =
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
( \(SubExpRes Certs
cs SubExp
se) -> do
VName
bn <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
s
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
bn] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
bn
)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [] [PrimType]
_ Lambda SOACS
lam = forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
nestedmap s :: [SubExp]
s@(SubExp
h : [SubExp]
r) [PrimType]
pt Lambda SOACS
lam = do
[Param Type]
params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\PrimType
tp -> forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
tp (forall d. [d] -> ShapeBase d
Shape [SubExp]
s) NoUniqueness
NoUniqueness) [PrimType]
pt
Lambda SOACS
body <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
r [PrimType]
pt Lambda SOACS
lam
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
params forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
h (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
params) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
body)
mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' :: Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam [Type]
tps SubExp
n = do
Lambda SOACS
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
[Param Type]
ds_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ds_param") [Type]
tps
[Param Type]
hs_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"hs_param") [Type]
tps
let ds_pars :: [VName]
ds_pars = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dec. Param dec -> VName
paramName [Param Type]
ds_params
let hs_pars :: [VName]
hs_pars = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dec. Param dec -> VName
paramName [Param Type]
hs_params
Lambda SOACS
lam_map <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
ds_params forall a. Semigroup a => a -> a -> a
<> [Param Type]
hs_params) forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f'" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
ds_pars forall a. Semigroup a => a -> a -> a
<> [VName]
hs_pars) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam')
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
ds_pars, [VName]
hs_pars, Lambda SOACS
lam_map)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam [Type]
tps SubExp
n = do
Lambda SOACS
lam_l <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
Lambda SOACS
lam_r <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
let q :: Int
q = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
([Param Type]
lps, [Param Type]
aps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
([Param Type]
ips, [Param Type]
rps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
Lambda SOACS
lam' <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) forall a b. (a -> b) -> a -> b
$ do
[SubExpRes]
lam_l_res <- forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips [SubExpRes]
lam_l_res) forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
ip] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r
[Param Type]
ls_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ls_param") [Type]
tps
[Param Type]
as_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"as_param") [Type]
tps
[Param Type]
rs_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"rs_param") [Type]
tps
let map_params :: [Param Type]
map_params = [Param Type]
ls_params forall a. Semigroup a => a -> a -> a
<> [Param Type]
as_params forall a. Semigroup a => a -> a -> a
<> [Param Type]
rs_params
Lambda SOACS
lam_map <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
map_params forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f" forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
map_params) forall a b. (a -> b) -> a -> b
$
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
as_params, Lambda SOACS
lam_map)
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout VName
is SubExp
n SubExp
w = do
Param Type
par_is <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"is" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
is'_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_is] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"is'"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, forall dec. Param dec -> VName
paramName Param Type
par_is))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_is)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"is'" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
is) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
is'_lam
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
dst VName
is [VName]
vs = do
[Type]
tps <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
[Param Type]
scatter_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"scatter_param" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
tps
Lambda SOACS
scatter_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_i forall a. a -> [a] -> [a]
: [Param Type]
scatter_params) forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scatter_map_res") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
[Exp SOACS]
p1 <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
scatter_params) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_i
[Exp SOACS]
p2 <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scatter_params
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Exp SOACS]
p1 forall a. Semigroup a => a -> a -> a
<> [Exp SOACS]
p2
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scatter_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
n (VName
is forall a. a -> [a] -> [a]
: [VName]
vs) Lambda SOACS
scatter_lam forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t -> (,,) (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) Int
1) [Type]
tps [VName]
dst
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
vs [DimIndex SubExp]
s = do
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
( \VName
x -> do
Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"sorted" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
s)
)
[VName]
vs
diffMinMaxHist ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
minmax SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
Type
vs_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
let vs_elm_type :: PrimType
vs_elm_type = forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
let vs_dims :: [SubExp]
vs_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
let inner_dims :: [SubExp]
inner_dims = forall a. [a] -> [a]
tail [SubExp]
vs_dims
let nr_dims :: Int
nr_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims
Type
dst_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
let dst_dims :: [SubExp]
dst_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
VName
dst_cpy <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_copy") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
dst
Param Type
acc_v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Param Type
acc_i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Param Type
i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
hist_lam_inner <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
acc_v_p, Param Type
acc_i_p, Param Type
v_p, Param Type
i_p] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
[ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
]
)
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
[ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
(PrimType -> CmpOp
CmpEq PrimType
t)
(forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
(forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
]
)
Lambda SOACS
hist_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
int64, PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
hist_lam_inner
VName
dst_minus_ones <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"minus_ones" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
SubExp
ne_minus_ones <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minus_ones" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
VName
iota_n <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
VName
inp_iota <- do
if Int
nr_dims forall a. Eq a => a -> a -> Bool
== Int
1
then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
iota_n
else do
Param Type
i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"res" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
let hist_op :: HistOp SOACS
hist_op = forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy, VName
dst_minus_ones] [SubExp
ne, if Int
nr_dims forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1) else SubExp
ne_minus_ones] Lambda SOACS
hist_lam
Lambda SOACS
f' <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 (forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
VName
x_inds <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_inds")
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_inds] forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs, VName
inp_iota] [HistOp SOACS
hist_op] Lambda SOACS
f'
ADM ()
m
VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x
Param Type
x_ind_dst <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
x_bar_dst <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Lambda SOACS
dst_lam_inner <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
x_ind_dst, Param Type
x_bar_dst] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"dst_bar"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param Type
x_ind_dst) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
x_bar_dst)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
Lambda SOACS
dst_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
vs_elm_type] Lambda SOACS
dst_lam_inner
VName
dst_bar <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
x_inds, VName
x_bar] (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
dst_lam)
VName -> VName -> ADM ()
updateAdj VName
dst VName
dst_bar
VName
vs_bar <- VName -> ADM VName
lookupAdjVal VName
vs
[VName]
inds' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"inds" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
inner_dims []
let inds :: [VName]
inds = VName
x_inds forall a. a -> [a] -> [a]
: [VName]
inds'
[Param Type]
par_x_ind_vs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nr_dims forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
par_x_bar_vs <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Lambda SOACS
vs_lam_inner <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_x_bar_vs forall a. a -> [a] -> [a]
: [Param Type]
par_x_ind_vs) forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Param Type]
par_x_ind_vs) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
SubExp
vs_bar_i <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
vs_bar forall a. Semigroup a => a -> a -> a
<> String
"_el") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
vs_bar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
par_x_ind_vs
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpPlus PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_x_bar_vs) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
vs_bar_i)
)
Lambda SOACS
vs_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims (PrimType
vs_elm_type forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate Int
nr_dims PrimType
int64) Lambda SOACS
vs_lam_inner
VName
vs_bar_p <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_partial") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
x_bar forall a. a -> [a] -> [a]
: [VName]
inds) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
vs_lam)
SubExp
q <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"q"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dst_dims
[VName]
scatter_inps <- do
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flat" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary (forall d. [d] -> ShapeBase d
Shape [SubExp
q])) forall a b. (a -> b) -> a -> b
$
[VName]
inds forall a. [a] -> [a] -> [a]
++ [VName
vs_bar_p]
Lambda SOACS
f'' <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate Int
nr_dims (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t]
VName
vs_bar' <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
q [VName]
scatter_inps Lambda SOACS
f'' [(forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims, Int
1, VName
vs_bar)]
VName -> VName -> ADM ()
insAdj VName
vs VName
vs_bar'
where
mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [] [SubExp]
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
mk_indices [SubExp
d] [SubExp]
iotas = do
[VName]
reps <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rep" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
d])) [SubExp]
iotas
VName
iota_d <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName]
reps forall a. [a] -> [a] -> [a]
++ [VName
iota_d]
mk_indices (SubExp
d : [SubExp]
dims) [SubExp]
iotas = do
VName
iota_d <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall a b. (a -> b) -> a -> b
$
[SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
dims forall a b. (a -> b) -> a -> b
$
[SubExp]
iotas forall a. [a] -> [a] -> [a]
++ [VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i_param]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [VName
iota_d] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
diffMulHist ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMulHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
mul SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
Type
vs_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
let vs_dims :: [SubExp]
vs_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
let vs_elm_type :: PrimType
vs_elm_type = forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
Type
dst_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
let dst_dims :: [SubExp]
dst_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
let inner_dims :: [SubExp]
inner_dims = forall a. [a] -> [a]
tail [SubExp]
vs_dims
Param Type
v_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Lambda SOACS
lam_ps_zs_inner <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
v_param] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
t, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1])
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param, forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
Lambda SOACS
lam_ps_zs <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
vs_dims [PrimType
vs_elm_type] Lambda SOACS
lam_ps_zs_inner
[SubExpRes]
ps_zs_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_ps_zs [forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vs]
[VName]
ps_zs <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"ps_zs" [SubExpRes]
ps_zs_res
let [VName
ps, VName
zs] = [VName]
ps_zs
Lambda SOACS
lam_mul_inner <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
Lambda SOACS
lam_mul <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner
VName
nz_prods0 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nz_prd" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
ne
let hist_nzp :: HistOp SOACS
hist_nzp = forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
nz_prods0] [SubExp
ne] Lambda SOACS
lam_mul
Lambda SOACS
lam_add_inner <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) PrimType
int64
Lambda SOACS
lam_add <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
int64] Lambda SOACS
lam_add_inner
VName
zr_counts0 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"zr_cts" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
SubExp
zrn_ne <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zr_ne" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
let hist_zrn :: HistOp SOACS
hist_zrn = forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
zr_counts0] [if forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 else SubExp
zrn_ne] Lambda SOACS
lam_add
Lambda SOACS
f' <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 (forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
VName
nz_prods <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"non_zero_prod"
VName
zr_counts <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zero_count"
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
nz_prods, VName
zr_counts] forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
is, VName
ps, VName
zs] [HistOp SOACS
hist_nzp, HistOp SOACS
hist_zrn] Lambda SOACS
f'
Param Type
p_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"prod" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Param Type
c_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"count" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lam_h_part_inner <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
p_param, Param Type
c_param] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"h_part"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param Type
c_param))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p_param)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
Lambda SOACS
lam_h_part <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
lam_h_part_inner
[SubExpRes]
h_part_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_h_part forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
nz_prods, VName
zr_counts]
[VName]
h_part' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part" [SubExpRes]
h_part_res
let [VName
h_part] = [VName]
h_part'
Lambda SOACS
lam_mul_inner' <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
Lambda SOACS
lam_mul' <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner'
[SubExpRes]
x_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_mul' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
h_part]
[VName]
x' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"x" [SubExpRes]
x_res
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [VName]
x'
ADM ()
m
VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x
Lambda SOACS
lam_mul'' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
[SubExpRes]
dst_bar_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_mul'' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
h_part, VName
x_bar]
[VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_bar") [SubExpRes]
dst_bar_res
VName -> VName -> ADM ()
updateAdj VName
dst forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [VName]
dst_bar
Lambda SOACS
lam_mul''' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
[SubExpRes]
part_bar_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_mul''' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
x_bar]
[VName]
part_bar' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"part_bar" [SubExpRes]
part_bar_res
let [VName
part_bar] = [VName]
part_bar'
[Param Type]
inner_params <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"zr_cts", String
"pr_bar", String
"nz_prd", String
"a"] forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
int64, PrimType
t, PrimType
t, PrimType
t]
let [Param Type
zr_cts, Param Type
pr_bar, Param Type
nz_prd, Param Type
a_param] = [Param Type]
inner_params
Lambda SOACS
lam_vsbar_inner <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
inner_params forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpDiv PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
BinOp
LogAnd
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
)
Lambda SOACS
lam_vsbar_middle <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
t, PrimType
t, PrimType
t] Lambda SOACS
lam_vsbar_inner
Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
a_param' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type
Lambda SOACS
lam_vsbar <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param, Param Type
a_param'] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, forall dec. Param dec -> VName
paramName Param Type
i_param))
( forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
let i :: Slice SubExp
i = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_type [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i_param]
[VName]
names <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"zr_cts", String
"pr_bar", String
"nz_prd"]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
name -> forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
names [VName
zr_counts, VName
part_bar, VName
nz_prods]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_vsbar_middle forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
names forall a. Semigroup a => a -> a -> a
<> [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param']
)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
dst_type)
VName
vs_bar <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is, VName
vs] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar
VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar
diffAddHist ::
VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffAddHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
add SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: Type
t = forall dec. Param dec -> dec
paramDec forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
add
VName
dst_cpy <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_copy") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
dst
Lambda SOACS
f <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type
t]
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs] [forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy] [SubExp
ne] Lambda SOACS
add] Lambda SOACS
f
ADM ()
m
VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x
VName -> VName -> ADM ()
updateAdj VName
dst VName
x_bar
Type
x_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_i") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
i_param
Lambda SOACS
lam_vsbar <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, VName
i))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x_bar forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
x_type [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i])
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
ne)
VName
vs_bar <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar
VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep [VName]
xs [Type]
tps SubExp
bit SubExp
n SubExp
w = do
VName
is <- VName -> SubExp -> SubExp -> ADM VName
mapout (forall a. [a] -> a
head [VName]
xs) SubExp
n SubExp
w
Param Type
num_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"num" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
num_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
num_param] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"num_res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)
( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> BinOp
And IntType
Int64)
(forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit))
(forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
)
( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
2)
( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> BinOp
And IntType
Int64)
(forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit) (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)))
(forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
)
)
VName
bins <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"bins" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
num_lam
Param Type
flag_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flag" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
flag_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
flag_param] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag_res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
PrimType
int64
(forall a b. (a -> b) -> [a] -> [b]
map ((,) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
flag_param) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
(forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Integer
i -> forall a b. (a -> b) -> [a] -> [b]
map (\Integer
j -> if Integer
i forall a. Eq a => a -> a -> Bool
== Integer
j then Integer
1 else Integer
0) [Integer
0 .. Integer
3])) ([Integer
0 .. Integer
3] :: [Integer]))
[VName]
flags <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flags" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
bins] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam
[Param Type]
scan_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"a1", String
"b1", String
"c1", String
"d1", String
"a2", String
"b2", String
"c2", String
"d2"]
Lambda SOACS
scan_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
scan_params forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_res") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> ([a], [a])
splitAt Int
4 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scan_params
ScremaForm SOACS
scan <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Integer -> SubExp
intConst IntType
Int64) [Integer
0, Integer
0, Integer
0, Integer
0]
[VName]
offsets <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"offsets" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName]
flags ScremaForm SOACS
scan
SubExp
ind <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ind_last" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
n) (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
let i :: Slice SubExp
i = forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
ind]
[VName]
nabcd <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"na", String
"nb", String
"nc", String
"nd"]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
abcd -> forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
abcd] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
nabcd [VName]
offsets
let vars :: [SubExp]
vars = forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nabcd
[Param Type]
map_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"bin", String
"a", String
"b", String
"c", String
"d"]
Lambda SOACS
map_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
map_params forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
PrimType
int64
(forall a b. (a -> b) -> [a] -> [b]
map ((,) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Param Type]
map_params) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \Int
j Param Type
p ->
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
SubExp
t <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"t" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p) (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp
t forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
take Int
j [SubExp]
vars)
)
[Int
0 .. Int
3]
(forall a. [a] -> [a]
tail [Param Type]
map_params)
)
VName
nis <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nis" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName
bins forall a. a -> [a] -> [a]
: [VName]
offsets) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
[VName]
scatter_dst <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\Type
t -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
tps
SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
nis [VName]
xs
where
iConst :: Integer -> m (Exp (Rep m))
iConst Integer
c = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
c
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [VName]
xs SubExp
n SubExp
w = do
SubExp
logw <- SubExp -> ADM SubExp
log2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"w1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
SubExp
iters <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"iters" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 SubExp
logw forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall v. PrimExp v -> PrimExp v -> PrimExp v
~/~ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2)))
[Type]
types <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
xs
[Param (TypeBase Shape Uniqueness)]
params <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) [VName]
xs [Type]
types
VName
i <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
Body SOACS
loopbody <- forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall a b. (a -> b) -> a -> b
$ do
SubExp
bit <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bit" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2)
[VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
params) [Type]
types SubExp
bit SubExp
n SubExp
w
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" forall a b. (a -> b) -> a -> b
$
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop
(forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
params forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
xs)
(forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
iters [])
Body SOACS
loopbody
where
log2 :: SubExp -> ADM SubExp
log2 :: SubExp -> ADM SubExp
log2 SubExp
m = do
[Param (TypeBase Shape Uniqueness)]
params <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"cond", String
"r", String
"i"] forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
Bool, PrimType
int64, PrimType
int64]
let [Param (TypeBase Shape Uniqueness)
cond, Param (TypeBase Shape Uniqueness)
r, Param (TypeBase Shape Uniqueness)
i] = [Param (TypeBase Shape Uniqueness)]
params
Body SOACS
body <- forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) forall a b. (a -> b) -> a -> b
$ do
SubExp
r' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"r'" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
r) forall {k} (t :: k) v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TPrimExp Int64 VName
1)
SubExp
cond' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond'" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
r' forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
SubExp
i' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i'" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
i) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExpRes]
subExpsRes [SubExp
cond', SubExp
r', SubExp
i']
SubExp
cond_init <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"test" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
m forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
[SubExp]
l <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"log2res" forall a b. (a -> b) -> a -> b
$
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop
(forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
params [SubExp
cond_init, SubExp
m, PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
int64])
(forall rep. VName -> LoopForm rep
WhileLoop forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
cond)
Body SOACS
body
let [SubExp
_, SubExp
_, SubExp
res] = [SubExp]
l
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
res
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
xs SubExp
n SubExp
w = do
VName
iota_n <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[VName]
radres <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [forall a. [a] -> a
head [VName]
xs, VName
iota_n] SubExp
n SubExp
w
let [VName
is', VName
iota'] = [VName]
radres
Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let slice :: [DimIndex SubExp]
slice = [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i_param]
Lambda SOACS
map_lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param] forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex (forall a. [a] -> [a]
tail [VName]
xs) [DimIndex SubExp]
slice
[VName]
sorted <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota'] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName
iota' forall a. a -> [a] -> [a]
: VName
is' forall a. a -> [a] -> [a]
: [VName]
sorted
diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM ()
diffHist :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops [VName]
xs StmAux ()
aux SubExp
n Lambda SOACS
lam0 [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m = do
[Type]
as_type <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [VName]
as
[Type]
dst_type <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
dst
[VName]
nes <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"new_dst" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [SubExp]
w)) [SubExp]
ne
Lambda SOACS
h_map <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
as_type
[VName]
h_part <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Semigroup a => a -> a -> a
(<>) String
"_h_part" forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString) [VName]
xs
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
h_part forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName]
as [forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp]
w) SubExp
rf [VName]
nes [SubExp]
ne Lambda SOACS
lam0] Lambda SOACS
h_map
Lambda SOACS
lam0' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
xs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (forall a. [a] -> a
head [SubExp]
w) ([VName]
dst forall a. Semigroup a => a -> a -> a
<> [VName]
h_part) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam0')
ADM ()
m
[VName]
xs_bar <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse VName -> ADM VName
lookupAdjVal [VName]
xs
([VName]
dst_params, [VName]
hp_params, Lambda SOACS
f') <- Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam0 [Type]
dst_type forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [SubExp]
w
Lambda SOACS
f'_adj_dst <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
dst_params Lambda SOACS
f'
Lambda SOACS
f'_adj_hp <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
hp_params Lambda SOACS
f'
[SubExpRes]
dst_bar' <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
f'_adj_dst forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ [VName]
dst forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
[VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"dst_bar" [SubExpRes]
dst_bar'
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
dst [VName]
dst_bar
[SubExpRes]
h_part_bar' <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
f'_adj_hp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ [VName]
dst forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
[VName]
h_part_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part_bar" [SubExpRes]
h_part_bar'
Lambda SOACS
lam <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
Lambda SOACS
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
[VName]
sorted <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
as SubExp
n forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [SubExp]
w
let siota :: VName
siota = forall a. [a] -> a
head [VName]
sorted
let sis :: VName
sis = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [VName]
sorted
let sas :: [VName]
sas = forall a. Int -> [a] -> [a]
drop Int
2 [VName]
sorted
VName
iota_n <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
flag_lam <- LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam Param Type
par_i VName
sis
VName
flag <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flag" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam
Param Type
par_i' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i' :: VName
i' = forall dec. Param dec -> VName
paramName Param Type
par_i'
Lambda SOACS
g_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_i'] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_inps") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
SubExp
im1 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i_1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
SubExp
nmi <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
i')
let s1 :: [DimIndex SubExp]
s1 = [forall d. d -> DimIndex d
DimFix SubExp
im1]
let s2 :: [DimIndex SubExp]
s2 = [forall d. d -> DimIndex d
DimFix SubExp
nmi]
SubExp
f1 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"f1" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i']
[SubExp]
r1 <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r1"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
f1)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [SubExp]
ne)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s1)
[SubExp]
r2 <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r2"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i' forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) forall a. a -> [a] -> [a]
: [SubExp]
ne)
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
s2)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) forall a. a -> [a] -> [a]
: [SubExp]
ne)
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimValue -> SubExp
Constant (PrimType -> PrimValue
blankPrimValue PrimType
Bool) :) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s2
)
)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ SubExp
f1 forall a. a -> [a] -> [a]
: [SubExp]
r1 forall a. [a] -> [a] -> [a]
++ [SubExp]
r2
[Lambda SOACS]
scan_lams <-
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
( \Lambda SOACS
l -> do
Param Type
f1 <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f1" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
Param Type
f2 <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f2" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
[Param Type]
ps <- forall rep. Lambda rep -> [LParam rep]
lambdaParams forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
let ([Param Type]
p1, [Param Type]
p2) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne) [Param Type]
ps
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
f1 forall a. a -> [a] -> [a]
: [Param Type]
p1 forall a. [a] -> [a] -> [a]
++ Param Type
f2 forall a. a -> [a] -> [a]
: [Param Type]
p2) forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
let f :: ADM (Exp (Rep ADM))
f = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
LogOr (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f1) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM))
f forall a. a -> [a] -> [a]
: forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
p2)
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ADM (Exp (Rep ADM))
f :) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"gres"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
l (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
ps)
)
)
[Lambda SOACS
lam, Lambda SOACS
lam']
let ne' :: [SubExp]
ne' = PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False) forall a. a -> [a] -> [a]
: [SubExp]
ne
[VName]
scansres <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_ctrb_scan" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC (forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Lambda rep -> [SubExp] -> Scan rep
`Scan` [SubExp]
ne') [Lambda SOACS]
scan_lams) Lambda SOACS
g_lam)
let (VName
_ : [VName]
ls_arr, VName
_ : [VName]
rs_arr_rev) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne forall a. Num a => a -> a -> a
+ Int
1) [VName]
scansres
Param Type
par_i'' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i'' :: VName
i'' = forall dec. Param dec -> VName
paramName Param Type
par_i''
Lambda SOACS
map_lam <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_i''] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. [a] -> a
head [SubExp]
w, VName
i''))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
h_part_bar [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i''])
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ do
forall a b. (a -> b) -> [a] -> [b]
map (\Type
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue forall a b. (a -> b) -> a -> b
$ forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) [Type]
as_type
)
[VName]
f_bar <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"f_bar" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
sis] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam0 [Type]
as_type SubExp
n
Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
f_bar) [VName]
as_params Lambda SOACS
f
Param Type
par_i''' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i''' :: VName
i''' = forall dec. Param dec -> VName
paramName Param Type
par_i'''
Lambda SOACS
rev_lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_i'''] forall a b. (a -> b) -> a -> b
$ do
SubExp
nmim1 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i_1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
i''' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
[VName] -> [SubExpRes]
varsRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
rs_arr_rev [forall d. d -> DimIndex d
DimFix SubExp
nmim1]
[VName]
rs_arr <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"rs_arr" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam
[VName]
sas_bar <-
String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"sas_bar"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
f_adj (forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ [VName]
ls_arr forall a. Semigroup a => a -> a -> a
<> [VName]
sas forall a. Semigroup a => a -> a -> a
<> [VName]
rs_arr)
[VName]
scatter_dst <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\Type
t -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
as_type
[VName]
as_bar <- SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
siota [VName]
sas_bar
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (forall a. [a] -> [a]
tail [VName]
as) [VName]
as_bar
where
mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam LParam SOACS
par_i VName
sis =
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [LParam SOACS
par_i] forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
let i :: VName
i = forall dec. Param dec -> VName
paramName LParam SOACS
par_i
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0))
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
Bool)
( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
VName
i_p <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_p" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
[VName]
vs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"vs" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
sis forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
i, VName
i_p]
let [VName
vs_i, VName
vs_p] = [VName]
vs
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
vs_i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. a -> TPrimExp Int64 a
le64 VName
vs_p
)