{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Hist
  ( diffMinMaxHist,
    diffMulHist,
    diffAddHist,
    diffHist,
  )
where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename

getBinOpPlus :: PrimType -> BinOp
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus (IntType IntType
x) = IntType -> Overflow -> BinOp
Add IntType
x Overflow
OverflowUndef
getBinOpPlus (FloatType FloatType
f) = FloatType -> BinOp
FAdd FloatType
f
getBinOpPlus PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpMul, Hist.hs: input not supported"

getBinOpDiv :: PrimType -> BinOp
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getBinOpDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getBinOpDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpDiv, Hist.hs: input not supported"

withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = PrimExp VName -> TPrimExp Bool VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Bool VName)
-> PrimExp VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis

elseIf :: PrimType -> [(ADM (Exp SOACS), ADM (Exp SOACS))] -> [ADM (Body SOACS)] -> ADM (Exp SOACS)
elseIf :: PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2)] [ADM (Body SOACS)
bt, ADM (Body SOACS)
bf] =
  ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c1 ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c2)
    ADM (Body (Rep ADM))
ADM (Body SOACS)
bt
    ADM (Body (Rep ADM))
ADM (Body SOACS)
bf
elseIf PrimType
t ((ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2) : [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs) (ADM (Body SOACS)
bt : [ADM (Body SOACS)]
bs) =
  ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
    (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c1 ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c2)
    ADM (Body (Rep ADM))
ADM (Body SOACS)
bt
    (ADM (Body (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
    ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp SOACS) -> [ADM (Exp SOACS)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    (ADM (Exp SOACS) -> [ADM (Exp SOACS)])
-> ADM (Exp SOACS) -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> a -> b
$ PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs [ADM (Body SOACS)]
bs
elseIf PrimType
_ [(ADM (Exp SOACS), ADM (Exp SOACS))]
_ [ADM (Body SOACS)]
_ = String -> ADM (Exp SOACS)
forall a. HasCallStack => String -> a
error String
"In elseIf, Hist.hs: input not supported"

bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
s =
  (SubExpRes -> ADM VName) -> [SubExpRes] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ( \(SubExpRes Certs
cs SubExp
se) -> do
        VName
bn <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
s
        Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
bn] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
bn
    )

nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [] [PrimType]
_ Lambda SOACS
lam = Lambda SOACS -> ADM (Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
nestedmap s :: [SubExp]
s@(SubExp
h : [SubExp]
r) [PrimType]
pt Lambda SOACS
lam = do
  [Param Type]
params <- (PrimType -> ADM (Param Type)) -> [PrimType] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\PrimType
tp -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
tp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
s) NoUniqueness
NoUniqueness) [PrimType]
pt
  Lambda SOACS
body <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
r [PrimType]
pt Lambda SOACS
lam
  [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
    ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [SubExpRes]) -> SOAC SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
h ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
body)

-- \ds hs -> map2 lam ds hs
mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' :: Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam [Type]
tps SubExp
n = do
  Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam

  [Param Type]
ds_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ds_param") [Type]
tps
  [Param Type]
hs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"hs_param") [Type]
tps
  let ds_pars :: [VName]
ds_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ds_params
  let hs_pars :: [VName]
hs_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
hs_params
  Lambda SOACS
lam_map <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
ds_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
hs_params) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f'" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [SubExpRes]) -> SOAC SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
ds_pars [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
hs_pars) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam')

  ([VName], [VName], Lambda SOACS)
-> ADM ([VName], [VName], Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
ds_pars, [VName]
hs_pars, Lambda SOACS
lam_map)

-- \ls as rs -> map3 (\li ai ri -> li `lam` ai `lam` ri) ls as rs
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam [Type]
tps SubExp
n = do
  Lambda SOACS
lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  Lambda SOACS
lam_r <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
  let q :: Int
q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      ([Param Type]
lps, [Param Type]
aps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
      ([Param Type]
ips, [Param Type]
rps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
  Lambda SOACS
lam' <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    [SubExpRes]
lam_l_res <- Body (Rep ADM) -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep ADM) -> ADM [SubExpRes])
-> Body (Rep ADM) -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
    [(Param Type, SubExpRes)]
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExpRes] -> [(Param Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips [SubExpRes]
lam_l_res) (((Param Type, SubExpRes) -> ADM ()) -> ADM ())
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
      Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
    Body (Rep ADM) -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep ADM) -> ADM [SubExpRes])
-> Body (Rep ADM) -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r

  [Param Type]
ls_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ls_param") [Type]
tps
  [Param Type]
as_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"as_param") [Type]
tps
  [Param Type]
rs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"rs_param") [Type]
tps
  let map_params :: [Param Type]
map_params = [Param Type]
ls_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
as_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rs_params
  Lambda SOACS
lam_map <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
map_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f" (Exp SOACS -> ADM [SubExpRes]) -> Exp SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
        Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
          SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$
            Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam'

  ([VName], Lambda SOACS) -> ADM ([VName], Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
as_params, Lambda SOACS
lam_map)

mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout VName
is SubExp
n SubExp
w = do
  Param Type
par_is <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"is" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
is'_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_is] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"is'"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_is))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_is)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)

  String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"is'" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName -> [VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
is) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
is'_lam

multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
dst VName
is [VName]
vs = do
  [Type]
tps <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
  Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  [Param Type]
scatter_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"scatter_param" (Type -> ADM (Param Type))
-> (Type -> Type) -> Type -> ADM (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
tps
  Lambda SOACS
scatter_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_i Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
scatter_params) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scatter_map_res") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        [Exp SOACS]
p1 <- Int -> ADM (Exp SOACS) -> ADM [Exp SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
scatter_params) (ADM (Exp SOACS) -> ADM [Exp SOACS])
-> ADM (Exp SOACS) -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_i
        [Exp SOACS]
p2 <- (Param Type -> ADM (Exp SOACS)) -> [Param Type] -> ADM [Exp SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scatter_params
        [Exp SOACS] -> ADM [Exp SOACS]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Exp SOACS] -> ADM [Exp SOACS]) -> [Exp SOACS] -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ [Exp SOACS]
p1 [Exp SOACS] -> [Exp SOACS] -> [Exp SOACS]
forall a. Semigroup a => a -> a -> a
<> [Exp SOACS]
p2

  String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scatter_res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
    SubExp
-> [VName] -> Lambda SOACS -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
n (VName
is VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs) Lambda SOACS
scatter_lam ([(Shape, Int, VName)] -> SOAC SOACS)
-> [(Shape, Int, VName)] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$
      (Type -> VName -> (Shape, Int, VName))
-> [Type] -> [VName] -> [(Shape, Int, VName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t -> (,,) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) Int
1) [Type]
tps [VName]
dst

multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
vs [DimIndex SubExp]
s = do
  (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ( \VName
x -> do
        Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"sorted" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
s)
    )
    [VName]
vs

--
-- special case of histogram with min/max as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst minmax ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let (x, x_inds) = zip vs (iota n)
--                       |> reduce_by_index (dst_cpy,-1s) argminmax (ne,-1) is
--
-- Reverse sweep:
--     dst_bar += map2 (\i b -> if i == -1
--                              then b
--                              else 0
--                     ) x_inds x_bar

--     vs_ctrbs = map2 (\i b -> if i == -1
--                              then 0
--                              else vs_bar[i] + b
--                     ) x_inds x_bar
--     vs_bar <- scatter vs_bar x_inds vs_ctrbs
diffMinMaxHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
minmax SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
  Type
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_elm_type :: PrimType
vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  let vs_dims :: [SubExp]
vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let inner_dims :: [SubExp]
inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims
  let nr_dims :: Int
nr_dims = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims
  Type
dst_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
  let dst_dims :: [SubExp]
dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type

  VName
dst_cpy <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)

  Param Type
acc_v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
acc_i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
hist_lam_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
acc_v_p, Param Type
LParam (Rep ADM)
acc_i_p, Param Type
LParam (Rep ADM)
v_p, Param Type
LParam (Rep ADM)
i_p] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
                BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
              ]
          )
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
              [ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
                      (PrimType -> CmpOp
CmpEq PrimType
t)
                      (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
                      (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
                  )
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
              ]
          )
  Lambda SOACS
hist_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
int64, PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
hist_lam_inner

  VName
dst_minus_ones <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"minus_ones" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
  SubExp
ne_minus_ones <-
    String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minus_ones" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
  VName
iota_n <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  VName
inp_iota <- do
    if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
      then VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
iota_n
      else do
        Param Type
i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
        Lambda SOACS
lam <-
          [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
            ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
              Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i

        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"res" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

  let hist_op :: HistOp SOACS
hist_op = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy, VName
dst_minus_ones] [SubExp
ne, if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1) else SubExp
ne_minus_ones] Lambda SOACS
hist_lam
  Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
  VName
x_inds <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inds")
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
    [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_inds] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
      Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs, VName
inp_iota] [HistOp SOACS
hist_op] Lambda SOACS
f'

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  Param Type
x_ind_dst <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
x_bar_dst <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
dst_lam_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
x_ind_dst, Param Type
LParam (Rep ADM)
x_bar_dst] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"dst_bar"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x_ind_dst) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
x_bar_dst)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
  Lambda SOACS
dst_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
vs_elm_type] Lambda SOACS
dst_lam_inner

  VName
dst_bar <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
x_inds, VName
x_bar] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
dst_lam)

  VName -> VName -> ADM ()
updateAdj VName
dst VName
dst_bar

  VName
vs_bar <- VName -> ADM VName
lookupAdjVal VName
vs

  [VName]
inds' <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"inds" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM [VName]) -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
inner_dims []
  let inds :: [VName]
inds = VName
x_inds VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds'

  [Param Type]
