{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Definition of /Second-Order Array Combinators/ (SOACs), which are
-- the main form of parallelism in the early stages of the compiler.
module Futhark.IR.SOACS.SOAC
  ( SOAC (..),
    ScremaForm (..),
    HistOp (..),
    Scan (..),
    scanResults,
    singleScan,
    Reduce (..),
    redResults,
    singleReduce,

    -- * Utility
    scremaType,
    soacType,
    typeCheckSOAC,
    mkIdentityLambda,
    isIdentityLambda,
    nilFn,
    scanomapSOAC,
    redomapSOAC,
    scanSOAC,
    reduceSOAC,
    mapSOAC,
    isScanomapSOAC,
    isRedomapSOAC,
    isScanSOAC,
    isReduceSOAC,
    isMapSOAC,
    scremaLambda,
    ppScrema,
    ppHist,
    ppStream,
    ppScatter,
    groupScatterResults,
    groupScatterResults',
    splitScatterResults,

    -- * Generic traversal
    SOACMapper (..),
    identitySOACMapper,
    mapSOACM,
    traverseSOACStms,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.Alias qualified as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, CanBeAliased (..))
import Futhark.IR.Prop.Aliases
import Futhark.IR.TypeCheck qualified as TC
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty (Doc, align, comma, commasep, docText, parens, ppTuple', pretty, (<+>), (</>))
import Futhark.Util.Pretty qualified as PP
import Prelude hiding (id, (.))

-- | A second-order array combinator (SOAC).
data SOAC rep
  = Stream SubExp [VName] [SubExp] (Lambda rep)
  | -- | @Scatter <length> <inputs> <lambda> <outputs>@
    --
    -- Scatter maps values from a set of input arrays to indices and values of a
    -- set of output arrays. It is able to write multiple values to multiple
    -- outputs each of which may have multiple dimensions.
    --
    -- <inputs> is a list of input arrays, all having size <length>, elements of
    -- which are applied to the <lambda> function. For instance, if there are
    -- two arrays, <lambda> will get two values as input, one from each array.
    --
    -- <outputs> specifies the result of the <lambda> and which arrays to write
    -- to. Each element of the list consists of a <VName> specifying which array
    -- to scatter to, a <Shape> describing the shape of that array, and an <Int>
    -- describing how many elements should be written to that array for each
    -- invocation of the <lambda>.
    --
    -- <lambda> is a function that takes inputs from <inputs> and returns values
    -- according to the output-specification in <outputs>. It returns values in
    -- the following manner:
    --
    --     [index_0, index_1, ..., index_n, value_0, value_1, ..., value_m]
    --
    -- For each output in <outputs>, <lambda> returns <i> * <j> index values and
    -- <j> output values, where <i> is the number of dimensions (rank) of the
    -- given output, and <j> is the number of output values written to the given
    -- output.
    --
    -- For example, given the following output specification:
    --
    --     [([x1, y1, z1], 2, arr1), ([x2, y2], 1, arr2)]
    --
    -- <lambda> will produce 6 (3 * 2) index values and 2 output values for
    -- <arr1>, and 2 (2 * 1) index values and 1 output value for
    -- arr2. Additionally, the results are grouped, so the first 6 index values
    -- will correspond to the first two output values, and so on. For this
    -- example, <lambda> should return a total of 11 values, 8 index values and
    -- 3 output values.  See also 'splitScatterResults'.
    Scatter SubExp [VName] (Lambda rep) [(Shape, Int, VName)]
  | -- | @Hist <length> <input arrays> <dest-arrays-and-ops> <bucket fun>@
    --
    -- The final lambda produces indexes and values for the 'HistOp's.
    Hist SubExp [VName] [HistOp rep] (Lambda rep)
  | -- FIXME: this should not be here
    JVP (Lambda rep) [SubExp] [SubExp]
  | -- FIXME: this should not be here
    VJP (Lambda rep) [SubExp] [SubExp]
  | -- | A combination of scan, reduction, and map.  The first
    -- t'SubExp' is the size of the input arrays.
    Screma SubExp [VName] (ScremaForm rep)
  deriving (SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
Eq, SOAC rep -> SOAC rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
min :: SOAC rep -> SOAC rep -> SOAC rep
$cmin :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
max :: SOAC rep -> SOAC rep -> SOAC rep
$cmax :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
>= :: SOAC rep -> SOAC rep -> Bool
$c>= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
> :: SOAC rep -> SOAC rep -> Bool
$c> :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
<= :: SOAC rep -> SOAC rep -> Bool
$c<= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
< :: SOAC rep -> SOAC rep -> Bool
$c< :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
compare :: SOAC rep -> SOAC rep -> Ordering
$ccompare :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
Ord, Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
show :: SOAC rep -> String
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
showsPrec :: Int -> SOAC rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
Show)

-- | Information about computing a single histogram.
data HistOp rep = HistOp
  { forall rep. HistOp rep -> Shape
histShape :: Shape,
    -- | Race factor @RF@ means that only @1/RF@
    -- bins are used.
    forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
    forall rep. HistOp rep -> [VName]
histDest :: [VName],
    forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
Eq, HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
Ord, Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
Show)

-- | The essential parts of a 'Screma' factored out (everything
-- except the input arrays).
data ScremaForm rep
  = ScremaForm
      [Scan rep]
      [Reduce rep]
      (Lambda rep)
  deriving (ScremaForm rep -> ScremaForm rep -> Bool
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScremaForm rep -> ScremaForm rep -> Bool
$c/= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
== :: ScremaForm rep -> ScremaForm rep -> Bool
$c== :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
Eq, ScremaForm rep -> ScremaForm rep -> Bool
ScremaForm rep -> ScremaForm rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
min :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmin :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
max :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmax :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
>= :: ScremaForm rep -> ScremaForm rep -> Bool
$c>= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
> :: ScremaForm rep -> ScremaForm rep -> Bool
$c> :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
<= :: ScremaForm rep -> ScremaForm rep -> Bool
$c<= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
< :: ScremaForm rep -> ScremaForm rep -> Bool
$c< :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
compare :: ScremaForm rep -> ScremaForm rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
Ord, Int -> ScremaForm rep -> ShowS
forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
forall rep. RepTypes rep => ScremaForm rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScremaForm rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
show :: ScremaForm rep -> String
$cshow :: forall rep. RepTypes rep => ScremaForm rep -> String
showsPrec :: Int -> ScremaForm rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
Show)

singleBinOp :: Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp :: forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp [Lambda rep]
lams =
  Lambda
    { lambdaParams :: [LParam rep]
lambdaParams = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
xParams [Lambda rep]
lams forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
yParams [Lambda rep]
lams,
      lambdaReturnType :: [Type]
lambdaReturnType = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Lambda rep -> [Type]
lambdaReturnType [Lambda rep]
lams,
      lambdaBody :: Body rep
lambdaBody =
        forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
          (forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (forall rep. Body rep -> Stms rep
bodyStms forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams))
          (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Body rep -> Result
bodyResult forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Lambda rep -> Body rep
lambdaBody) [Lambda rep]
lams)
    }
  where
    xParams :: Lambda rep -> [Param (LParamInfo rep)]
xParams Lambda rep
lam = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
    yParams :: Lambda rep -> [Param (LParamInfo rep)]
yParams Lambda rep
lam = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)) (forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)

