{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Definition of /Second-Order Array Combinators/ (SOACs), which are
-- the main form of parallelism in the early stages of the compiler.
module Futhark.IR.SOACS.SOAC
  ( SOAC (..),
    ScremaForm (..),
    HistOp (..),
    Scan (..),
    scanResults,
    singleScan,
    Reduce (..),
    redResults,
    singleReduce,

    -- * Utility
    scremaType,
    soacType,
    typeCheckSOAC,
    mkIdentityLambda,
    isIdentityLambda,
    nilFn,
    scanomapSOAC,
    redomapSOAC,
    scanSOAC,
    reduceSOAC,
    mapSOAC,
    isScanomapSOAC,
    isRedomapSOAC,
    isScanSOAC,
    isReduceSOAC,
    isMapSOAC,
    scremaLambda,
    ppScrema,
    ppHist,
    ppStream,
    ppScatter,
    groupScatterResults,
    groupScatterResults',
    splitScatterResults,

    -- * Generic traversal
    SOACMapper (..),
    identitySOACMapper,
    mapSOACM,
    traverseSOACStms,
  )
where

import Control.Category
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, removeLambdaAliases)
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty (Doc, Pretty, align, comma, commasep, docText, parens, ppTuple', pretty, (<+>), (</>))
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

-- | A second-order array combinator (SOAC).
data SOAC rep
  = Stream SubExp [VName] [SubExp] (Lambda rep)
  | -- | @Scatter <length> <inputs> <lambda> <outputs>@
    --
    -- Scatter maps values from a set of input arrays to indices and values of a
    -- set of output arrays. It is able to write multiple values to multiple
    -- outputs each of which may have multiple dimensions.
    --
    -- <inputs> is a list of input arrays, all having size <length>, elements of
    -- which are applied to the <lambda> function. For instance, if there are
    -- two arrays, <lambda> will get two values as input, one from each array.
    --
    -- <outputs> specifies the result of the <lambda> and which arrays to write
    -- to. Each element of the list consists of a <VName> specifying which array
    -- to scatter to, a <Shape> describing the shape of that array, and an <Int>
    -- describing how many elements should be written to that array for each
    -- invocation of the <lambda>.
    --
    -- <lambda> is a function that takes inputs from <inputs> and returns values
    -- according to the output-specification in <outputs>. It returns values in
    -- the following manner:
    --
    --     [index_0, index_1, ..., index_n, value_0, value_1, ..., value_m]
    --
    -- For each output in <outputs>, <lambda> returns <i> * <j> index values and
    -- <j> output values, where <i> is the number of dimensions (rank) of the
    -- given output, and <j> is the number of output values written to the given
    -- output.
    --
    -- For example, given the following output specification:
    --
    --     [([x1, y1, z1], 2, arr1), ([x2, y2], 1, arr2)]
    --
    -- <lambda> will produce 6 (3 * 2) index values and 2 output values for
    -- <arr1>, and 2 (2 * 1) index values and 1 output value for
    -- arr2. Additionally, the results are grouped, so the first 6 index values
    -- will correspond to the first two output values, and so on. For this
    -- example, <lambda> should return a total of 11 values, 8 index values and
    -- 3 output values.  See also 'splitScatterResults'.
    Scatter SubExp [VName] (Lambda rep) [(Shape, Int, VName)]
  | -- | @Hist <length> <input arrays> <dest-arrays-and-ops> <bucket fun>@
    --
    -- The final lambda produces indexes and values for the 'HistOp's.
    Hist SubExp [VName] [HistOp rep] (Lambda rep)
  | -- FIXME: this should not be here
    JVP (Lambda rep) [SubExp] [SubExp]
  | -- FIXME: this should not be here
    VJP (Lambda rep) [SubExp] [SubExp]
  | -- | A combination of scan, reduction, and map.  The first
    -- t'SubExp' is the size of the input arrays.
    Screma SubExp [VName] (ScremaForm rep)
  deriving (SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
/= :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c== :: forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
Eq, SOAC rep -> SOAC rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). RepTypes rep => Eq (SOAC rep)
forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall k (rep :: k).
RepTypes rep =>
SOAC rep -> SOAC rep -> Ordering
forall k (rep :: k).
RepTypes rep =>
SOAC rep -> SOAC rep -> SOAC rep
min :: SOAC rep -> SOAC rep -> SOAC rep
$cmin :: forall k (rep :: k).
RepTypes rep =>
SOAC rep -> SOAC rep -> SOAC rep
max :: SOAC rep -> SOAC rep -> SOAC rep
$cmax :: forall k (rep :: k).
RepTypes rep =>
SOAC rep -> SOAC rep -> SOAC rep
>= :: SOAC rep -> SOAC rep -> Bool
$c>= :: forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
> :: SOAC rep -> SOAC rep -> Bool
$c> :: forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
<= :: SOAC rep -> SOAC rep -> Bool
$c<= :: forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
< :: SOAC rep -> SOAC rep -> Bool
$c< :: forall k (rep :: k). RepTypes rep => SOAC rep -> SOAC rep -> Bool
compare :: SOAC rep -> SOAC rep -> Ordering
$ccompare :: forall k (rep :: k).
RepTypes rep =>
SOAC rep -> SOAC rep -> Ordering
Ord, Int -> SOAC rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> SOAC rep -> ShowS
forall k (rep :: k). RepTypes rep => [SOAC rep] -> ShowS
forall k (rep :: k). RepTypes rep => SOAC rep -> String
showList :: [SOAC rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [SOAC rep] -> ShowS
show :: SOAC rep -> String
$cshow :: forall k (rep :: k). RepTypes rep => SOAC rep -> String
showsPrec :: Int -> SOAC rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> SOAC rep -> ShowS
Show)

-- | Information about computing a single histogram.
data HistOp rep = HistOp
  { forall {k} (rep :: k). HistOp rep -> Shape
histShape :: Shape,
    -- | Race factor @RF@ means that only @1/RF@
    -- bins are used.
    forall {k} (rep :: k). HistOp rep -> SubExp
histRaceFactor :: SubExp,
    forall {k} (rep :: k). HistOp rep -> [VName]
histDest :: [VName],
    forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
Eq, HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). RepTypes rep => Eq (HistOp rep)
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Ordering
forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall k (rep :: k).
RepTypes rep =>
HistOp rep -> HistOp rep -> Ordering
Ord, Int -> HistOp rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> HistOp rep -> ShowS
forall k (rep :: k). RepTypes rep => [HistOp rep] -> ShowS
forall k (rep :: k). RepTypes rep => HistOp rep -> String
showList :: [HistOp rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall k (rep :: k). RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> HistOp rep -> ShowS
Show)

-- | The essential parts of a 'Screma' factored out (everything
-- except the input arrays).
data ScremaForm rep
  = ScremaForm
      [Scan rep]
      [Reduce rep]
      (Lambda rep)
  deriving (ScremaForm rep -> ScremaForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
/= :: ScremaForm rep -> ScremaForm rep -> Bool
$c/= :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
== :: ScremaForm rep -> ScremaForm rep -> Bool
$c== :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
Eq, ScremaForm rep -> ScremaForm rep -> Bool
ScremaForm rep -> ScremaForm rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). RepTypes rep => Eq (ScremaForm rep)
forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
min :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmin :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
max :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmax :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
>= :: ScremaForm rep -> ScremaForm rep -> Bool
$c>= :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
> :: ScremaForm rep -> ScremaForm rep -> Bool
$c> :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
<= :: ScremaForm rep -> ScremaForm rep -> Bool
$c<= :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
< :: ScremaForm rep -> ScremaForm rep -> Bool
$c< :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
compare :: ScremaForm rep -> ScremaForm rep -> Ordering
$ccompare :: forall k (rep :: k).
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
Ord, Int -> ScremaForm rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> ScremaForm rep -> ShowS
forall k (rep :: k). RepTypes rep => [ScremaForm rep] -> ShowS
forall k (rep :: k). RepTypes rep => ScremaForm rep -> String
showList :: [ScremaForm rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [ScremaForm rep] -> ShowS
show :: ScremaForm rep -> String
$cshow :: forall k (rep :: k). RepTypes rep => ScremaForm rep -> String
showsPrec :: Int -> ScremaForm rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> ScremaForm rep -> ShowS
Show)

singleBinOp :: Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp :: forall {k} (rep :: k). Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp [Lambda rep]
lams =
  Lambda
    { lambdaParams :: [LParam rep]
lambdaParams = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
xParams [Lambda rep]
lams forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
yParams [Lambda rep]
lams,
      lambdaReturnType :: [Type]
lambdaReturnType = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType [Lambda rep]
lams,
      lambdaBody :: Body rep
lambdaBody =
        forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
          (forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams))
          (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Body rep -> Result
bodyResult forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams)
    }
  where
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)

