{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Hist
( diffMinMaxHist,
diffMulHist,
diffAddHist,
diffHist,
)
where
import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus :: PrimType -> BinOp
getBinOpPlus (IntType IntType
x) = IntType -> Overflow -> BinOp
Add IntType
x Overflow
OverflowUndef
getBinOpPlus (FloatType FloatType
f) = FloatType -> BinOp
FAdd FloatType
f
getBinOpPlus PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpMul, Hist.hs: input not supported"
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv :: PrimType -> BinOp
getBinOpDiv (IntType IntType
t) = IntType -> Safety -> BinOp
SDiv IntType
t Safety
Unsafe
getBinOpDiv (FloatType FloatType
t) = FloatType -> BinOp
FDiv FloatType
t
getBinOpDiv PrimType
_ = String -> BinOp
forall a. HasCallStack => String -> a
error String
"In getBinOpDiv, Hist.hs: input not supported"
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = PrimExp VName -> TPrimExp Bool VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Bool VName)
-> PrimExp VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis
elseIf :: PrimType -> [(ADM (Exp SOACS), ADM (Exp SOACS))] -> [ADM (Body SOACS)] -> ADM (Exp SOACS)
elseIf :: PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2)] [ADM (Body SOACS)
bt, ADM (Body SOACS)
bf] =
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c1 ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c2)
ADM (Body (Rep ADM))
ADM (Body SOACS)
bt
ADM (Body (Rep ADM))
ADM (Body SOACS)
bf
elseIf PrimType
t ((ADM (Exp SOACS)
c1, ADM (Exp SOACS)
c2) : [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs) (ADM (Body SOACS)
bt : [ADM (Body SOACS)]
bs) =
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c1 ADM (Exp (Rep ADM))
ADM (Exp SOACS)
c2)
ADM (Body (Rep ADM))
ADM (Body SOACS)
bt
(ADM (Body (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp SOACS) -> [ADM (Exp SOACS)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
(ADM (Exp SOACS) -> [ADM (Exp SOACS)])
-> ADM (Exp SOACS) -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> a -> b
$ PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf PrimType
t [(ADM (Exp SOACS), ADM (Exp SOACS))]
cs [ADM (Body SOACS)]
bs
elseIf PrimType
_ [(ADM (Exp SOACS), ADM (Exp SOACS))]
_ [ADM (Body SOACS)]
_ = String -> ADM (Exp SOACS)
forall a. HasCallStack => String -> a
error String
"In elseIf, Hist.hs: input not supported"
bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes :: String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
s =
(SubExpRes -> ADM VName) -> [SubExpRes] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
( \(SubExpRes Certs
cs SubExp
se) -> do
VName
bn <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
s
Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
bn] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
bn
)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap :: [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [] [PrimType]
_ Lambda SOACS
lam = Lambda SOACS -> ADM (Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
lam
nestedmap s :: [SubExp]
s@(SubExp
h : [SubExp]
r) [PrimType]
pt Lambda SOACS
lam = do
[Param Type]
params <- (PrimType -> ADM (Param Type)) -> [PrimType] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\PrimType
tp -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
tp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
s) NoUniqueness
NoUniqueness) [PrimType]
pt
Lambda SOACS
body <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
r [PrimType]
pt Lambda SOACS
lam
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [SubExpRes]) -> SOAC SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
h ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
body)
mkF' :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' :: Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam [Type]
tps SubExp
n = do
Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
[Param Type]
ds_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ds_param") [Type]
tps
[Param Type]
hs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"hs_param") [Type]
tps
let ds_pars :: [VName]
ds_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
ds_params
let hs_pars :: [VName]
hs_pars = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
hs_params
Lambda SOACS
lam_map <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
ds_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
hs_params) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f'" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [SubExpRes]) -> SOAC SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
ds_pars [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
hs_pars) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam')
([VName], [VName], Lambda SOACS)
-> ADM ([VName], [VName], Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
ds_pars, [VName]
hs_pars, Lambda SOACS
lam_map)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF :: Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam [Type]
tps SubExp
n = do
Lambda SOACS
lam_l <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
Lambda SOACS
lam_r <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
let q :: Int
q = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
([Param Type]
lps, [Param Type]
aps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_l
([Param Type]
ips, [Param Type]
rps) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
q ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_r
Lambda SOACS
lam' <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda ([Param Type]
lps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
aps [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rps) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
[SubExpRes]
lam_l_res <- Body (Rep ADM) -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep ADM) -> ADM [SubExpRes])
-> Body (Rep ADM) -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_l
[(Param Type, SubExpRes)]
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExpRes] -> [(Param Type, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
ips [SubExpRes]
lam_l_res) (((Param Type, SubExpRes) -> ADM ()) -> ADM ())
-> ((Param Type, SubExpRes) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
ip, SubExpRes Certs
cs SubExp
se) ->
Certs -> ADM () -> ADM ()
forall a. Certs -> ADM a -> ADM a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ip] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
Body (Rep ADM) -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Body (Rep ADM) -> ADM [SubExpRes])
-> Body (Rep ADM) -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam_r
[Param Type]
ls_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"ls_param") [Type]
tps
[Param Type]
as_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"as_param") [Type]
tps
[Param Type]
rs_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"rs_param") [Type]
tps
let map_params :: [Param Type]
map_params = [Param Type]
ls_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
as_params [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
rs_params
Lambda SOACS
lam_map <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
map_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_f" (Exp SOACS -> ADM [SubExpRes]) -> Exp SOACS -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$
Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam'
([VName], Lambda SOACS) -> ADM ([VName], Lambda SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
as_params, Lambda SOACS
lam_map)
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout :: VName -> SubExp -> SubExp -> ADM VName
mapout VName
is SubExp
n SubExp
w = do
Param Type
par_is <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"is" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
is'_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_is] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"is'"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_is))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_is)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w)
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"is'" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName -> [VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
is) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
is'_lam
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter :: SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
dst VName
is [VName]
vs = do
[Type]
tps <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
[Param Type]
scatter_params <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"scatter_param" (Type -> ADM (Param Type))
-> (Type -> Type) -> Type -> ADM (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
tps
Lambda SOACS
scatter_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_i Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
scatter_params) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scatter_map_res") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
[Exp SOACS]
p1 <- Int -> ADM (Exp SOACS) -> ADM [Exp SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Param Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
scatter_params) (ADM (Exp SOACS) -> ADM [Exp SOACS])
-> ADM (Exp SOACS) -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_i
[Exp SOACS]
p2 <- (Param Type -> ADM (Exp SOACS)) -> [Param Type] -> ADM [Exp SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scatter_params
[Exp SOACS] -> ADM [Exp SOACS]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Exp SOACS] -> ADM [Exp SOACS]) -> [Exp SOACS] -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ [Exp SOACS]
p1 [Exp SOACS] -> [Exp SOACS] -> [Exp SOACS]
forall a. Semigroup a => a -> a -> a
<> [Exp SOACS]
p2
String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scatter_res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName] -> Lambda SOACS -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
n (VName
is VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs) Lambda SOACS
scatter_lam ([(Shape, Int, VName)] -> SOAC SOACS)
-> [(Shape, Int, VName)] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$
(Type -> VName -> (Shape, Int, VName))
-> [Type] -> [VName] -> [(Shape, Int, VName)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t -> (,,) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) Int
1) [Type]
tps [VName]
dst
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex :: [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
vs [DimIndex SubExp]
s = do
(VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
( \VName
x -> do
Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"sorted" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
s)
)
[VName]
vs
diffMinMaxHist ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMinMaxHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
minmax SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
minmax
Type
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
let vs_elm_type :: PrimType
vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
let vs_dims :: [SubExp]
vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
let inner_dims :: [SubExp]
inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims
let nr_dims :: Int
nr_dims = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims
Type
dst_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
let dst_dims :: [SubExp]
dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
VName
dst_cpy <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)
Param Type
acc_v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Param Type
acc_i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
v_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Param Type
i_p <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
hist_lam_inner <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
acc_v_p, Param Type
LParam (Rep ADM)
acc_i_p, Param Type
LParam (Rep ADM)
v_p, Param Type
LParam (Rep ADM)
i_p] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"idx_res"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
[ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p,
BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
SMin IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p)
]
)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
[ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp
(PrimType -> CmpOp
CmpEq PrimType
t)
(Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p)
(BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
minmax (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p))
)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
acc_i_p])
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_p, Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
i_p])
]
)
Lambda SOACS
hist_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
int64, PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
hist_lam_inner
VName
dst_minus_ones <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"minus_ones" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
SubExp
ne_minus_ones <-
String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"minus_ones" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1))
VName
iota_n <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
VName
inp_iota <- do
if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
then VName -> ADM VName
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
iota_n
else do
Param Type
i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"res" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
let hist_op :: HistOp SOACS
hist_op = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy, VName
dst_minus_ones] [SubExp
ne, if Int
nr_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1) else SubExp
ne_minus_ones] Lambda SOACS
hist_lam
Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
VName
x_inds <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inds")
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x, VName
x_inds] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs, VName
inp_iota] [HistOp SOACS
hist_op] Lambda SOACS
f'
ADM ()
m
VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x
Param Type
x_ind_dst <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
x_bar_dst <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Lambda SOACS
dst_lam_inner <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
x_ind_dst, Param Type
LParam (Rep ADM)
x_bar_dst] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"dst_bar"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
x_ind_dst) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
x_bar_dst)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
Lambda SOACS
dst_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
vs_elm_type] Lambda SOACS
dst_lam_inner
VName
dst_bar <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
x_inds, VName
x_bar] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
dst_lam)
VName -> VName -> ADM ()
updateAdj VName
dst VName
dst_bar
VName
vs_bar <- VName -> ADM VName
lookupAdjVal VName
vs
[VName]
inds' <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"inds" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM [VName]) -> ADM [VName] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
inner_dims []
let inds :: [VName]
inds = VName
x_inds VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds'
[Param Type]
par_x_ind_vs <- Int -> ADM (Param Type) -> ADM [Param Type]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nr_dims (ADM (Param Type) -> ADM [Param Type])
-> ADM (Param Type) -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_ind_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
par_x_bar_vs <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar_param") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Lambda SOACS
vs_lam_inner <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
par_x_bar_vs Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
par_x_ind_vs) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName) -> Param Type -> VName
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head [Param Type]
par_x_ind_vs) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. -TPrimExp Int64 VName
1)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
SubExp
vs_bar_i <-
String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
vs_bar String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_el") (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
vs_bar (Slice SubExp -> BasicOp)
-> ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp]
-> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> BasicOp) -> [DimIndex SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$
(Param Type -> DimIndex SubExp)
-> [Param Type] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (Param Type -> SubExp) -> Param Type -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
par_x_ind_vs
BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpPlus PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
par_x_bar_vs) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
vs_bar_i)
)
Lambda SOACS
vs_lam <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims (PrimType
vs_elm_type PrimType -> [PrimType] -> [PrimType]
forall a. a -> [a] -> [a]
: Int -> PrimType -> [PrimType]
forall a. Int -> a -> [a]
replicate Int
nr_dims PrimType
int64) Lambda SOACS
vs_lam_inner
VName
vs_bar_p <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_partial") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
x_bar VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
inds) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
vs_lam)
SubExp
q <-
String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"q"
(Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp -> SubExp -> [SubExp] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dst_dims
[VName]
scatter_inps <- do
(VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flat" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
q])) ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
[VName]
inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
vs_bar_p]
Lambda SOACS
f'' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
nr_dims (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t]
VName
vs_bar' <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp
-> [VName] -> Lambda SOACS -> [(Shape, Int, VName)] -> SOAC SOACS
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
q [VName]
scatter_inps Lambda SOACS
f'' [([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims, Int
1, VName
vs_bar)]
VName -> VName -> ADM ()
insAdj VName
vs VName
vs_bar'
where
mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices :: [SubExp] -> [SubExp] -> ADM [VName]
mk_indices [] [SubExp]
_ = [VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
mk_indices [SubExp
d] [SubExp]
iotas = do
[VName]
reps <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rep" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
d])) [SubExp]
iotas
VName
iota_d <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [VName]
reps [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
iota_d]
mk_indices (SubExp
d : [SubExp]
dims) [SubExp]
iotas = do
VName
iota_d <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
d (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [SubExp] -> ADM [VName]
mk_indices [SubExp]
dims ([SubExp] -> ADM [VName]) -> [SubExp] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
[SubExp]
iotas [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
d [VName
iota_d] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
diffMulHist ::
VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffMulHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n BinOp
mul SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
mul
Type
vs_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
let vs_dims :: [SubExp]
vs_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
vs_type
let vs_elm_type :: PrimType
vs_elm_type = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
vs_type
Type
dst_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
dst
let dst_dims :: [SubExp]
dst_dims = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dst_type
let inner_dims :: [SubExp]
inner_dims = [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail [SubExp]
vs_dims
Param Type
v_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"v" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Lambda SOACS
lam_ps_zs_inner <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
v_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
t, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1])
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
v_param, SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0])
Lambda SOACS
lam_ps_zs <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
vs_dims [PrimType
vs_elm_type] Lambda SOACS
lam_ps_zs_inner
[SubExpRes]
ps_zs_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_ps_zs [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vs]
[VName]
ps_zs <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"ps_zs" [SubExpRes]
ps_zs_res
let [VName
ps, VName
zs] = [VName]
ps_zs
Lambda SOACS
lam_mul_inner <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
Lambda SOACS
lam_mul <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner
VName
nz_prods0 <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nz_prd" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
ne
let hist_nzp :: HistOp SOACS
hist_nzp = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
nz_prods0] [SubExp
ne] Lambda SOACS
lam_mul
Lambda SOACS
lam_add_inner <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) PrimType
int64
Lambda SOACS
lam_add <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
int64] Lambda SOACS
lam_add_inner
VName
zr_counts0 <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"zr_cts" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dst_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
SubExp
zrn_ne <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zr_ne" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
inner_dims) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
let hist_zrn :: HistOp SOACS
hist_zrn = Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
zr_counts0] [if [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 then IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0 else SubExp
zrn_ne] Lambda SOACS
lam_add
Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type, Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int64 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
vs_dims) NoUniqueness
NoUniqueness]
VName
nz_prods <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"non_zero_prod"
VName
zr_counts <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"zero_count"
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
nz_prods, VName
zr_counts] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$
Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
is, VName
ps, VName
zs] [HistOp SOACS
hist_nzp, HistOp SOACS
hist_zrn] Lambda SOACS
f'
Param Type
p_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"prod" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
Param Type
c_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"count" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
lam_h_part_inner <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
p_param, Param Type
LParam (Rep ADM)
c_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"h_part"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
0 TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
c_param))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p_param)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
Lambda SOACS
lam_h_part <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
int64] Lambda SOACS
lam_h_part_inner
[SubExpRes]
h_part_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_h_part ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
nz_prods, VName
zr_counts]
[VName]
h_part' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part" [SubExpRes]
h_part_res
let [VName
h_part] = [VName]
h_part'
Lambda SOACS
lam_mul_inner' <- BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
mul PrimType
t
Lambda SOACS
lam_mul' <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
dst_dims [PrimType
vs_elm_type, PrimType
vs_elm_type] Lambda SOACS
lam_mul_inner'
[SubExpRes]
x_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
h_part]
[VName]
x' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"x" [SubExpRes]
x_res
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
x'
ADM ()
m
VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x
Lambda SOACS
lam_mul'' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
[SubExpRes]
dst_bar_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul'' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
h_part, VName
x_bar]
[VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") [SubExpRes]
dst_bar_res
VName -> VName -> ADM ()
updateAdj VName
dst (VName -> ADM ()) -> VName -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
dst_bar
Lambda SOACS
lam_mul''' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam_mul'
[SubExpRes]
part_bar_res <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_mul''' ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
dst, VName
x_bar]
[VName]
part_bar' <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"part_bar" [SubExpRes]
part_bar_res
let [VName
part_bar] = [VName]
part_bar'
[Param Type]
inner_params <- (String -> Type -> ADM (Param Type))
-> [String] -> [Type] -> ADM [Param Type]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"zr_cts", String
"pr_bar", String
"nz_prd", String
"a"] ([Type] -> ADM [Param Type]) -> [Type] -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ (PrimType -> Type) -> [PrimType] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
int64, PrimType
t, PrimType
t, PrimType
t]
let [Param Type
zr_cts, Param Type
pr_bar, Param Type
nz_prd, Param Type
a_param] = [Param Type]
inner_params
Lambda SOACS
lam_vsbar_inner <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
inner_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (PrimType -> BinOp
getBinOpDiv PrimType
t) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
BinOp
LogAnd
(CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
zr_cts))
(CmpOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (PrimType -> CmpOp
CmpEq PrimType
t) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t) (ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM)))
-> ADM (Exp (Rep ADM)) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param)
)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
mul (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
nz_prd) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
pr_bar))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t)
)
Lambda SOACS
lam_vsbar_middle <- [SubExp] -> [PrimType] -> Lambda SOACS -> ADM (Lambda SOACS)
nestedmap [SubExp]
inner_dims [PrimType
int64, PrimType
t, PrimType
t, PrimType
t] Lambda SOACS
lam_vsbar_inner
Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Param Type
a_param' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"a" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
vs_type
Lambda SOACS
lam_vsbar <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param, Param Type
LParam (Rep ADM)
a_param'] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param))
( ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
let i :: Slice SubExp
i = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
[VName]
names <- (String -> ADM VName) -> [String] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"zr_cts", String
"pr_bar", String
"nz_prd"]
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
name -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
names [VName
zr_counts, VName
part_bar, VName
nz_prods]
Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
lam_vsbar_middle ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
names [ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)]
forall a. Semigroup a => a -> a -> a
<> [Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
a_param']
)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep ADM)
forall rep. Type -> Exp rep
zeroExp (Type -> Exp (Rep ADM)) -> Type -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
dst_type)
VName
vs_bar <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is, VName
vs] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar
VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar
diffAddHist ::
VjpOps -> VName -> StmAux () -> SubExp -> Lambda SOACS -> SubExp -> VName -> VName -> SubExp -> SubExp -> VName -> ADM () -> ADM ()
diffAddHist :: VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
_ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
add SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m = do
let t :: Type
t = Param Type -> Type
forall dec. Param dec -> dec
paramDec (Param Type -> Type) -> Param Type -> Type
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
add
VName
dst_cpy <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
dst String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (VName -> SubExp
Var VName
dst)
Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64, Type
t]
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
x] (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName
is, VName
vs] [Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
rf [VName
dst_cpy] [SubExp
ne] Lambda SOACS
add] Lambda SOACS
f
ADM ()
m
VName
x_bar <- VName -> ADM VName
lookupAdjVal VName
x
VName -> VName -> ADM ()
updateAdj VName
dst VName
x_bar
Type
x_type <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_i") (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param
Lambda SOACS
lam_vsbar <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"vs_bar"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
w, VName
i))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x_bar (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
x_type [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i])
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
ne)
VName
vs_bar <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
vs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_bar") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_vsbar
VName -> VName -> ADM ()
updateAdj VName
vs VName
vs_bar
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep :: [VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep [VName]
xs [Type]
tps SubExp
bit SubExp
n SubExp
w = do
VName
is <- VName -> SubExp -> SubExp -> ADM VName
mapout ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs) SubExp
n SubExp
w
Param Type
num_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"num" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
num_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
num_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"num_res"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)
( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> BinOp
And IntType
Int64)
(BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit))
(Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
)
( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
2)
( BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> BinOp
And IntType
Int64)
(BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> BinOp
AShr IntType
Int64) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
num_param) (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
bit) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)))
(Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
)
)
VName
bins <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"bins" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
is] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
num_lam
Param Type
flag_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"flag" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
flag_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
flag_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag_res"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
PrimType
int64
((Integer -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> [Integer] -> [(ADM (Exp SOACS), ADM (Exp SOACS))]
forall a b. (a -> b) -> [a] -> [b]
map ((,) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
flag_param) (ADM (Exp SOACS) -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> (Integer -> ADM (Exp SOACS))
-> Integer
-> (ADM (Exp SOACS), ADM (Exp SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
((Integer -> ADM (Body SOACS)) -> [Integer] -> [ADM (Body SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> (Integer -> [ADM (Exp SOACS)]) -> Integer -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> ADM (Exp SOACS)) -> [Integer] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst ([Integer] -> [ADM (Exp SOACS)])
-> (Integer -> [Integer]) -> Integer -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\Integer
i -> (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
j -> if Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
j then Integer
1 else Integer
0) [Integer
0 .. Integer
3])) ([Integer
0 .. Integer
3] :: [Integer]))
[VName]
flags <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flags" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
bins] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam
[Param Type]
scan_params <- (String -> ADM (Param Type)) -> [String] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((String -> Type -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (Type -> String -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"a1", String
"b1", String
"c1", String
"d1", String
"a2", String
"b2", String
"c2", String
"d2"]
Lambda SOACS
scan_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
scan_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_res") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
([ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> ADM [Exp SOACS])
-> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((ADM (Exp SOACS) -> ADM (Exp SOACS) -> ADM (Exp SOACS))
-> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)] -> ADM [Exp SOACS]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM)))
-> BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)) (([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS])
-> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]) -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ Int -> [ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
4 ([ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)]))
-> [ADM (Exp SOACS)] -> ([ADM (Exp SOACS)], [ADM (Exp SOACS)])
forall a b. (a -> b) -> a -> b
$ (Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
scan_params
ScremaForm SOACS
scan <- [Scan SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC ([Scan SOACS] -> ADM (ScremaForm SOACS))
-> [Scan SOACS] -> ADM (ScremaForm SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [Scan SOACS]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> [Scan SOACS]) -> Scan SOACS -> [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
scan_lam ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ (Integer -> SubExp) -> [Integer] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Integer -> SubExp
intConst IntType
Int64) [Integer
0, Integer
0, Integer
0, Integer
0]
[VName]
offsets <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"offsets" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName]
flags ScremaForm SOACS
scan
SubExp
ind <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ind_last" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
n) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
let i :: Slice SubExp
i = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
ind]
[VName]
nabcd <- (String -> ADM VName) -> [String] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"na", String
"nb", String
"nc", String
"nd"]
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (\VName
abcd -> [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
abcd] (Exp SOACS -> ADM ()) -> (VName -> Exp SOACS) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Slice SubExp -> BasicOp)
-> Slice SubExp -> VName -> BasicOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Slice SubExp -> BasicOp
Index Slice SubExp
i) [VName]
nabcd [VName]
offsets
let vars :: [SubExp]
vars = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
nabcd
[Param Type]
map_params <- (String -> ADM (Param Type)) -> [String] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((String -> Type -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (Type -> String -> ADM (Param Type))
-> Type -> String -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [String
"bin", String
"a", String
"b", String
"c", String
"d"]
Lambda SOACS
map_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type]
[LParam (Rep ADM)]
map_params (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_res"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> [(ADM (Exp SOACS), ADM (Exp SOACS))]
-> [ADM (Body SOACS)]
-> ADM (Exp SOACS)
elseIf
PrimType
int64
((Integer -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> [Integer] -> [(ADM (Exp SOACS), ADM (Exp SOACS))]
forall a b. (a -> b) -> [a] -> [b]
map ((,) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam (Param Type -> ADM (Exp (Rep ADM)))
-> Param Type -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [Param Type] -> Param Type
forall a. HasCallStack => [a] -> a
head [Param Type]
map_params) (ADM (Exp SOACS) -> (ADM (Exp SOACS), ADM (Exp SOACS)))
-> (Integer -> ADM (Exp SOACS))
-> Integer
-> (ADM (Exp SOACS), ADM (Exp SOACS))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ADM (Exp (Rep ADM))
Integer -> ADM (Exp SOACS)
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst) [Integer
0 .. Integer
2])
( (Int -> Param Type -> ADM (Body SOACS))
-> [Int] -> [Param Type] -> [ADM (Body SOACS)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \Int
j Param Type
p ->
[ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
SubExp
t <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"t" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
p) (Integer -> ADM (Exp (Rep ADM))
forall {m :: * -> *}. MonadBuilder m => Integer -> m (Exp (Rep m))
iConst Integer
1)
BinOp -> SubExp -> [SubExp] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (SubExp
t SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
j [SubExp]
vars)
)
[Int
0 .. Int
3]
([Param Type] -> [Param Type]
forall a. HasCallStack => [a] -> [a]
tail [Param Type]
map_params)
)
VName
nis <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"nis" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n (VName
bins VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
offsets) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
[VName]
scatter_dst <- (Type -> ADM VName) -> [Type] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
tps
SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
nis [VName]
xs
where
iConst :: Integer -> m (Exp (Rep m))
iConst Integer
c = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
c
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [VName]
xs SubExp
n SubExp
w = do
SubExp
logw <- SubExp -> ADM SubExp
log2 (SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"w1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
SubExp
iters <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"iters" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimExp VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp (TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 SubExp
logw TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~/~ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2)))
[Type]
types <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
xs
[Param (TypeBase Shape Uniqueness)]
params <- (VName -> Type -> ADM (Param (TypeBase Shape Uniqueness)))
-> [VName] -> [Type] -> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x) (TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness)))
-> (Type -> TypeBase Shape Uniqueness)
-> Type
-> ADM (Param (TypeBase Shape Uniqueness))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> Uniqueness -> TypeBase Shape Uniqueness)
-> Uniqueness -> Type -> TypeBase Shape Uniqueness
forall a b c. (a -> b -> c) -> b -> a -> c
flip Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Uniqueness
Nonunique) [VName]
xs [Type]
types
VName
i <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
Body SOACS
loopbody <- ADM [SubExpRes] -> ADM (Body (Rep ADM))
ADM [SubExpRes] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body SOACS))
-> (ADM [SubExpRes] -> ADM [SubExpRes])
-> ADM [SubExpRes]
-> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM [SubExpRes] -> ADM [SubExpRes]
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) (ADM [SubExpRes] -> ADM (Body SOACS))
-> ADM [SubExpRes] -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ do
SubExp
bit <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bit" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
2)
[VName] -> [Type] -> SubExp -> SubExp -> SubExp -> ADM [VName]
radixSortStep ((Param (TypeBase Shape Uniqueness) -> VName)
-> [Param (TypeBase Shape Uniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase Shape Uniqueness)]
params) [Type]
types SubExp
bit SubExp
n SubExp
w
String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
[(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop
([FParam SOACS] -> [SubExp] -> [(FParam SOACS, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
[FParam SOACS]
params ([SubExp] -> [(FParam SOACS, SubExp)])
-> [SubExp] -> [(FParam SOACS, SubExp)]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
xs)
(VName -> IntType -> SubExp -> LoopForm
ForLoop VName
i IntType
Int64 SubExp
iters)
Body SOACS
loopbody
where
log2 :: SubExp -> ADM SubExp
log2 :: SubExp -> ADM SubExp
log2 SubExp
m = do
[Param (TypeBase Shape Uniqueness)]
params <- (String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness)))
-> [String]
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String
-> TypeBase Shape Uniqueness
-> ADM (Param (TypeBase Shape Uniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam [String
"cond", String
"r", String
"i"] ([TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)])
-> [TypeBase Shape Uniqueness]
-> ADM [Param (TypeBase Shape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ (PrimType -> TypeBase Shape Uniqueness)
-> [PrimType] -> [TypeBase Shape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> TypeBase Shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim [PrimType
Bool, PrimType
int64, PrimType
int64]
let [Param (TypeBase Shape Uniqueness)
cond, Param (TypeBase Shape Uniqueness)
r, Param (TypeBase Shape Uniqueness)
i] = [Param (TypeBase Shape Uniqueness)]
params
Body SOACS
body <- ADM [SubExpRes] -> ADM (Body (Rep ADM))
ADM [SubExpRes] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body SOACS))
-> (ADM [SubExpRes] -> ADM [SubExpRes])
-> ADM [SubExpRes]
-> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM [SubExpRes] -> ADM [SubExpRes]
forall a. Scope SOACS -> ADM a -> ADM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param (TypeBase Shape Uniqueness)]
params) (ADM [SubExpRes] -> ADM (Body SOACS))
-> ADM [SubExpRes] -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
SubExp
r' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"r'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
r) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.>>. TPrimExp Int64 VName
1)
SubExp
cond' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
r' TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
SubExp
i' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i'" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
[SubExpRes] -> ADM [SubExpRes]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExpRes] -> ADM [SubExpRes]) -> [SubExpRes] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExpRes]
subExpsRes [SubExp
cond', SubExp
r', SubExp
i']
SubExp
cond_init <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"test" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
m TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
[SubExp]
l <-
String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"log2res" (Exp (Rep ADM) -> ADM [SubExp]) -> Exp (Rep ADM) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$
[(FParam SOACS, SubExp)] -> LoopForm -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop
([Param (TypeBase Shape Uniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape Uniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape Uniqueness)]
params [SubExp
cond_init, SubExp
m, PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
int64])
(VName -> LoopForm
WhileLoop (VName -> LoopForm) -> VName -> LoopForm
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape Uniqueness)
cond)
Body SOACS
body
let [SubExp
_, SubExp
_, SubExp
res] = [SubExp]
l
SubExp -> ADM SubExp
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
res
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' :: [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
xs SubExp
n SubExp
w = do
VName
iota_n <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"red_iota" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
[VName]
radres <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort [[VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
xs, VName
iota_n] SubExp
n SubExp
w
let [VName
is', VName
iota'] = [VName]
radres
Param Type
i_param <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let slice :: [DimIndex SubExp]
slice = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
i_param]
Lambda SOACS
map_lam <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
i_param] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
xs) [DimIndex SubExp]
slice
[VName]
sorted <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"sorted" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota'] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
[VName] -> ADM [VName]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ VName
iota' VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: VName
is' VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
sorted
diffHist :: VjpOps -> [VName] -> StmAux () -> SubExp -> Lambda SOACS -> [SubExp] -> [VName] -> [SubExp] -> SubExp -> [VName] -> ADM () -> ADM ()
diffHist :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops [VName]
xs StmAux ()
aux SubExp
n Lambda SOACS
lam0 [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m = do
[Type]
as_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> ADM [Type]) -> [VName] -> ADM [Type]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as
[Type]
dst_type <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
dst
[VName]
nes <- (SubExp -> ADM VName) -> [SubExp] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"new_dst" (Exp SOACS -> ADM VName)
-> (SubExp -> Exp SOACS) -> SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ SubExp -> [SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> [SubExp]) -> SubExp -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w)) [SubExp]
ne
Lambda SOACS
h_map <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
as_type
[VName]
h_part <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ADM VName) -> (VName -> String) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> String -> String) -> String -> String -> String
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> String -> String
forall a. Semigroup a => a -> a -> a
(<>) String
"_h_part" (String -> String) -> (VName -> String) -> VName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString) [VName]
xs
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
h_part (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> [HistOp SOACS] -> Lambda SOACS -> SOAC SOACS
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
n [VName]
as [Shape
-> SubExp -> [VName] -> [SubExp] -> Lambda SOACS -> HistOp SOACS
forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
w) SubExp
rf [VName]
nes [SubExp]
ne Lambda SOACS
lam0] Lambda SOACS
h_map
Lambda SOACS
lam0' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ())
-> (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
xs (Exp SOACS -> ADM ())
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM ()) -> SOAC SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma ([SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w) ([VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam0')
ADM ()
m
[VName]
xs_bar <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM VName
lookupAdjVal [VName]
xs
([VName]
dst_params, [VName]
hp_params, Lambda SOACS
f') <- Lambda SOACS
-> [Type] -> SubExp -> ADM ([VName], [VName], Lambda SOACS)
mkF' Lambda SOACS
lam0 [Type]
dst_type (SubExp -> ADM ([VName], [VName], Lambda SOACS))
-> SubExp -> ADM ([VName], [VName], Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w
Lambda SOACS
f'_adj_dst <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
dst_params Lambda SOACS
f'
Lambda SOACS
f'_adj_hp <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
xs_bar) [VName]
hp_params Lambda SOACS
f'
[SubExpRes]
dst_bar' <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f'_adj_dst ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
[VName]
dst_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"dst_bar" [SubExpRes]
dst_bar'
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
dst [VName]
dst_bar
[SubExpRes]
h_part_bar' <- Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f'_adj_hp ([ADM (Exp (Rep ADM))] -> ADM [SubExpRes])
-> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
dst [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
h_part
[VName]
h_part_bar <- String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"h_part_bar" [SubExpRes]
h_part_bar'
Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
[VName]
sorted <- [VName] -> SubExp -> SubExp -> ADM [VName]
radixSort' [VName]
as SubExp
n (SubExp -> ADM [VName]) -> SubExp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w
let siota :: VName
siota = [VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
sorted
let sis :: VName
sis = [VName] -> VName
forall a. HasCallStack => [a] -> a
head ([VName] -> VName) -> [VName] -> VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
sorted
let sas :: [VName]
sas = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop Int
2 [VName]
sorted
VName
iota_n <-
String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
Lambda SOACS
flag_lam <- LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam Param Type
LParam SOACS
par_i VName
sis
VName
flag <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"flag" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
flag_lam
Param Type
par_i' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i' :: VName
i' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'
Lambda SOACS
g_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i'] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([SubExp] -> [SubExpRes]) -> ADM [SubExp] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [SubExpRes]
subExpsRes (ADM [SubExp] -> ADM [SubExpRes])
-> ([Exp SOACS] -> ADM [SubExp]) -> [Exp SOACS] -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"scan_inps") ([Exp SOACS] -> ADM [SubExpRes])
-> ADM [Exp SOACS] -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
SubExp
im1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i_1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
SubExp
nmi <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i')
let s1 :: [DimIndex SubExp]
s1 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
im1]
let s2 :: [DimIndex SubExp]
s2 = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmi]
SubExp
f1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"f1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i']
[SubExp]
r1 <-
String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r1"
(Exp SOACS -> ADM [SubExp]) -> ADM (Exp SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
f1)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp [SubExp]
ne)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s1)
[SubExp]
r2 <-
String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"r2"
(Exp SOACS -> ADM [SubExp]) -> ADM (Exp SOACS) -> ADM [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i' TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep ADM) -> ADM (Exp (Rep ADM)))
-> Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
flag (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
s2)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimType -> PrimValue
onePrimValue PrimType
Bool) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp SOACS)])
-> ([VName] -> [SubExp]) -> [VName] -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimValue -> SubExp
Constant (PrimType -> PrimValue
blankPrimValue PrimType
Bool) :) ([SubExp] -> [SubExp])
-> ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var
([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
sas [DimIndex SubExp]
s2
)
)
(SubExp -> ADM (Exp SOACS)) -> [SubExp] -> ADM [Exp SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> ADM [Exp SOACS]) -> [SubExp] -> ADM [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ SubExp
f1 SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
r1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
r2
[Lambda SOACS]
scan_lams <-
(Lambda SOACS -> ADM (Lambda SOACS))
-> [Lambda SOACS] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
( \Lambda SOACS
l -> do
Param Type
f1 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f1" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
Param Type
f2 <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"f2" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Bool
[Param Type]
ps <- Lambda SOACS -> [Param Type]
Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams (Lambda SOACS -> [Param Type])
-> ADM (Lambda SOACS) -> ADM [Param Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
let ([Param Type]
p1, [Param Type]
p2) = Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne) [Param Type]
ps
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda (Param Type
f1 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
p1 [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Param Type
f2 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
p2) (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
let f :: ADM (Exp (Rep ADM))
f = BinOp
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
LogOr (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f1) (Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam Param Type
f2)
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM))
ADM (Exp SOACS)
f ADM (Exp SOACS) -> [ADM (Exp SOACS)] -> [ADM (Exp SOACS)]
forall a. a -> [a] -> [a]
: (Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
p2)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ADM (Exp (Rep ADM))
ADM (Exp SOACS)
f :) ([ADM (Exp SOACS)] -> [ADM (Exp SOACS)])
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> [ADM (Exp SOACS)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var)
([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"gres"
([SubExpRes] -> ADM [VName]) -> ADM [SubExpRes] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
l ((Param Type -> ADM (Exp SOACS))
-> [Param Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> ADM (Exp (Rep ADM))
Param Type -> ADM (Exp SOACS)
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam [Param Type]
ps)
)
)
[Lambda SOACS
lam, Lambda SOACS
lam']
let ne' :: [SubExp]
ne' = PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
ne
[VName]
scansres <-
String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] ([Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC ((Lambda SOACS -> Scan SOACS) -> [Lambda SOACS] -> [Scan SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
`Scan` [SubExp]
ne') [Lambda SOACS]
scan_lams) Lambda SOACS
g_lam)
let (VName
_ : [VName]
ls_arr, VName
_ : [VName]
rs_arr_rev) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ne Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [VName]
scansres
Param Type
par_i'' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i'' :: VName
i'' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i''
Lambda SOACS
map_lam <-
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i''] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_res"
(Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ (SubExp, VName) -> [(SubExp, VName)]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> SubExp
forall a. HasCallStack => [a] -> a
head [SubExp]
w, VName
i''))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
[ADM (Exp SOACS)] -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp SOACS)] -> ADM (Body SOACS))
-> ([VName] -> [ADM (Exp SOACS)]) -> [VName] -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Exp SOACS)) -> [VName] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> ADM (Body SOACS)) -> ADM [VName] -> ADM (Body SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
h_part_bar [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i''])
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
(Type -> ADM (Exp SOACS)) -> [Type] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map (\Type
t -> Exp SOACS -> ADM (Exp SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
tail ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue (PrimType -> PrimValue) -> PrimType -> PrimValue
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) [Type]
as_type
)
[VName]
f_bar <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"f_bar" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
sis] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
([VName]
as_params, Lambda SOACS
f) <- Lambda SOACS -> [Type] -> SubExp -> ADM ([VName], Lambda SOACS)
mkF Lambda SOACS
lam0 [Type]
as_type SubExp
n
Lambda SOACS
f_adj <- VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((VName -> Adj) -> [VName] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Adj
adjFromVar [VName]
f_bar) [VName]
as_params Lambda SOACS
f
Param Type
par_i''' <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let i''' :: VName
i''' = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i'''
Lambda SOACS
rev_lam <- [LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i'''] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
SubExp
nmim1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"n_i_1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i''' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
[VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> [DimIndex SubExp] -> ADM [VName]
multiIndex [VName]
rs_arr_rev [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nmim1]
[VName]
rs_arr <- String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"rs_arr" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (Op (Rep ADM) -> Exp (Rep ADM)) -> Op (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota_n] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam
[VName]
sas_bar <-
String -> [SubExpRes] -> ADM [VName]
bindSubExpRes String
"sas_bar"
([SubExpRes] -> ADM [VName]) -> ADM [SubExpRes] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m [SubExpRes]
eLambda Lambda (Rep ADM)
Lambda SOACS
f_adj ((VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> ADM (Exp (Rep ADM))
SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp SOACS))
-> (VName -> SubExp) -> VName -> ADM (Exp SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [ADM (Exp (Rep ADM))])
-> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [VName]
ls_arr [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
sas [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
rs_arr)
[VName]
scatter_dst <- (Type -> ADM VName) -> [Type] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (\Type
t -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"scatter_dst" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)) [Type]
as_type
[VName]
as_bar <- SubExp -> [VName] -> VName -> [VName] -> ADM [VName]
multiScatter SubExp
n [VName]
scatter_dst VName
siota [VName]
sas_bar
(VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
tail [VName]
as) [VName]
as_bar
where
mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS)
mkFlagLam LParam SOACS
par_i VName
sis =
[LParam (Rep ADM)] -> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m [SubExpRes] -> m (Lambda (Rep m))
mkLambda [LParam (Rep ADM)
LParam SOACS
par_i] (ADM [SubExpRes] -> ADM (Lambda (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
([VName] -> [SubExpRes]) -> ADM [VName] -> ADM [SubExpRes]
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> [SubExpRes]
varsRes (ADM [VName] -> ADM [SubExpRes])
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [SubExpRes]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"flag" (Exp SOACS -> ADM [SubExpRes])
-> ADM (Exp SOACS) -> ADM [SubExpRes]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< do
let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
par_i
ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0))
([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> ADM (Exp (Rep ADM))) -> SubExp -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrimValue PrimType
Bool)
( [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody ([ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM)))
-> [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$
ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))])
-> ADM (Exp (Rep ADM)) -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ do
VName
i_p <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_p" (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
[VName]
vs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"vs" (Exp SOACS -> ADM VName)
-> (VName -> Exp SOACS) -> VName -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
sis (Slice SubExp -> BasicOp)
-> (VName -> Slice SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> (VName -> [DimIndex SubExp]) -> VName -> Slice SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> [DimIndex SubExp]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimIndex SubExp -> [DimIndex SubExp])
-> (VName -> DimIndex SubExp) -> VName -> [DimIndex SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName
i, VName
i_p]
let [VName
vs_i, VName
vs_p] = [VName]
vs
TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot (TPrimExp Bool VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
vs_i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
vs_p
)