par_x_ind_vs <- Int -> ADM (Param Type) -> ADM [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nr_dims (ADM (Param Type) -> ADM [Param Type])
-> ADM (Param Type) -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
par_x_bar_vs <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
vs_lam_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_x_bar_vs Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
par_x_ind_vs) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName) -> Param Type -> VName
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head [Param Type]
par_x_ind_vs) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
              ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                SubExp
vs_bar_i <-
                  String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
vs_bar String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_el") (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
                    VName -> Slice SubExp -> BasicOp
Index VName
vs_bar (Slice SubExp -> BasicOp)
-> ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp]
-> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$
                      (Param Type -> DimIndex SubExp)
-> [Param Type] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param Type -> SubExp) -> Param Type -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
par_x_ind_vs
                BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpPlus PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_x_bar_vs) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
vs_bar_i)
          )
  Lambda SOACS
vs_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims (PrimType
vs_elm_type PrimType -> [PrimType] -> [PrimType]
forall a. a -> [a] -> [a]
: Int -> PrimType -> [PrimType]
forall a. Int -> a -> [a]
replicate Int
nr_dims PrimType
int64) Lambda SOACS
vs_lam_inner

  VName
vs_bar_p <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_partial") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
x_bar VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
vs_lam)

  SubExp
q <-
    String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"q"
      (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> SubExp -> [SubExp] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dst_dims

  [VName]
scatter_inps <- do
    -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p]
    -- ToDo: Cosmin asks: is the below the correct translation of the line above?
    (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flat" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
q])) ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
      [VName]
inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
vs_bar_p]

  Lambda SOACS
f'' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
nr_dims (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t]
  VName
vs_bar' <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp
-> [VName] -> Lambda SOACS -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
q [VName]
scatter_inps Lambda SOACS
f'' [([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims, Int
1, VName
vs_bar)]
  VName -> VName -> ADM ()
insAdj VName
vs VName
vs_bar'
  where
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
    mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [] [SubExp]
_ = [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    mk_indices [SubExp
d] [SubExp]
iotas = do
      [VName]
reps <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rep" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
d])) [SubExp]
iotas
      VName
iota_d <-
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [VName]
reps [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
iota_d]
    mk_indices (SubExp
d : [SubExp]
dims) [SubExp]
iotas = do
      VName
iota_d <-
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

      Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      Lambda SOACS
lam <-
        [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
          ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
            [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
dims ([SubExp] -> ADM [VName]) -> [SubExp] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
              [SubExp]
iotas [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]

      String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [VName
iota_d] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam

--
-- special case of histogram with multiplication as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (*) ne is vs
-- Forward sweep:
--     dst does not need to be copied: dst is not overwritten
--     let (ps, zs) = map (\v -> if v == 0 then (1,1) else (v,0)) vs
--     let non_zero_prod = reduce_by_index nes (*) ne is ps
--     let zero_count = reduce_by_index 0s (+) 0 is zs
--     let h_part = map2 (\p c -> if c == 0 then p else 0
--                       ) non_zero_prod zero_count
--     let x = map2 (*) dst h_part
--
-- Reverse sweep:
--     dst_bar += map2 (*) h_part x_bar

--     let part_bar = map2 (*) dst x_bar
--     vs_bar += map2 (\i v -> let zr_cts = zero_count[i]
--                             let pr_bar = part_bar[i]
--                             let nz_prd = non_zero_prod[i]
--                             in if zr_cts == 0
--                             then pr_bar * (nz_prd / v)
--                             else if zr_cts == 1 and v == 0
--                             then nz_prd * pr_bar
--                             else 0
--                    ) is vs
diffMulHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMulHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
mul SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
  Type
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
  let vs_dims :: [SubExp]
vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
  let vs_elm_type :: PrimType
vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
  Type
dst_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
  let dst_dims :: [SubExp]
dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
  let inner_dims :: [SubExp]
inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims

  Param Type
v_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Lambda SOACS
lam_ps_zs_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
v_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
t, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1])
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param, SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
  Lambda SOACS
lam_ps_zs <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
vs_dims [PrimType
vs_elm_type] Lambda SOACS
lam_ps_zs_inner
  [SubExpRes]
ps_zs_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_ps_zs [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vs]
  [VName]
ps_zs <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"ps_zs" [SubExpRes]
ps_zs_res
  let [VName
ps, VName
zs] = [VName]
ps_zs

  Lambda SOACS
lam_mul_inner <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
lam_mul <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner
  VName
nz_prods0 <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nz_prd" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
ne
  let hist_nzp :: HistOp SOACS
hist_nzp = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
nz_prods0] [SubExp
ne] Lambda SOACS
lam_mul

  Lambda SOACS
lam_add_inner <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) PrimType
int64
  Lambda SOACS
lam_add <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
int64] Lambda SOACS
lam_add_inner
  VName