-- | How to compute a single scan result.
data Scan rep = Scan
  { forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda :: Lambda rep,
    forall {k} (rep :: k). Scan rep -> [SubExp]
scanNeutral :: [SubExp]
  }
  deriving (Scan rep -> Scan rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
/= :: Scan rep -> Scan rep -> Bool
$c/= :: forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
== :: Scan rep -> Scan rep -> Bool
$c== :: forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
Eq, Scan rep -> Scan rep -> Bool
Scan rep -> Scan rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). RepTypes rep => Eq (Scan rep)
forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
forall k (rep :: k).
RepTypes rep =>
Scan rep -> Scan rep -> Ordering
forall k (rep :: k).
RepTypes rep =>
Scan rep -> Scan rep -> Scan rep
min :: Scan rep -> Scan rep -> Scan rep
$cmin :: forall k (rep :: k).
RepTypes rep =>
Scan rep -> Scan rep -> Scan rep
max :: Scan rep -> Scan rep -> Scan rep
$cmax :: forall k (rep :: k).
RepTypes rep =>
Scan rep -> Scan rep -> Scan rep
>= :: Scan rep -> Scan rep -> Bool
$c>= :: forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
> :: Scan rep -> Scan rep -> Bool
$c> :: forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
<= :: Scan rep -> Scan rep -> Bool
$c<= :: forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
< :: Scan rep -> Scan rep -> Bool
$c< :: forall k (rep :: k). RepTypes rep => Scan rep -> Scan rep -> Bool
compare :: Scan rep -> Scan rep -> Ordering
$ccompare :: forall k (rep :: k).
RepTypes rep =>
Scan rep -> Scan rep -> Ordering
Ord, Int -> Scan rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> Scan rep -> ShowS
forall k (rep :: k). RepTypes rep => [Scan rep] -> ShowS
forall k (rep :: k). RepTypes rep => Scan rep -> String
showList :: [Scan rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [Scan rep] -> ShowS
show :: Scan rep -> String
$cshow :: forall k (rep :: k). RepTypes rep => Scan rep -> String
showsPrec :: Int -> Scan rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> Scan rep -> ShowS
Show)

