{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Hist
  ( diffMinMaxHist,
    diffMulHist,
    diffAddHist,
    diffHist,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

getBinOpPlus :: PrimType -> BinOp
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus (IntType IntType
x) = IntType -> Overflow -> BinOp
Add IntType
x Overflow
OverflowUndef
getBinOpPlus (FloatType FloatType
f) = FloatType -> BinOp
FAdd FloatType
f
getBinOpPlus PrimType
_ = forall a. HasCallStack => String -> a
error String
"In getBinOpMul, Hist.hs: input not supported"

getBinOpDiv :: PrimType -> BinOp
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getBinOpDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getBinOpDiv PrimType
_ = forall a. HasCallStack => String -> a
error String
"In getBinOpDiv, Hist.hs: input not supported"

withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis

elseIf :: PrimType -> [(ADM (Exp SOACS), ADM (Exp SOACS))] -> [ADM (Body SOACS)] -> ADM (Exp SOACS)
elseIf :: PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2)] [ADM (Body SOACS)
bt, ADM (Body SOACS)
bf] =
  forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp SOACS)
c1 ADM (Exp SOACS)
c2)
    ADM (Body SOACS)
bt
    ADM (Body SOACS)
bf
elseIf PrimType
t ((ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2) : [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs) (ADM (Body SOACS)
bt : [ADM (Body SOACS)]
bs) =
  forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp SOACS)
c1 ADM (Exp SOACS)
c2)
    ADM (Body SOACS)
bt
    forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
    forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure
    forall a b. (a -> b) -> a -> b
$ PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs [ADM (Body SOACS)]
bs
elseIf PrimType
_ [(ADM (Exp SOACS), ADM (Exp SOACS))]
_ [ADM (Body SOACS)]
_ = forall a. HasCallStack => String -> a
error String
"In elseIf, Hist.hs: input not supported"

bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
s =
  forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
    ( \(SubExpRes Certs
cs SubExp
se) -> do
        VName
bn <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
s
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
bn] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
bn
    )

nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [] [PrimType]
_ Lambda SOACS
lam = forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
nestedmap s :: [SubExp]
s@(SubExp
h : [SubExp]
r) [PrimType]
pt Lambda SOACS
lam = do
  [Param Type]
params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\PrimType
tp -> forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
tp (forall d. [d] -> ShapeBase d
Shape [SubExp]
s) NoUniqueness
NoUniqueness) [PrimType]
pt
  Lambda SOACS
body <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
r [PrimType]
pt Lambda SOACS
lam
  forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
params forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
      forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
h (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
params) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
body)

-- \ds hs -> map2 lam ds hs
mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' :: Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam [Type]
tps SubExp
n = do
  Lambda SOACS
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam

  [Param Type]
ds_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ds_param") [Type]
tps
  [Param Type]
hs_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"hs_param") [Type]
tps
  let ds_pars :: [VName]
ds_pars = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dec. Param dec -> VName
paramName [Param Type]
ds_params
  let hs_pars :: [VName]
hs_pars = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dec. Param dec -> VName
paramName [Param Type]
hs_params
  Lambda SOACS
lam_map <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
ds_params forall a. Semigroup a => a -> a -> a
<> [Param Type]
hs_params) forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f'" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
ds_pars forall a. Semigroup a => a -> a -> a
<> [VName]
hs_pars) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam')

  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
ds_pars, [VName]
hs_pars, Lambda SOACS
lam_map)

-- \ls as rs -> map3 (\li ai ri -> li `lam` ai `lam` ri) ls as rs
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam [Type]
tps SubExp
n = do
  Lambda SOACS
lam_l <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  Lambda SOACS
lam_r <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  let q :: Int
q = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      ([Param Type]
lps, [Param Type]
aps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
      ([Param Type]
ips, [Param Type]
rps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
  Lambda SOACS
lam' <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) forall a b. (a -> b) -> a -> b
$ do
    [SubExpRes]
lam_l_res <- forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips [SubExpRes]
lam_l_res) forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
ip] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r

  [Param Type]
ls_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ls_param") [Type]
tps
  [Param Type]
as_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"as_param") [Type]
tps
  [Param Type]
rs_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"rs_param") [Type]
tps
  let map_params :: [Param Type]
map_params = [Param Type]
ls_params forall a. Semigroup a => a -> a -> a
<> [Param Type]
as_params forall a. Semigroup a => a -> a -> a
<> [Param Type]
rs_params
  Lambda SOACS