zr_counts0 <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"zr_cts" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  SubExp
zrn_ne <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zr_ne" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  let hist_zrn :: HistOp SOACS
hist_zrn = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
zr_counts0] [if [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 else SubExp
zrn_ne] Lambda SOACS
lam_add

  Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
  VName
nz_prods <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"non_zero_prod"
  VName
zr_counts <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zero_count"
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
    [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
nz_prods, VName
zr_counts] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
      Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
is, VName
ps, VName
zs] [HistOp SOACS
hist_nzp, HistOp SOACS
hist_zrn] Lambda SOACS
f'

  Param Type
p_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"prod" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  Param Type
c_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"count" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
lam_h_part_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
p_param, Param Type
LParam (Rep ADM)
c_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"h_part"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
c_param))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p_param)
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
  Lambda SOACS
lam_h_part <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
lam_h_part_inner
  [SubExpRes]
h_part_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_h_part ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
nz_prods, VName
zr_counts]
  [VName]
h_part' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part" [SubExpRes]
h_part_res
  let [VName
h_part] = [VName]
h_part'

  Lambda SOACS
lam_mul_inner' <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
  Lambda SOACS
lam_mul' <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner'
  [SubExpRes]
x_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
h_part]
  [VName]
x' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"x" [SubExpRes]
x_res
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
x'

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  Lambda SOACS
lam_mul'' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
  [SubExpRes]
dst_bar_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul'' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
h_part, VName
x_bar]
  [VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") [SubExpRes]
dst_bar_res
  VName -> VName -> ADM ()
updateAdj VName
dst (VName -> ADM ()) -> VName -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
dst_bar

  Lambda SOACS
lam_mul''' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
  [SubExpRes]
part_bar_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul''' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
x_bar]
  [VName]
part_bar' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"part_bar" [SubExpRes]
part_bar_res
  let [VName
part_bar] = [VName]
part_bar'

  [Param Type]
inner_params <- (String -> Type -> ADM (Param Type))
-> [String] -> [Type] -> ADM [Param Type]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"zr_cts", String
"pr_bar", String
"nz_prd", String
"a"] ([Type] -> ADM [Param Type]) -> [Type] -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ (PrimType -> Type) -> [PrimType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
int64, PrimType
t, PrimType
t, PrimType
t]
  let [Param Type
zr_cts, Param Type
pr_bar, Param Type
nz_prd, Param Type
a_param] = [Param Type]
inner_params
  Lambda SOACS
lam_vsbar_inner <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
inner_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpDiv PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
              ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$
                ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                  ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                      BinOp
LogAnd
                      (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
                      (CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
                  )
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar))
                  ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
          )

  Lambda SOACS
lam_vsbar_middle <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
t, PrimType
t, PrimType
t] Lambda SOACS
lam_vsbar_inner

  Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Param Type
a_param' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type
  Lambda SOACS
lam_vsbar <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param, Param Type
LParam (Rep ADM)
a_param'] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param))
          ( ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
              let i :: Slice SubExp
i = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
              [VName]
names <- (String -> ADM VName) -> [String] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"zr_cts", String
"pr_bar", String
"nz_prd"]
              (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
name -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
names [VName
zr_counts, VName
part_bar, VName
nz_prods]
              Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_vsbar_middle ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
names [ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)]
forall a. Semigroup a => a -> a -> a
<> [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param']
          )
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep ADM)
forall rep. Type -> Exp rep
zeroExp (Type -> Exp (Rep ADM)) -> Type -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
dst_type)

  VName
vs_bar <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is, VName
vs] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar

  VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar

--
-- special case of histogram with add as operator.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--     let x = reduce_by_index dst (+) ne is vs
-- Forward sweep:
--     need to copy dst: reverse sweep might use it 7
--       (see ex. in reducebyindexminmax6.fut where the first map requires the original dst to be differentiated).
--     let dst_cpy = copy dst
--     let x = reduce_by_index dst_cpy (+) ne is vs
--
-- Reverse sweep:
--     dst_bar += x_bar
--
--     vs_bar += map (\i -> x_bar[i]) is
diffAddHist ::
  VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffAddHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
add SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
  let t :: Type
t = Param Type -> Type
forall dec. Param dec -> dec
paramDec (Param Type -> Type) -> Param Type -> Type
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
add

  VName
dst_cpy <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)

  Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type
t]
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
    SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs] [Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy] [SubExp
ne] Lambda SOACS
add] Lambda SOACS
f

  ADM ()
m

  VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x

  VName -> VName -> ADM ()
updateAdj VName
dst VName
x_bar

  Type
x_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
  Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_i") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param
  Lambda SOACS
lam_vsbar <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, VName
i))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x_bar (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
x_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i])
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
ne)

  VName
vs_bar <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar
  VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar

--
-- a step in the radix sort implementation
-- it assumes the key we are sorting
-- after is [n]i64 and it is the first VName
--
-- local def radix_sort_step [n] 't (xs: [n]t) (get_bit: i32 -> t -> i32)
--                                  (digit_n: i32): [n]t =
--   let num x = get_bit (digit_n+1) x * 2 + get_bit digit_n x
--   let pairwise op (a1,b1,c1,d1) (a2,b2,c2,d2) =
--     (a1 `op` a2, b1 `op` b2, c1 `op` c2, d1 `op` d2)
--   let bins = xs |> map num
--   let flags = bins |> map (\x -> if x == 0 then (1,0,0,0)
--                                  else if x == 1 then (0,1,0,0)
--                                  else if x == 2 then (0,0,1,0)
--                                  else (0,0,0,1))
--   let offsets = scan (pairwise (+)) (0,0,0,0) flags
--   let (na,nb,nc,_nd) = last offsets
--   let f bin (a,b,c,d) = match bin
--                         case 0 -> a-1
--                         case 1 -> na+b-1
--                         case 2 -> na+nb+c-1
--                         case _ -> na+nb+nc+d-1
--   let is = map2 f bins offsets
--   in scatter scratch is xs
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep [VName]
xs [Type]
tps SubExp
bit SubExp
n SubExp
w = do
  -- let is = head xs
  VName
is <- VName -> SubExp -> SubExp -> ADM VName
mapout ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs) SubExp
n SubExp
w

  Param Type
num_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"num" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
num_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
num_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"num_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
          (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)
          ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> BinOp
And IntType
Int64)
              (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit))
              (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
          )
          ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
              (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
2)
              ( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
                  (IntType -> BinOp
And IntType
Int64)
                  (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)))
                  (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
              )
          )

  VName
bins <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"bins" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
num_lam
  Param Type
flag_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flag" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
flag_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
flag_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
          PrimType
int64
          ((Integer -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> [Integer] -> [(ADM (Exp SOACS), ADM (Exp SOACS))]
forall a b. (a -> b) -> [a] -> [b]
map ((,) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
flag_param) (ADM (Exp SOACS) -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> (Integer -> ADM (Exp SOACS))
-> Integer
-> (ADM (Exp SOACS), ADM (Exp SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
          ((Integer -> ADM (Body SOACS)) -> [Integer] -> [ADM (Body SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> (Integer -> [ADM (Exp SOACS)]) -> Integer -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> ADM (Exp SOACS)) -> [Integer] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst ([Integer] -> [ADM (Exp SOACS)])
-> (Integer -> [Integer]) -> Integer -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Integer
i -> (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
j -> if Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
j then Integer
1 else Integer
0) [Integer
0 .. Integer
3])) ([Integer
0 .. Integer
3] :: [Integer]))

  [VName]
flags <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flags" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
bins] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam

  [Param Type]
scan_params <- (String -> ADM (Param Type)) -> [String] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((String -> Type -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (Type -> String -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"a1", String
"b1", String
"c1", String
"d1", String
"a2", String
"b2", String
"c2", String
"d2"]
  Lambda SOACS
scan_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
scan_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_res") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        ([ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> ADM [Exp SOACS])
-> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((ADM (Exp SOACS) -> ADM (Exp SOACS) -> ADM (Exp SOACS))
-> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> ADM [Exp SOACS]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (BinOp
 -> ADM (Exp (Rep ADM))
 -> ADM (Exp (Rep ADM))
 -> ADM (Exp (Rep ADM)))
-> BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)) (([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS])
-> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ Int -> [ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
4 ([ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]))
-> [ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)])
forall a b. (a -> b) -> a -> b
$ (Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scan_params

  ScremaForm SOACS
scan <- [Scan SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC ([Scan SOACS] -> ADM (ScremaForm SOACS))
-> [Scan SOACS] -> ADM (ScremaForm SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [Scan SOACS]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> [Scan SOACS]) -> Scan SOACS -> [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_lam ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ (Integer -> SubExp) -> [Integer] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Integer -> SubExp
intConst IntType
Int64) [Integer
0, Integer
0, Integer
0, Integer
0]
  [VName]
offsets <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"offsets" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName]
flags ScremaForm SOACS
scan

  SubExp
ind <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ind_last" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
n) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
  let i :: Slice SubExp
i = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
ind]
  [VName]
nabcd <- (String -> ADM VName) -> [String] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"na", String
"nb", String
"nc", String
"nd"]
  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
abcd -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
abcd] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
nabcd [VName]
offsets

  let vars :: [SubExp]
vars = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nabcd
  [Param Type]