-- | How many reduction results are produced by these 'Scan's?
scanResults :: [Scan rep] -> Int
scanResults :: forall {k} (rep :: k). [Scan rep] -> Int
scanResults = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Scan rep -> [SubExp]
scanNeutral)

-- | Combine multiple scan operators to a single operator.
singleScan :: Buildable rep => [Scan rep] -> Scan rep
singleScan :: forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans =
  let scan_nes :: [SubExp]
scan_nes = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Scan rep -> [SubExp]
scanNeutral [Scan rep]
scans
      scan_lam :: Lambda rep
scan_lam = forall {k} (rep :: k). Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda [Scan rep]
scans
   in forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
scan_lam [SubExp]
scan_nes

-- | How to compute a single reduction result.
data Reduce rep = Reduce
  { forall {k} (rep :: k). Reduce rep -> Commutativity
redComm :: Commutativity,
    forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda :: Lambda rep,
    forall {k} (rep :: k). Reduce rep -> [SubExp]
redNeutral :: [SubExp]
  }
  deriving (Reduce rep -> Reduce rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
/= :: Reduce rep -> Reduce rep -> Bool
$c/= :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
== :: Reduce rep -> Reduce rep -> Bool
$c== :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
Eq, Reduce rep -> Reduce rep -> Bool
Reduce rep -> Reduce rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall k (rep :: k). RepTypes rep => Eq (Reduce rep)
forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Ordering
forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Reduce rep
min :: Reduce rep -> Reduce rep -> Reduce rep
$cmin :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Reduce rep
max :: Reduce rep -> Reduce rep -> Reduce rep
$cmax :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Reduce rep
>= :: Reduce rep -> Reduce rep -> Bool
$c>= :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
> :: Reduce rep -> Reduce rep -> Bool
$c> :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
<= :: Reduce rep -> Reduce rep -> Bool
$c<= :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
< :: Reduce rep -> Reduce rep -> Bool
$c< :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Bool
compare :: Reduce rep -> Reduce rep -> Ordering
$ccompare :: forall k (rep :: k).
RepTypes rep =>
Reduce rep -> Reduce rep -> Ordering
Ord, Int -> Reduce rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> Reduce rep -> ShowS
forall k (rep :: k). RepTypes rep => [Reduce rep] -> ShowS
forall k (rep :: k). RepTypes rep => Reduce rep -> String
showList :: [Reduce rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [Reduce rep] -> ShowS
show :: Reduce rep -> String
$cshow :: forall k (rep :: k). RepTypes rep => Reduce rep -> String
showsPrec :: Int -> Reduce rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> Reduce rep -> ShowS
Show)

-- | How many reduction results are produced by these 'Reduce's?
redResults :: [Reduce rep] -> Int
redResults :: forall {k} (rep :: k). [Reduce rep] -> Int
redResults = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Reduce rep -> [SubExp]
redNeutral)

-- | Combine multiple reduction operators to a single operator.
singleReduce :: Buildable rep => [Reduce rep] -> Reduce rep
singleReduce :: forall {k} (rep :: k). Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds =
  let red_nes :: [SubExp]
red_nes = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
      red_lam :: Lambda rep
red_lam = forall {k} (rep :: k). Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda [Reduce rep]
reds
   in forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce (forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). Reduce rep -> Commutativity
redComm [Reduce rep]
reds)) Lambda rep
red_lam [SubExp]
red_nes

-- | The types produced by a single 'Screma', given the size of the
-- input array.
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType :: forall {k} (rep :: k). SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  [Type]
scan_tps forall a. [a] -> [a] -> [a]
++ [Type]
red_tps forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
  where
    scan_tps :: [Type]
scan_tps =
      forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
    red_tps :: [Type]
red_tps = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
    map_tps :: [Type]
map_tps = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam

-- | Construct a lambda that takes parameters of the given types and
-- simply returns them unchanged.
mkIdentityLambda ::
  (Buildable rep, MonadFreshNames m) =>
  [Type] ->
  m (Lambda rep)
mkIdentityLambda :: forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts = do
  [Param Type]
params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda
      { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
params,
        lambdaBody :: Body rep
lambdaBody = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
params,
        lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
      }

-- | Is the given lambda an identity lambda?
isIdentityLambda :: Lambda rep -> Bool
isIdentityLambda :: forall {k} (rep :: k). Lambda rep -> Bool
isIdentityLambda Lambda rep
lam =
  forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
    forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall dec. Param dec -> VName
paramName) (forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)

-- | A lambda with no parameters that returns no values.
nilFn :: Buildable rep => Lambda rep
nilFn :: forall {k} (rep :: k). Buildable rep => Lambda rep
nilFn = forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda forall a. Monoid a => a
mempty (forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty) forall a. Monoid a => a
mempty

-- | Construct a Screma with possibly multiple scans, and
-- the given map function.
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC :: forall {k} (rep :: k). [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans = forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scans []

-- | Construct a Screma with possibly multiple reductions, and
-- the given map function.
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC :: forall {k} (rep :: k). [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC = forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm []

-- | Construct a Screma with possibly multiple scans, and identity map
-- function.
scanSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Scan rep] ->
  m (ScremaForm rep)
scanSOAC :: forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan rep]
scans = forall {k} (rep :: k). [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans

-- | Construct a Screma with possibly multiple reductions, and
-- identity map function.
reduceSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Reduce rep] ->
  m (ScremaForm rep)
reduceSOAC :: forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds = forall {k} (rep :: k). [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Reduce rep]
reds forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds

