{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.SOAC (vjpSOAC) where

import Control.Monad
import Futhark.AD.Rev.Hist
import Futhark.AD.Rev.Map
import Futhark.AD.Rev.Monad
import Futhark.AD.Rev.Reduce
import Futhark.AD.Rev.Scan
import Futhark.AD.Rev.Scatter
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Util (chunks)

-- We split any multi-op scan or reduction into multiple operations so
-- we can detect special cases.  Post-AD, the result may be fused
-- again.
splitScanRed ::
  VjpOps ->
  ([a] -> ADM (ScremaForm SOACS), a -> [SubExp]) ->
  (Pat Type, StmAux (), [a], SubExp, [VName]) ->
  ADM () ->
  ADM ()
splitScanRed :: forall a.
VjpOps
-> ([a] -> ADM (ScremaForm SOACS), a -> [SubExp])
-> (Pat Type, StmAux (), [a], SubExp, [VName])
-> ADM ()
-> ADM ()
splitScanRed VjpOps
vjpops ([a] -> ADM (ScremaForm SOACS)
opSOAC, a -> [SubExp]
opNeutral) (Pat Type
pat, StmAux ()
aux, [a]
ops, SubExp
w, [VName]
as) ADM ()
m = do
  let ks :: [Int]
ks = forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> [SubExp]
opNeutral) [a]
ops
      pat_per_op :: [Pat Type]
pat_per_op = forall a b. (a -> b) -> [a] -> [b]
map forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ks forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat
      as_per_op :: [[VName]]
as_per_op = forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ks [VName]
as
      onOps :: [a] -> [Pat Type] -> [[VName]] -> ADM ()
onOps (a
op : [a]
ops') (Pat Type
op_pat : [Pat Type]
op_pats') ([VName]
op_as : [[VName]]
op_as') = do
        ScremaForm SOACS
op_form <- [a] -> ADM (ScremaForm SOACS)
opSOAC [a
op]
        VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpops Pat Type
op_pat StmAux ()
aux (forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
op_as ScremaForm SOACS
op_form) forall a b. (a -> b) -> a -> b
$
          [a] -> [Pat Type] -> [[VName]] -> ADM ()
onOps [a]
ops' [Pat Type]
op_pats' [[VName]]
op_as'
      onOps [a]
_ [Pat Type]
_ [[VName]]
_ = ADM ()
m
  [a] -> [Pat Type] -> [[VName]] -> ADM ()
onOps [a]
ops [Pat Type]
pat_per_op [[VName]]
as_per_op

-- We split multi-op histograms into multiple operations so we
-- can take advantage of special cases. Post-AD, the result may
-- be fused again.
splitHist :: VjpOps -> Pat Type -> StmAux () -> [HistOp SOACS] -> SubExp -> [VName] -> [VName] -> ADM () -> ADM ()
splitHist :: VjpOps
-> Pat Type
-> StmAux ()
-> [HistOp SOACS]
-> SubExp
-> [VName]
-> [VName]
-> ADM ()
-> ADM ()
splitHist VjpOps
vjpops Pat Type
pat StmAux ()
aux [HistOp SOACS]
ops SubExp
w [VName]
is [VName]
as ADM ()
m = do
  let ks :: [Int]
ks = forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp SOACS]
ops
      pat_per_op :: [Pat Type]
pat_per_op = forall a b. (a -> b) -> [a] -> [b]
map forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ks forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat
      as_per_op :: [[VName]]
as_per_op = forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ks [VName]
as
      onOps :: [HistOp SOACS] -> [Pat Type] -> [VName] -> [[VName]] -> ADM ()
onOps (HistOp SOACS
op : [HistOp SOACS]
ops') (Pat Type
op_pat : [Pat Type]
op_pats') (VName
op_is : [VName]
op_is') ([VName]
op_as : [[VName]]
op_as') = do
        Lambda SOACS
f <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 :) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
op_as
        VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpops Pat Type
op_pat StmAux ()
aux (forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w (VName
op_is forall a. a -> [a] -> [a]
: [VName]
op_as) [HistOp SOACS
op] Lambda SOACS
f) forall a b. (a -> b) -> a -> b
$
          [HistOp SOACS] -> [Pat Type] -> [VName] -> [[VName]] -> ADM ()
