{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Reduce
  ( diffReduce,
    diffMinMaxReduce,
    diffVecReduce,
    diffMulReduce,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

eReverse :: (MonadBuilder m) => VName -> m VName
eReverse :: forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse VName
arr = do
  Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  let w :: SubExp
w = Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
arr_t
  SubExp
start <-
    String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rev_start" (Exp (Rep m) -> m SubExp)
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m SubExp) -> BasicOp -> m SubExp
forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  let stride :: SubExp
stride = IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)
      slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w SubExp
stride]
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_rev") (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

scanExc ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  String ->
  Scan SOACS ->
  [VName] ->
  m [VName]
scanExc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
desc Scan SOACS
scan [VName]
arrs = do
  SubExp
w <- Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([Type] -> SubExp) -> m [Type] -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
  ScremaForm SOACS
form <- [Scan SOACS] -> m (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan SOACS
scan]
  [VName]
res_incl <- String -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp (String
desc String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_incl") (Exp (Rep m) -> m [VName]) -> Exp (Rep m) -> m [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs ScremaForm SOACS
form

  VName
iota <-
    String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp SOACS -> m VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  Param Type
iparam <- String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"iota_param" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

  Lambda SOACS
lam <- [LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep m)
iparam] (m Result -> m (Lambda (Rep m))) -> m Result -> m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$ do
    let first_elem :: m (Exp (Rep m))
first_elem =
          CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
            (PrimType -> CmpOp
CmpEq PrimType
int64)
            (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam)))
            (SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0))
        prev :: m (Exp (Rep m))
prev = TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> m (Exp (Rep m)))
-> TPrimExp Int64 VName -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
    ([SubExp] -> Result) -> m [SubExp] -> m Result
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes (m [SubExp] -> m Result)
-> (Exp SOACS -> m [SubExp]) -> Exp SOACS -> m Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep m) -> m [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"scan_ex_res"
      (Exp SOACS -> m Result) -> m (Exp SOACS) -> m Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        m (Exp (Rep m))
first_elem
        ([SubExp] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> m (Body (Rep m))) -> [SubExp] -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan)
        ([m (Exp (Rep m))] -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([m (Exp (Rep m))] -> m (Body (Rep m)))
-> [m (Exp (Rep m))] -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ (VName -> m (Exp SOACS)) -> [VName] -> [m (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [m (Exp (Rep m))
prev]) [VName]
res_incl)

  String -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc (Exp (Rep m) -> m [VName]) -> Exp (Rep m) -> m [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)

mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam = do
  Lambda SOACS
lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  Lambda SOACS
lam_r <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  let q :: Int
q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      ([Param Type]
lps, [Param Type]
aps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
      ([Param Type]
ips, [Param Type]
rps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
  Lambda SOACS
lam' <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    Result
lam_l_res <- Body (Rep ADM) -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep ADM) -> ADM Result) -> Body (Rep ADM) -> ADM Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
    [(Param Type, SubExpRes)]
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> Result -> [(Param Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips Result
lam_l_res) (((Param Type, SubExpRes) -> ADM ()) -> ADM ())
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    Body (Rep ADM) -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep ADM) -> ADM Result) -> Body (Rep ADM) -> ADM Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r
  ([VName], Lambda SOACS) -> ADM ([VName], Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
aps, Lambda SOACS
lam')

diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce VjpOps
_ops [VName
adj] SubExp
w [VName
a] Reduce SOACS
red
  | Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp (Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)])
-> Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
    BinOp -> Bool
isAdd BinOp
op = do
      VName
adj_rep <-
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
adj String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_rep") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
            Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
adj
      ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
a VName
adj_rep
  where
    isAdd :: BinOp -> Bool
isAdd FAdd {} = Bool
True
    isAdd Add {} = Bool
True
    isAdd BinOp