-- | Construct a Screma corresponding to a map.
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC :: forall {k} (rep :: k). Lambda rep -> ScremaForm rep
mapSOAC = forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] []

-- | Does this Screma correspond to a scan-map composition?
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC :: forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Scan rep]
scans, Lambda rep
map_lam)

-- | Does this Screma correspond to pure scan?
isScanSOAC :: ScremaForm rep -> Maybe [Scan rep]
isScanSOAC :: forall {k} (rep :: k). ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm rep
form = do
  ([Scan rep]
scans, Lambda rep
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm rep
form
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [Scan rep]
scans

-- | Does this Screma correspond to a reduce-map composition?
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC :: forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Reduce rep]
reds, Lambda rep
map_lam)

-- | Does this Screma correspond to a pure reduce?
isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC :: forall {k} (rep :: k). ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm rep
form = do
  ([Reduce rep]
reds, Lambda rep
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [Reduce rep]
reds

-- | Does this Screma correspond to a simple map, without any
-- reduction or scan results?
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC :: forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
map_lam

-- | Return the "main" lambda of the Screma.  For a map, this is
-- equivalent to 'isMapSOAC'.  Note that the meaning of the return
-- value of this lambda depends crucially on exactly which Screma this
-- is.  The parameters will correspond exactly to elements of the
-- input arrays, however.
scremaLambda :: ScremaForm rep -> Lambda rep
scremaLambda :: forall {k} (rep :: k). ScremaForm rep -> Lambda rep
scremaLambda (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
map_lam) = Lambda rep
map_lam

-- | @groupScatterResults <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- <output specification>.
--
-- This function is used for extracting and grouping the results of a
-- scatter. In the SOAC representation, the lambda inside a 'Scatter' returns
-- all indices and values as one big list. This function groups each value with
-- its corresponding indices (as determined by the t'Shape' of the output array).
--
-- The elements of the resulting list correspond to the shape and name of the
-- output parameters, in addition to a list of values written to that output
-- parameter, along with the array indices marking where to write them to.
--
-- See 'Scatter' for more information.
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
arrays) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
   in forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results
        forall a b. a -> (a -> b) -> b
& forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
        forall a b. a -> (a -> b) -> b
& forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays

-- | @groupScatterResults' <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- output specification. This is the simpler version of @groupScatterResults@,
-- which doesn't return any information about shapes or output arrays.
--
-- See 'groupScatterResults' for more information,
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
  let ([a]
indices, [a]
values) = forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
      ([Shape]
shapes, [Int]
ns, [array]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      chunk_sizes :: [Int]
chunk_sizes =
        forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> forall a. Int -> a -> [a]
replicate Int
n forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
   in forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values

-- | @splitScatterResults <output specification> <results>@
--
-- Splits the results array into indices and values according to the output
-- specification.
--
-- See 'groupScatterResults' for more information.
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      num_indices :: Int
num_indices = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
   in forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results

-- | Like 'Mapper', but just for 'SOAC's.
data SOACMapper frep trep m = SOACMapper
  { forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
    forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda :: Lambda frep -> m (Lambda trep),
    forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
  }

-- | A mapper that simply returns the SOAC verbatim.
identitySOACMapper :: Monad m => SOACMapper rep rep m
identitySOACMapper :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
SOACMapper rep rep m
identitySOACMapper =
  SOACMapper
    { mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSOACVName :: VName -> m VName
mapOnSOACVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }

-- | Map a monadic action across the immediate children of a
-- SOAC.  The mapping does not descend recursively into subexpressions
-- and is done left-to-right.
mapSOACM ::
  Monad m =>
  SOACMapper frep trep m ->
  SOAC frep ->
  m (SOAC trep)
mapSOACM :: forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper frep trep m
tv (JVP Lambda frep
lam [SubExp]
args [SubExp]
vec) =
  forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
mapSOACM SOACMapper frep trep m
tv (VJP Lambda frep
lam [SubExp]
args [SubExp]
vec) =
  forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
mapSOACM SOACMapper frep trep m
tv (Stream SubExp
size [VName]
arrs [SubExp]
accs Lambda frep
lam) =
  forall {k} (rep :: k).
SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
size
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
accs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
mapSOACM SOACMapper frep trep m
tv (Scatter SubExp
w [VName]
ivs Lambda frep
lam [(Shape, Int, VName)]
as) =
  forall {k} (rep :: k).
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
ivs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(Shape
aw, Int
an, VName
a) ->
          (,,)
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
aw
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv VName
a
      )
      [(Shape, Int, VName)]
as
mapSOACM SOACMapper frep trep m
tv (Hist SubExp
w [VName]
arrs [HistOp frep]
ops Lambda frep
bucket_fun) =
  forall {k} (rep :: k).
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(HistOp Shape
shape SubExp
rf [VName]
op_arrs [SubExp]
nes Lambda frep
op) ->
          forall {k} (rep :: k).
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
shape
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
rf
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
op_arrs
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
nes
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
op
      )
      [HistOp frep]
ops
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
bucket_fun
mapSOACM SOACMapper frep trep m
tv (Screma SubExp
w [VName]
arrs (ScremaForm [Scan frep]
scans [Reduce frep]
reds Lambda frep
map_lam)) =
  forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Scan frep]
scans
              ( \(Scan Lambda frep
red_lam [SubExp]
red_nes) ->
                  forall {k} (rep :: k). Lambda rep -> [SubExp] -> Scan rep
Scan
                    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Reduce frep]
reds
              ( \(Reduce Commutativity
comm Lambda frep
red_lam [SubExp]
red_nes) ->
                  forall {k} (rep :: k).
Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm
                    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
map_lam
        )