lam_map <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
map_params forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f" forall a b. (a -> b) -> a -> b
$
        forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
          forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
map_params) forall a b. (a -> b) -> a -> b
$
            forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam'

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
as_params, Lambda SOACS
lam_map)

mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout VName
is SubExp
n SubExp
w = do
  Param Type
par_is <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"is" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
is'_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_is] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"is'"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, forall dec. Param dec -> VName
paramName Param Type
par_is))
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_is)
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)

  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"is'" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
is) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
is'_lam

multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
dst VName
is [VName]
vs = do
  [Type]
tps <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
  Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  [Param Type]
scatter_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"scatter_param" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
tps
  Lambda SOACS
scatter_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_i forall a. a -> [a] -> [a]
: [Param Type]
scatter_params) forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scatter_map_res") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        [Exp SOACS]
p1 <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
scatter_params) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_i
        [Exp SOACS]
p2 <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scatter_params
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Exp SOACS]
p1 forall a. Semigroup a => a -> a -> a
<> [Exp SOACS]
p2

  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scatter_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
    forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
n (VName
is forall a. a -> [a] -> [a]
: [VName]
vs) Lambda SOACS
scatter_lam forall a b. (a -> b) -> a -> b
$
      forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t -> (,,) (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) Int
1) [Type]
tps [VName]
dst

multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
vs [DimIndex SubExp]
s = do
  forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
    ( \VName
x -> do
        Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"sorted" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
s)
    )
    [VName]
vs

--
-- special case of histogram with min/max as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst minmax ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let (x, x_inds) = zip vs (iota n)
--                       |> reduce_by_index (dst_cpy,-1s) argminmax (ne,-1) is
--
-- Reverse sweep:
--     dst_bar += map2 (\i b -> if i == -1
--                              then b
--                              else 0
--                     ) x_inds x_bar

--     vs_ctrbs = map2 (\i b -> if i == -1
--                              then 0
--                              else vs_bar[i] + b
--                     ) x_inds x_bar
--     vs_bar <- scatter vs_bar x_inds vs_ctrbs
diffMinMaxHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
minmax SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
  Type
vs_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_elm_type :: PrimType
vs_elm_type = forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  let vs_dims :: [SubExp]
vs_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let inner_dims :: [SubExp]
inner_dims = forall a. [a] -> [a]
tail [SubExp]
vs_dims
  let nr_dims :: Int
nr_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims
  Type
dst_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
  let dst_dims :: [SubExp]
dst_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type

  VName
dst_cpy <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_copy") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
dst

  Param Type
acc_v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
acc_i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
hist_lam_inner <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
acc_v_p, Param Type
acc_i_p, Param Type
v_p, Param Type
i_p] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
                forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
              ]
          )
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
                      (PrimType -> CmpOp
CmpEq PrimType
t)
                      (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
                      (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
                  )
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
              ]
          )
  Lambda SOACS
hist_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
int64, PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
hist_lam_inner

  VName
dst_minus_ones <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"minus_ones" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
  SubExp
ne_minus_ones <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minus_ones" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
  VName
iota_n <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  VName
inp_iota <- do
    if Int
nr_dims forall a. Eq a => a -> a -> Bool
== Int
1
      then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
iota_n
      else do
        Param Type
i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
        Lambda SOACS
lam <-
          forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i] forall a b. (a -> b) -> a -> b
$
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i

        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"res" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

  let hist_op :: HistOp SOACS
hist_op = forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy, VName
dst_minus_ones] [SubExp
ne, if Int
nr_dims forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1) else SubExp
ne_minus_ones] Lambda SOACS
hist_lam
  Lambda SOACS
f' <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 (forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
  VName
x_inds <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_inds")
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_inds] forall a b. (a -> b) -> a -> b
$
      forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs, VName
inp_iota] [HistOp SOACS
hist_op] Lambda SOACS
f'

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  Param Type
x_ind_dst <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
x_bar_dst <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
dst_lam_inner <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
x_ind_dst, Param Type
x_bar_dst] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"dst_bar"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param Type
x_ind_dst) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
x_bar_dst)
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
  Lambda SOACS
dst_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
vs_elm_type] Lambda SOACS
dst_lam_inner

  VName
dst_bar <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
      forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
x_inds, VName
x_bar] (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
dst_lam)

  VName -> VName -> ADM ()
updateAdj VName
dst VName
dst_bar

  VName
