{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Reduce
  ( diffReduce,
    diffMinMaxReduce,
    diffVecReduce,
    diffMulReduce,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

eReverse :: MonadBuilder m => VName -> m VName
eReverse :: forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse VName
arr = do
  Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  let w :: SubExp
w = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
  SubExp
start <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rev_start" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  let stride :: SubExp
stride = IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)
      slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w SubExp
stride]
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr forall a. Semigroup a => a -> a -> a
<> String
"_rev") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

eRotate :: MonadBuilder m => [SubExp] -> VName -> m VName
eRotate :: forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> VName -> m VName
eRotate [SubExp]
rots VName
arr = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr forall a. Semigroup a => a -> a -> a
<> String
"_rot") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots VName
arr

scanExc ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  String ->
  Scan SOACS ->
  [VName] ->
  m [VName]
scanExc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
desc Scan SOACS
scan [VName]
arrs = do
  SubExp
w <- forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
  ScremaForm SOACS
form <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan SOACS
scan]
  [VName]
res_incl <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp (String
desc forall a. Semigroup a => a -> a -> a
<> String
"_incl") forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs ScremaForm SOACS
form
  [VName]
res_incl_rot <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> VName -> m VName
eRotate [IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)]) [VName]
res_incl

  VName
iota <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  Param Type
iparam <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"iota_param" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  [Param Type]
vparams <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"vp") [Type]
ts
  let params :: [Param Type]
params = Param Type
iparam forall a. a -> [a] -> [a]
: [Param Type]
vparams

  Body SOACS
body <- forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
params) forall a b. (a -> b) -> a -> b
$ do
    let first_elem :: BuilderT
  SOACS
  (State VNameSource)
  (Exp (Rep (BuilderT SOACS (State VNameSource))))
first_elem =
          forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
            (PrimType -> CmpOp
CmpEq PrimType
int64)
            (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param Type
iparam)))
            (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0))
    forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
      [ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          BuilderT
  SOACS
  (State VNameSource)
  (Exp (Rep (BuilderT SOACS (State VNameSource))))
first_elem
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
nes)
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
vparams)
      ]

  let lam :: Lambda SOACS
lam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
params Body SOACS
body [Type]
ts
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota forall a. a -> [a] -> [a]
: [VName]
res_incl_rot) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
  where
    nes :: [SubExp]
nes = forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
    ts :: [Type]
ts = forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan

mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam = do
  Lambda SOACS
lam_l <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  Lambda SOACS
lam_r <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  let q :: Int
q = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      ([Param Type]
lps, [Param Type]
aps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
      ([Param Type]
ips, [Param Type]
rps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
  Lambda SOACS
lam' <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) forall a b. (a -> b) -> a -> b
$ do
    Result
lam_l_res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips Result
lam_l_res) forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
ip] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
aps, Lambda SOACS
lam')

diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce VjpOps
_ops [VName
adj] SubExp
w [VName
a] Reduce SOACS
red
  | Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp forall a b. (a -> b) -> a -> b
$ forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
    BinOp -> Bool
isAdd BinOp
op = do
      VName
adj_rep <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
adj forall a. Semigroup a => a -> a -> a
<> String
"_rep") forall a b. (a -> b) -> a -> b
$
          forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
adj
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
a VName
adj_rep
  where
    isAdd :: BinOp -> Bool
isAdd FAdd {} = Bool
True
    isAdd Add {} = Bool
True
    isAdd BinOp
_ = Bool
False
--
-- Differentiating a general single reduce:
--    let y = reduce \odot ne as
-- Forward sweep:
--    let ls = scan_exc \odot  ne as
--    let rs = scan_exc \odot' ne (reverse as)
-- Reverse sweep:
--    let as_c = map3 (f_bar y_bar) ls as (reverse rs)
-- where
--   x \odot' y = y \odot x
--   y_bar is the adjoint of the result y
--   f l_i a_i r_i = l_i \odot a_i \odot r_i
--   f_bar = the reverse diff of f with respect to a_i under the adjoint y_bar
-- The plan is to create
--   one scanomap SOAC which computes ls and rs
--   another map which computes as_c
--
diffReduce VjpOps
ops [VName]
pat_adj SubExp
w [VName]
as Reduce SOACS
red = do
  Reduce SOACS