-- | A helper for defining 'TraverseOpStms'.
traverseSOACStms :: Monad m => OpStmsTraverser m (SOAC rep) rep
traverseSOACStms :: forall {k} (m :: * -> *) (rep :: k).
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms Scope rep -> Stms rep -> m (Stms rep)
f = forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep m
mapper
  where
    mapper :: SOACMapper rep rep m
mapper = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = forall {k} (m :: * -> *) (rep :: k).
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f}

instance ASTRep rep => FreeIn (Scan rep) where
  freeIn' :: Scan rep -> FV
freeIn' (Scan Lambda rep
lam [SubExp]
ne) = forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne

instance ASTRep rep => FreeIn (Reduce rep) where
  freeIn' :: Reduce rep -> FV
freeIn' (Reduce Commutativity
_ Lambda rep
lam [SubExp]
ne) = forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne

instance ASTRep rep => FreeIn (ScremaForm rep) where
  freeIn' :: ScremaForm rep -> FV
freeIn' (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
lam) =
    forall a. FreeIn a => a -> FV
freeIn' [Scan rep]
scans forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [Reduce rep]
reds forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam

instance ASTRep rep => FreeIn (HistOp rep) where
  freeIn' :: HistOp rep -> FV
freeIn' (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
lam) =
    forall a. FreeIn a => a -> FV
freeIn' Shape
w forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
rf forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [VName]
dests forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [SubExp]
nes forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam

instance ASTRep rep => FreeIn (SOAC rep) where
  freeIn' :: SOAC rep -> FV
freeIn' = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState forall a. Monoid a => a
mempty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep (StateT FV Identity)
free
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
      free :: SOACMapper rep rep (StateT FV Identity)
free =
        SOACMapper
          { mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSOACLambda = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn'
          }

instance ASTRep rep => Substitute (SOAC rep) where
  substituteNames :: Map VName VName -> SOAC rep -> SOAC rep
substituteNames Map VName VName
subst =
    forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep Identity
substitute
    where
      substitute :: SOACMapper rep rep Identity