vs_bar <- VName -> ADM VName
lookupAdjVal VName
vs

  [VName]
inds' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"inds" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
inner_dims []
  let inds :: [VName]
inds = VName
x_inds forall a. a -> [a] -> [a]
: [VName]
inds'

  [Param Type]
par_x_ind_vs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nr_dims forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
par_x_bar_vs <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
vs_lam_inner <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_x_bar_vs forall a. a -> [a] -> [a]
: [Param Type]
par_x_ind_vs) forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Param Type]
par_x_ind_vs) forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
                SubExp
vs_bar_i <-
                  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
vs_bar forall a. Semigroup a => a -> a -> a
<> String
"_el") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
vs_bar forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
par_x_ind_vs
                forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpPlus PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_x_bar_vs) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
vs_bar_i)
          )
  Lambda SOACS
vs_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims (PrimType
vs_elm_type forall a. a -> [a] -> [a]
: forall a. Int -> a -> [a]
replicate Int
nr_dims PrimType
int64) Lambda SOACS
vs_lam_inner

  VName
vs_bar_p <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_partial") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
      forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
x_bar forall a. a -> [a] -> [a]
: [VName]
inds) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
vs_lam)

  SubExp
q <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"q"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dst_dims

  [VName]
scatter_inps <- do
    -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p]
    -- ToDo: Cosmin asks: is the below the correct translation of the line above?
    forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flat" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary (forall d. [d] -> ShapeBase d
Shape [SubExp
q])) forall a b. (a -> b) -> a -> b
$
      [VName]
inds forall a. [a] -> [a] -> [a]
++ [VName
vs_bar_p]

  Lambda SOACS
f'' <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate Int
nr_dims (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t]
  VName
vs_bar' <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
      forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
q [VName]
scatter_inps Lambda SOACS
f'' [(forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims, Int
1, VName
vs_bar)]
  VName -> VName -> ADM ()
insAdj VName
vs VName
vs_bar'
  where
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [] [SubExp]
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    mk_indices [SubExp
d] [SubExp]
iotas = do
      [VName]
reps <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rep" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
d])) [SubExp]
iotas
      VName
iota_d <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName]
reps forall a. [a] -> [a] -> [a]
++ [VName
iota_d]
    mk_indices (SubExp
d : [SubExp]
dims) [SubExp]
iotas = do
      VName
iota_d <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

      Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      Lambda SOACS
lam <-
        forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param] forall a b. (a -> b) -> a -> b
$
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall a b. (a -> b) -> a -> b
$
            [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
dims forall a b. (a -> b) -> a -> b
$
              [SubExp]
iotas forall a. [a] -> [a] -> [a]
++ [VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i_param]

      forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [VName
iota_d] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

--
-- special case of histogram with multiplication as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (*) ne is vs
-- Forward sweep:
--     dst does not need to be copied: dst is not overwritten
--     let (ps, zs) = map (\v -> if v == 0 then (1,1) else (v,0)) vs
--     let non_zero_prod = reduce_by_index nes (*) ne is ps
--     let zero_count = reduce_by_index 0s (+) 0 is zs
--     let h_part = map2 (\p c -> if c == 0 then p else 0
--                       ) non_zero_prod zero_count
--     let x = map2 (*) dst h_part
--
-- Reverse sweep:
--     dst_bar += map2 (*) h_part x_bar

--     let part_bar = map2 (*) dst x_bar
--     vs_bar += map2 (\i v -> let zr_cts = zero_count[i]
--                             let pr_bar = part_bar[i]
--                             let nz_prd = non_zero_prod[i]
--                             in if zr_cts == 0
--                             then pr_bar * (nz_prd / v)
--                             else if zr_cts == 1 and v == 0
--                             then nz_prd * pr_bar
--                             else 0
--                    ) is vs
diffMulHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMulHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
mul SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
  Type
vs_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_dims :: [SubExp]
vs_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let vs_elm_type :: PrimType
vs_elm_type = forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  Type
dst_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
  let dst_dims :: [SubExp]
dst_dims = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
  let inner_dims :: [SubExp]
inner_dims = forall a. [a] -> [a]
tail [SubExp]
vs_dims

  Param Type
v_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
lam_ps_zs_inner <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
v_param] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t))
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
t, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1])
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param, forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
  Lambda SOACS
lam_ps_zs <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
vs_dims [PrimType
vs_elm_type] Lambda SOACS
lam_ps_zs_inner
  [SubExpRes]