-- | How to compute a single scan result.
data Scan rep = Scan
  { forall rep. Scan rep -> Lambda rep
scanLambda :: Lambda rep,
    forall rep. Scan rep -> [SubExp]
scanNeutral :: [SubExp]
  }
  deriving (Scan rep -> Scan rep -> Bool
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scan rep -> Scan rep -> Bool
$c/= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
== :: Scan rep -> Scan rep -> Bool
$c== :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
Eq, Scan rep -> Scan rep -> Bool
Scan rep -> Scan rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
min :: Scan rep -> Scan rep -> Scan rep
$cmin :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
max :: Scan rep -> Scan rep -> Scan rep
$cmax :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
>= :: Scan rep -> Scan rep -> Bool
$c>= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
> :: Scan rep -> Scan rep -> Bool
$c> :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
<= :: Scan rep -> Scan rep -> Bool
$c<= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
< :: Scan rep -> Scan rep -> Bool
$c< :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
compare :: Scan rep -> Scan rep -> Ordering
$ccompare :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
Ord, Int -> Scan rep -> ShowS
forall rep. RepTypes rep => Int -> Scan rep -> ShowS
forall rep. RepTypes rep => [Scan rep] -> ShowS
forall rep. RepTypes rep => Scan rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scan rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [Scan rep] -> ShowS
show :: Scan rep -> String
$cshow :: forall rep. RepTypes rep => Scan rep -> String
showsPrec :: Int -> Scan rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> Scan rep -> ShowS
Show)

-- | How many reduction results are produced by these 'Scan's?
scanResults :: [Scan rep] -> Int
scanResults :: forall rep. [Scan rep] -> Int
scanResults = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Scan rep -> [SubExp]
scanNeutral)

-- | Combine multiple scan operators to a single operator.
singleScan :: Buildable rep => [Scan rep] -> Scan rep
singleScan :: forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans =
  let scan_nes :: [SubExp]
scan_nes = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Scan rep -> [SubExp]
scanNeutral [Scan rep]
scans
      scan_lam :: Lambda rep
scan_lam = forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall rep. Scan rep -> Lambda rep
scanLambda [Scan rep]
scans
   in forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
scan_lam [SubExp]
scan_nes

-- | How to compute a single reduction result.
data Reduce rep = Reduce
  { forall rep. Reduce rep -> Commutativity
redComm :: Commutativity,
    forall rep. Reduce rep -> Lambda rep
redLambda :: Lambda rep,
    forall rep. Reduce rep -> [SubExp]
redNeutral :: [SubExp]
  }
  deriving (Reduce rep -> Reduce rep -> Bool
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reduce rep -> Reduce rep -> Bool
$c/= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
== :: Reduce rep -> Reduce rep -> Bool
$c== :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
Eq, Reduce rep -> Reduce rep -> Bool
Reduce rep -> Reduce rep -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
min :: Reduce rep -> Reduce rep -> Reduce rep
$cmin :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
max :: Reduce rep -> Reduce rep -> Reduce rep
$cmax :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
>= :: Reduce rep -> Reduce rep -> Bool
$c>= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
> :: Reduce rep -> Reduce rep -> Bool
$c> :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
<= :: Reduce rep -> Reduce rep -> Bool
$c<= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
< :: Reduce rep -> Reduce rep -> Bool
$c< :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
compare :: Reduce rep -> Reduce rep -> Ordering
$ccompare :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
Ord, Int -> Reduce rep -> ShowS
forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
forall rep. RepTypes rep => [Reduce rep] -> ShowS
forall rep. RepTypes rep => Reduce rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reduce rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [Reduce rep] -> ShowS
show :: Reduce rep -> String
$cshow :: forall rep. RepTypes rep => Reduce rep -> String
showsPrec :: Int -> Reduce rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
Show)

-- | How many reduction results are produced by these 'Reduce's?
redResults :: [Reduce rep] -> Int
redResults :: forall rep. [Reduce rep] -> Int
redResults = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a -> b) -> [a] -> [b]
map (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Reduce rep -> [SubExp]
redNeutral)

