{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Reduce
  ( diffReduce,
    diffMinMaxReduce,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

eReverse :: MonadBuilder m => VName -> m VName
eReverse :: forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse VName
arr = do
  Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
  let w :: SubExp
w = forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 Type
arr_t
  SubExp
start <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rev_start" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  let stride :: SubExp
stride = IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)
      slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w SubExp
stride]
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr forall a. Semigroup a => a -> a -> a
<> String
"_rev") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
slice

eRotate :: MonadBuilder m => [SubExp] -> VName -> m VName
eRotate :: forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> VName -> m VName
eRotate [SubExp]
rots VName
arr = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
arr forall a. Semigroup a => a -> a -> a
<> String
"_rot") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots VName
arr

scanExc ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  String ->
  Scan SOACS ->
  [VName] ->
  m [VName]
scanExc :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
desc Scan SOACS
scan [VName]
arrs = do
  SubExp
w <- forall u. Int -> [TypeBase (ShapeBase SubExp) u] -> SubExp
arraysSize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
arrs
  ScremaForm SOACS
form <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan SOACS
scan]
  [VName]
res_incl <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp (String
desc forall a. Semigroup a => a -> a -> a
<> String
"_incl") forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs ScremaForm SOACS
form
  [VName]
res_incl_rot <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> VName -> m VName
eRotate [IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)]) [VName]
res_incl

  VName
iota <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  Param Type
iparam <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"iota_param" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  [Param Type]
vparams <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"vp") [Type]
ts
  let params :: [Param Type]
params = Param Type
iparam forall a. a -> [a] -> [a]
: [Param Type]
vparams

  Body SOACS
body <- forall {k1} {k2} (rep :: k1) (m :: * -> *) (somerep :: k2).
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [Param Type]
params) forall a b. (a -> b) -> a -> b
$ do
    let first_elem :: BuilderT
  SOACS
  (State VNameSource)
  (Exp (Rep (BuilderT SOACS (State VNameSource))))
first_elem =
          forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
            (PrimType -> CmpOp
CmpEq PrimType
int64)
            (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param Type
iparam)))
            (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0))
    forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
      [ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          BuilderT
  SOACS
  (State VNameSource)
  (Exp (Rep (BuilderT SOACS (State VNameSource))))
first_elem
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
nes)
          (forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
vparams)
      ]

  let lam :: Lambda SOACS
lam = forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
params Body SOACS
body [Type]
ts
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota forall a. a -> [a] -> [a]
: [VName]
res_incl_rot) (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
  where
    nes :: [SubExp]
nes = forall {k} (rep :: k). Scan rep -> [SubExp]
scanNeutral Scan SOACS
scan
    ts :: [Type]
ts = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan SOACS
scan

mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam = do
  Lambda SOACS
lam_l <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  Lambda SOACS
lam_r <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  let q :: Int
q = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      ([Param Type]
lps, [Param Type]
aps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
      ([Param Type]
ips, [Param Type]
rps) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
q forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
  Lambda SOACS
lam' <- forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) forall a b. (a -> b) -> a -> b
$ do
    Result
lam_l_res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips Result
lam_l_res) forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
ip] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
aps, Lambda SOACS
lam')

diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce VjpOps
_ops [VName
adj] SubExp
w [VName
a] Reduce SOACS
red
  | Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall {k} (rep :: k).
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
    BinOp -> Bool
isAdd BinOp
op = do
      VName
adj_rep <-
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
adj forall a. Semigroup a => a -> a -> a
<> String
"_rep") forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            ShapeBase SubExp -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
adj
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
a VName
adj_rep
  where
    isAdd :: BinOp -> Bool
isAdd FAdd {} = Bool
True
    isAdd Add {} = Bool
True
    isAdd BinOp
_ = Bool
False
--
-- Differentiating a general single reduce:
--    let y = reduce \odot ne as
-- Forward sweep:
--    let ls = scan_exc \odot  ne as
--    let rs = scan_exc \odot' ne (reverse as)
-- Reverse sweep:
--    let as_c = map3 (f_bar y_bar) ls as (reverse rs)
-- where
--   x \odot' y = y \odot x
--   y_bar is the adjoint of the result y
--   f l_i a_i r_i = l_i \odot a_i \odot r_i
--   f_bar = the reverse diff of f with respect to a_i under the adjoint y_bar
-- The plan is to create
--   one scanomap SOAC which computes ls and rs
--   another map which computes as_c
--
diffReduce VjpOps
ops [VName]
pat_adj SubExp
w [VName]
as Reduce SOACS
red = do
  Reduce SOACS