ps_zs_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_ps_zs [forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vs]
  [VName]
ps_zs <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"ps_zs" [SubExpRes]
ps_zs_res
  let [VName
ps, VName
zs] = [VName]
ps_zs

  Lambda SOACS
lam_mul_inner <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
lam_mul <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner
  VName
nz_prods0 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nz_prd" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
ne
  let hist_nzp :: HistOp SOACS
hist_nzp = forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
nz_prods0] [SubExp
ne] Lambda SOACS
lam_mul

  Lambda SOACS
lam_add_inner <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) PrimType
int64
  Lambda SOACS
lam_add <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
int64] Lambda SOACS
lam_add_inner
  VName
zr_counts0 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"zr_cts" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  SubExp
zrn_ne <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zr_ne" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  let hist_zrn :: HistOp SOACS
hist_zrn = forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
zr_counts0] [if forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 else SubExp
zrn_ne] Lambda SOACS
lam_add

  Lambda SOACS
f' <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 (forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
  VName
nz_prods <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"non_zero_prod"
  VName
zr_counts <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zero_count"
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
nz_prods, VName
zr_counts] forall a b. (a -> b) -> a -> b
$
      forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
is, VName
ps, VName
zs] [HistOp SOACS
hist_nzp, HistOp SOACS
hist_zrn] Lambda SOACS
f'

  Param Type
p_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"prod" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
c_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"count" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
lam_h_part_inner <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
p_param, Param Type
c_param] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"h_part"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param Type
c_param))
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p_param)
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
  Lambda SOACS
lam_h_part <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
lam_h_part_inner
  [SubExpRes]
h_part_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_h_part forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
nz_prods, VName
zr_counts]
  [VName]
h_part' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part" [SubExpRes]
h_part_res
  let [VName
h_part] = [VName]
h_part'

  Lambda SOACS
lam_mul_inner' <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
lam_mul' <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner'
  [SubExpRes]
x_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_mul' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
h_part]
  [VName]
x' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"x" [SubExpRes]
x_res
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [VName]
x'

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  Lambda SOACS
lam_mul'' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
  [SubExpRes]
dst_bar_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_mul'' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
h_part, VName
x_bar]
  [VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_bar") [SubExpRes]
dst_bar_res
  VName -> VName -> ADM ()
updateAdj VName
dst forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [VName]
dst_bar

  Lambda SOACS
lam_mul''' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
  [SubExpRes]
part_bar_res <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_mul''' forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
x_bar]
  [VName]
part_bar' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"part_bar" [SubExpRes]
part_bar_res
  let [VName
part_bar] = [VName]
part_bar'

  [Param Type]
inner_params <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"zr_cts", String
"pr_bar", String
"nz_prd", String
"a"] forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
int64, PrimType
t, PrimType
t, PrimType
t]
  let [Param Type
zr_cts, Param Type
pr_bar, Param Type
nz_prd, Param Type
a_param] = [Param Type]
inner_params
  Lambda SOACS
lam_vsbar_inner <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
inner_params forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpDiv PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                      BinOp
LogAnd
                      (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
                      (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
                  )
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar))
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
          )

  Lambda SOACS
lam_vsbar_middle <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
t, PrimType
t, PrimType
t] Lambda SOACS
lam_vsbar_inner

  Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
a_param' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type
  Lambda SOACS
lam_vsbar <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param, Param Type
a_param'] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, forall dec. Param dec -> VName
paramName Param Type
i_param))
          ( forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
              let i :: Slice SubExp
i = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_type [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i_param]
              [VName]
names <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"zr_cts", String
"pr_bar", String
"nz_prd"]
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
name -> forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
names [VName
zr_counts, VName
part_bar, VName
nz_prods]
              forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
lam_vsbar_middle forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
names forall a. Semigroup a => a -> a -> a
<> [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param']
          )
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
dst_type)

  VName
vs_bar <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is, VName
vs] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar

  VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar

--
-- special case of histogram with add as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (+) ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let x = reduce_by_index dst_cpy (+) ne is vs
--
-- Reverse sweep:
--     dst_bar += x_bar
--
--     vs_bar += map (\i -> x_bar[i]) is
diffAddHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffAddHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
add SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: Type
t = forall dec. Param dec -> dec
paramDec forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
add

  VName
dst_cpy <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst forall a. Semigroup a => a -> a -> a
<> String
"_copy") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
dst

  Lambda SOACS