red' <- forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed Reduce SOACS
red
  Reduce SOACS
flip_red <- forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
flipReduce Reduce SOACS
red
  [VName]
ls <- forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
"ls" (Reduce SOACS -> Scan SOACS
redToScan Reduce SOACS
red') [VName]
as
  [VName]
rs <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
"ls" (Reduce SOACS -> Scan SOACS
redToScan Reduce SOACS
flip_red)
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse [VName]
as

  ([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF forall a b. (a -> b) -> a -> b
$ forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red

  Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
pat_adj) [VName]
as_params Lambda SOACS
f

  [VName]
as_adj <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adjs" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
ls forall a. [a] -> [a] -> [a]
++ [VName]
as forall a. [a] -> [a] -> [a]
++ [VName]
rs) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
f_adj)

  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_adj
  where
    renameRed :: Reduce rep -> f (Reduce rep)
renameRed (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) =
      forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes

    redToScan :: Reduce SOACS -> Scan SOACS
    redToScan :: Reduce SOACS -> Scan SOACS
redToScan (Reduce Commutativity
_ Lambda SOACS
lam [SubExp]
nes) = forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
nes
    flipReduce :: Reduce rep -> m (Reduce rep)
flipReduce (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) = do
      Lambda rep
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = forall {a}. [a] -> [a]
flipParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda rep
lam' [SubExp]
nes
    flipParams :: [a] -> [a]
flipParams [a]
ps = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. [a] -> [a] -> [a]
(++)) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ps forall a. Integral a => a -> a -> a
`div` Int
2) [a]
ps

--
-- Special case of reduce with min/max:
--    let x = reduce minmax ne as
-- Forward trace (assuming w = length as):
--    let (x, x_ind) =
--      reduce (\ acc_v acc_i v i ->
--                 if (acc_v == v) then (acc_v, min acc_i i)
--                 else if (acc_v == minmax acc_v v)
--                      then (acc_v, acc_i)
--                      else (v, i))
--             (ne_min, -1)
--             (zip as (iota w))
-- Reverse trace:
--    num_elems = i64.bool (0 <= x_ind)
--    m_bar_repl = replicate num_elems m_bar
--    as_bar[x_ind:num_elems:1] += m_bar_repl
diffMinMaxReduce ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
minmax SubExp
ne VName
as ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax

  Param Type
acc_v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
acc_i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
red_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
acc_v_p, Param Type
acc_i_p, Param Type
v_p, Param Type
i_p] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
                forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
              ]
          )
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
                      (PrimType -> CmpOp
CmpEq PrimType
t)
                      (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
                      (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
                  )
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
              ]
          )

  VName
red_iota <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall a b. (a -> b) -> a -> b
$
      forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  ScremaForm SOACS
form <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
red_lam [SubExp
ne, IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)]]
  VName
x_ind <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_ind")
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_ind] forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
as, VName
red_iota] ScremaForm SOACS
form

  ADM ()
m

  VName
x_adj <- VName -> ADM VName
lookupAdjVal VName
x
  SubExp
in_bounds <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minmax_in_bounds" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w
  VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
as (Maybe SubExp -> InBounds
CheckBounds (forall a. a -> Maybe a
Just SubExp
in_bounds), VName -> SubExp
Var VName
x_ind) (VName -> SubExp
Var VName
x_adj)

--
-- Special case of vectorised reduce:
--    let x = reduce (map2 op) nes as
-- Idea:
--    rewrite to
--      let x = map2 (\as ne -> reduce op ne as) (transpose as) nes
--    and diff
diffVecReduce ::
  VjpOps -> Pat Type -> StmAux () -> SubExp -> Commutativity -> Lambda SOACS -> VName -> VName -> ADM () -> ADM ()
diffVecReduce :: VjpOps
-> Pat Type
-> StmAux ()
-> SubExp
-> Commutativity
-> Lambda SOACS
-> VName
-> VName
-> ADM ()
-> ADM ()
diffVecReduce VjpOps
ops Pat Type
x StmAux ()
aux SubExp
w Commutativity
iscomm Lambda SOACS
lam VName
ne VName
as ADM ()
m = do
  Seq (Stm SOACS)
