{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.SOAC (vjpSOAC) where

import Control.Monad
import Futhark.AD.Rev.Map
import Futhark.AD.Rev.Monad
import Futhark.AD.Rev.Reduce
import Futhark.AD.Rev.Scan
import Futhark.AD.Rev.Scatter
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Util (chunks)

-- We split any multi-op scan or reduction into multiple operations so
-- we can detect special cases.  Post-AD, the result may be fused
-- again.
splitScanRed ::
  VjpOps ->
  ([a] -> ADM (ScremaForm SOACS), a -> [SubExp]) ->
  (Pat Type, StmAux (), [a], SubExp, [VName]) ->
  ADM () ->
  ADM ()
splitScanRed :: VjpOps
-> ([a] -> ADM (ScremaForm SOACS), a -> [SubExp])
-> (Pat Type, StmAux (), [a], SubExp, [VName])
-> ADM ()
-> ADM ()
splitScanRed VjpOps
vjpops ([a] -> ADM (ScremaForm SOACS)
opSOAC, a -> [SubExp]
opNeutral) (Pat Type
pat, StmAux ()
aux, [a]
ops, SubExp
w, [VName]
as) ADM ()
m = do
  let ks :: [Int]
ks = (a -> Int) -> [a] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (a -> [SubExp]) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> [SubExp]
opNeutral) [a]
ops
      pat_per_op :: [Pat Type]