map_params <- (String -> ADM (Param Type)) -> [String] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((String -> Type -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (Type -> String -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"bin", String
"a", String
"b", String
"c", String
"d"]
  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
map_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
          PrimType
int64
          ((Integer -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> [Integer] -> [(ADM (Exp SOACS), ADM (Exp SOACS))]
forall a b. (a -> b) -> [a] -> [b]
map ((,) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam (Param Type -> ADM (Exp (Rep ADM)))
-> Param Type -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head [Param Type]
map_params) (ADM (Exp SOACS) -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> (Integer -> ADM (Exp SOACS))
-> Integer
-> (ADM (Exp SOACS), ADM (Exp SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
          ( (Int -> Param Type -> ADM (Body SOACS))
-> [Int] -> [Param Type] -> [ADM (Body SOACS)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \Int
j Param Type
p ->
                  [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                    ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                      SubExp
t <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"t" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
                      BinOp -> SubExp -> [SubExp] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp
t SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
j [SubExp]
vars)
              )
              [Int
0 .. Int
3]
              ([Param Type] -> [Param Type]
forall a. HasCallStack => [a] -> [a]
tail [Param Type]
map_params)
          )

  VName
nis <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nis" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName
bins VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
offsets) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  [VName]
scatter_dst <- (Type -> ADM VName) -> [Type] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
tps
  SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
nis [VName]
xs
  where
    iConst :: Integer -> m (Exp (Rep m))
iConst Integer
c = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
c

--
-- the radix sort implementation
-- def radix_sort [n] 't (xs: [n]i64) =
--   let iters = if n == 0 then 0 else 32
--   in loop xs for i < iters do radix_sort_step xs i64.get_bit (i*2)
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [VName]
xs SubExp
n SubExp
w = do
  SubExp
logw <- SubExp -> ADM SubExp
log2 (SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"w1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
  -- ceil logw by (logw + 1) / 2
  SubExp
iters <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"iters" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 SubExp
logw TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~/~ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2)))

  [Type]
types <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
xs
  [Param (TypeBase Shape Uniqueness)]
params <- (VName -> Type -> ADM (Param (TypeBase Shape Uniqueness)))
-> [VName] -> [Type] -> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x) (TypeBase Shape Uniqueness
 -> ADM (Param (TypeBase Shape Uniqueness)))
-> (Type -> TypeBase Shape Uniqueness)
-> Type
-> ADM (Param (TypeBase Shape Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> Type -> TypeBase Shape Uniqueness
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) [VName]
xs [Type]
types
  VName
i <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
  Body SOACS
loopbody <- ADM [SubExpRes] -> ADM (Body (Rep ADM))
ADM [SubExpRes] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body SOACS))
-> (ADM [SubExpRes] -> ADM [SubExpRes])
-> ADM [SubExpRes]
-> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM [SubExpRes] -> ADM [SubExpRes]
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) (ADM [SubExpRes] -> ADM (Body SOACS))
-> ADM [SubExpRes] -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ do
      SubExp
bit <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bit" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2)
      [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
params) [Type]
types SubExp
bit SubExp
n SubExp
w

  String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
    [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop
      ([FParam SOACS] -> [SubExp] -> [(FParam SOACS, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
[FParam SOACS]
params ([SubExp] -> [(FParam SOACS, SubExp)])
-> [SubExp] -> [(FParam SOACS, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
xs)
      (VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
Int64 SubExp
iters)
      Body SOACS
loopbody
  where
    log2 :: SubExp -> ADM SubExp
    log2 :: SubExp -> ADM SubExp
log2 SubExp
m = do
      [Param (TypeBase Shape Uniqueness)]
params <- (String
 -> TypeBase Shape Uniqueness
 -> ADM (Param (TypeBase Shape Uniqueness)))
-> [String]
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"cond", String
"r", String
"i"] ([TypeBase Shape Uniqueness]
 -> ADM [Param (TypeBase Shape Uniqueness)])
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ (PrimType -> TypeBase Shape Uniqueness)
-> [PrimType] -> [TypeBase Shape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
Bool, PrimType
int64, PrimType
int64]
      let [Param (TypeBase Shape Uniqueness)
cond, Param (TypeBase Shape Uniqueness)
r, Param (TypeBase Shape Uniqueness)
i] = [Param (TypeBase Shape Uniqueness)]
params

      Body SOACS
body <- ADM [SubExpRes] -> ADM (Body (Rep ADM))
ADM [SubExpRes] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body SOACS))
-> (ADM [SubExpRes] -> ADM [SubExpRes])
-> ADM [SubExpRes]
-> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM [SubExpRes] -> ADM [SubExpRes]
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) (ADM [SubExpRes] -> ADM (Body SOACS))
-> ADM [SubExpRes] -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
        SubExp
r' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"r'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
r) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TPrimExp Int64 VName
1)
        SubExp
cond' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
r' TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        SubExp
i' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
        [SubExpRes] -> ADM [SubExpRes]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExpRes] -> ADM [SubExpRes]) -> [SubExpRes] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExpRes]
subExpsRes [SubExp
cond', SubExp
r', SubExp
i']

      SubExp
cond_init <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"test" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
m TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)

      [SubExp]