f <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type
t]
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
    forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs] [forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy] [SubExp
ne] Lambda SOACS
add] Lambda SOACS
f

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  VName -> VName -> ADM ()
updateAdj VName
dst VName
x_bar

  Type
x_type <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
  Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_i") forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = forall dec. Param dec -> VName
paramName Param Type
i_param
  Lambda SOACS
lam_vsbar <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, VName
i))
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x_bar forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
x_type [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i])
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
ne)

  VName
vs_bar <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs forall a. Semigroup a => a -> a -> a
<> String
"_bar") forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar
  VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar

--
-- a step in the radix sort implementation
-- it assumes the key we are sorting
-- after is [n]i64 and it is the first VName
--
-- local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32)
--                                  (digit_n: i32): [n]t =
--   let num x = get_bit (digit_n+1) x * 2 + get_bit digit_n x
--   let pairwise op (a1,b1,c1,d1) (a2,b2,c2,d2) =
--     (a1 `op` a2, b1 `op` b2, c1 `op` c2, d1 `op` d2)
--   let bins = xs |> map num
--   let flags = bins |> map (\x -> if x == 0 then (1,0,0,0)
--                                  else if x == 1 then (0,1,0,0)
--                                  else if x == 2 then (0,0,1,0)
--                                  else (0,0,0,1))
--   let offsets = scan (pairwise (+)) (0,0,0,0) flags
--   let (na,nb,nc,_nd) = last offsets
--   let f bin (a,b,c,d) = match bin
--                         case 0 -> a-1
--                         case 1 -> na+b-1
--                         case 2 -> na+nb+c-1
--                         case _ -> na+nb+nc+d-1
--   let is = map2 f bins offsets
--   in scatter scratch is xs
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep [VName]
xs [Type]
tps SubExp
bit SubExp
n SubExp
w = do
  -- let is = head xs
  VName
is <- VName -> SubExp -> SubExp -> ADM VName
mapout (forall a. [a] -> a
head [VName]
xs) SubExp
n SubExp
w

  Param Type
num_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"num" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
num_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
num_param] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"num_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
          (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)
          ( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> BinOp
And IntType
Int64)
              (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit))
              (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
          )
          ( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
              (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
2)
              ( forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                  (IntType -> BinOp
And IntType
Int64)
                  (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit) (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)))
                  (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
              )
          )

  VName
bins <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"bins" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
num_lam
  Param Type
flag_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flag" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
flag_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
flag_param] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
          PrimType
int64
          (forall a b. (a -> b) -> [a] -> [b]
map ((,) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
flag_param) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
          (forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Integer
i -> forall a b. (a -> b) -> [a] -> [b]
map (\Integer
j -> if Integer
i forall a. Eq a => a -> a -> Bool
== Integer
j then Integer
1 else Integer
0) [Integer
0 .. Integer
3])) ([Integer
0 .. Integer
3] :: [Integer]))

  [VName]
flags <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flags" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
bins] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam

  [Param Type]
scan_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"a1", String
"b1", String
"c1", String
"d1", String
"a2", String
"b2", String
"c2", String
"d2"]
  Lambda SOACS
scan_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
scan_params forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_res") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> ([a], [a])
splitAt Int
4 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scan_params

  ScremaForm SOACS
scan <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Integer -> SubExp
intConst IntType
Int64) [Integer
0, Integer
0, Integer
0, Integer
0]
  [VName]
offsets <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"offsets" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName]
flags ScremaForm SOACS
scan

  SubExp
ind <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ind_last" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
n) (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
  let i :: Slice SubExp
i = forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
ind]
  [VName]
nabcd <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"na", String
"nb", String
"nc", String
"nd"]
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
abcd -> forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
abcd] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
nabcd [VName]
offsets

  let vars :: [SubExp]
vars = forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nabcd
  [Param Type]
map_params <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"bin", String
"a", String
"b", String
"c", String
"d"]
  Lambda SOACS
map_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
map_params forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
          PrimType
int64
          (forall a b. (a -> b) -> [a] -> [b]
map ((,) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [Param Type]
map_params) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
          ( forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \Int
j Param Type
p ->
                  forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
                      SubExp
t <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"t" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p) (forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
                      forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp
t forall a. a -> [a] -> [a]
: forall a. Int -> [a] -> [a]
take Int
j [SubExp]
vars)
              )
              [Int
0 .. Int
3]
              (forall a. [a] -> [a]
tail [Param Type]
map_params)
          )

  VName