_ = Bool
False
--
-- Differentiating a general single reduce:
--    let y = reduce \odot ne as
-- Forward sweep:
--    let ls = scan_exc \odot  ne as
--    let rs = scan_exc \odot' ne (reverse as)
-- Reverse sweep:
--    let as_c = map3 (f_bar y_bar) ls as (reverse rs)
-- where
--   x \odot' y = y \odot x
--   y_bar is the adjoint of the result y
--   f l_i a_i r_i = l_i \odot a_i \odot r_i
--   f_bar = the reverse diff of f with respect to a_i under the adjoint y_bar
-- The plan is to create
--   one scanomap SOAC which computes ls and rs
--   another map which computes as_c
--
diffReduce VjpOps
ops [VName]
pat_adj SubExp
w [VName]
as Reduce SOACS
red = do
  Reduce SOACS
red' <- Reduce SOACS -> ADM (Reduce SOACS)
forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed Reduce SOACS
red
  Reduce SOACS
flip_red <- Reduce SOACS -> ADM (Reduce SOACS)
forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed (Reduce SOACS -> ADM (Reduce SOACS))
-> ADM (Reduce SOACS) -> ADM (Reduce SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Reduce SOACS -> ADM (Reduce SOACS)
forall {f :: * -> *} {rep}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
 Rename (BodyDec rep), Rename (FParamInfo rep),
 Rename (LParamInfo rep), Rename (RetType rep),
 Rename (BranchType rep), MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
flipReduce Reduce SOACS
red
  [VName]
ls <- String -> Scan SOACS -> [VName] -> ADM [VName]
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
"ls" (Reduce SOACS -> Scan SOACS
redToScan Reduce SOACS
red') [VName]
as
  [VName]
rs <-
    (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse
      ([VName] -> ADM [VName]) -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Scan SOACS -> [VName] -> ADM [VName]
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
"ls" (Reduce SOACS -> Scan SOACS
redToScan Reduce SOACS
flip_red)
      ([VName] -> ADM [VName]) -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse [VName]
as

  ([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF (Lambda SOACS -> ADM ([VName], Lambda SOACS))
-> Lambda SOACS -> ADM ([VName], Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red

  Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
pat_adj) [VName]
as_params Lambda SOACS
f

  [VName]
as_adj <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adjs" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
ls [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
rs) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
f_adj)

  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_adj
  where
    renameRed :: Reduce rep -> f (Reduce rep)
renameRed (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) =
      Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm (Lambda rep -> [SubExp] -> Reduce rep)
-> f (Lambda rep) -> f ([SubExp] -> Reduce rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda rep -> f (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam f ([SubExp] -> Reduce rep) -> f [SubExp] -> f (Reduce rep)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> f [SubExp]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes

    redToScan :: Reduce SOACS -> Scan SOACS
    redToScan :: Reduce SOACS -> Scan SOACS
redToScan (Reduce Commutativity
_ Lambda SOACS
lam [SubExp]
nes) = Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
nes
    flipReduce :: Reduce rep -> m (Reduce rep)
flipReduce (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) = do
      Lambda rep
lam' <- Lambda rep -> m (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = [LParam rep] -> [LParam rep]
forall {a}. [a] -> [a]
flipParams ([LParam rep] -> [LParam rep]) -> [LParam rep] -> [LParam rep]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}
      Reduce rep -> m (Reduce rep)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Reduce rep -> m (Reduce rep)) -> Reduce rep -> m (Reduce rep)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda rep
lam' [SubExp]
nes
    flipParams :: [a] -> [a]
flipParams [a]
ps = ([a] -> [a] -> [a]) -> ([a], [a]) -> [a]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (([a] -> [a] -> [a]) -> [a] -> [a] -> [a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
(++)) (([a], [a]) -> [a]) -> ([a], [a]) -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ps Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [a]
ps

--
-- Special case of reduce with min/max:
--    let x = reduce minmax ne as
-- Forward trace (assuming w = length as):
--    let (x, x_ind) =
--      reduce (\ acc_v acc_i v i ->
--                 if (acc_v == v) then (acc_v, min acc_i i)
--                 else if (acc_v == minmax acc_v v)
--                      then (acc_v, acc_i)
--                      else (v, i))
--             (ne_min, -1)
--             (zip as (iota w))
-- Reverse trace:
--    num_elems = i64.bool (0 <= x_ind)
--    m_bar_repl = replicate num_elems m_bar
--    as_bar[x_ind:num_elems:1] += m_bar_repl
diffMinMaxReduce ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
minmax SubExp
ne VName
as ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax

  Param Type
acc_v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
acc_i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
red_lam <-
    [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
acc_v_p, Param Type
LParam (Rep ADM)
acc_i_p, Param Type
LParam (Rep ADM)
v_p, Param Type
LParam (Rep ADM)
i_p] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
        (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
                BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
              ]
          )
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
                      (PrimType -> CmpOp
CmpEq PrimType
t)
                      (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
                      (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
                  )
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
              ]
          )

  VName
red_iota <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
        SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  ScremaForm SOACS
form <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
red_lam [SubExp
ne, IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)]]
  VName
x_ind <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ind")
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_ind] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
as, VName
red_iota] ScremaForm SOACS
form

  ADM ()