l <-
        String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"log2res" (Exp (Rep ADM) -> ADM [SubExp]) -> Exp (Rep ADM) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$
          [(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop
            ([Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
params [SubExp
cond_init, SubExp
m, PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
int64])
            (VName -> LoopForm
WhileLoop (VName -> LoopForm) -> VName -> LoopForm
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
cond)
            Body SOACS
body

      let [SubExp
_, SubExp
_, SubExp
res] = [SubExp]
l
      SubExp -> ADM SubExp
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
res

radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
xs SubExp
n SubExp
w = do
  VName
iota_n <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
      SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  [VName]
radres <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [[VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs, VName
iota_n] SubExp
n SubExp
w
  let [VName
is', VName
iota'] = [VName]
radres

  Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let slice :: [DimIndex SubExp]
slice = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
  Lambda SOACS
map_lam <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
xs) [DimIndex SubExp]
slice

  [VName]
sorted <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota'] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
  [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ VName
iota' VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: VName
is' VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
sorted

--
-- generic case of histogram.
-- Original, assuming `is: [n]i64` and `dst: [w]btp`
--   let xs = reduce_by_index dst odot ne is as
-- Forward sweep:
-- let h_part = reduce_by_index (replicate w ne) odot ne is as
-- let xs = map2 odot dst h_part
-- Reverse sweep:
-- h_part_bar += f'' dst h_part
-- dst_bar += f' dst h_part

-- let flag = map (\i -> i == 0 || sis[i] != sis[i-1]) (iota n)
-- let flag_rev = map (\i -> i==0 || flag[n-i]) (iota n)
-- let ls = seg_scan_exc odot ne flag sas
-- let rs = reverse sas |>
--          seg_scan_exc odot ne flag_rev |> reverse
-- let f_bar = map (\i -> if i < w && -1 < w
--                        then h_part_bar[i]
--                        else 0s
--                 ) sis
-- let sas_bar = f f_dst ls sas rs
-- as_bar += scatter (Scratch alpha n) siota sas_bar
-- Where:
--  siota: 'iota n' sorted wrt 'is'
--  sis: 'is' sorted wrt 'is'
--  sas: 'as' sorted wrt 'is'
--  f'' = vjpLambda xs_bar h_part (map2 odot)
--  f' = vjpLambda xs_bar dst (map2 odot)
--  f  = vjpLambda f_bar sas (map4 (\di li ai ri -> di odot li odot ai odot ri))
--  0s is an alpha-dimensional array with 0 (possibly 0-dim)
diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM ()
diffHist :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops [VName]
xs StmAux ()
aux SubExp
n Lambda SOACS
lam0 [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m = do
  [Type]
as_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> ADM [Type]) -> [VName] -> ADM [Type]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as
  [Type]
dst_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
dst

  [VName]
nes <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"new_dst" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w)) [SubExp]
ne

  Lambda SOACS
h_map <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
as_type
  [VName]
h_part <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ADM VName) -> (VName -> String) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> String -> String) -> String -> String -> String
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> String -> String
forall a. Semigroup a => a -> a -> a
(<>) String
"_h_part" (String -> String) -> (VName -> String) -> VName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString) [VName]
xs
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
h_part (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
    SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName]
as [Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
w) SubExp
rf [VName]
nes [SubExp]
ne Lambda SOACS
lam0] Lambda SOACS
h_map

  Lambda SOACS
lam0' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
xs (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
    SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma ([SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w) ([VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam0')

  ADM ()
m

  [VName]
xs_bar <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM VName
lookupAdjVal [VName]
xs

  ([VName]
dst_params, [VName]
hp_params, Lambda SOACS
f') <- Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam0 [Type]
dst_type (SubExp -> ADM ([VName], [VName], Lambda SOACS))
-> SubExp -> ADM ([VName], [VName], Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w
  Lambda SOACS
f'_adj_dst <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
dst_params Lambda SOACS
f'
  Lambda SOACS
f'_adj_hp <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
hp_params Lambda SOACS
f'

  [SubExpRes]
dst_bar' <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f'_adj_dst ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
  [VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"dst_bar" [SubExpRes]
dst_bar'
  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
dst [VName]
dst_bar

  [SubExpRes]
h_part_bar' <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f'_adj_hp ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
  [VName]
h_part_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part_bar" [SubExpRes]
h_part_bar'

  Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0

  -- is' <- mapout (head as) n (head w)
  -- sorted <- radixSort' (is' : tail as) n $ head w
  [VName]
sorted <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
as SubExp
n (SubExp -> ADM [VName]) -> SubExp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w
  let siota :: VName
siota = [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
sorted
  let sis :: VName
sis = [VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
sorted
  let sas :: [VName]
sas = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop Int
2 [VName]
sorted

  VName
iota_n <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
flag_lam <- LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam Param Type
LParam SOACS
par_i VName
sis
  VName
flag <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flag" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam

  -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n)
  Param Type
par_i' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i' :: VName
i' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'
  Lambda SOACS
g_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i'] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_inps") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
        SubExp
im1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i_1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
        SubExp
nmi <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i')
        let s1 :: [DimIndex SubExp]
s1 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
im1]
        let s2 :: [DimIndex SubExp]
s2 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmi]

        -- flag array for left scan
        SubExp
f1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"f1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i']

        -- array for left scan
        [SubExp]
r1 <-
          String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r1"
            (Exp SOACS -> ADM [SubExp]) -> ADM (Exp SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
f1)
              ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [SubExp]
ne)
              ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s1)

        -- array for right scan inc flag
        [SubExp]
r2 <-
          String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r2"
            (Exp SOACS -> ADM [SubExp]) -> ADM (Exp SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i' TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
              ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne)
              ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                  ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                    ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                      (Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
s2)
                      ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne)
                      ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp SOACS)])