stms <- forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ forall a b. (a -> b) -> a -> b
$ do
    Int
rank <- forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
as
    let rear :: [Int]
rear = [Int
1, Int
0] forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1]

    VName
tran_as <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"tran_as" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
rear VName
as
    Type
ts <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
tran_as
    Type
t_ne <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ne

    Param Type
as_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"as_param" forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
ts
    Param Type
ne_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ne_param" forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t_ne

    ScremaForm SOACS
reduce_form <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
iscomm Lambda SOACS
lam [VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
ne_param]]

    Lambda SOACS
map_lam <-
      forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
as_param, Param Type
ne_param] forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res" forall a b. (a -> b) -> a -> b
$
          forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
            forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [forall dec. Param dec -> VName
paramName Param Type
as_param] ScremaForm SOACS
reduce_form
    forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
x StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
ts) [VName
tran_as, VName
ne] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops) ADM ()
m Seq (Stm SOACS)
stms

--
-- Special case of reduce with mul:
--    let x = reduce (*) ne as
-- Forward trace (assuming w = length as):
--    let (p, z) = map (\a -> if a == 0 then (1, 1) else (a, 0)) as
--    non_zero_prod = reduce (*) ne p
--    zr_count = reduce (+) 0 z
--    let x =
--      if 0 == zr_count
--      then non_zero_prod
--      else 0
-- Reverse trace:
--    as_bar = map2
--      (\a a_bar ->
--        if zr_count == 0
--        then a_bar + non_zero_prod/a * x_bar
--        else if zr_count == 1
--        then a_bar + (if a == 0 then non_zero_prod * x_bar else 0)
--        else as_bar
--      ) as as_bar
diffMulReduce ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMulReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
mul SubExp
ne VName
as ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
  let const_zero :: ADM (Exp (Rep ADM))
const_zero = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t

  Param Type
a_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
map_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
a_param] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param) ADM (Exp (Rep ADM))
const_zero)
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
t, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1])
          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param, forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])

  VName
ps <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ps"
  VName
zs <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zs"
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
ps, VName
zs] forall a b. (a -> b) -> a -> b
$
      forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
        forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
as] forall a b. (a -> b) -> a -> b
$
          forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  Lambda SOACS
red_lam_mul <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
red_lam_add <- forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) PrimType
int64

  ScremaForm SOACS
red_form_mul <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
red_lam_mul forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
ne
  ScremaForm SOACS
red_form_add <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
red_lam_add forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0

  VName
nz_prods <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"non_zero_prod"
  VName
zr_count <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zero_count"
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
nz_prods] forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
ps] ScremaForm SOACS
red_form_mul
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
zr_count] forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
zs] ScremaForm SOACS
red_form_add

  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x]
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. a -> TPrimExp Int64 a
le64 VName
zr_count)
        (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nz_prods)
        (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ADM (Exp (Rep ADM))
const_zero)

  ADM ()
m

  VName
x_adj <- VName -> ADM VName
lookupAdjVal VName
x

  Param Type
a_param_rev <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
map_lam_rev <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
a_param_rev] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. a -> TPrimExp Int64 a
le64 VName
zr_count)
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x_adj) forall a b. (a -> b) -> a -> b
$
                  forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getDiv PrimType
t) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nz_prods) forall a b. (a -> b) -> a -> b
$
                    forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param_rev
          )
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  (forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
1 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall a. a -> TPrimExp Int64 a
le64 VName
zr_count)
                  ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
                      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                        forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                          (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param_rev) ADM (Exp (Rep ADM))
const_zero)
                          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$
                              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
                                forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x_adj) forall a b. (a -> b) -> a -> b
$
                                  forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$
                                    VName -> SubExp
Var VName
nz_prods
                          )
                          (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ADM (Exp (Rep ADM))
const_zero)
                  )
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ADM (Exp (Rep ADM))
const_zero)
          )

  VName
as_adjup <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"adjs" forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
as] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam_rev

  VName -> VName -> ADM ()
updateAdj VName
as VName
as_adjup
  where
    getDiv :: PrimType -> BinOp
    getDiv :: PrimType -> BinOp
getDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
    getDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
    getDiv PrimType
_ = forall a. HasCallStack => String -> a
error String
"In getDiv, Reduce.hs: input not supported"