nis <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nis" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName
bins forall a. a -> [a] -> [a]
: [VName]
offsets) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  [VName]
scatter_dst <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\Type
t -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
tps
  SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
nis [VName]
xs
  where
    iConst :: Integer -> m (Exp (Rep m))
iConst Integer
c = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
c

--
-- the radix sort implementation
-- def radix_sort [n] 't (xs: [n]i64) =
--   let iters = if n == 0 then 0 else 32
--   in loop xs for i < iters do radix_sort_step xs i64.get_bit (i*2)
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [VName]
xs SubExp
n SubExp
w = do
  SubExp
logw <- SubExp -> ADM SubExp
log2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"w1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
  -- ceil logw by (logw + 1) / 2
  SubExp
iters <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"iters" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 SubExp
logw forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) forall v. PrimExp v -> PrimExp v -> PrimExp v
~/~ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2)))

  [Type]
types <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
xs
  [Param (TypeBase Shape Uniqueness)]
params <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) [VName]
xs [Type]
types
  VName
i <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
  Body SOACS
loopbody <- forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) forall a b. (a -> b) -> a -> b
$
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall a b. (a -> b) -> a -> b
$ do
      SubExp
bit <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bit" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2)
      [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
params) [Type]
types SubExp
bit SubExp
n SubExp
w

  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" forall a b. (a -> b) -> a -> b
$
    forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop
      (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
params forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
xs)
      (forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
iters [])
      Body SOACS
loopbody
  where
    log2 :: SubExp -> ADM SubExp
    log2 :: SubExp -> ADM SubExp
log2 SubExp
m = do
      [Param (TypeBase Shape Uniqueness)]
params <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"cond", String
"r", String
"i"] forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
Bool, PrimType
int64, PrimType
int64]
      let [Param (TypeBase Shape Uniqueness)
cond, Param (TypeBase Shape Uniqueness)
r, Param (TypeBase Shape Uniqueness)
i] = [Param (TypeBase Shape Uniqueness)]
params

      Body SOACS
body <- forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) forall a b. (a -> b) -> a -> b
$ do
        SubExp
r' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"r'" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
r) forall {k} (t :: k) v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TPrimExp Int64 VName
1)
        SubExp
cond' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond'" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
r' forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        SubExp
i' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i'" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 (forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
i) forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExpRes]
subExpsRes [SubExp
cond', SubExp
r', SubExp
i']

      SubExp
cond_init <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"test" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
m forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)

      [SubExp]
l <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"log2res" forall a b. (a -> b) -> a -> b
$
          forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop
            (forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
params [SubExp
cond_init, SubExp
m, PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
int64])
            (forall rep. VName -> LoopForm rep
WhileLoop forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
cond)
            Body SOACS
body

      let [SubExp
_, SubExp
_, SubExp
res] = [SubExp]
l
      forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
res

radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
xs SubExp
n SubExp
w = do
  VName
iota_n <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  [VName]
radres <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [forall a. [a] -> a
head [VName]
xs, VName
iota_n] SubExp
n SubExp
w
  let [VName
is', VName
iota'] = [VName]
radres

  Param Type
i_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let slice :: [DimIndex SubExp]
slice = [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
i_param]
  Lambda SOACS
map_lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
i_param] forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex (forall a. [a] -> [a]
tail [VName]
xs) [DimIndex SubExp]
slice

  [VName]
sorted <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota'] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName
iota' forall a. a -> [a] -> [a]
: VName
is' forall a. a -> [a] -> [a]
: [VName]
sorted

--
-- generic case of histogram.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--   let xs = reduce_by_index dst odot ne is as
-- Forward sweep:
-- let h_part = reduce_by_index (replicate w ne) odot ne is as
-- let xs = map2 odot dst h_part
-- Reverse sweep:
-- h_part_bar += f'' dst h_part
-- dst_bar += f' dst h_part