-- | Combine multiple reduction operators to a single operator.
singleReduce :: Buildable rep => [Reduce rep] -> Reduce rep
singleReduce :: forall rep. Buildable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds =
  let red_nes :: [SubExp]
red_nes = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
      red_lam :: Lambda rep
red_lam = forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall rep. Reduce rep -> Lambda rep
redLambda [Reduce rep]
reds
   in forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce (forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall rep. Reduce rep -> Commutativity
redComm [Reduce rep]
reds)) Lambda rep
red_lam [SubExp]
red_nes

-- | The types produced by a single 'Screma', given the size of the
-- input array.
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType :: forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  [Type]
scan_tps forall a. [a] -> [a] -> [a]
++ [Type]
red_tps forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
  where
    scan_tps :: [Type]
scan_tps =
      forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
    red_tps :: [Type]
red_tps = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
    map_tps :: [Type]
map_tps = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam

-- | Construct a lambda that takes parameters of the given types and
-- simply returns them unchanged.
mkIdentityLambda ::
  (Buildable rep, MonadFreshNames m) =>
  [Type] ->
  m (Lambda rep)
mkIdentityLambda :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts = do
  [Param Type]
params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda
      { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
params,
        lambdaBody :: Body rep
lambdaBody = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
params,
        lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
      }

-- | Is the given lambda an identity lambda?
isIdentityLambda :: Lambda rep -> Bool
isIdentityLambda :: forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
lam =
  forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
    forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall dec. Param dec -> VName
paramName) (forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)

-- | A lambda with no parameters that returns no values.
nilFn :: Buildable rep => Lambda rep
nilFn :: forall rep. Buildable rep => Lambda rep
nilFn = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda forall a. Monoid a => a
mempty (forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty) forall a. Monoid a => a
mempty

-- | Construct a Screma with possibly multiple scans, and
-- the given map function.
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC :: forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans = forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scans []

-- | Construct a Screma with possibly multiple reductions, and
-- the given map function.
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC :: forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC = forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm []

-- | Construct a Screma with possibly multiple scans, and identity map
-- function.
scanSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Scan rep] ->
  m (ScremaForm rep)
scanSOAC :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan rep]
scans = forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans

-- | Construct a Screma with possibly multiple reductions, and
-- identity map function.
reduceSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Reduce rep] ->
  m (ScremaForm rep)
reduceSOAC :: forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds = forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Reduce rep]
reds forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall rep. Lambda rep -> [Type]
lambdaReturnType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds

-- | Construct a Screma corresponding to a map.
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC :: forall rep. Lambda rep -> ScremaForm rep
mapSOAC = forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] []

-- | Does this Screma correspond to a scan-map composition?
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Scan rep]
scans, Lambda rep
map_lam)

-- | Does this Screma correspond to pure scan?
isScanSOAC :: ScremaForm rep -> Maybe [Scan rep]
isScanSOAC :: forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm rep
form = do
  ([Scan rep]
scans, Lambda rep
map_lam) <- forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm rep
form
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [Scan rep]
scans

-- | Does this Screma correspond to a reduce-map composition?
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Reduce rep]
reds, Lambda rep
map_lam)

-- | Does this Screma correspond to a pure reduce?
isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC :: forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm rep
form = do
  ([Reduce rep]
reds, Lambda rep
map_lam) <- forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [Reduce rep]
reds

-- | Does this Screma correspond to a simple map, without any
-- reduction or scan results?
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC :: forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda rep
map_lam

-- | Return the "main" lambda of the Screma.  For a map, this is
-- equivalent to 'isMapSOAC'.  Note that the meaning of the return
-- value of this lambda depends crucially on exactly which Screma this
-- is.  The parameters will correspond exactly to elements of the
-- input arrays, however.
scremaLambda :: ScremaForm rep -> Lambda rep
scremaLambda :: forall rep. ScremaForm rep -> Lambda rep
scremaLambda (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
map_lam) = Lambda rep
map_lam

-- | @groupScatterResults <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- <output specification>.
--
-- This function is used for extracting and grouping the results of a
-- scatter. In the SOAC representation, the lambda inside a 'Scatter' returns
-- all indices and values as one big list. This function groups each value with
-- its corresponding indices (as determined by the t'Shape' of the output array).
--
-- The elements of the resulting list correspond to the shape and name of the
-- output parameters, in addition to a list of values written to that output
-- parameter, along with the array indices marking where to write them to.
--
-- See 'Scatter' for more information.
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
arrays) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
   in forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results
        forall a b. a -> (a -> b) -> b
& forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
        forall a b. a -> (a -> b) -> b
& forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays

-- | @groupScatterResults' <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- output specification. This is the simpler version of @groupScatterResults@,
-- which doesn't return any information about shapes or output arrays.
--
-- See 'groupScatterResults' for more information,
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
  let ([a]
indices, [a]
values) = forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
      ([Shape]
shapes, [Int]
ns, [array]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      chunk_sizes :: [Int]
chunk_sizes =
        forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> forall a. Int -> a -> [a]
replicate Int
n forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
   in forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values

-- | @splitScatterResults <output specification> <results>@
--
-- Splits the results array into indices and values according to the output
-- specification.
--
-- See 'groupScatterResults' for more information.
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      num_indices :: Int
num_indices = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
   in forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results

-- | Like 'Mapper', but just for 'SOAC's.
data SOACMapper frep trep m = SOACMapper
  { forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
    forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda :: Lambda frep -> m (Lambda trep),
    forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
  }

-- | A mapper that simply returns the SOAC verbatim.
identitySOACMapper :: forall rep m. Monad m => SOACMapper rep rep m
identitySOACMapper :: forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper =
  SOACMapper
    { mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      mapOnSOACVName :: VName -> m VName
mapOnSOACVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }

-- | Map a monadic action across the immediate children of a
-- SOAC.  The mapping does not descend recursively into subexpressions
-- and is done left-to-right.
mapSOACM ::
  Monad m =>
  SOACMapper frep trep m ->
  SOAC frep ->
  m (SOAC trep)
mapSOACM :: forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper frep trep m
tv (JVP Lambda frep
lam [SubExp]
args [SubExp]
vec) =
  forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
mapSOACM SOACMapper frep trep m
tv (VJP Lambda frep
lam [SubExp]
args [SubExp]
vec) =
  forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
args
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
vec
mapSOACM SOACMapper frep trep m
tv (Stream SubExp
size [VName]
arrs [SubExp]
accs Lambda frep
lam) =
  forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
size
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
accs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
mapSOACM SOACMapper frep trep m
tv (Scatter SubExp
w [VName]
ivs Lambda frep
lam [(Shape, Int, VName)]
as) =
  forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
ivs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(Shape
aw, Int
an, VName
a) ->
          (,,)
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
aw
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv VName
a
      )
      [(Shape, Int, VName)]