m

  VName
x_adj <- VName -> ADM VName
lookupAdjVal VName
x
  SubExp
in_bounds <-
    String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minmax_in_bounds" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
      CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w
  VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
as (Maybe SubExp -> InBounds
CheckBounds (SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
in_bounds), VName -> SubExp
Var VName
x_ind) (VName -> SubExp
Var VName
x_adj)

--
-- Special case of vectorised reduce:
--    let x = reduce (map2 op) nes as
-- Idea:
--    rewrite to
--      let x = map2 (\as ne -> reduce op ne as) (transpose as) nes
--    and diff
diffVecReduce ::
  VjpOps -> Pat Type -> StmAux () -> SubExp -> Commutativity -> Lambda SOACS -> VName -> VName -> ADM () -> ADM ()
diffVecReduce :: VjpOps
-> Pat Type
-> StmAux ()
-> SubExp
-> Commutativity
-> Lambda SOACS
-> VName
-> VName
-> ADM ()
-> ADM ()
diffVecReduce VjpOps
ops Pat Type
x StmAux ()
aux SubExp
w Commutativity
iscomm Lambda SOACS
lam VName
ne VName
as ADM ()
m = do
  Seq (Stm SOACS)
stms <- ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> ADM Type -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
as
    let rear :: [Int]
rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    VName
tran_as <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"tran_as" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
rear VName
as
    Type
ts <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
tran_as
    Type
t_ne <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ne

    Param Type
as_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"as_param" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
ts
    Param Type
ne_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ne_param" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t_ne

    ScremaForm SOACS
reduce_form <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
iscomm Lambda SOACS
lam [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ne_param]]

    Lambda SOACS
map_lam <-
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
as_param, Param Type
LParam (Rep ADM)
ne_param] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
        ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res" (Exp SOACS -> ADM Result) -> Exp SOACS -> ADM Result
forall a b. (a -> b) -> a -> b
$
          OpC SOACS SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (OpC SOACS SOACS -> Exp SOACS) -> OpC SOACS SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
            SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
as_param] ScremaForm SOACS
reduce_form
    Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep ADM))
-> StmAux (ExpDec (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec (Rep ADM))
x StmAux ()
StmAux (ExpDec (Rep ADM))
aux (Exp (Rep ADM) -> Stm (Rep ADM)) -> Exp (Rep ADM) -> Stm (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
ts) [VName
tran_as, VName
ne] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  (Stm SOACS -> ADM () -> ADM ())
-> ADM () -> Seq (Stm SOACS) -> ADM ()
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops) ADM ()
m Seq (Stm SOACS)
stms