red' <- forall {k} {f :: * -> *} {rep :: k}.
(Rename (LetDec rep), Rename (ExpDec rep), Rename (BodyDec rep),
 Rename (FParamInfo rep), Rename (LParamInfo rep),
 Rename (RetType rep), Rename (BranchType rep), Rename (Op rep),
 MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed Reduce SOACS
red
  Reduce SOACS
flip_red <- forall {k} {f :: * -> *} {rep :: k}.
(Rename (LetDec rep), Rename (ExpDec rep), Rename (BodyDec rep),
 Rename (FParamInfo rep), Rename (LParamInfo rep),
 Rename (RetType rep), Rename (BranchType rep), Rename (Op rep),
 MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
renameRed forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} {f :: * -> *} {rep :: k}.
(Rename (LetDec rep), Rename (ExpDec rep), Rename (BodyDec rep),
 Rename (FParamInfo rep), Rename (LParamInfo rep),
 Rename (RetType rep), Rename (BranchType rep), Rename (Op rep),
 MonadFreshNames f) =>
Reduce rep -> f (Reduce rep)
flipReduce Reduce SOACS
red
  [VName]
ls <- forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
"ls" (Reduce SOACS -> Scan SOACS
redToScan Reduce SOACS
red') [VName]
as
  [VName]
rs <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
String -> Scan SOACS -> [VName] -> m [VName]
scanExc String
"ls" (Reduce SOACS -> Scan SOACS
redToScan Reduce SOACS
flip_red)
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *). MonadBuilder m => VName -> m VName
eReverse [VName]
as

  ([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> ADM ([VName], Lambda SOACS)
mkF forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda Reduce SOACS
red

  Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
pat_adj) [VName]
as_params Lambda SOACS
f

  [VName]
as_adj <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adjs" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
ls forall a. [a] -> [a] -> [a]
++ [VName]
as forall a. [a] -> [a] -> [a]
++ [VName]
rs) (forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
f_adj)

  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_adj
  where
    renameRed :: Reduce rep -> f (Reduce rep)
renameRed (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) =
      forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes

    redToScan :: Reduce SOACS -> Scan SOACS
    redToScan :: Reduce SOACS -> Scan SOACS
redToScan (Reduce Commutativity
_ Lambda SOACS
lam [SubExp]
nes) = forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
nes
    flipReduce :: Reduce rep -> m (Reduce rep)
flipReduce (Reduce Commutativity
comm Lambda rep
lam [SubExp]
nes) = do
      Lambda rep
lam' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = forall {a}. [a] -> [a]
flipParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda rep
lam' [SubExp]
nes
    flipParams :: [a] -> [a]
flipParams [a]
ps = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. [a] -> [a] -> [a]
(++)) forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
ps forall a. Integral a => a -> a -> a
`div` Int
2) [a]
ps

--
-- Special case of reduce with min/max:
--    let x = reduce minmax ne as
-- Forward trace (assuming w = length as):
--    let (x, x_ind) =
--      reduce (\ acc_v acc_i v i ->
--                 if (acc_v == v) then (acc_v, min acc_i i)
--                 else if (acc_v == minmax acc_v v)
--                      then (acc_v, acc_i)
--                      else (v, i))
--             (ne_min, -1)
--             (zip as (iota w))
-- Reverse trace:
--    num_elems = i64.bool (0 <= x_ind)
--    m_bar_repl = replicate num_elems m_bar
--    as_bar[x_ind:num_elems:1] += m_bar_repl
diffMinMaxReduce ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxReduce :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxReduce VjpOps
_ops VName
x StmAux ()
aux SubExp
w BinOp
minmax SubExp
ne VName
as ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax

  Param Type
acc_v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
acc_i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
v_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
i_p <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
red_lam <-
    forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
acc_v_p, Param Type
acc_i_p, Param Type
v_p, Param Type
i_p] forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
                forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
              ]
          )
          ( forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
                      (PrimType -> CmpOp
CmpEq PrimType
t)
                      (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
                      (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
                  )
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
                  (forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
              ]
          )

  VName
red_iota <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
        SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  ScremaForm SOACS
form <- forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
red_lam [SubExp
ne, IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)]]
  VName
x_ind <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x forall a. Semigroup a => a -> a -> a
<> String
"_ind")
  forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_ind] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
as, VName
red_iota] ScremaForm SOACS
form

  ADM ()
m

  VName
x_adj <- VName -> ADM VName
lookupAdjVal VName
x
  SubExp
in_bounds <-
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minmax_in_bounds" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) SubExp
w
  VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
as (Maybe SubExp -> InBounds
CheckBounds (forall a. a -> Maybe a
Just SubExp
in_bounds), VName -> SubExp
Var VName
x_ind) (VName -> SubExp
Var VName
x_adj)