as
mapSOACM SOACMapper frep trep m
tv (Hist SubExp
w [VName]
arrs [HistOp frep]
ops Lambda frep
bucket_fun) =
  forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(HistOp Shape
shape SubExp
rf [VName]
op_arrs [SubExp]
nes Lambda frep
op) ->
          forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
shape
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
rf
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
op_arrs
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
nes
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
op
      )
      [HistOp frep]
ops
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
bucket_fun
mapSOACM SOACMapper frep trep m
tv (Screma SubExp
w [VName]
arrs (ScremaForm [Scan frep]
scans [Reduce frep]
reds Lambda frep
map_lam)) =
  forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
            forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Scan frep]
scans
              ( \(Scan Lambda frep
red_lam [SubExp]
red_nes) ->
                  forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan
                    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Reduce frep]
reds
              ( \(Reduce Commutativity
comm Lambda frep
red_lam [SubExp]
red_nes) ->
                  forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm
                    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
map_lam
        )

-- | A helper for defining 'TraverseOpStms'.
traverseSOACStms :: Monad m => OpStmsTraverser m (SOAC rep) rep
traverseSOACStms :: forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms Scope rep -> Stms rep -> m (Stms rep)
f = forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep m
mapper
  where
    mapper :: SOACMapper rep rep m
mapper = forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (Lambda rep) rep
traverseLambdaStms Scope rep -> Stms rep -> m (Stms rep)
f}

instance ASTRep rep => FreeIn (Scan rep) where
  freeIn' :: Scan rep -> FV
freeIn' (Scan Lambda rep
lam [SubExp]
ne) = forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne

instance ASTRep rep => FreeIn (Reduce rep) where
  freeIn' :: Reduce rep -> FV
freeIn' (Reduce Commutativity
_ Lambda rep
lam [SubExp]
ne) = forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [SubExp]
ne

instance ASTRep rep => FreeIn (ScremaForm rep) where
  freeIn' :: ScremaForm rep -> FV
freeIn' (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
lam) =
    forall a. FreeIn a => a -> FV
freeIn' [Scan rep]
scans forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [Reduce rep]
reds forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam

instance ASTRep rep => FreeIn (HistOp rep) where
  freeIn' :: HistOp rep -> FV
freeIn' (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
lam) =
    forall a. FreeIn a => a -> FV
freeIn' Shape
w forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
rf forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [VName]
dests forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [SubExp]
nes forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' Lambda rep
lam

instance ASTRep rep => FreeIn (SOAC rep) where
  freeIn' :: SOAC rep -> FV
freeIn' = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> s
execState forall a. Monoid a => a
mempty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep (StateT FV Identity)
free
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
      free :: SOACMapper rep rep (StateT FV Identity)
free =
        SOACMapper
          { mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSOACLambda = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk forall a. FreeIn a => a -> FV
freeIn'
          }

instance ASTRep rep => Substitute (SOAC rep) where
  substituteNames :: Map VName VName -> SOAC rep -> SOAC rep
substituteNames Map VName VName
subst =
    forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep Identity
substitute
    where
      substitute :: SOACMapper rep rep Identity