onOps [HistOp SOACS]
ops' [Pat Type]
op_pats' [VName]
op_is' [[VName]]
op_as'
      onOps [HistOp SOACS]
_ [Pat Type]
_ [VName]
_ [[VName]]
_ = ADM ()
m
  [HistOp SOACS] -> [Pat Type] -> [VName] -> [[VName]] -> ADM ()
onOps [HistOp SOACS]
ops [Pat Type]
pat_per_op [VName]
is [[VName]]
as_per_op

-- unfusing a map-histogram construct into a map and a histogram.
histomapToMapAndHist :: Pat Type -> (SubExp, [HistOp SOACS], Lambda SOACS, [VName]) -> ADM (Stm SOACS, Stm SOACS)
histomapToMapAndHist :: Pat Type
-> (SubExp, [HistOp SOACS], Lambda SOACS, [VName])
-> ADM (Stm SOACS, Stm SOACS)
histomapToMapAndHist (Pat [PatElem Type]
pes) (SubExp
w, [HistOp SOACS]
histops, Lambda SOACS
map_lam, [VName]
as) = do
  [Ident]
map_pat <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Type -> ADM Ident
accMapPatElem forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
  let map_stm :: Stm SOACS
map_stm = forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
map_pat forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
as forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam
  Lambda SOACS
new_lam <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
  let hist_stm :: Stm SOACS
hist_stm = forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes) (forall dec. dec -> StmAux dec
defAux ()) forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
map_pat) [HistOp SOACS]
histops Lambda SOACS
new_lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stm SOACS
map_stm, Stm SOACS
hist_stm)
  where
    accMapPatElem :: Type -> ADM Ident
accMapPatElem =
      forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"hist_map_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w)

commonSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m = do
  forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op SOAC SOACS
soac
  ADM ()
m
  forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
pat