substitute =
        SOACMapper
          { mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance ASTRep rep => Rename (SOAC rep) where
  rename :: SOAC rep -> RenameM (SOAC rep)
rename = forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep RenameM
renamer
    where
      renamer :: SOACMapper rep rep RenameM
renamer = forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename

-- | The type of a SOAC.
soacType :: Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType :: forall {k} (rep :: k). Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType (JVP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
  forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    forall a. [a] -> [a] -> [a]
++ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
soacType (VJP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
  forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
soacType (Stream SubExp
outersize [VName]
_ [SubExp]
accs Lambda rep
lam) =
  forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
  where
    nms :: [VName]
nms = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (Int
1 forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [LParam rep]
params
    substs :: Map VName SubExp
substs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize forall a. a -> [a] -> [a]
: [SubExp]
accs)
    Lambda [LParam rep]
params Body rep
_ [Type]
rtp = Lambda rep
lam
soacType (Scatter SubExp
_w [VName]
_ivs Lambda rep
lam [(Shape, Int, VName)]
dests) =
  forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
  where
    indexes :: Int
indexes = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
    val_ts :: [Type]
val_ts = forall a. Int -> [a] -> [a]
drop Int
indexes forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    ([Shape]
ws, [Int]
ns, [VName]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
soacType (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_bucket_fun) = do
  HistOp rep
op <- [HistOp rep]
ops
  forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` forall {k} (rep :: k). HistOp rep -> Shape
histShape HistOp rep
op) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp HistOp rep
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm rep
form) =
  forall {k} (rep :: k). SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form

instance ASTRep rep => TypedOp (SOAC rep) where
  opType :: forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SOAC rep -> m [ExtType]
opType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType

instance (ASTRep rep, Aliased rep) => AliasedOp (SOAC rep) where
  opAliases :: SOAC rep -> [Names]
opAliases = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType

  consumedInOp :: SOAC rep -> Names
consumedInOp JVP {} = forall a. Monoid a => a
mempty
  consumedInOp VJP {} = forall a. Monoid a => a
mempty
  -- Only map functions can consume anything.  The operands to scan
  -- and reduce functions are always considered "fresh".
  consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
map_lam)) =
    (VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
map_lam
    where
      consumedArray :: VName -> VName
consumedArray VName
v = forall a. a -> Maybe a -> a
fromMaybe VName
v forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
      params_to_arrs :: [(VName, VName)]
params_to_arrs = forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
map_lam) [VName]
arrs
  consumedInOp (Stream SubExp
_ [VName]
arrs [SubExp]
accs Lambda rep
lam) =
    [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
    where
      consumedArray :: VName -> SubExp
consumedArray VName
v = forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
      -- Drop the chunk parameter, which cannot alias anything.
      paramsToInput :: [(VName, SubExp)]
paramsToInput =
        forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$ forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam) ([SubExp]
accs forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
  consumedInOp (Scatter SubExp
_ [VName]
_ Lambda rep
_ [(Shape, Int, VName)]
as) =
    [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) [(Shape, Int, VName)]
as
  consumedInOp (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_) =
    [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {k} (rep :: k). HistOp rep -> [VName]
histDest [HistOp rep]
ops

mapHistOp ::
  (Lambda frep -> Lambda trep) ->
  HistOp frep ->
  HistOp trep
mapHistOp :: forall {k} {k} (frep :: k) (trep :: k).
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp Lambda frep -> Lambda trep
f (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda frep
lam) =
  forall {k} (rep :: k).
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes forall a b. (a -> b) -> a -> b
$ Lambda frep -> Lambda trep
f Lambda frep
lam

instance
  ( ASTRep rep,
    ASTRep (Aliases rep),
    CanBeAliased (Op rep)
  ) =>
  CanBeAliased (SOAC rep)
  where
  type OpWithAliases (SOAC rep) = SOAC (Aliases rep)

  addOpAliases :: AliasTable -> SOAC rep -> OpWithAliases (SOAC rep)
addOpAliases AliasTable
aliases (JVP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [SubExp]
args [SubExp]
vec
  addOpAliases AliasTable
aliases (VJP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    forall {k} (rep :: k).
Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [SubExp]
args [SubExp]
vec
  addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr [SubExp]
accs Lambda rep
lam) =
    forall {k} (rep :: k).
SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
arr [SubExp]
accs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam
  addOpAliases AliasTable
aliases (Scatter SubExp
len [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests) =
    forall {k} (rep :: k).
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len [VName]
arrs (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [(Shape, Int, VName)]
dests
  addOpAliases AliasTable
aliases (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
    forall {k} (rep :: k).
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
      SubExp
w
      [VName]
arrs
      (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} {k} (frep :: k) (trep :: k).
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)) [HistOp rep]
ops)
      (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
bucket_fun)
  addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    forall {k} (rep :: k).
SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k).
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
        (forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Scan (Aliases rep)
onScan [Scan rep]
scans)
        (forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Reduce (Aliases rep)
onRed [Reduce rep]
reds)
        (forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
map_lam)
    where
      onRed :: Reduce rep -> Reduce (Aliases rep)
onRed Reduce rep
red = Reduce rep
red {redLambda :: Lambda (Aliases rep)
redLambda = forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda Reduce rep
red}
      onScan :: Scan rep -> Scan (Aliases rep)
onScan Scan rep
scan = Scan rep
scan {scanLambda :: Lambda (Aliases rep)
scanLambda = forall {k} (rep :: k).
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda Scan rep
scan}

  removeOpAliases :: OpWithAliases (SOAC rep) -> SOAC rep
removeOpAliases = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Aliases rep) rep Identity
remove
    where
      remove :: SOACMapper (Aliases rep) rep Identity
remove = forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
removeLambdaAliases) forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance ASTRep rep => IsOp (SOAC rep) where
  safeOp :: SOAC rep -> Bool
safeOp SOAC rep
_ = Bool
False
  cheapOp :: SOAC rep -> Bool
cheapOp SOAC rep
_ = Bool
False

substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
  let shp' :: Shape
shp' = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
   in forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u

substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
  forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs

instance (ASTRep rep, CanBeWise (Op rep)) => CanBeWise (SOAC rep) where
  type OpWithWisdom (SOAC rep) = SOAC (Wise rep)

  removeOpWisdom :: OpWithWisdom (SOAC rep) -> SOAC rep
removeOpWisdom = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
CanBeWise (Op rep) =>
Lambda (Wise rep) -> Lambda rep
removeLambdaWisdom) forall (f :: * -> *) a. Applicative f => a -> f a
pure)
  addOpWisdom :: SOAC rep -> OpWithWisdom (SOAC rep)
addOpWisdom = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} {k} (m :: * -> *) (frep :: k) (trep :: k).
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM (forall {k} {k} (frep :: k) (trep :: k) (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k).
Informing rep =>
Lambda rep -> Lambda (Wise rep)
informLambda) forall (f :: * -> *) a. Applicative f => a -> f a
pure)

instance RepTypes rep => ST.IndexOp (SOAC rep) where
  indexOp :: forall {k} (rep :: k).
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SOAC rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k SOAC rep
soac [TPrimExp Int64 VName
i] = do
    (Lambda rep
lam, SubExpRes
se, [Param (LParamInfo rep)]
arr_params, [VName]
arrs) <- SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp SOAC rep
soac
    let arr_indexes :: Map VName (PrimExp VName, Certs)
arr_indexes = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex [Param (LParamInfo rep)]
arr_params [VName]
arrs
        arr_indexes' :: Map VName (PrimExp VName, Certs)
arr_indexes' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
arr_indexes forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    case SubExpRes
se of
      SubExpRes Certs
_ (Var VName
v) -> forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip Certs -> PrimExp VName -> Indexed
ST.Indexed) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
arr_indexes'
      SubExpRes
_ -> forall a. Maybe a
Nothing
    where
      lambdaAndSubExp :: SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
        Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut (forall {k} (rep :: k). [Scan rep] -> Int
scanResults [Scan rep]
scans forall a. Num a => a -> a -> a
+ forall {k} (rep :: k). [Reduce rep] -> Int
redResults [Reduce rep]
reds) Lambda rep
map_lam [VName]
arrs
      lambdaAndSubExp SOAC rep
_ =
        forall a. Maybe a
Nothing

      nthMapOut :: Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut Int
num_accs Lambda rep
lam [VName]
arrs = do
        SubExpRes
se <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs forall a. Num a => a -> a -> a
+ Int
k) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam, SubExpRes
se, forall a. Int -> [a] -> [a]
drop Int
num_accs forall a b. (a -> b) -> a -> b
$ forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam, [VName]
arrs)

      arrIndex :: Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex Param (LParamInfo rep)
p VName
arr = do
        ST.Indexed Certs
cs PrimExp VName
pe <- forall {k} (rep :: k).
VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable rep
vtable
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p, (PrimExp VName
pe, Certs
cs))

      expandPrimExpTable :: Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
table Stm rep
stm
        | [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall {k} (m :: * -> *) (rep :: k) v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm,
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable rep
vtable) (Certs -> [VName]
unCerts forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Certs
stmCerts Stm rep
stm) =
            forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, forall {k} (rep :: k). Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs) Map VName (PrimExp VName, Certs)
table
        | Bool
otherwise =
            Map VName (PrimExp VName, Certs)
table

      asPrimExp :: Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table VName
v
        | Just (PrimExp VName
e, Certs
cs) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
table = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
        | Just (Prim PrimType
pt) <- forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SOAC rep
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing

-- | Type-check a SOAC.
typeCheckSOAC :: TC.Checkable rep => SOAC (Aliases rep) -> TC.TypeM rep ()
typeCheckSOAC :: forall {k} (rep :: k).
Checkable rep =>
SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC (VJP Lambda (Aliases rep)
lam [SubExp]
args [SubExp]
vec) = do
  [Arg]
args' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
  forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
  [Type]
vec_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam) forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
      Doc Any
"Return type"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam))
        forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (JVP Lambda (Aliases rep)
lam [SubExp]
args [SubExp]
vec) = do
  [Arg]
args' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
  forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
  [Type]
vec_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args') forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
      Doc Any
"Parameter type"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args')
        forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (Stream SubExp
size [VName]
arrexps [SubExp]
accexps Lambda (Aliases rep)
lam) = do
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  [Arg]
accargs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
accexps
  [Type]
arrargs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType [VName]
arrexps
  [Arg]
_ <- forall {k} (rep :: k).
Checkable rep =>
SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
  let chunk :: Param (LParamInfo rep)
chunk = forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall {k} {rep :: k}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda (Aliases rep)
lam
  let asArg :: a -> (a, b)
asArg a
t = (a
t, forall a. Monoid a => a
mempty)
      inttp :: TypeBase shape u
inttp = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      lamarrs' :: [Type]
lamarrs' = forall a b. (a -> b) -> [a] -> [b]
map (forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
chunk)) [Type]
arrargs
  let acc_len :: Int
acc_len = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
  let lamrtp :: [Type]
lamrtp = forall a. Int -> [a] -> [a]
take Int
acc_len forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
      Text
"Stream with inconsistent accumulator type in lambda."
  -- just get the dflow of lambda on the fakearg, which does not alias
  -- arr, so we can later check that aliases of arr are not used inside lam.
  let fake_lamarrs' :: [Arg]
fake_lamarrs' = forall a b. (a -> b) -> [a] -> [b]
map forall {b} {a}. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
  forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall {b} {a}. Monoid b => a -> (a, b)
asArg forall {shape} {u}. TypeBase shape u
inttp forall a. a -> [a] -> [a]
: [Arg]
accargs forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w [VName]
arrs Lambda (Aliases rep)
lam [(Shape, Int, VName)]
as) = do
  -- Requirements:
  --
  --   0. @lambdaReturnType@ of @lam@ must be a list
  --      [index types..., value types, ...].
  --
  --   1. The number of index types and value types must be equal to the number
  --      of return values from @lam@.
  --
  --   2. Each index type must have the type i64.
  --
  --   3. Each array in @as@ and the value types must have the same type
  --
  --   4. Each array in @as@ is consumed.  This is not really a check, but more
  --      of a requirement, so that e.g. the source is not hoisted out of a
  --      loop, which will mean it cannot be consumed.
  --
  --   5. Each of arrs must be an array matching a corresponding lambda
  --      parameters.
  --
  -- Code:

  -- First check the input size.
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- 0.
  let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
      indexes :: Int
indexes = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
as_ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
      rts :: [Type]
rts = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
      rtsI :: [Type]
rtsI = forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
      rtsV :: [Type]
rtsV = forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts

  -- 1.
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
as_ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError Text
"Scatter: number of index types, value types and array outputs do not match."

  -- 2.
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 forall a. Eq a => a -> a -> Bool
== Type
rtI) forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError Text
"Scatter: Index return type must be i64."

  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(Shape, Int, VName)]
as) forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
    -- All lengths must have type i64.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw

    -- 3.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs forall a b. (a -> b) -> a -> b
$ \Type
rtV -> forall {k} (rep :: k).
Checkable rep =>
[Type] -> VName -> TypeM rep ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a

    -- 4.
    forall {k} (rep :: k). Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k). Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
a

  -- 5.
  [Arg]
arrargs <- forall {k} (rep :: k).
Checkable rep =>
SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
w [VName]
arrs [HistOp (Aliases rep)]
ops Lambda (Aliases rep)
bucket_fun) = do
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- Check the operators.
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases rep)]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases rep)
op) -> do
    [Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
dest_shape
    forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf

    -- Operator type must match the type of neutral elements.
    forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
    let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
        Text
"Operator has return type "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
          forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t

    -- Arrays must have proper type.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
      forall {k} (rep :: k).
Checkable rep =>
[Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape] VName
dest
      forall {k} (rep :: k). Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k). Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

  -- Types of input arrays must equal parameter types for bucket function.
  [Arg]
img' <- forall {k} (rep :: k).
Checkable rep =>
SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
bucket_fun [Arg]
img'

  -- Return type of bucket function must be an index for each
  -- operation followed by the values to write.
  [Type]
nes_ts <- forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> [SubExp]
histNeutral) [HistOp (Aliases rep)]
ops
  let bucket_ret_t :: [Type]
bucket_ret_t =
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((forall a. Int -> a -> [a]
`replicate` forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> Shape
histShape) [HistOp (Aliases rep)]
ops
          forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun) forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
      Text
"Bucket function has return type "
        forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun)
        forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
        forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Aliases rep)]
scans [Reduce (Aliases rep)]
reds Lambda (Aliases rep)
map_lam)) = do
  forall {k} (rep :: k).
Checkable rep =>
[Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  [Arg]
arrs' <- forall {k} (rep :: k).
Checkable rep =>
SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
map_lam [Arg]
arrs'

  [Arg]
scan_nes' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases rep)]
scans forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases rep)
scan_lam [SubExp]
scan_nes) -> do
      [Arg]
scan_nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
scan_nes
      let scan_t :: [Type]
scan_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
      forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Scan function returns type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam)
            forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
scan_t
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
scan_nes'

  [Arg]
red_nes' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases rep)]
reds forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases rep)
red_lam [SubExp]
red_nes) -> do
      [Arg]
red_nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k). Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
red_nes
      let red_t :: [Type]
red_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
      forall {k} (rep :: k).
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
red_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t forall a. Eq a => a -> a -> Bool
== forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam) forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Reduce function returns type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam)
            forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
red_t
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
red_nes'

  let map_lam_ts :: [Type]
map_lam_ts = forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
map_lam

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    ( forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
        forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
    )
    forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k) a. ErrorCase rep -> TypeM rep a
TC.bad
    forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Text -> ErrorCase rep
TC.TypeError
    forall a b. (a -> b) -> a -> b
$ Text
"Map function return type "
      forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
map_lam_ts
      forall a. Semigroup a => a -> a -> a
<> Text
" wrong for given scan and reduction functions."

instance OpMetrics (Op rep) => OpMetrics (SOAC rep) where
  opMetrics :: SOAC rep -> MetricsM ()
opMetrics (VJP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"VJP" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (JVP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"JVP" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Stream SubExp
_ [VName]
_ [SubExp]
_ Lambda rep
lam) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Scatter SubExp
_len [VName]
_ Lambda rep
lam [(Shape, Int, VName)]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
bucket_fun) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
bucket_fun
  opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (rep :: k). Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
      forall {k} (rep :: k).
OpMetrics (Op rep) =>
Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
map_lam

instance PrettyRep rep => PP.Pretty (SOAC rep) where
  pretty :: forall ann. SOAC rep -> Doc ann
pretty (VJP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Doc ann
"vjp"
      forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens
        ( forall ann. Doc ann -> Doc ann
PP.align forall a b. (a -> b) -> a -> b
$
            forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
args) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
vec)
        )
  pretty (JVP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Doc ann
"jvp"
      forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens
        ( forall ann. Doc ann -> Doc ann
PP.align forall a b. (a -> b) -> a -> b
$
            forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
args) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
vec)
        )
  pretty (Stream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam) =
    forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam
  pretty (Scatter SubExp
w [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests) =
    forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter SubExp
w [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests
  pretty (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
    forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun
  pretty (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam))
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans,
      forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
        Doc ann
"map"
          forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
            ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
            )
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans =
        Doc ann
"redomap"
          forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
            ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Reduce rep]
reds) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
            )
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
        Doc ann
"scanomap"
          forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
            ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Scan rep]
scans) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
            )
  pretty (Screma SubExp
w [VName]
arrs ScremaForm rep
form) = forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [VName]
arrs ScremaForm rep
form

-- | Prettyprint the given Screma.
ppScrema ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema :: forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [inp]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  Doc ann
"screma"
    forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Scan rep]
scans) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Reduce rep]
reds) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
      )

-- | Prettyprint the given Stream.
ppStream ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream :: forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [inp]
arrs [SubExp]
acc Lambda rep
lam =
  Doc ann
"streamSeq"
    forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
size forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
acc) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam
      )

-- | Prettyprint the given Scatter.
ppScatter ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter :: forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter SubExp
w [inp]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests =
  Doc ann
"scatter"
    forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
commasep (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [(Shape, Int, VName)]
dests)
      )

instance PrettyRep rep => Pretty (Scan rep) where
  pretty :: forall ann. Scan rep -> Doc ann
pretty (Scan Lambda rep
scan_lam [SubExp]
scan_nes) =
    forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
scan_lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
scan_nes)

ppComm :: Commutativity -> Doc ann
ppComm :: forall ann. Commutativity -> Doc ann
ppComm Commutativity
Noncommutative = forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = Doc ann
"commutative "

instance PrettyRep rep => Pretty (Reduce rep) where
  pretty :: forall ann. Reduce rep -> Doc ann
pretty (Reduce Commutativity
comm Lambda rep
red_lam [SubExp]
red_nes) =
    forall ann. Commutativity -> Doc ann
ppComm Commutativity
comm forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
red_lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
red_nes)

-- | Prettyprint the given histogram operation.
ppHist ::
  (PrettyRep rep, Pretty inp) =>
  SubExp ->
  [inp] ->
  [HistOp rep] ->
  Lambda rep ->
  Doc ann
ppHist :: forall {k} (rep :: k) inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [inp]
arrs [HistOp rep]
ops Lambda rep
bucket_fun =
  Doc ann
"hist"
    forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} {rep :: k} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
bucket_fun
      )
  where
    ppOp :: HistOp rep -> Doc ann
ppOp (HistOp Shape
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
      forall a ann. Pretty a => a -> Doc ann
pretty Shape
dest_w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
rf forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
dests) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
op