-> ([VName] -> [SubExp]) -> [VName] -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimValue -> SubExp
Constant (PrimType -> PrimValue
blankPrimValue PrimType
Bool) :) ([SubExp] -> [SubExp])
-> ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var
                          ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s2
                      )
              )

        (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> ADM [Exp SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> ADM [Exp SOACS]) -> [SubExp] -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ SubExp
f1 SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
r1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
r2

  -- scan (\(f1,v1) (f2,v2) ->
  --   let f = f1 || f2
  --   let v = if f2 then v2 else g v1 v2
  --   in (f,v) ) (false,ne) (zip flags vals)
  [Lambda SOACS]
scan_lams <-
    (Lambda SOACS -> ADM (Lambda SOACS))
-> [Lambda SOACS] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
      ( \Lambda SOACS
l -> do
          Param Type
f1 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f1" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
          Param Type
f2 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f2" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
          [Param Type]
ps <- Lambda SOACS -> [Param Type]
Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda SOACS -> [Param Type])
-> ADM (Lambda SOACS) -> ADM [Param Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
          let ([Param Type]
p1, [Param Type]
p2) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne) [Param Type]
ps

          [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
f1 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
p1 [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Param Type
f2 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
p2) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
            ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
              let f :: ADM (Exp (Rep ADM))
f = BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
LogOr (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
              ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
                ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM))
ADM (Exp SOACS)
f ADM (Exp SOACS) -> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)]
forall a. a -> [a] -> [a]
: (Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
p2)
                ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ADM (Exp (Rep ADM))
ADM (Exp SOACS)
f :) ([ADM (Exp SOACS)] -> [ADM (Exp SOACS)])
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var)
                    ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"gres"
                    ([SubExpRes] -> ADM [VName]) -> ADM [SubExpRes] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
l ((Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
ps)
                )
      )
      [Lambda SOACS
lam, Lambda SOACS
lam']

  let ne' :: [SubExp]
ne' = PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne

  [VName]
scansres <-
    String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
      SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] ([Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC ((Lambda SOACS -> Scan SOACS) -> [Lambda SOACS] -> [Scan SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
`Scan` [SubExp]
ne') [Lambda SOACS]
scan_lams) Lambda SOACS
g_lam)

  let (VName
_ : [VName]
ls_arr, VName
_ : [VName]
rs_arr_rev) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [VName]
scansres

  -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis
  Param Type
par_i'' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i'' :: VName
i'' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i''
  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i''] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res"
        (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w, VName
i''))
          ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
h_part_bar [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i''])
          ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
              (Type -> ADM (Exp SOACS)) -> [Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (\Type
t -> Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) [Type]
as_type
          )

  [VName]
f_bar <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"f_bar" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
sis] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

  ([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam0 [Type]
as_type SubExp
n
  Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
f_bar) [VName]
as_params Lambda SOACS
f

  -- map (\i -> rs_arr_rev[n-i-1]) (iota n)
  Param Type
par_i''' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i''' :: VName
i''' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'''
  Lambda SOACS
rev_lam <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i'''] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    SubExp
nmim1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i_1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i''' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
    [VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
rs_arr_rev [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmim1]

  [VName]
rs_arr <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"rs_arr" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam

  [VName]
sas_bar <-
    String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"sas_bar"
      ([SubExpRes] -> ADM [VName]) -> ADM [SubExpRes] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f_adj ((VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
ls_arr [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
sas [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
rs_arr)

  [VName]
scatter_dst <- (Type -> ADM VName) -> [Type] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
as_type
  [VName]
as_bar <- SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
siota [VName]
sas_bar

  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as) [VName]
as_bar
  where
    -- map (\i -> if i == 0 then true else is[i] != is[i-1]) (iota n)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
    mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam LParam SOACS
par_i VName
sis =
      [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [LParam (Rep ADM)
LParam SOACS
par_i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
        ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
          let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
par_i
          ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0))
            ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
Bool)
            ( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
                  VName
i_p <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_p" (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
                  [VName]
vs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"vs" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
sis (Slice SubExp -> BasicOp)
-> (VName -> Slice SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> (VName -> [DimIndex SubExp]) -> VName -> Slice SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> [DimIndex SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimIndex SubExp -> [DimIndex SubExp])
-> (VName -> DimIndex SubExp) -> VName -> [DimIndex SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
i, VName
i_p]
                  let [VName
vs_i, VName
vs_p] = [VName]
vs
                  TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
vs_i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
vs_p
            )