-- Reverse-mode differentiation of SOACs
vjpSOAC :: VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
-- Differentiating Reduces
vjpSOAC :: VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux soac :: SOAC SOACS
soac@(Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just [Reduce Commutativity
iscomm Lambda SOACS
lam [Var VName
ne]] <- forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    [VName
a] <- [VName]
as,
    Just Lambda SOACS
op <- Lambda SOACS -> Maybe (Lambda SOACS)
mapOp Lambda SOACS
lam =
      VjpOps
-> Pat Type
-> StmAux ()
-> SubExp
-> Commutativity
-> Lambda SOACS
-> VName
-> VName
-> ADM ()
-> ADM ()
diffVecReduce VjpOps
ops Pat Type
pat StmAux ()
aux SubExp
w Commutativity
iscomm Lambda SOACS
op VName
ne VName
a ADM ()
m
  | Just [Reduce SOACS]
reds <- forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [Reduce SOACS]
reds forall a. Ord a => a -> a -> Bool
> Int
1 =
      forall a.
VjpOps
-> ([a] -> ADM (ScremaForm SOACS), a -> [SubExp])
-> (Pat Type, StmAux (), [a], SubExp, [VName])
-> ADM ()
-> ADM ()
splitScanRed VjpOps
ops (forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC, forall rep. Reduce rep -> [SubExp]
redNeutral) (Pat Type
pat, StmAux ()
aux, [Reduce SOACS]
reds, SubExp
w, [VName]
as) ADM ()
m
  | Just [Reduce SOACS
red] <- forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    [VName
x] <- forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
    [SubExp
ne] <- forall rep. Reduce rep -> [SubExp]
redNeutral Reduce SOACS
red,
    [VName
a] <- [VName]
as,
    Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp forall a b. (a -> b) -> a -> b
$ forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
    BinOp -> Bool
isMinMaxOp BinOp
op =
      VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxReduce VjpOps
ops VName
x StmAux ()
aux SubExp
w BinOp
op SubExp
ne VName
a ADM ()
m
  | Just [Reduce SOACS
red] <- forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    [VName
x] <- forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
    [SubExp
ne] <- forall rep. Reduce rep -> [SubExp]
redNeutral Reduce SOACS
red,
    [VName
a] <- [VName]
as,
    Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp forall a b. (a -> b) -> a -> b
$ forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
    BinOp -> Bool
isMulOp BinOp
op =
      VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulReduce VjpOps
ops VName
x StmAux ()
aux SubExp
w BinOp
op SubExp
ne VName
a ADM ()
m
  | Just Reduce SOACS
red <- forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form = do
      [VName]
pat_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Adj -> ADM VName
adjVal forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m
      VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce VjpOps
ops [VName]
pat_adj SubExp
w [VName]
as Reduce SOACS
red

-- Differentiating Scans
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux soac :: SOAC SOACS
soac@(Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just [Scan Lambda SOACS
lam [SubExp
ne]] <- forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
    [VName
x] <- forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
    [VName
a] <- [VName]
as,
    Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda SOACS
lam,
    BinOp -> Bool
isAddOp BinOp
op = do
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m
      VjpOps
-> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd VjpOps
ops VName
x SubExp
w Lambda SOACS
lam SubExp
ne VName
a
  | Just [Scan Lambda SOACS
lam [SubExp]
ne] <- forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
    Just Lambda SOACS
op <- Lambda SOACS -> Maybe (Lambda SOACS)
mapOp Lambda SOACS
lam = do
      VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> ADM ()
-> ADM ()
diffScanVec VjpOps
ops (forall dec. Pat dec -> [VName]
patNames Pat Type
pat) StmAux ()
aux SubExp
w Lambda SOACS
op [SubExp]
ne [VName]
as ADM ()
m
  | Just [Scan SOACS]
scans <- forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [Scan SOACS]
scans forall a. Ord a => a -> a -> Bool
> Int
1 =
      forall a.
VjpOps
-> ([a] -> ADM (ScremaForm SOACS), a -> [SubExp])
-> (Pat Type, StmAux (), [a], SubExp, [VName])
-> ADM ()
-> ADM ()
splitScanRed VjpOps
ops (forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC, forall rep. Scan rep -> [SubExp]
scanNeutral) (Pat Type
pat, StmAux ()
aux, [Scan SOACS]
scans, SubExp
w, [VName]
as) ADM ()
m
  | Just Scan SOACS
red <- forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form = do
      forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m
      VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops (forall dec. Pat dec -> [VName]
patNames Pat Type
pat) SubExp
w [VName]
as Scan SOACS
red

-- Differentiating Maps
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux soac :: SOAC SOACS
soac@(Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just Lambda SOACS
lam <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form = do
      [Adj]
pat_adj <- Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m
      VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
as

-- Differentiating Redomaps
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
_aux (Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <-
      forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
      (Stm SOACS
mapstm, Stm SOACS
redstm) <-
        forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat Type
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
as)
      VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
mapstm forall a b. (a -> b) -> a -> b
$ VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
redstm ADM ()
m

-- Differentiating Scanomaps
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
_aux (Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <-
      forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = do
      (Stm SOACS
mapstm, Stm SOACS
scanstm) <-
        forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Scan rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
scanomapToMapAndScan Pat Type
pat (SubExp
w, [Scan SOACS]
scans, Lambda SOACS
map_lam, [VName]
as)
      VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
mapstm forall a b. (a -> b) -> a -> b
$ VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
scanstm ADM ()
m

-- Differentiating Scatter
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux (Scatter SubExp
w [VName]
ass Lambda SOACS
lam [(ShapeBase SubExp, Int, VName)]
written_info) ADM ()
m
  | forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam =
      VjpOps
-> Pat Type
-> StmAux ()
-> (SubExp, [VName], Lambda SOACS,
    [(ShapeBase SubExp, Int, VName)])
-> ADM ()
-> ADM ()
vjpScatter VjpOps
ops Pat Type
pat StmAux ()
aux (SubExp
w, [VName]
ass, Lambda SOACS
lam, [(ShapeBase SubExp, Int, VName)]
written_info) ADM ()
m
  | Bool
otherwise = do
      [Ident]
map_idents <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Type
t -> forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"map_res" (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
w)) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      let map_stm :: Stm SOACS
map_stm = forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
map_idents forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
ass forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam
      Lambda SOACS
lam_id <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      let scatter_stm :: Stm SOACS
scatter_stm = forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep.
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w (forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
map_idents) Lambda SOACS
lam_id [(ShapeBase SubExp, Int, VName)]
written_info
      VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
map_stm forall a b. (a -> b) -> a -> b
$ VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
scatter_stm ADM ()
m

-- Differentiating Histograms
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux (Hist SubExp
n [VName]
as [HistOp SOACS]
histops Lambda SOACS
f) ADM ()
m
  | forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
f,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
histops forall a. Ord a => a -> a -> Bool
> Int
1 = do
      let ([VName]
is, [VName]
vs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
histops) [VName]
as
      VjpOps
-> Pat Type
-> StmAux ()
-> [HistOp SOACS]
-> SubExp
-> [VName]
-> [VName]
-> ADM ()
-> ADM ()
splitHist VjpOps
ops Pat Type
pat StmAux ()
aux [HistOp SOACS]
histops SubExp
n [VName]
is [VName]
vs ADM ()
m
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux (Hist SubExp
n [VName
is, VName
vs] [HistOp SOACS
histop] Lambda SOACS
f) ADM ()
m
  | forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
f,
    [VName
x] <- forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
    HistOp (Shape [SubExp
w]) SubExp
rf [VName
dst] [SubExp
ne] Lambda SOACS
lam <- HistOp SOACS
histop,
    Lambda SOACS
lam' <- Lambda SOACS -> Lambda SOACS
nestedMapOp Lambda SOACS
lam,
    Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda SOACS
lam',
    BinOp -> Bool
isMinMaxOp BinOp
op =
      VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxHist VjpOps
ops VName
x StmAux ()
aux SubExp
n BinOp
op SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m
  | forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
f,
    [VName
x] <- forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
    HistOp (Shape [SubExp
w]) SubExp
rf [VName
dst] [SubExp
ne] Lambda SOACS
lam <- HistOp SOACS
histop,
    Lambda SOACS
lam' <- Lambda SOACS -> Lambda SOACS
nestedMapOp Lambda SOACS
lam,
    Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda SOACS
lam',
    BinOp -> Bool
isMulOp BinOp
op =
      VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMulHist VjpOps
ops VName
x StmAux ()
aux SubExp
n BinOp
op SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m
  | forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
f,
    [VName
x] <- forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
    HistOp (Shape [SubExp
w]) SubExp
rf [VName
dst] [SubExp
ne] Lambda SOACS
lam <- HistOp SOACS
histop,
    Lambda SOACS
lam' <- Lambda SOACS -> Lambda SOACS
nestedMapOp Lambda SOACS
lam,
    Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda SOACS
lam',
    BinOp -> Bool
isAddOp BinOp
op =
      VjpOps
-> VName
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> SubExp
-> VName
-> VName
-> SubExp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffAddHist VjpOps
ops VName
x StmAux ()
aux SubExp
n Lambda SOACS
lam SubExp
ne VName
is VName
vs SubExp
w SubExp
rf VName
dst ADM ()
m
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux (Hist SubExp
n [VName]
as [HistOp SOACS
histop] Lambda SOACS
f) ADM ()
m
  | forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
f,
    HistOp (Shape [SubExp]
w) SubExp
rf [VName]
dst [SubExp]
ne Lambda SOACS
lam <- HistOp SOACS
histop = do
      VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> [SubExp]
-> SubExp
-> [VName]
-> ADM ()
-> ADM ()
diffHist VjpOps
ops (forall dec. Pat dec -> [VName]
patNames Pat Type
pat) StmAux ()
aux SubExp
n Lambda SOACS
lam [SubExp]
ne [VName]
as [SubExp]
w SubExp
rf [VName]
dst ADM ()
m
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
_aux (Hist SubExp
n [VName]
as [HistOp SOACS]
histops Lambda SOACS
f) ADM ()
m
  | Bool -> Bool
not (forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
f) = do
      (Stm SOACS
mapstm, Stm SOACS
redstm) <-
        Pat Type
-> (SubExp, [HistOp SOACS], Lambda SOACS, [VName])
-> ADM (Stm SOACS, Stm SOACS)
histomapToMapAndHist Pat Type
pat (SubExp
n, [HistOp SOACS]
histops, Lambda SOACS
f, [VName]
as)
      VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
mapstm forall a b. (a -> b) -> a -> b
$ VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
redstm ADM ()
m
vjpSOAC VjpOps
_ Pat Type
_ StmAux ()
_ SOAC SOACS
soac ADM ()
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"vjpSOAC unhandled:\n" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString SOAC SOACS
soac

---------------
--- Helpers ---
---------------

isMinMaxOp :: BinOp -> Bool
isMinMaxOp :: BinOp -> Bool
isMinMaxOp (SMin IntType
_) = Bool
True
isMinMaxOp (UMin IntType
_) = Bool
True
isMinMaxOp (FMin FloatType
_) = Bool
True
isMinMaxOp (SMax IntType
_) = Bool
True
isMinMaxOp (UMax IntType
_) = Bool
True
isMinMaxOp (FMax FloatType
_) = Bool
True
isMinMaxOp BinOp
_ = Bool
False

isMulOp :: BinOp -> Bool
isMulOp :: BinOp -> Bool
isMulOp (Mul IntType
_ Overflow
_) = Bool
True
isMulOp (FMul FloatType
_) = Bool
True
isMulOp BinOp
_ = Bool
False

isAddOp :: BinOp -> Bool
isAddOp :: BinOp -> Bool
isAddOp (Add IntType
_ Overflow
_) = Bool
True
isAddOp (FAdd FloatType
_) = Bool
True
isAddOp BinOp
_ = Bool
False

-- Identifies vectorized operators (lambdas):
--   if the lambda argument is a map, then returns
--   just the map's lambda; otherwise nothing.
mapOp :: Lambda SOACS -> Maybe (Lambda SOACS)
mapOp :: Lambda SOACS -> Maybe (Lambda SOACS)
mapOp (Lambda [LParam SOACS
pa1, LParam SOACS
pa2] Body SOACS
lam_body [Type]
_)
  | [SubExpRes Certs
cs SubExp
r] <- forall rep. Body rep -> [SubExpRes]
bodyResult Body SOACS
lam_body,
    Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty,
    [Stm SOACS
map_stm] <- forall rep. Stms rep -> [Stm rep]
stmsToList (forall rep. Body rep -> Stms rep
bodyStms Body SOACS
lam_body),
    (Let (Pat [PatElem (LetDec SOACS)
pe]) StmAux (ExpDec SOACS)
_ (Op Op SOACS
scrm)) <- Stm SOACS
map_stm,
    (Screma SubExp
_ [VName
a1, VName
a2] (ScremaForm [] [] Lambda SOACS
map_lam)) <- Op SOACS
scrm,
    (VName
a1 forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName LParam SOACS
pa1 Bool -> Bool -> Bool
&& VName
a2 forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName LParam SOACS
pa2) Bool -> Bool -> Bool
|| (VName
a1 forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName LParam SOACS
pa2 Bool -> Bool -> Bool
&& VName
a2 forall a. Eq a => a -> a -> Bool
== forall dec. Param dec -> VName
paramName LParam SOACS
pa1),
    SubExp
r forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec SOACS)
pe) =
      forall a. a -> Maybe a
Just Lambda SOACS
map_lam
mapOp Lambda SOACS
_ = forall a. Maybe a
Nothing

-- getting the innermost lambda of a perfect-map nest
--   (i.e., the first lambda that does not consists of exactly a map)
nestedMapOp :: Lambda SOACS -> Lambda SOACS
nestedMapOp :: Lambda SOACS -> Lambda SOACS
nestedMapOp Lambda SOACS
lam =
  forall b a. b -> (a -> b) -> Maybe a -> b
maybe Lambda SOACS
lam Lambda SOACS -> Lambda SOACS
nestedMapOp (Lambda SOACS -> Maybe (Lambda SOACS)
mapOp Lambda SOACS
lam)