-- let flag = map (\i -> i == 0 || sis[i] != sis[i-1]) (iota n)
-- let flag_rev = map (\i -> i==0 || flag[n-i]) (iota n)
-- let ls = seg_scan_exc odot ne flag sas
-- let rs = reverse sas |>
--          seg_scan_exc odot ne flag_rev |> reverse
-- let f_bar = map (\i -> if i < w && -1 < w
--                        then h_part_bar[i]
--                        else 0s
--                 ) sis
-- let sas_bar = f f_dst ls sas rs
-- as_bar += scatter (Scratch alpha n) siota sas_bar
-- Where:
--  siota: 'iota n' sorted wrt 'is'
--  sis: 'is' sorted wrt 'is'
--  sas: 'as' sorted wrt 'is'
--  f'' = vjpLambda xs_bar h_part (map2 odot)
--  f' = vjpLambda xs_bar dst (map2 odot)
--  f  = vjpLambda f_bar sas (map4 (\di li ai ri -> di odot li odot ai odot ri))
--  0s is an alpha-dimensional array with 0 (possibly 0-dim)
diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM ()
diffHist :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops [VName]
xs StmAux ()
aux SubExp
n Lambda SOACS
lam0 [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m = do
  [Type]
as_type <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [VName]
as
  [Type]
dst_type <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
dst

  [VName]
nes <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"new_dst" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [SubExp]
w)) [SubExp]
ne

  Lambda SOACS
h_map <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
as_type
  [VName]
h_part <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Semigroup a => a -> a -> a
(<>) String
"_h_part" forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString) [VName]
xs
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
h_part forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
    forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName]
as [forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (forall d. [d] -> ShapeBase d
Shape [SubExp]
w) SubExp
rf [VName]
nes [SubExp]
ne Lambda SOACS
lam0] Lambda SOACS
h_map

  Lambda SOACS
lam0' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
xs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
    forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (forall a. [a] -> a
head [SubExp]
w) ([VName]
dst forall a. Semigroup a => a -> a -> a
<> [VName]
h_part) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam0')

  ADM ()
m

  [VName]
xs_bar <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse VName -> ADM VName
lookupAdjVal [VName]
xs

  ([VName]
dst_params, [VName]
hp_params, Lambda SOACS
f') <- Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam0 [Type]
dst_type forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [SubExp]
w
  Lambda SOACS
f'_adj_dst <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
dst_params Lambda SOACS
f'
  Lambda SOACS
f'_adj_hp <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
hp_params Lambda SOACS
f'

  [SubExpRes]
dst_bar' <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
f'_adj_dst forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ [VName]
dst forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
  [VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"dst_bar" [SubExpRes]
dst_bar'
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
dst [VName]
dst_bar

  [SubExpRes]
h_part_bar' <- forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
f'_adj_hp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ [VName]
dst forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
  [VName]
h_part_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part_bar" [SubExpRes]
h_part_bar'

  Lambda SOACS
lam <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  Lambda SOACS
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0

  -- is' <- mapout (head as) n (head w)
  -- sorted <- radixSort' (is' : tail as) n $ head w
  [VName]
sorted <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
as SubExp
n forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [SubExp]
w
  let siota :: VName
siota = forall a. [a] -> a
head [VName]
sorted
  let sis :: VName
sis = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail [VName]
sorted
  let sas :: [VName]
sas = forall a. Int -> [a] -> [a]
drop Int
2 [VName]
sorted

  VName
iota_n <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  Param Type
par_i <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
flag_lam <- LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam Param Type
par_i VName
sis
  VName
flag <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flag" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam

  -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n)
  Param Type
par_i' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i' :: VName
i' = forall dec. Param dec -> VName
paramName Param Type
par_i'
  Lambda SOACS
g_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_i'] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_inps") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        SubExp
im1 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i_1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
        SubExp
nmi <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
i')
        let s1 :: [DimIndex SubExp]
s1 = [forall d. d -> DimIndex d
DimFix SubExp
im1]
        let s2 :: [DimIndex SubExp]
s2 = [forall d. d -> DimIndex d
DimFix SubExp
nmi]

        -- flag array for left scan
        SubExp
f1 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"f1" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i']

        -- array for left scan
        [SubExp]
r1 <-
          forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r1"
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
f1)
              (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [SubExp]
ne)
              (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s1)

        -- array for right scan inc flag
        [SubExp]
r2 <-
          forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r2"
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
i' forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
              (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) forall a. a -> [a] -> [a]
: [SubExp]
ne)
              ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
                  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
                    forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                      (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
s2)
                      (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) forall a. a -> [a] -> [a]
: [SubExp]
ne)
                      ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimValue -> SubExp
Constant (PrimType -> PrimValue
blankPrimValue PrimType
Bool) :) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var
                          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s2
                      )
              )

        forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ SubExp