--
-- Special case of reduce with mul:
--    let x = reduce (*) ne as
-- Forward trace (assuming w = length as):
--    let (p, z) = map (\a -> if a == 0 then (1, 1) else (a, 0)) as
--    non_zero_prod = reduce (*) ne p
--    zr_count = reduce (+) 0 z
--    let x =
--      if 0 == zr_count
--      then non_zero_prod
--      else 0
-- Reverse trace:
--    as_bar = map2
--      (\a a_bar ->
--        if zr_count == 0
--        then a_bar + non_zero_prod/a * x_bar
--        else if zr_count == 1
--        then a_bar + (if a == 0 then non_zero_prod * x_bar else 0)
--        else as_bar
--      ) as as_bar
diffMulReduce ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMulReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
mul SubExp
ne VName
as ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
  let const_zero :: ADM (Exp (Rep ADM))
const_zero = SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t

  Param Type
a_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
a_param] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param) ADM (Exp (Rep ADM))
const_zero)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
t, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1])
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param, SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])

  VName
ps <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ps"
  VName
zs <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zs"
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
    [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
ps, VName
zs] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
      Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
as] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$
          Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  Lambda SOACS
red_lam_mul <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
red_lam_add <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) PrimType
int64

  ScremaForm SOACS
red_form_mul <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC ([Reduce SOACS] -> ADM (ScremaForm SOACS))
-> [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> [Reduce SOACS]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Reduce SOACS -> [Reduce SOACS]) -> Reduce SOACS -> [Reduce SOACS]
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
red_lam_mul ([SubExp] -> Reduce SOACS) -> [SubExp] -> Reduce SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
ne
  ScremaForm SOACS
red_form_add <- [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC ([Reduce SOACS] -> ADM (ScremaForm SOACS))
-> [Reduce SOACS] -> ADM (ScremaForm SOACS)
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> [Reduce SOACS]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Reduce SOACS -> [Reduce SOACS]) -> Reduce SOACS -> [Reduce SOACS]
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
red_lam_add ([SubExp] -> Reduce SOACS) -> [SubExp] -> Reduce SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0

  VName
nz_prods <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"non_zero_prod"
  VName
zr_count <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zero_count"
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
nz_prods] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
ps] ScremaForm SOACS
red_form_mul
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
zr_count] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
zs] ScremaForm SOACS
red_form_add

  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
    [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x]
      (Exp SOACS -> ADM ()) -> ADM (Exp SOACS) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
zr_count)
        ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nz_prods)
        ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp SOACS) -> [ADM (Exp SOACS)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ADM (Exp (Rep ADM))
ADM (Exp SOACS)
const_zero)

  ADM ()
m

  VName
x_adj <- VName -> ADM VName
lookupAdjVal VName
x

  Param Type
a_param_rev <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
map_lam_rev <-
    [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
a_param_rev] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_res"
        (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
zr_count)
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
              ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$
                BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x_adj) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                  BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getDiv PrimType
t) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
nz_prods) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                    Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param_rev
          )
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
              ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$
                ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
1 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
zr_count)
                  ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                      ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$
                        ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param_rev) ADM (Exp (Rep ADM))
const_zero)
                          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                              ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$
                                BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x_adj) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                                  SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                                    VName -> SubExp
Var VName
nz_prods
                          )
                          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp SOACS) -> [ADM (Exp SOACS)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ADM (Exp (Rep ADM))
ADM (Exp SOACS)
const_zero)
                  )
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp SOACS) -> [ADM (Exp SOACS)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ADM (Exp (Rep ADM))
ADM (Exp SOACS)
const_zero)
          )

  VName
as_adjup <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"adjs" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
as] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam_rev

  VName -> VName -> ADM ()
updateAdj VName
as VName
as_adjup
  where
    getDiv :: PrimType -> BinOp
    getDiv :: PrimType -> BinOp
getDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
    getDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
    getDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getDiv, Reduce.hs: input not supported"