substitute =
        SOACMapper
          { mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance ASTRep rep => Rename (SOAC rep) where
  rename :: SOAC rep -> RenameM (SOAC rep)
rename = forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep RenameM
renamer
    where
      renamer :: SOACMapper rep rep RenameM
renamer = forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename forall a. Rename a => a -> RenameM a
rename

-- | The type of a SOAC.
soacType :: Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType :: forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType (JVP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
  forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
soacType (VJP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
  forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
soacType (Stream SubExp
outersize [VName]
_ [SubExp]
accs Lambda rep
lam) =
  forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
  where
    nms :: [VName]
nms = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (Int
1 forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [LParam rep]
params
    substs :: Map VName SubExp
substs = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize forall a. a -> [a] -> [a]
: [SubExp]
accs)
    Lambda [LParam rep]
params Body rep
_ [Type]
rtp = Lambda rep
lam
soacType (Scatter SubExp
_w [VName]
_ivs Lambda rep
lam [(Shape, Int, VName)]
dests) =
  forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
  where
    indexes :: Int
indexes = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
    val_ts :: [Type]
val_ts = forall a. Int -> [a] -> [a]
drop Int
indexes forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
    ([Shape]
ws, [Int]
ns, [VName]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
soacType (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_bucket_fun) = do
  HistOp rep
op <- [HistOp rep]
ops
  forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` forall rep. HistOp rep -> Shape
histShape HistOp rep
op) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm rep
form) =
  forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form

instance ASTRep rep => TypedOp (SOAC rep) where
  opType :: forall t (m :: * -> *). HasScope t m => SOAC rep -> m [ExtType]
opType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType

instance Aliased rep => AliasedOp (SOAC rep) where
  opAliases :: SOAC rep -> [Names]
opAliases = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Typed (LParamInfo rep) => SOAC rep -> [Type]
soacType

  consumedInOp :: SOAC rep -> Names
consumedInOp JVP {} = forall a. Monoid a => a
mempty
  consumedInOp VJP {} = forall a. Monoid a => a
mempty
  -- Only map functions can consume anything.  The operands to scan
  -- and reduce functions are always considered "fresh".
  consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
map_lam)) =
    (VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
map_lam
    where
      consumedArray :: VName -> VName
consumedArray VName
v = forall a. a -> Maybe a -> a
fromMaybe VName
v forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
      params_to_arrs :: [(VName, VName)]
params_to_arrs = forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
map_lam) [VName]
arrs
  consumedInOp (Stream SubExp
_ [VName]
arrs [SubExp]
accs Lambda rep
lam) =
    [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
    where
      consumedArray :: VName -> SubExp
consumedArray VName
v = forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
      -- Drop the chunk parameter, which cannot alias anything.
      paramsToInput :: [(VName, SubExp)]
paramsToInput =
        forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop Int
1 forall a b. (a -> b) -> a -> b
$ forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam) ([SubExp]
accs forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
  consumedInOp (Scatter SubExp
_ [VName]
_ Lambda rep
_ [(Shape, Int, VName)]
as) =
    [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) [(Shape, Int, VName)]
as
  consumedInOp (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
_) =
    [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops

mapHistOp ::
  (Lambda frep -> Lambda trep) ->
  HistOp frep ->
  HistOp trep
mapHistOp :: forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp Lambda frep -> Lambda trep
f (HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes Lambda frep
lam) =
  forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
w SubExp
rf [VName]
dests [SubExp]
nes forall a b. (a -> b) -> a -> b
$ Lambda frep -> Lambda trep
f Lambda frep
lam

instance CanBeAliased SOAC where
  addOpAliases :: forall rep.
AliasableRep rep =>
AliasTable -> SOAC rep -> SOAC (Aliases rep)
addOpAliases AliasTable
aliases (JVP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP (forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [SubExp]
args [SubExp]
vec
  addOpAliases AliasTable
aliases (VJP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP (forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [SubExp]
args [SubExp]
vec
  addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr [SubExp]
accs Lambda rep
lam) =
    forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
arr [SubExp]
accs forall a b. (a -> b) -> a -> b
$ forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam
  addOpAliases AliasTable
aliases (Scatter SubExp
len [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests) =
    forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len [VName]
arrs (forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [(Shape, Int, VName)]
dests
  addOpAliases AliasTable
aliases (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
    forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist
      SubExp
w
      [VName]
arrs
      (forall a b. (a -> b) -> [a] -> [b]
map (forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp (forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)) [HistOp rep]
ops)
      (forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
bucket_fun)
  addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs forall a b. (a -> b) -> a -> b
$
      forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
        (forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Scan (Aliases rep)
onScan [Scan rep]
scans)
        (forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Reduce (Aliases rep)
onRed [Reduce rep]
reds)
        (forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
map_lam)
    where
      onRed :: Reduce rep -> Reduce (Aliases rep)
onRed Reduce rep
red = Reduce rep
red {redLambda :: Lambda (Aliases rep)
redLambda = forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall rep. Reduce rep -> Lambda rep
redLambda Reduce rep
red}
      onScan :: Scan rep -> Scan (Aliases rep)
onScan Scan rep
scan = Scan rep
scan {scanLambda :: Lambda (Aliases rep)
scanLambda = forall rep.
AliasableRep rep =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases forall a b. (a -> b) -> a -> b
$ forall rep. Scan rep -> Lambda rep
scanLambda Scan rep
scan}

instance ASTRep rep => IsOp (SOAC rep) where
  safeOp :: SOAC rep -> Bool
safeOp SOAC rep
_ = Bool
False
  cheapOp :: SOAC rep -> Bool
cheapOp SOAC rep
_ = Bool
False

substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
  let shp' :: Shape
shp' = forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
   in forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u

substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
  forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs

instance CanBeWise SOAC where
  addOpWisdom :: forall rep. Informing rep => SOAC rep -> SOAC (Wise rep)
addOpWisdom = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM (forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Informing rep => Lambda rep -> Lambda (Wise rep)
informLambda) forall (f :: * -> *) a. Applicative f => a -> f a
pure)

instance RepTypes rep => ST.IndexOp (SOAC rep) where
  indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SOAC rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k SOAC rep
soac [TPrimExp Int64 VName
i] = do
    (Lambda rep
lam, SubExpRes
se, [Param (LParamInfo rep)]
arr_params, [VName]
arrs) <- SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp SOAC rep
soac
    let arr_indexes :: Map VName (PrimExp VName, Certs)
arr_indexes = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex [Param (LParamInfo rep)]
arr_params [VName]
arrs
        arr_indexes' :: Map VName (PrimExp VName, Certs)
arr_indexes' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
arr_indexes forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    case SubExpRes
se of
      SubExpRes Certs
_ (Var VName
v) -> forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip Certs -> PrimExp VName -> Indexed
ST.Indexed) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
arr_indexes'
      SubExpRes
_ -> forall a. Maybe a
Nothing
    where
      lambdaAndSubExp :: SOAC rep
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
        Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut (forall rep. [Scan rep] -> Int
scanResults [Scan rep]
scans forall a. Num a => a -> a -> a
+ forall rep. [Reduce rep] -> Int
redResults [Reduce rep]
reds) Lambda rep
map_lam [VName]
arrs
      lambdaAndSubExp SOAC rep
_ =
        forall a. Maybe a
Nothing

      nthMapOut :: Int
-> Lambda rep
-> [VName]
-> Maybe (Lambda rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut Int
num_accs Lambda rep
lam [VName]
arrs = do
        SubExpRes
se <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs forall a. Num a => a -> a -> a
+ Int
k) forall a b. (a -> b) -> a -> b
$ forall rep. Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda rep
lam, SubExpRes
se, forall a. Int -> [a] -> [a]
drop Int
num_accs forall a b. (a -> b) -> a -> b
$ forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam, [VName]
arrs)

      arrIndex :: Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex Param (LParamInfo rep)
p VName
arr = do
        ST.Indexed Certs
cs PrimExp VName
pe <- forall rep.
VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable rep
vtable
        forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p, (PrimExp VName
pe, Certs
cs))

      expandPrimExpTable :: Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
table Stm rep
stm
        | [VName
v] <- forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table) forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
          forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable rep
vtable) (Certs -> [VName]
unCerts forall a b. (a -> b) -> a -> b
$ forall rep. Stm rep -> Certs
stmCerts Stm rep
stm) =
            forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, forall rep. Stm rep -> Certs
stmCerts Stm rep
stm forall a. Semigroup a => a -> a -> a
<> Certs
cs) Map VName (PrimExp VName, Certs)
table
        | Bool
otherwise =
            Map VName (PrimExp VName, Certs)
table

      asPrimExp :: Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table VName
v
        | Just (PrimExp VName
e, Certs
cs) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
table = forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimExp VName
e
        | Just (Prim PrimType
pt) <- forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SOAC rep
_ [TPrimExp Int64 VName]
_ = forall a. Maybe a
Nothing

-- | Type-check a SOAC.
typeCheckSOAC :: TC.Checkable rep => SOAC (Aliases rep) -> TC.TypeM rep ()
typeCheckSOAC :: forall rep. Checkable rep => SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC (VJP Lambda (Aliases rep)
lam [SubExp]
args [SubExp]
vec) = do
  [Arg]
args' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
  forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
  [Type]
vec_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts forall a. Eq a => a -> a -> Bool
== forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam) forall a b. (a -> b) -> a -> b
$
    forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
      Doc Any
"Return type"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam))
        forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (JVP Lambda (Aliases rep)
lam [SubExp]
args [SubExp]
vec) = do
  [Arg]
args' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
args
  forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
args'
  [Type]
vec_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Type
TC.checkSubExp [SubExp]
vec
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
vec_ts forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args') forall a b. (a -> b) -> a -> b
$
    forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Doc a -> Text
docText forall a b. (a -> b) -> a -> b
$
      Doc Any
"Parameter type"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
args')
        forall a. Doc a -> Doc a -> Doc a
</> Doc Any
"does not match type of seed vector"
        forall a. Doc a -> Doc a -> Doc a
</> forall ann. Int -> Doc ann -> Doc ann
PP.indent Int
2 (forall a ann. Pretty a => a -> Doc ann
pretty [Type]
vec_ts)
typeCheckSOAC (Stream SubExp
size [VName]
arrexps [SubExp]
accexps Lambda (Aliases rep)
lam) = do
  forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  [Arg]
accargs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
accexps
  [Type]
arrargs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrexps
  [Arg]
_ <- forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
  Param (LParamInfo rep)
chunk <- case forall {rep}. Lambda rep -> [Param (LParamInfo rep)]
lambdaParams Lambda (Aliases rep)
lam of
    LParam (Aliases rep)
chunk : [LParam (Aliases rep)]
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure LParam (Aliases rep)
chunk
    [] -> forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$ forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Stream lambda without parameters."
  let asArg :: a -> (a, b)
asArg a
t = (a
t, forall a. Monoid a => a
mempty)
      inttp :: TypeBase shape u
inttp = forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      lamarrs' :: [Type]
lamarrs' = forall a b. (a -> b) -> [a] -> [b]
map (forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
chunk)) [Type]
arrargs
      acc_len :: Int
acc_len = forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
      lamrtp :: [Type]
lamrtp = forall a. Int -> [a] -> [a]
take Int
acc_len forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) forall a b. (a -> b) -> a -> b
$
    forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
      Text
"Stream with inconsistent accumulator type in lambda."
  -- just get the dflow of lambda on the fakearg, which does not alias
  -- arr, so we can later check that aliases of arr are not used inside lam.
  let fake_lamarrs' :: [Arg]
fake_lamarrs' = forall a b. (a -> b) -> [a] -> [b]
map forall {b} {a}. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
  forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam forall a b. (a -> b) -> a -> b
$ forall {b} {a}. Monoid b => a -> (a, b)
asArg forall {shape} {u}. TypeBase shape u
inttp forall a. a -> [a] -> [a]
: [Arg]
accargs forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w [VName]
arrs Lambda (Aliases rep)
lam [(Shape, Int, VName)]
as) = do
  -- Requirements:
  --
  --   0. @lambdaReturnType@ of @lam@ must be a list
  --      [index types..., value types, ...].
  --
  --   1. The number of index types and value types must be equal to the number
  --      of return values from @lam@.
  --
  --   2. Each index type must have the type i64.
  --
  --   3. Each array in @as@ and the value types must have the same type
  --
  --   4. Each array in @as@ is consumed.  This is not really a check, but more
  --      of a requirement, so that e.g. the source is not hoisted out of a
  --      loop, which will mean it cannot be consumed.
  --
  --   5. Each of arrs must be an array matching a corresponding lambda
  --      parameters.
  --
  -- Code:

  -- First check the input size.
  forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- 0.
  let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
      indexes :: Int
indexes = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
as_ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
      rts :: [Type]
rts = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
      rtsI :: [Type]
rtsI = forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
      rtsV :: [Type]
rtsV = forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts

  -- 1.
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
as_ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) forall a b. (a -> b) -> a -> b
$
    forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
      forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Scatter: number of index types, value types and array outputs do not match."

  -- 2.
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 forall a. Eq a => a -> a -> Bool
== Type
rtI) forall a b. (a -> b) -> a -> b
$
      forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall a b. (a -> b) -> a -> b
$
        forall rep. Text -> ErrorCase rep
TC.TypeError Text
"Scatter: Index return type must be i64."

  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(Shape, Int, VName)]
as) forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
    -- All lengths must have type i64.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw

    -- 3.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs forall a b. (a -> b) -> a -> b
$ \Type
rtV -> forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a

    -- 4.
    forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
a

  -- 5.
  [Arg]
arrargs <- forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
w [VName]
arrs [HistOp (Aliases rep)]
ops Lambda (Aliases rep)
bucket_fun) = do
  forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- Check the operators.
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases rep)]
ops forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases rep)
op) -> do
    [Arg]
nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
dest_shape
    forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf

    -- Operator type must match the type of neutral elements.
    forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
nes' forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
    let nes_t :: [Type]
nes_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t forall a. Eq a => a -> a -> Bool
== forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) forall a b. (a -> b) -> a -> b
$
      forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
        Text
"Operator has return type "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
          forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
          forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
nes_t

    -- Arrays must have proper type.
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
      forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> Shape -> Type
`arrayOfShape` Shape
dest_shape] VName
dest
      forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

  -- Types of input arrays must equal parameter types for bucket function.
  [Arg]
img' <- forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
bucket_fun [Arg]
img'

  -- Return type of bucket function must be an index for each
  -- operation followed by the values to write.
  [Type]
nes_ts <- forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp (Aliases rep)]
ops
  let bucket_ret_t :: [Type]
bucket_ret_t =
        forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((forall a. Int -> a -> [a]
`replicate` forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. ArrayShape a => a -> Int
shapeRank forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> Shape
histShape) [HistOp (Aliases rep)]
ops
          forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t forall a. Eq a => a -> a -> Bool
== forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun) forall a b. (a -> b) -> a -> b
$
    forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
      Text
"Bucket function has return type "
        forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun)
        forall a. Semigroup a => a -> a -> a
<> Text
" but should have type "
        forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Aliases rep)]
scans [Reduce (Aliases rep)]
reds Lambda (Aliases rep)
map_lam)) = do
  forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  [Arg]
arrs' <- forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
map_lam [Arg]
arrs'

  [Arg]
scan_nes' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases rep)]
scans forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases rep)
scan_lam [SubExp]
scan_nes) -> do
      [Arg]
scan_nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
scan_nes
      let scan_t :: [Type]
scan_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
      forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t forall a. Eq a => a -> a -> Bool
== forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam) forall a b. (a -> b) -> a -> b
$
        forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Scan function returns type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam)
            forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
scan_t
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
scan_nes'

  [Arg]
red_nes' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases rep)]
reds forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases rep)
red_lam [SubExp]
red_nes) -> do
      [Arg]
red_nes' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
red_nes
      let red_t :: [Type]
red_t = forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
      forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
red_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t forall a. Eq a => a -> a -> Bool
== forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam) forall a b. (a -> b) -> a -> b
$
        forall rep a. ErrorCase rep -> TypeM rep a
TC.bad forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError forall a b. (a -> b) -> a -> b
$
          Text
"Reduce function returns type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam)
            forall a. Semigroup a => a -> a -> a
<> Text
" but neutral element has type "
            forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
red_t
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [Arg]
red_nes'

  let map_lam_ts :: [Type]
map_lam_ts = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
map_lam

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    ( forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
        forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
    )
    forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep a. ErrorCase rep -> TypeM rep a
TC.bad
    forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Text -> ErrorCase rep
TC.TypeError
    forall a b. (a -> b) -> a -> b
$ Text
"Map function return type "
      forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => [a] -> Text
prettyTuple [Type]
map_lam_ts
      forall a. Semigroup a => a -> a -> a
<> Text
" wrong for given scan and reduction functions."

instance RephraseOp SOAC where
  rephraseInOp :: forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> SOAC from -> m (SOAC to)
rephraseInOp Rephraser m from to
r (VJP Lambda from
lam [SubExp]
args [SubExp]
vec) =
    forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
args forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
vec
  rephraseInOp Rephraser m from to
r (JVP Lambda from
lam [SubExp]
args [SubExp]
vec) =
    forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
args forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
vec
  rephraseInOp Rephraser m from to
r (Stream SubExp
w [VName]
arrs [SubExp]
acc Lambda from
lam) =
    forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
w [VName]
arrs [SubExp]
acc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
  rephraseInOp Rephraser m from to
r (Scatter SubExp
w [VName]
arrs Lambda from
lam [(Shape, Int, VName)]
dests) =
    forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w [VName]
arrs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Shape, Int, VName)]
dests
  rephraseInOp Rephraser m from to
r (Hist SubExp
w [VName]
arrs [HistOp from]
ops Lambda from
lam) =
    forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Hist SubExp
w [VName]
arrs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp from -> m (HistOp to)
onOp [HistOp from]
ops forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
    where
      onOp :: HistOp from -> m (HistOp to)
onOp (HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes Lambda from
op) =
        forall rep.
Shape -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp Shape
dest_shape SubExp
rf [VName]
dests [SubExp]
nes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op
  rephraseInOp Rephraser m from to
r (Screma SubExp
w [VName]
arrs (ScremaForm [Scan from]
scans [Reduce from]
red Lambda from
lam)) =
    forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Scan from -> m (Scan to)
onScan [Scan from]
scans
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Reduce from -> m (Reduce to)
onRed [Reduce from]
red
              forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
lam
          )
    where
      onScan :: Scan from -> m (Scan to)
onScan (Scan Lambda from
op [SubExp]
nes) = forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes
      onRed :: Reduce from -> m (Reduce to)
onRed (Reduce Commutativity
comm Lambda from
op [SubExp]
nes) = forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda Rephraser m from to
r Lambda from
op forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes

instance OpMetrics (Op rep) => OpMetrics (SOAC rep) where
  opMetrics :: SOAC rep -> MetricsM ()
opMetrics (VJP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"VJP" forall a b. (a -> b) -> a -> b
$ forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (JVP Lambda rep
lam [SubExp]
_ [SubExp]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"JVP" forall a b. (a -> b) -> a -> b
$ forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Stream SubExp
_ [VName]
_ [SubExp]
_ Lambda rep
lam) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" forall a b. (a -> b) -> a -> b
$ forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Scatter SubExp
_len [VName]
_ Lambda rep
lam [(Shape, Int, VName)]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" forall a b. (a -> b) -> a -> b
$ forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Hist SubExp
_ [VName]
_ [HistOp rep]
ops Lambda rep
bucket_fun) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
bucket_fun
  opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
      forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
map_lam

instance PrettyRep rep => PP.Pretty (SOAC rep) where
  pretty :: forall ann. SOAC rep -> Doc ann
pretty (VJP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Doc ann
"vjp"
      forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens
        ( forall ann. Doc ann -> Doc ann
PP.align forall a b. (a -> b) -> a -> b
$
            forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
args) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
vec)
        )
  pretty (JVP Lambda rep
lam [SubExp]
args [SubExp]
vec) =
    Doc ann
"jvp"
      forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens
        ( forall ann. Doc ann -> Doc ann
PP.align forall a b. (a -> b) -> a -> b
$
            forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
args) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
              forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
vec)
        )
  pretty (Stream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam) =
    forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [VName]
arrs [SubExp]
acc Lambda rep
lam
  pretty (Scatter SubExp
w [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests) =
    forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter SubExp
w [VName]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests
  pretty (Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun) =
    forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
bucket_fun
  pretty (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam))
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans,
      forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
        Doc ann
"map"
          forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
            ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
            )
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans =
        Doc ann
"redomap"
          forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
            ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Reduce rep]
reds) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
            )
    | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
        Doc ann
"scanomap"
          forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
            ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Scan rep]
scans) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
                forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
            )
  pretty (Screma SubExp
w [VName]
arrs ScremaForm rep
form) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [VName]
arrs ScremaForm rep
form

-- | Prettyprint the given Screma.
ppScrema ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema SubExp
w [inp]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  Doc ann
"screma"
    forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Scan rep]
scans) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Reduce rep]
reds) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
map_lam
      )

-- | Prettyprint the given Stream.
ppStream ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
ppStream SubExp
size [inp]
arrs [SubExp]
acc Lambda rep
lam =
  Doc ann
"streamSeq"
    forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
size forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
acc) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam
      )

-- | Prettyprint the given Scatter.
ppScatter ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
ppScatter SubExp
w [inp]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests =
  Doc ann
"scatter"
    forall a. Semigroup a => a -> a -> a
<> (forall ann. Doc ann -> Doc ann
parens forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall ann. Doc ann -> Doc ann
align)
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
commasep (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [(Shape, Int, VName)]
dests)
      )

instance PrettyRep rep => Pretty (Scan rep) where
  pretty :: forall ann. Scan rep -> Doc ann
pretty (Scan Lambda rep
scan_lam [SubExp]
scan_nes) =
    forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
scan_lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
scan_nes)

ppComm :: Commutativity -> Doc ann
ppComm :: forall ann. Commutativity -> Doc ann
ppComm Commutativity
Noncommutative = forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = Doc ann
"commutative "

instance PrettyRep rep => Pretty (Reduce rep) where
  pretty :: forall ann. Reduce rep -> Doc ann
pretty (Reduce Commutativity
comm Lambda rep
red_lam [SubExp]
red_nes) =
    forall ann. Commutativity -> Doc ann
ppComm Commutativity
comm forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
red_lam forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
      forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
red_nes)

-- | Prettyprint the given histogram operation.
ppHist ::
  (PrettyRep rep, Pretty inp) =>
  SubExp ->
  [inp] ->
  [HistOp rep] ->
  Lambda rep ->
  Doc ann
ppHist :: forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
ppHist SubExp
w [inp]
arrs [HistOp rep]
ops Lambda rep
bucket_fun =
  Doc ann
"hist"
    forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann -> Doc ann
parens
      ( forall a ann. Pretty a => a -> Doc ann
pretty SubExp
w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [inp]
arrs) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall ann. Doc ann -> Doc ann
PP.braces (forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> [a]
intersperse (forall ann. Doc ann
comma forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
PP.line) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {rep} {ann}. PrettyRep rep => HistOp rep -> Doc ann
ppOp [HistOp rep]
ops) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
          forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
bucket_fun
      )
  where
    ppOp :: HistOp rep -> Doc ann
ppOp (HistOp Shape
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
      forall a ann. Pretty a => a -> Doc ann
pretty Shape
dest_w forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
<+> forall a ann. Pretty a => a -> Doc ann
pretty SubExp
rf forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
<+> forall ann. Doc ann -> Doc ann
PP.braces (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [VName]
dests) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
</> forall a. [Doc a] -> Doc a
ppTuple' (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [SubExp]
nes) forall a. Semigroup a => a -> a -> a
<> forall ann. Doc ann
comma
        forall a. Doc a -> Doc a -> Doc a
</> forall a ann. Pretty a => a -> Doc ann
pretty Lambda rep
op