f1 forall a. a -> [a] -> [a]
: [SubExp]
r1 forall a. [a] -> [a] -> [a]
++ [SubExp]
r2

  -- scan (\(f1,v1) (f2,v2) ->
  --   let f = f1 || f2
  --   let v = if f2 then v2 else g v1 v2
  --   in (f,v) ) (false,ne) (zip flags vals)
  [Lambda SOACS]
scan_lams <-
    forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
      ( \Lambda SOACS
l -> do
          Param Type
f1 <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f1" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
          Param Type
f2 <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f2" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
          [Param Type]
ps <- forall rep. Lambda rep -> [LParam rep]
lambdaParams forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
          let ([Param Type]
p1, [Param Type]
p2) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne) [Param Type]
ps

          forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
f1 forall a. a -> [a] -> [a]
: [Param Type]
p1 forall a. [a] -> [a] -> [a]
++ Param Type
f2 forall a. a -> [a] -> [a]
: [Param Type]
p2) forall a b. (a -> b) -> a -> b
$
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
              let f :: ADM (Exp (Rep ADM))
f = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
LogOr (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f1) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
              forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
                (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM))
f forall a. a -> [a] -> [a]
: forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
p2)
                ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ADM (Exp (Rep ADM))
f :) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var)
                    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"gres"
                    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
l (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
ps)
                )
      )
      [Lambda SOACS
lam, Lambda SOACS
lam']

  let ne' :: [SubExp]
ne' = PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False) forall a. a -> [a] -> [a]
: [SubExp]
ne

  [VName]
scansres <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_ctrb_scan" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
      forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC (forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Lambda rep -> [SubExp] -> Scan rep
`Scan` [SubExp]
ne') [Lambda SOACS]
scan_lams) Lambda SOACS
g_lam)

  let (VName
_ : [VName]
ls_arr, VName
_ : [VName]
rs_arr_rev) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne forall a. Num a => a -> a -> a
+ Int
1) [VName]
scansres

  -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis
  Param Type
par_i'' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i'' :: VName
i'' = forall dec. Param dec -> VName
paramName Param Type
par_i''
  Lambda SOACS
map_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_i''] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. [a] -> a
head [SubExp]
w, VName
i''))
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
h_part_bar [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i''])
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ do
              forall a b. (a -> b) -> [a] -> [b]
map (\Type
t -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue forall a b. (a -> b) -> a -> b
$ forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) [Type]
as_type
          )

  [VName]
f_bar <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"f_bar" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
sis] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  ([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam0 [Type]
as_type SubExp
n
  Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
f_bar) [VName]
as_params Lambda SOACS
f

  -- map (\i -> rs_arr_rev[n-i-1]) (iota n)
  Param Type
par_i''' <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i''' :: VName
i''' = forall dec. Param dec -> VName
paramName Param Type
par_i'''
  Lambda SOACS
rev_lam <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
par_i'''] forall a b. (a -> b) -> a -> b
$ do
    SubExp
nmim1 <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i_1" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
le64 VName
i''' forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
    [VName] -> [SubExpRes]
varsRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
rs_arr_rev [forall d. d -> DimIndex d
DimFix SubExp
nmim1]

  [VName]
rs_arr <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"rs_arr" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam

  [VName]
sas_bar <-
    String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"sas_bar"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda SOACS
f_adj (forall a b. (a -> b) -> [a] -> [b]
map (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) forall a b. (a -> b) -> a -> b
$ [VName]
ls_arr forall a. Semigroup a => a -> a -> a
<> [VName]
sas forall a. Semigroup a => a -> a -> a
<> [VName]
rs_arr)

  [VName]
scatter_dst <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\Type
t -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
as_type
  [VName]
as_bar <- SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
siota [VName]
sas_bar

  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (forall a. [a] -> [a]
tail [VName]
as) [VName]
as_bar
  where
    -- map (\i -> if i == 0 then true else is[i] != is[i-1]) (iota n)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam LParam SOACS
par_i VName
sis =
      forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [LParam SOACS
par_i] forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
          let i :: VName
i = forall dec. Param dec -> VName
paramName LParam SOACS
par_i
          forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0))
            (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
Bool)
            ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
                forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ do
                  VName
i_p <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_p" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall a. a -> TPrimExp Int64 a
le64 VName
i forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
                  [VName]
vs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"vs" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
sis forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
i, VName
i_p]
                  let [VName
vs_i, VName
vs_p] = [VName]
vs
                  forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
le64 VName
vs_i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. a -> TPrimExp Int64 a
le64 VName
vs_p
            )