pat_per_op = ([PatElem Type] -> Pat Type) -> [[PatElem Type]] -> [Pat Type]
forall a b. (a -> b) -> [a] -> [b]
map [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat ([[PatElem Type]] -> [Pat Type]) -> [[PatElem Type]] -> [Pat Type]
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElem Type] -> [[PatElem Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ks ([PatElem Type] -> [[PatElem Type]])
-> [PatElem Type] -> [[PatElem Type]]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat
      as_per_op :: [[VName]]
as_per_op = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ks [VName]
as
      onOps :: [a] -> [Pat Type] -> [[VName]] -> ADM ()
onOps (a
op : [a]
ops') (Pat Type
op_pat : [Pat Type]
op_pats') ([VName]
op_as : [[VName]]
op_as') = do
        ScremaForm SOACS
op_form <- [a] -> ADM (ScremaForm SOACS)
opSOAC [a
op]
        VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpops Pat Type
op_pat StmAux ()
aux (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
op_as ScremaForm SOACS
op_form) (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
          [a] -> [Pat Type] -> [[VName]] -> ADM ()
onOps [a]
ops' [Pat Type]
op_pats' [[VName]]
op_as'
      onOps [a]
_ [Pat Type]
_ [[VName]]
_ = ADM ()
m
  [a] -> [Pat Type] -> [[VName]] -> ADM ()
onOps [a]
ops [Pat Type]
pat_per_op [[VName]]
as_per_op

commonSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m = do
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op Op SOACS
SOAC SOACS
soac
  ADM ()
m
  ADM [Adj] -> ADM [Adj]
forall a. ADM a -> ADM a
returnSweepCode (ADM [Adj] -> ADM [Adj]) -> ADM [Adj] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ (VName -> ADM Adj) -> [VName] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj ([VName] -> ADM [Adj]) -> [VName] -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat

vjpSOAC :: VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC :: VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux soac :: SOAC SOACS
soac@(Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just [Reduce SOACS]
reds <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    [Reduce SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Reduce SOACS]
reds Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 =
      VjpOps
-> ([Reduce SOACS] -> ADM (ScremaForm SOACS),
    Reduce SOACS -> [SubExp])
-> (Pat Type, StmAux (), [Reduce SOACS], SubExp, [VName])
-> ADM ()
-> ADM ()
forall a.
VjpOps
-> ([a] -> ADM (ScremaForm SOACS), a -> [SubExp])
-> (Pat Type, StmAux (), [a], SubExp, [VName])
-> ADM ()
-> ADM ()
splitScanRed VjpOps
ops ([Reduce SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC, Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral) (Pat Type
pat, StmAux ()
aux, [Reduce SOACS]
reds, SubExp
w, [VName]
as) ADM ()
m
  | Just [Reduce SOACS
red] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
    [VName
x] <- Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat,
    [SubExp
ne] <- Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral Reduce SOACS
red,
    [VName
a] <- [VName]
as,
    Just [(BinOp
op, PrimType
_, VName
_, VName
_)] <- Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp (Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)])
-> Lambda SOACS -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ Reduce SOACS -> Lambda SOACS
forall rep. Reduce rep -> Lambda rep
redLambda Reduce SOACS
red,
    BinOp -> Bool
isMinMaxOp BinOp
op =
      VjpOps
-> VName
-> StmAux ()
-> SubExp
-> BinOp
-> SubExp
-> VName
-> ADM ()
-> ADM ()
diffMinMaxReduce VjpOps
ops VName
x StmAux ()
aux SubExp
w BinOp
op SubExp
ne VName
a ADM ()
m
  | Just Reduce SOACS
red <- [Reduce SOACS] -> Reduce SOACS
forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce ([Reduce SOACS] -> Reduce SOACS)
-> Maybe [Reduce SOACS] -> Maybe (Reduce SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScremaForm SOACS -> Maybe [Reduce SOACS]
forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form = do
      [VName]
pat_adj <- (Adj -> ADM VName) -> [Adj] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Adj -> ADM VName
adjVal ([Adj] -> ADM [VName]) -> ADM [Adj] -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m
      VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM ()
diffReduce VjpOps
ops [VName]
pat_adj SubExp
w [VName]
as Reduce SOACS
red
  where
    isMinMaxOp :: BinOp -> Bool
isMinMaxOp (SMin IntType
_) = Bool
True
    isMinMaxOp (UMin IntType
_) = Bool
True
    isMinMaxOp (FMin FloatType
_) = Bool
True
    isMinMaxOp (SMax IntType
_) = Bool
True
    isMinMaxOp (UMax IntType
_) = Bool
True
    isMinMaxOp (FMax FloatType
_) = Bool
True
    isMinMaxOp BinOp
_ = Bool
False
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux soac :: SOAC SOACS
soac@(Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just [Scan SOACS]
scans <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
    [Scan SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Scan SOACS]
scans Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 =
      VjpOps
-> ([Scan SOACS] -> ADM (ScremaForm SOACS), Scan SOACS -> [SubExp])
-> (Pat Type, StmAux (), [Scan SOACS], SubExp, [VName])
-> ADM ()
-> ADM ()
forall a.
VjpOps
-> ([a] -> ADM (ScremaForm SOACS), a -> [SubExp])
-> (Pat Type, StmAux (), [a], SubExp, [VName])
-> ADM ()
-> ADM ()
splitScanRed VjpOps
ops ([Scan SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC, Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral) (Pat Type
pat, StmAux ()
aux, [Scan SOACS]
scans, SubExp
w, [VName]
as) ADM ()
m
  | Just Scan SOACS
red <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan ([Scan SOACS] -> Scan SOACS)
-> Maybe [Scan SOACS] -> Maybe (Scan SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form = do
      ADM [Adj] -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM [Adj] -> ADM ()) -> ADM [Adj] -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m
      VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat) SubExp
w [VName]
as Scan SOACS
red
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux soac :: SOAC SOACS
soac@(Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form = do
      [Adj]
pat_adj <- Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM [Adj]
commonSOAC Pat Type
pat StmAux ()
aux SOAC SOACS
soac ADM ()
m
      VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
as
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
_aux (Screma SubExp
w [VName]
as ScremaForm SOACS
form) ADM ()
m
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <-
      ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
      (Stm SOACS
mapstm, Stm SOACS
redstm) <-
        Pat (LetDec SOACS)
-> (SubExp, [Reduce SOACS], Lambda SOACS, [VName])
-> ADM (Stm SOACS, Stm SOACS)
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
 Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat Type
Pat (LetDec SOACS)
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
as)
      VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
mapstm (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
redstm ADM ()
m
vjpSOAC VjpOps
ops Pat Type
pat StmAux ()
aux (Scatter SubExp
w [VName]
lam Lambda SOACS
ass [(Shape, Int, VName)]
written_info) ADM ()
m =
  VjpOps
-> Pat Type
-> StmAux ()
-> (SubExp, [VName], Lambda SOACS, [(Shape, Int, VName)])
-> ADM ()
-> ADM ()
vjpScatter VjpOps
ops Pat Type
pat StmAux ()
aux (SubExp
w, [VName]
lam, Lambda SOACS
ass, [(Shape, Int, VName)]
written_info) ADM ()
m
vjpSOAC VjpOps
_ Pat Type
_ StmAux ()
_ SOAC SOACS
soac ADM ()
_ =
  [Char] -> ADM ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM ()) -> [Char] -> ADM ()
forall a b. (a -> b) -> a -> b
$ [Char]
"vjpSOAC unhandled:\n" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SOAC SOACS -> [Char]
forall a. Pretty a => a -> [Char]
pretty SOAC SOACS
soac