{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Definition of /Second-Order Array Combinators/ (SOACs), which are
-- the main form of parallelism in the early stages of the compiler.
module Futhark.IR.SOACS.SOAC
  ( SOAC (..),
    StreamOrd (..),
    StreamForm (..),
    ScremaForm (..),
    HistOp (..),
    Scan (..),
    scanResults,
    singleScan,
    Reduce (..),
    redResults,
    singleReduce,

    -- * Utility
    scremaType,
    soacType,
    typeCheckSOAC,
    mkIdentityLambda,
    isIdentityLambda,
    nilFn,
    scanomapSOAC,
    redomapSOAC,
    scanSOAC,
    reduceSOAC,
    mapSOAC,
    isScanomapSOAC,
    isRedomapSOAC,
    isScanSOAC,
    isReduceSOAC,
    isMapSOAC,
    ppScrema,
    ppHist,
    groupScatterResults,
    groupScatterResults',
    splitScatterResults,

    -- * Generic traversal
    SOACMapper (..),
    identitySOACMapper,
    mapSOACM,
  )
where

import Control.Category
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, removeLambdaAliases)
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty (Doc, Pretty, comma, commasep, parens, ppr, text, (<+>), (</>))
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))

-- | A second-order array combinator (SOAC).
data SOAC rep
  = Stream SubExp [VName] (StreamForm rep) [SubExp] (Lambda rep)
  | -- | @Scatter <length> <lambda> <inputs> <outputs>@
    --
    -- Scatter maps values from a set of input arrays to indices and values of a
    -- set of output arrays. It is able to write multiple values to multiple
    -- outputs each of which may have multiple dimensions.
    --
    -- <inputs> is a list of input arrays, all having size <length>, elements of
    -- which are applied to the <lambda> function. For instance, if there are
    -- two arrays, <lambda> will get two values as input, one from each array.
    --
    -- <outputs> specifies the result of the <lambda> and which arrays to write
    -- to. Each element of the list consists of a <VName> specifying which array
    -- to scatter to, a <Shape> describing the shape of that array, and an <Int>
    -- describing how many elements should be written to that array for each
    -- invocation of the <lambda>.
    --
    -- <lambda> is a function that takes inputs from <inputs> and returns values
    -- according to the output-specification in <outputs>. It returns values in
    -- the following manner:
    --
    --     [index_0, index_1, ..., index_n, value_0, value_1, ..., value_m]
    --
    -- For each output in <outputs>, <lambda> returns <i> * <j> index values and
    -- <j> output values, where <i> is the number of dimensions (rank) of the
    -- given output, and <j> is the number of output values written to the given
    -- output.
    --
    -- For example, given the following output specification:
    --
    --     [([x1, y1, z1], 2, arr1), ([x2, y2], 1, arr2)]
    --
    -- <lambda> will produce 6 (3 * 2) index values and 2 output values for
    -- <arr1>, and 2 (2 * 1) index values and 1 output value for
    -- arr2. Additionally, the results are grouped, so the first 6 index values
    -- will correspond to the first two output values, and so on. For this
    -- example, <lambda> should return a total of 11 values, 8 index values and
    -- 3 output values.
    Scatter SubExp (Lambda rep) [VName] [(Shape, Int, VName)]
  | -- | @Hist <length> <dest-arrays-and-ops> <bucket fun> <input arrays>@
    --
    -- The first SubExp is the length of the input arrays. The first
    -- list describes the operations to perform.  The t'Lambda' is the
    -- bucket function.  Finally comes the input images.
    Hist SubExp [HistOp rep] (Lambda rep) [VName]
  | -- | A combination of scan, reduction, and map.  The first
    -- t'SubExp' is the size of the input arrays.
    Screma SubExp [VName] (ScremaForm rep)
  deriving (SOAC rep -> SOAC rep -> Bool
(SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool) -> Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
Eq, Eq (SOAC rep)
Eq (SOAC rep)
-> (SOAC rep -> SOAC rep -> Ordering)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> Ord (SOAC rep)
SOAC rep -> SOAC rep -> Bool
SOAC rep -> SOAC rep -> Ordering
SOAC rep -> SOAC rep -> SOAC rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
min :: SOAC rep -> SOAC rep -> SOAC rep
$cmin :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
max :: SOAC rep -> SOAC rep -> SOAC rep
$cmax :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
>= :: SOAC rep -> SOAC rep -> Bool
$c>= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
> :: SOAC rep -> SOAC rep -> Bool
$c> :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
<= :: SOAC rep -> SOAC rep -> Bool
$c<= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
< :: SOAC rep -> SOAC rep -> Bool
$c< :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
compare :: SOAC rep -> SOAC rep -> Ordering
$ccompare :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (SOAC rep)
Ord, Int -> SOAC rep -> ShowS
[SOAC rep] -> ShowS
SOAC rep -> String
(Int -> SOAC rep -> ShowS)
-> (SOAC rep -> String) -> ([SOAC rep] -> ShowS) -> Show (SOAC rep)
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
show :: SOAC rep -> String
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
showsPrec :: Int -> SOAC rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
Show)

-- | Information about computing a single histogram.
data HistOp rep = HistOp
  { HistOp rep -> SubExp
histWidth :: SubExp,
    -- | Race factor @RF@ means that only @1/RF@
    -- bins are used.
    HistOp rep -> SubExp
histRaceFactor :: SubExp,
    HistOp rep -> [VName]
histDest :: [VName],
    HistOp rep -> [SubExp]
histNeutral :: [SubExp],
    HistOp rep -> Lambda rep
histOp :: Lambda rep
  }
  deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep)
-> (HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (HistOp rep)
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
Show)

-- | Is the stream chunk required to correspond to a contiguous
-- subsequence of the original input ('InOrder') or not?  'Disorder'
-- streams can be more efficient, but not all algorithms work with
-- this.
data StreamOrd = InOrder | Disorder
  deriving (StreamOrd -> StreamOrd -> Bool
(StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool) -> Eq StreamOrd
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamOrd -> StreamOrd -> Bool
$c/= :: StreamOrd -> StreamOrd -> Bool
== :: StreamOrd -> StreamOrd -> Bool
$c== :: StreamOrd -> StreamOrd -> Bool
Eq, Eq StreamOrd
Eq StreamOrd
-> (StreamOrd -> StreamOrd -> Ordering)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> Ord StreamOrd
StreamOrd -> StreamOrd -> Bool
StreamOrd -> StreamOrd -> Ordering
StreamOrd -> StreamOrd -> StreamOrd
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: StreamOrd -> StreamOrd -> StreamOrd
$cmin :: StreamOrd -> StreamOrd -> StreamOrd
max :: StreamOrd -> StreamOrd -> StreamOrd
$cmax :: StreamOrd -> StreamOrd -> StreamOrd
>= :: StreamOrd -> StreamOrd -> Bool
$c>= :: StreamOrd -> StreamOrd -> Bool
> :: StreamOrd -> StreamOrd -> Bool
$c> :: StreamOrd -> StreamOrd -> Bool
<= :: StreamOrd -> StreamOrd -> Bool
$c<= :: StreamOrd -> StreamOrd -> Bool
< :: StreamOrd -> StreamOrd -> Bool
$c< :: StreamOrd -> StreamOrd -> Bool
compare :: StreamOrd -> StreamOrd -> Ordering
$ccompare :: StreamOrd -> StreamOrd -> Ordering
$cp1Ord :: Eq StreamOrd
Ord, Int -> StreamOrd -> ShowS
[StreamOrd] -> ShowS
StreamOrd -> String
(Int -> StreamOrd -> ShowS)
-> (StreamOrd -> String)
-> ([StreamOrd] -> ShowS)
-> Show StreamOrd
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamOrd] -> ShowS
$cshowList :: [StreamOrd] -> ShowS
show :: StreamOrd -> String
$cshow :: StreamOrd -> String
showsPrec :: Int -> StreamOrd -> ShowS
$cshowsPrec :: Int -> StreamOrd -> ShowS
Show)

-- | What kind of stream is this?
data StreamForm rep
  = Parallel StreamOrd Commutativity (Lambda rep)
  | Sequential
  deriving (StreamForm rep -> StreamForm rep -> Bool
(StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> Eq (StreamForm rep)
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamForm rep -> StreamForm rep -> Bool
$c/= :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
== :: StreamForm rep -> StreamForm rep -> Bool
$c== :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
Eq, Eq (StreamForm rep)
Eq (StreamForm rep)
-> (StreamForm rep -> StreamForm rep -> Ordering)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> StreamForm rep)
-> (StreamForm rep -> StreamForm rep -> StreamForm rep)
-> Ord (StreamForm rep)
StreamForm rep -> StreamForm rep -> Bool
StreamForm rep -> StreamForm rep -> Ordering
StreamForm rep -> StreamForm rep -> StreamForm rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (StreamForm rep)
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Ordering
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> StreamForm rep
min :: StreamForm rep -> StreamForm rep -> StreamForm rep
$cmin :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> StreamForm rep
max :: StreamForm rep -> StreamForm rep -> StreamForm rep
$cmax :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> StreamForm rep
>= :: StreamForm rep -> StreamForm rep -> Bool
$c>= :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
> :: StreamForm rep -> StreamForm rep -> Bool
$c> :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
<= :: StreamForm rep -> StreamForm rep -> Bool
$c<= :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
< :: StreamForm rep -> StreamForm rep -> Bool
$c< :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
compare :: StreamForm rep -> StreamForm rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (StreamForm rep)
Ord, Int -> StreamForm rep -> ShowS
[StreamForm rep] -> ShowS
StreamForm rep -> String
(Int -> StreamForm rep -> ShowS)
-> (StreamForm rep -> String)
-> ([StreamForm rep] -> ShowS)
-> Show (StreamForm rep)
forall rep. RepTypes rep => Int -> StreamForm rep -> ShowS
forall rep. RepTypes rep => [StreamForm rep] -> ShowS
forall rep. RepTypes rep => StreamForm rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamForm rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [StreamForm rep] -> ShowS
show :: StreamForm rep -> String
$cshow :: forall rep. RepTypes rep => StreamForm rep -> String
showsPrec :: Int -> StreamForm rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> StreamForm rep -> ShowS
Show)

-- | The essential parts of a 'Screma' factored out (everything
-- except the input arrays).
data ScremaForm rep
  = ScremaForm
      [Scan rep]
      [Reduce rep]
      (Lambda rep)
  deriving (ScremaForm rep -> ScremaForm rep -> Bool
(ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScremaForm rep -> ScremaForm rep -> Bool
$c/= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
== :: ScremaForm rep -> ScremaForm rep -> Bool
$c== :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
Eq, Eq (ScremaForm rep)
Eq (ScremaForm rep)
-> (ScremaForm rep -> ScremaForm rep -> Ordering)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> Ord (ScremaForm rep)
ScremaForm rep -> ScremaForm rep -> Bool
ScremaForm rep -> ScremaForm rep -> Ordering
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
min :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmin :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
max :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmax :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
>= :: ScremaForm rep -> ScremaForm rep -> Bool
$c>= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
> :: ScremaForm rep -> ScremaForm rep -> Bool
$c> :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
<= :: ScremaForm rep -> ScremaForm rep -> Bool
$c<= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
< :: ScremaForm rep -> ScremaForm rep -> Bool
$c< :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
compare :: ScremaForm rep -> ScremaForm rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (ScremaForm rep)
Ord, Int -> ScremaForm rep -> ShowS
[ScremaForm rep] -> ShowS
ScremaForm rep -> String
(Int -> ScremaForm rep -> ShowS)
-> (ScremaForm rep -> String)
-> ([ScremaForm rep] -> ShowS)
-> Show (ScremaForm rep)
forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
forall rep. RepTypes rep => ScremaForm rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScremaForm rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
show :: ScremaForm rep -> String
$cshow :: forall rep. RepTypes rep => ScremaForm rep -> String
showsPrec :: Int -> ScremaForm rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
Show)

singleBinOp :: Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp :: [Lambda rep] -> Lambda rep
singleBinOp [Lambda rep]
lams =
  Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
    { lambdaParams :: [LParam rep]
lambdaParams = (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
xParams [Lambda rep]
lams [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
yParams [Lambda rep]
lams,
      lambdaReturnType :: [Type]
lambdaReturnType = (Lambda rep -> [Type]) -> [Lambda rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType [Lambda rep]
lams,
      lambdaBody :: BodyT rep
lambdaBody =
        Stms rep -> Result -> BodyT rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
          ([Stms rep] -> Stms rep
forall a. Monoid a => [a] -> a
mconcat ((Lambda rep -> Stms rep) -> [Lambda rep] -> [Stms rep]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep)
-> (Lambda rep -> BodyT rep) -> Lambda rep -> Stms rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda rep]
lams))
          ((Lambda rep -> Result) -> [Lambda rep] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result)
-> (Lambda rep -> BodyT rep) -> Lambda rep -> Result
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda rep]
lams)
    }
  where
    xParams :: LambdaT rep -> [Param (LParamInfo rep)]
xParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams LambdaT rep
lam)
    yParams :: LambdaT rep -> [Param (LParamInfo rep)]
yParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams LambdaT rep
lam)

-- | How to compute a single scan result.
data Scan rep = Scan
  { Scan rep -> Lambda rep
scanLambda :: Lambda rep,
    Scan rep -> [SubExp]
scanNeutral :: [SubExp]
  }
  deriving (Scan rep -> Scan rep -> Bool
(Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool) -> Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scan rep -> Scan rep -> Bool
$c/= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
== :: Scan rep -> Scan rep -> Bool
$c== :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
Eq, Eq (Scan rep)
Eq (Scan rep)
-> (Scan rep -> Scan rep -> Ordering)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Scan rep)
-> (Scan rep -> Scan rep -> Scan rep)
-> Ord (Scan rep)
Scan rep -> Scan rep -> Bool
Scan rep -> Scan rep -> Ordering
Scan rep -> Scan rep -> Scan rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
min :: Scan rep -> Scan rep -> Scan rep
$cmin :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
max :: Scan rep -> Scan rep -> Scan rep
$cmax :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
>= :: Scan rep -> Scan rep -> Bool
$c>= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
> :: Scan rep -> Scan rep -> Bool
$c> :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
<= :: Scan rep -> Scan rep -> Bool
$c<= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
< :: Scan rep -> Scan rep -> Bool
$c< :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
compare :: Scan rep -> Scan rep -> Ordering
$ccompare :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (Scan rep)
Ord, Int -> Scan rep -> ShowS
[Scan rep] -> ShowS
Scan rep -> String
(Int -> Scan rep -> ShowS)
-> (Scan rep -> String) -> ([Scan rep] -> ShowS) -> Show (Scan rep)
forall rep. RepTypes rep => Int -> Scan rep -> ShowS
forall rep. RepTypes rep => [Scan rep] -> ShowS
forall rep. RepTypes rep => Scan rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scan rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [Scan rep] -> ShowS
show :: Scan rep -> String
$cshow :: forall rep. RepTypes rep => Scan rep -> String
showsPrec :: Int -> Scan rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> Scan rep -> ShowS
Show)

-- | How many reduction results are produced by these 'Scan's?
scanResults :: [Scan rep] -> Int
scanResults :: [Scan rep] -> Int
scanResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Scan rep] -> [Int]) -> [Scan rep] -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Scan rep -> Int) -> [Scan rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Scan rep -> [SubExp]) -> Scan rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral)

-- | Combine multiple scan operators to a single operator.
singleScan :: Buildable rep => [Scan rep] -> Scan rep
singleScan :: [Scan rep] -> Scan rep
singleScan [Scan rep]
scans =
  let scan_nes :: [SubExp]
scan_nes = (Scan rep -> [SubExp]) -> [Scan rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan rep]
scans
      scan_lam :: Lambda rep
scan_lam = [Lambda rep] -> Lambda rep
forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Lambda rep) -> [Scan rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda [Scan rep]
scans
   in Lambda rep -> [SubExp] -> Scan rep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
scan_lam [SubExp]
scan_nes

-- | How to compute a single reduction result.
data Reduce rep = Reduce
  { Reduce rep -> Commutativity
redComm :: Commutativity,
    Reduce rep -> Lambda rep
redLambda :: Lambda rep,
    Reduce rep -> [SubExp]
redNeutral :: [SubExp]
  }
  deriving (Reduce rep -> Reduce rep -> Bool
(Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool) -> Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reduce rep -> Reduce rep -> Bool
$c/= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
== :: Reduce rep -> Reduce rep -> Bool
$c== :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
Eq, Eq (Reduce rep)
Eq (Reduce rep)
-> (Reduce rep -> Reduce rep -> Ordering)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> Ord (Reduce rep)
Reduce rep -> Reduce rep -> Bool
Reduce rep -> Reduce rep -> Ordering
Reduce rep -> Reduce rep -> Reduce rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
min :: Reduce rep -> Reduce rep -> Reduce rep
$cmin :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
max :: Reduce rep -> Reduce rep -> Reduce rep
$cmax :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
>= :: Reduce rep -> Reduce rep -> Bool
$c>= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
> :: Reduce rep -> Reduce rep -> Bool
$c> :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
<= :: Reduce rep -> Reduce rep -> Bool
$c<= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
< :: Reduce rep -> Reduce rep -> Bool
$c< :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
compare :: Reduce rep -> Reduce rep -> Ordering
$ccompare :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
$cp1Ord :: forall rep. RepTypes rep => Eq (Reduce rep)
Ord, Int -> Reduce rep -> ShowS
[Reduce rep] -> ShowS
Reduce rep -> String
(Int -> Reduce rep -> ShowS)
-> (Reduce rep -> String)
-> ([Reduce rep] -> ShowS)
-> Show (Reduce rep)
forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
forall rep. RepTypes rep => [Reduce rep] -> ShowS
forall rep. RepTypes rep => Reduce rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reduce rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [Reduce rep] -> ShowS
show :: Reduce rep -> String
$cshow :: forall rep. RepTypes rep => Reduce rep -> String
showsPrec :: Int -> Reduce rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
Show)

-- | How many reduction results are produced by these 'Reduce's?
redResults :: [Reduce rep] -> Int
redResults :: [Reduce rep] -> Int
redResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Reduce rep] -> [Int]) -> [Reduce rep] -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Reduce rep -> Int) -> [Reduce rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Reduce rep -> [SubExp]) -> Reduce rep -> Int
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral)

-- | Combine multiple reduction operators to a single operator.
singleReduce :: Buildable rep => [Reduce rep] -> Reduce rep
singleReduce :: [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds =
  let red_nes :: [SubExp]
red_nes = (Reduce rep -> [SubExp]) -> [Reduce rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
      red_lam :: Lambda rep
red_lam = [Lambda rep] -> Lambda rep
forall rep. Buildable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Lambda rep) -> [Reduce rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda [Reduce rep]
reds
   in Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce ([Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((Reduce rep -> Commutativity) -> [Reduce rep] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Commutativity
forall rep. Reduce rep -> Commutativity
redComm [Reduce rep]
reds)) Lambda rep
red_lam [SubExp]
red_nes

-- | The types produced by a single 'Screma', given the size of the
-- input array.
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  [Type]
scan_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
red_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
  where
    scan_tps :: [Type]
scan_tps =
      (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
        (Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
    red_tps :: [Type]
red_tps = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
    map_tps :: [Type]
map_tps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
map_lam

-- | Construct a lambda that takes parameters of the given types and
-- simply returns them unchanged.
mkIdentityLambda ::
  (Buildable rep, MonadFreshNames m) =>
  [Type] ->
  m (Lambda rep)
mkIdentityLambda :: [Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts = do
  [Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
  Lambda rep -> m (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return
    Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
      { lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
[LParam rep]
params,
        lambdaBody :: BodyT rep
lambdaBody = Stms rep -> Result -> BodyT rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty (Result -> BodyT rep) -> Result -> BodyT rep
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params,
        lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
      }

-- | Is the given lambda an identity lambda?
isIdentityLambda :: Lambda rep -> Bool
isIdentityLambda :: Lambda rep -> Bool
isIdentityLambda Lambda rep
lam =
  (SubExpRes -> SubExp) -> Result -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam))
    [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param (LParamInfo rep) -> SubExp)
-> [Param (LParamInfo rep)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (LParamInfo rep) -> VName)
-> Param (LParamInfo rep)
-> SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName) (Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)

-- | A lambda with no parameters that returns no values.
nilFn :: Buildable rep => Lambda rep
nilFn :: Lambda rep
nilFn = [LParam rep] -> BodyT rep -> [Type] -> Lambda rep
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam rep]
forall a. Monoid a => a
mempty (Stms rep -> Result -> BodyT rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty Result
forall a. Monoid a => a
mempty) [Type]
forall a. Monoid a => a
mempty

-- | Construct a Screma with possibly multiple scans, and
-- the given map function.
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scans []

-- | Construct a Screma with possibly multiple reductions, and
-- the given map function.
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm []

-- | Construct a Screma with possibly multiple scans, and identity map
-- function.
scanSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Scan rep] ->
  m (ScremaForm rep)
scanSOAC :: [Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan rep]
scans = [Scan rep] -> Lambda rep -> ScremaForm rep
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = (Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans

-- | Construct a Screma with possibly multiple reductions, and
-- identity map function.
reduceSOAC ::
  (Buildable rep, MonadFreshNames m) =>
  [Reduce rep] ->
  m (ScremaForm rep)
reduceSOAC :: [Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds = [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Reduce rep]
reds (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
  where
    ts :: [Type]
ts = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds

-- | Construct a Screma corresponding to a map.
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] []

-- | Does this Screma correspond to a scan-map composition?
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  ([Scan rep], Lambda rep) -> Maybe ([Scan rep], Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Scan rep]
scans, Lambda rep
map_lam)

-- | Does this Screma correspond to pure scan?
isScanSOAC :: ScremaForm rep -> Maybe [Scan rep]
isScanSOAC :: ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm rep
form = do
  ([Scan rep]
scans, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm rep
form
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  [Scan rep] -> Maybe [Scan rep]
forall (m :: * -> *) a. Monad m => a -> m a
return [Scan rep]
scans

-- | Does this Screma correspond to a reduce-map composition?
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  ([Reduce rep], Lambda rep) -> Maybe ([Reduce rep], Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Reduce rep]
reds, Lambda rep
map_lam)

-- | Does this Screma correspond to a pure reduce?
isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm rep
form = do
  ([Reduce rep]
reds, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
  [Reduce rep] -> Maybe [Reduce rep]
forall (m :: * -> *) a. Monad m => a -> m a
return [Reduce rep]
reds

-- | Does this Screma correspond to a simple map, without any
-- reduction or scan results?
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
  Lambda rep -> Maybe (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda rep
map_lam

-- | @groupScatterResults <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- <output specification>.
--
-- This function is used for extracting and grouping the results of a
-- scatter. In the SOAC representation, the lambda inside a 'Scatter' returns
-- all indices and values as one big list. This function groups each value with
-- its corresponding indices (as determined by the t'Shape' of the output array).
--
-- The elements of the resulting list correspond to the shape and name of the
-- output parameters, in addition to a list of values written to that output
-- parameter, along with the array indices marking where to write them to.
--
-- See 'Scatter' for more information.
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
arrays) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
   in [(Shape, Int, array)] -> [a] -> [([a], a)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results
        [([a], a)] -> ([([a], a)] -> [[([a], a)]]) -> [[([a], a)]]
forall a b. a -> (a -> b) -> b
& [Int] -> [([a], a)] -> [[([a], a)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
        [[([a], a)]]
-> ([[([a], a)]] -> [(Shape, array, [([a], a)])])
-> [(Shape, array, [([a], a)])]
forall a b. a -> (a -> b) -> b
& [Shape] -> [array] -> [[([a], a)]] -> [(Shape, array, [([a], a)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays

-- | @groupScatterResults' <output specification> <results>@
--
-- Groups the index values and result values of <results> according to the
-- output specification. This is the simpler version of @groupScatterResults@,
-- which doesn't return any information about shapes or output arrays.
--
-- See 'groupScatterResults' for more information,
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
  let ([a]
indices, [a]
values) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
      ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      chunk_sizes :: [Int]
chunk_sizes =
        [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int -> [Int]) -> [Shape] -> [Int] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
n (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
   in [[a]] -> [a] -> [([a], a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values

-- | @splitScatterResults <output specification> <results>@
--
-- Splits the results array into indices and values according to the output
-- specification.
--
-- See 'groupScatterResults' for more information.
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
  let ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
      num_indices :: Int
num_indices = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
   in Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results

-- | Like 'Mapper', but just for 'SOAC's.
data SOACMapper frep trep m = SOACMapper
  { SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
    SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda :: Lambda frep -> m (Lambda trep),
    SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
  }

-- | A mapper that simply returns the SOAC verbatim.
identitySOACMapper :: Monad m => SOACMapper rep rep m
identitySOACMapper :: SOACMapper rep rep m
identitySOACMapper =
  SOACMapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper
    { mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = Lambda rep -> m (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      mapOnSOACVName :: VName -> m VName
mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
    }

-- | Map a monadic action across the immediate children of a
-- SOAC.  The mapping does not descend recursively into subexpressions
-- and is done left-to-right.
mapSOACM ::
  (Applicative m, Monad m) =>
  SOACMapper frep trep m ->
  SOAC frep ->
  m (SOAC trep)
mapSOACM :: SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper frep trep m
tv (Stream SubExp
size [VName]
arrs StreamForm frep
form [SubExp]
accs Lambda frep
lam) =
  SubExp
-> [VName]
-> StreamForm trep
-> [SubExp]
-> Lambda trep
-> SOAC trep
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream (SubExp
 -> [VName]
 -> StreamForm trep
 -> [SubExp]
 -> Lambda trep
 -> SOAC trep)
-> m SubExp
-> m ([VName]
      -> StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
size
    m ([VName]
   -> StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
-> m [VName]
-> m (StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    m (StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
-> m (StreamForm trep) -> m ([SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StreamForm frep -> m (StreamForm trep)
mapOnStreamForm StreamForm frep
form
    m ([SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m (Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
accs
    m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
  where
    mapOnStreamForm :: StreamForm frep -> m (StreamForm trep)
mapOnStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda frep
lam0) =
      StreamOrd -> Commutativity -> Lambda trep -> StreamForm trep
forall rep.
StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
Parallel StreamOrd
o Commutativity
comm (Lambda trep -> StreamForm trep)
-> m (Lambda trep) -> m (StreamForm trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam0
    mapOnStreamForm StreamForm frep
Sequential =
      StreamForm trep -> m (StreamForm trep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure StreamForm trep
forall rep. StreamForm rep
Sequential
mapSOACM SOACMapper frep trep m
tv (Scatter SubExp
len Lambda frep
lam [VName]
ivs [(Shape, Int, VName)]
as) =
  SubExp
-> Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter
    (SubExp
 -> Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep)
-> m SubExp
-> m (Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
len
    m (Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep)
-> m (Lambda trep)
-> m ([VName] -> [(Shape, Int, VName)] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
    m ([VName] -> [(Shape, Int, VName)] -> SOAC trep)
-> m [VName] -> m ([(Shape, Int, VName)] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
ivs
    m ([(Shape, Int, VName)] -> SOAC trep)
-> m [(Shape, Int, VName)] -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Shape, Int, VName) -> m (Shape, Int, VName))
-> [(Shape, Int, VName)] -> m [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(Shape
aw, Int
an, VName
a) ->
          (,,) (Shape -> Int -> VName -> (Shape, Int, VName))
-> m Shape -> m (Int -> VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
aw
            m (Int -> VName -> (Shape, Int, VName))
-> m Int -> m (VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
            m (VName -> (Shape, Int, VName))
-> m VName -> m (Shape, Int, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv VName
a
      )
      [(Shape, Int, VName)]
as
mapSOACM SOACMapper frep trep m
tv (Hist SubExp
len [HistOp frep]
ops Lambda frep
bucket_fun [VName]
imgs) =
  SubExp -> [HistOp trep] -> Lambda trep -> [VName] -> SOAC trep
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [VName] -> SOAC rep
Hist
    (SubExp -> [HistOp trep] -> Lambda trep -> [VName] -> SOAC trep)
-> m SubExp
-> m ([HistOp trep] -> Lambda trep -> [VName] -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
len
    m ([HistOp trep] -> Lambda trep -> [VName] -> SOAC trep)
-> m [HistOp trep] -> m (Lambda trep -> [VName] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
      ( \(HistOp SubExp
e SubExp
rf [VName]
arrs [SubExp]
nes Lambda frep
op) ->
          SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
SubExp -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (SubExp
 -> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m SubExp
-> m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
e
            m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m SubExp
-> m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
rf
            m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m [VName] -> m ([SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
            m ([SubExp] -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
nes
            m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
op
      )
      [HistOp frep]
ops
    m (Lambda trep -> [VName] -> SOAC trep)
-> m (Lambda trep) -> m ([VName] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
bucket_fun
    m ([VName] -> SOAC trep) -> m [VName] -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
imgs
mapSOACM SOACMapper frep trep m
tv (Screma SubExp
w [VName]
arrs (ScremaForm [Scan frep]
scans [Reduce frep]
reds Lambda frep
map_lam)) =
  SubExp -> [VName] -> ScremaForm trep -> SOAC trep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (SubExp -> [VName] -> ScremaForm trep -> SOAC trep)
-> m SubExp -> m ([VName] -> ScremaForm trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
    m ([VName] -> ScremaForm trep -> SOAC trep)
-> m [VName] -> m (ScremaForm trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
    m (ScremaForm trep -> SOAC trep)
-> m (ScremaForm trep) -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( [Scan trep] -> [Reduce trep] -> Lambda trep -> ScremaForm trep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
            ([Scan trep] -> [Reduce trep] -> Lambda trep -> ScremaForm trep)
-> m [Scan trep]
-> m ([Reduce trep] -> Lambda trep -> ScremaForm trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Scan frep] -> (Scan frep -> m (Scan trep)) -> m [Scan trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Scan frep]
scans
              ( \(Scan Lambda frep
red_lam [SubExp]
red_nes) ->
                  Lambda trep -> [SubExp] -> Scan trep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan (Lambda trep -> [SubExp] -> Scan trep)
-> m (Lambda trep) -> m ([SubExp] -> Scan trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    m ([SubExp] -> Scan trep) -> m [SubExp] -> m (Scan trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            m ([Reduce trep] -> Lambda trep -> ScremaForm trep)
-> m [Reduce trep] -> m (Lambda trep -> ScremaForm trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Reduce frep]
-> (Reduce frep -> m (Reduce trep)) -> m [Reduce trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
              [Reduce frep]
reds
              ( \(Reduce Commutativity
comm Lambda frep
red_lam [SubExp]
red_nes) ->
                  Commutativity -> Lambda trep -> [SubExp] -> Reduce trep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm (Lambda trep -> [SubExp] -> Reduce trep)
-> m (Lambda trep) -> m ([SubExp] -> Reduce trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
                    m ([SubExp] -> Reduce trep) -> m [SubExp] -> m (Reduce trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
              )
            m (Lambda trep -> ScremaForm trep)
-> m (Lambda trep) -> m (ScremaForm trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
map_lam
        )

instance ASTRep rep => FreeIn (SOAC rep) where
  freeIn' :: SOAC rep -> FV
freeIn' = (State FV (SOAC rep) -> FV -> FV)
-> FV -> State FV (SOAC rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SOAC rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SOAC rep) -> FV)
-> (SOAC rep -> State FV (SOAC rep)) -> SOAC rep -> FV
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep (StateT FV Identity)
-> SOAC rep -> State FV (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep (StateT FV Identity)
free
    where
      walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
      free :: SOACMapper rep rep (StateT FV Identity)
free =
        SOACMapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper
          { mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSOACLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
            mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
          }

instance ASTRep rep => Substitute (SOAC rep) where
  substituteNames :: Map VName VName -> SOAC rep -> SOAC rep
substituteNames Map VName VName
subst =
    Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC rep -> Identity (SOAC rep)) -> SOAC rep -> SOAC rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep Identity -> SOAC rep -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep Identity
substitute
    where
      substitute :: SOACMapper rep rep Identity
substitute =
        SOACMapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper
          { mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = Lambda rep -> Identity (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
            mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
          }

instance ASTRep rep => Rename (SOAC rep) where
  rename :: SOAC rep -> RenameM (SOAC rep)
rename = SOACMapper rep rep RenameM -> SOAC rep -> RenameM (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep RenameM
renamer
    where
      renamer :: SOACMapper rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (VName -> RenameM VName)
-> SOACMapper rep rep RenameM
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename

-- | The type of a SOAC.
soacType :: SOAC rep -> [Type]
soacType :: SOAC rep -> [Type]
soacType (Stream SubExp
outersize [VName]
_ StreamForm rep
_ [SubExp]
accs Lambda rep
lam) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
  where
    nms :: [VName]
nms = (Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [Param (LParamInfo rep)]
params
    substs :: Map VName SubExp
substs = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
accs)
    Lambda [Param (LParamInfo rep)]
params BodyT rep
_ [Type]
rtp = Lambda rep
lam
soacType (Scatter SubExp
_w Lambda rep
lam [VName]
_ivs [(Shape, Int, VName)]
as) =
  (Type -> Shape -> Type) -> [Type] -> [Shape] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
  where
    indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
    val_ts :: [Type]
val_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam
    ([Shape]
ws, [Int]
ns, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
soacType (Hist SubExp
_len [HistOp rep]
ops Lambda rep
_bucket_fun [VName]
_imgs) = do
  HistOp rep
op <- [HistOp rep]
ops
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp rep -> SubExp
forall rep. HistOp rep -> SubExp
histWidth HistOp rep
op) (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm rep
form) =
  SubExp -> ScremaForm rep -> [Type]
forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form

instance TypedOp (SOAC rep) where
  opType :: SOAC rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SOAC rep -> [ExtType]) -> SOAC rep -> m [ExtType]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SOAC rep -> [Type]) -> SOAC rep -> [ExtType]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. SOAC rep -> [Type]
soacType

instance (ASTRep rep, Aliased rep) => AliasedOp (SOAC rep) where
  opAliases :: SOAC rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names]) -> (SOAC rep -> [Type]) -> SOAC rep -> [Names]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. SOAC rep -> [Type]
soacType

  -- Only map functions can consume anything.  The operands to scan
  -- and reduce functions are always considered "fresh".
  consumedInOp :: SOAC rep -> Names
consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
map_lam)) =
    (VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
map_lam
    where
      consumedArray :: VName -> VName
consumedArray VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
      params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
map_lam) [VName]
arrs
  consumedInOp (Stream SubExp
_ [VName]
arrs StreamForm rep
form [SubExp]
accs Lambda rep
lam) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
      [SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$
        case StreamForm rep
form of
          StreamForm rep
Sequential ->
            (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
          Parallel {} ->
            (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
    where
      consumedArray :: VName -> SubExp
consumedArray VName
v = SubExp -> Maybe SubExp -> SubExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, SubExp)] -> Maybe SubExp
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
      -- Drop the chunk parameter, which cannot alias anything.
      paramsToInput :: [(VName, SubExp)]
paramsToInput =
        [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
1 ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam) ([SubExp]
accs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
  consumedInOp (Scatter SubExp
_ Lambda rep
_ [VName]
_ [(Shape, Int, VName)]
as) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) [(Shape, Int, VName)]
as
  consumedInOp (Hist SubExp
_ [HistOp rep]
ops Lambda rep
_ [VName]
_) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops

mapHistOp ::
  (Lambda frep -> Lambda trep) ->
  HistOp frep ->
  HistOp trep
mapHistOp :: (Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp Lambda frep -> Lambda trep
f (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda frep
lam) =
  SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
SubExp -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes (Lambda trep -> HistOp trep) -> Lambda trep -> HistOp trep
forall a b. (a -> b) -> a -> b
$ Lambda frep -> Lambda trep
f Lambda frep
lam

instance
  ( ASTRep rep,
    ASTRep (Aliases rep),
    CanBeAliased (Op rep)
  ) =>
  CanBeAliased (SOAC rep)
  where
  type OpWithAliases (SOAC rep) = SOAC (Aliases rep)

  addOpAliases :: AliasTable -> SOAC rep -> OpWithAliases (SOAC rep)
addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr StreamForm rep
form [SubExp]
accs Lambda rep
lam) =
    SubExp
-> [VName]
-> StreamForm (Aliases rep)
-> [SubExp]
-> Lambda (Aliases rep)
-> SOAC (Aliases rep)
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
arr (StreamForm rep -> StreamForm (Aliases rep)
analyseStreamForm StreamForm rep
form) [SubExp]
accs (Lambda (Aliases rep) -> SOAC (Aliases rep))
-> Lambda (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$
      AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam
    where
      analyseStreamForm :: StreamForm rep -> StreamForm (Aliases rep)
analyseStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda rep
lam0) =
        StreamOrd
-> Commutativity
-> Lambda (Aliases rep)
-> StreamForm (Aliases rep)
forall rep.
StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
Parallel StreamOrd
o Commutativity
comm (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam0)
      analyseStreamForm StreamForm rep
Sequential = StreamForm (Aliases rep)
forall rep. StreamForm rep
Sequential
  addOpAliases AliasTable
aliases (Scatter SubExp
len Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
as) =
    SubExp
-> Lambda (Aliases rep)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Aliases rep)
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [VName]
ivs [(Shape, Int, VName)]
as
  addOpAliases AliasTable
aliases (Hist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
imgs) =
    SubExp
-> [HistOp (Aliases rep)]
-> Lambda (Aliases rep)
-> [VName]
-> SOAC (Aliases rep)
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [VName] -> SOAC rep
Hist
      SubExp
len
      ((HistOp rep -> HistOp (Aliases rep))
-> [HistOp rep] -> [HistOp (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map ((Lambda rep -> Lambda (Aliases rep))
-> HistOp rep -> HistOp (Aliases rep)
forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)) [HistOp rep]
ops)
      (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
bucket_fun)
      [VName]
imgs
  addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    SubExp -> [VName] -> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Aliases rep) -> SOAC (Aliases rep))
-> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$
      [Scan (Aliases rep)]
-> [Reduce (Aliases rep)]
-> Lambda (Aliases rep)
-> ScremaForm (Aliases rep)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
        ((Scan rep -> Scan (Aliases rep))
-> [Scan rep] -> [Scan (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Scan (Aliases rep)
onScan [Scan rep]
scans)
        ((Reduce rep -> Reduce (Aliases rep))
-> [Reduce rep] -> [Reduce (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Reduce (Aliases rep)
onRed [Reduce rep]
reds)
        (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
map_lam)
    where
      onRed :: Reduce rep -> Reduce (Aliases rep)
onRed Reduce rep
red = Reduce rep
red {redLambda :: Lambda (Aliases rep)
redLambda = AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep -> Lambda (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda Reduce rep
red}
      onScan :: Scan rep -> Scan (Aliases rep)
onScan Scan rep
scan = Scan rep
scan {scanLambda :: Lambda (Aliases rep)
scanLambda = AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep -> Lambda (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda Scan rep
scan}

  removeOpAliases :: OpWithAliases (SOAC rep) -> SOAC rep
removeOpAliases = Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC (Aliases rep) -> Identity (SOAC rep))
-> SOAC (Aliases rep)
-> SOAC rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Aliases rep) rep Identity
-> SOAC (Aliases rep) -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Aliases rep) rep Identity
remove
    where
      remove :: SOACMapper (Aliases rep) rep Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases rep) -> Identity (Lambda rep))
-> (VName -> Identity VName)
-> SOACMapper (Aliases rep) rep Identity
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Aliases rep) -> Lambda rep)
-> Lambda (Aliases rep)
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases rep) -> Lambda rep
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
removeLambdaAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance ASTRep rep => IsOp (SOAC rep) where
  safeOp :: SOAC rep -> Bool
safeOp SOAC rep
_ = Bool
False
  cheapOp :: SOAC rep -> Bool
cheapOp SOAC rep
_ = Bool
True

substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
  let shp' :: Shape
shp' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
   in PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u

substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
  SubExp -> VName -> Map VName SubExp -> SubExp
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs

instance (ASTRep rep, CanBeWise (Op rep)) => CanBeWise (SOAC rep) where
  type OpWithWisdom (SOAC rep) = SOAC (Wise rep)

  removeOpWisdom :: OpWithWisdom (SOAC rep) -> SOAC rep
removeOpWisdom = Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC (Wise rep) -> Identity (SOAC rep))
-> SOAC (Wise rep)
-> SOAC rep
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Wise rep) rep Identity
-> SOAC (Wise rep) -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Wise rep) rep Identity
remove
    where
      remove :: SOACMapper (Wise rep) rep Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise rep) -> Identity (Lambda rep))
-> (VName -> Identity VName)
-> SOACMapper (Wise rep) rep Identity
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Wise rep) -> Lambda rep)
-> Lambda (Wise rep)
-> Identity (Lambda rep)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise rep) -> Lambda rep
forall rep. CanBeWise (Op rep) => Lambda (Wise rep) -> Lambda rep
removeLambdaWisdom) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance RepTypes rep => ST.IndexOp (SOAC rep) where
  indexOp :: SymbolTable rep
-> Int -> SOAC rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k SOAC rep
soac [TPrimExp Int64 VName
i] = do
    (LambdaT rep
lam, SubExpRes
se, [Param (LParamInfo rep)]
arr_params, [VName]
arrs) <- SOAC rep
-> Maybe
     (LambdaT rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp SOAC rep
soac
    let arr_indexes :: Map VName (PrimExp VName, Certs)
arr_indexes = [(VName, (PrimExp VName, Certs))]
-> Map VName (PrimExp VName, Certs)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (PrimExp VName, Certs))]
 -> Map VName (PrimExp VName, Certs))
-> [(VName, (PrimExp VName, Certs))]
-> Map VName (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, (PrimExp VName, Certs))]
-> [(VName, (PrimExp VName, Certs))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, (PrimExp VName, Certs))]
 -> [(VName, (PrimExp VName, Certs))])
-> [Maybe (VName, (PrimExp VName, Certs))]
-> [(VName, (PrimExp VName, Certs))]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep)
 -> VName -> Maybe (VName, (PrimExp VName, Certs)))
-> [Param (LParamInfo rep)]
-> [VName]
-> [Maybe (VName, (PrimExp VName, Certs))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex [Param (LParamInfo rep)]
arr_params [VName]
arrs
        arr_indexes' :: Map VName (PrimExp VName, Certs)
arr_indexes' = (Map VName (PrimExp VName, Certs)
 -> Stm rep -> Map VName (PrimExp VName, Certs))
-> Map VName (PrimExp VName, Certs)
-> Seq (Stm rep)
-> Map VName (PrimExp VName, Certs)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
arr_indexes (Seq (Stm rep) -> Map VName (PrimExp VName, Certs))
-> Seq (Stm rep) -> Map VName (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Seq (Stm rep)) -> BodyT rep -> Seq (Stm rep)
forall a b. (a -> b) -> a -> b
$ LambdaT rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT rep
lam
    case SubExpRes
se of
      SubExpRes Certs
_ (Var VName
v) -> (PrimExp VName -> Certs -> Indexed)
-> (PrimExp VName, Certs) -> Indexed
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Certs -> PrimExp VName -> Indexed)
-> PrimExp VName -> Certs -> Indexed
forall a b c. (a -> b -> c) -> b -> a -> c
flip Certs -> PrimExp VName -> Indexed
ST.Indexed) ((PrimExp VName, Certs) -> Indexed)
-> Maybe (PrimExp VName, Certs) -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (PrimExp VName, Certs) -> Maybe (PrimExp VName, Certs)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
arr_indexes'
      SubExpRes
_ -> Maybe Indexed
forall a. Maybe a
Nothing
    where
      lambdaAndSubExp :: SOAC rep
-> Maybe
     (LambdaT rep, SubExpRes, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds LambdaT rep
map_lam)) =
        Int
-> LambdaT rep
-> [VName]
-> Maybe
     (LambdaT rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut ([Scan rep] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan rep]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce rep] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce rep]
reds) LambdaT rep
map_lam [VName]
arrs
      lambdaAndSubExp SOAC rep
_ =
        Maybe (LambdaT rep, SubExpRes, [Param (LParamInfo rep)], [VName])
forall a. Maybe a
Nothing

      nthMapOut :: Int
-> LambdaT rep
-> [VName]
-> Maybe
     (LambdaT rep, SubExpRes, [Param (LParamInfo rep)], [VName])
nthMapOut Int
num_accs LambdaT rep
lam [VName]
arrs = do
        SubExpRes
se <- Int -> Result -> Maybe SubExpRes
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) (Result -> Maybe SubExpRes) -> Result -> Maybe SubExpRes
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (BodyT rep -> Result) -> BodyT rep -> Result
forall a b. (a -> b) -> a -> b
$ LambdaT rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT rep
lam
        (LambdaT rep, SubExpRes, [Param (LParamInfo rep)], [VName])
-> Maybe
     (LambdaT rep, SubExpRes, [Param (LParamInfo rep)], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaT rep
lam, SubExpRes
se, Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ LambdaT rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams LambdaT rep
lam, [VName]
arrs)

      arrIndex :: Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certs))
arrIndex Param (LParamInfo rep)
p VName
arr = do
        ST.Indexed Certs
cs PrimExp VName
pe <- VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
forall rep.
VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable rep
vtable
        (VName, (PrimExp VName, Certs))
-> Maybe (VName, (PrimExp VName, Certs))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p, (PrimExp VName
pe, Certs
cs))

      expandPrimExpTable :: Map VName (PrimExp VName, Certs)
-> Stm rep -> Map VName (PrimExp VName, Certs)
expandPrimExpTable Map VName (PrimExp VName, Certs)
table Stm rep
stm
        | [VName
v] <- PatT (LetDec rep) -> [VName]
forall dec. PatT dec -> [VName]
patNames (PatT (LetDec rep) -> [VName]) -> PatT (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> PatT (LetDec rep)
forall rep. Stm rep -> Pat rep
stmPat Stm rep
stm,
          Just (PrimExp VName
pe, Certs
cs) <-
            WriterT Certs Maybe (PrimExp VName) -> Maybe (PrimExp VName, Certs)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certs Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certs))
-> WriterT Certs Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certs)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table) (Exp rep -> WriterT Certs Maybe (PrimExp VName))
-> Exp rep -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
          (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable rep
vtable) (Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm) =
          VName
-> (PrimExp VName, Certs)
-> Map VName (PrimExp VName, Certs)
-> Map VName (PrimExp VName, Certs)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts Stm rep
stm Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs) Map VName (PrimExp VName, Certs)
table
        | Bool
otherwise =
          Map VName (PrimExp VName, Certs)
table

      asPrimExp :: Map VName (PrimExp VName, Certs)
-> VName -> WriterT Certs Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certs)
table VName
v
        | Just (PrimExp VName
e, Certs
cs) <- VName
-> Map VName (PrimExp VName, Certs) -> Maybe (PrimExp VName, Certs)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certs)
table = Certs -> WriterT Certs Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certs
cs WriterT Certs Maybe ()
-> WriterT Certs Maybe (PrimExp VName)
-> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
        | Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
          PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certs Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certs Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
        | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certs Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable rep
_ Int
_ SOAC rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

-- | Type-check a SOAC.
typeCheckSOAC :: TC.Checkable rep => SOAC (Aliases rep) -> TC.TypeM rep ()
typeCheckSOAC :: SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC (Stream SubExp
size [VName]
arrexps StreamForm (Aliases rep)
form [SubExp]
accexps Lambda (Aliases rep)
lam) = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
  [Arg]
accargs <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
accexps
  [Type]
arrargs <- (VName -> TypeM rep Type) -> [VName] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrexps
  [Arg]
_ <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
  let chunk :: Param (LParamInfo rep)
chunk = [Param (LParamInfo rep)] -> Param (LParamInfo rep)
forall a. [a] -> a
head ([Param (LParamInfo rep)] -> Param (LParamInfo rep))
-> [Param (LParamInfo rep)] -> Param (LParamInfo rep)
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases rep) -> [LParam (Aliases rep)]
forall rep. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda (Aliases rep)
lam
  let asArg :: a -> (a, b)
asArg a
t = (a
t, b
forall a. Monoid a => a
mempty)
      inttp :: TypeBase shape u
inttp = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      lamarrs' :: [Type]
lamarrs' = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
chunk)) [Type]
arrargs
  let acc_len :: Int
acc_len = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
  let lamrtp :: [Type]
lamrtp = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
acc_len ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"Stream with inconsistent accumulator type in lambda."
  -- check reduce's lambda, if any
  ()
_ <- case StreamForm (Aliases rep)
form of
    Parallel StreamOrd
_ Commutativity
_ Lambda (Aliases rep)
lam0 -> do
      let acct :: [Type]
acct = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs
          outerRetType :: [Type]
outerRetType = Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam0
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam0 ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
accargs
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
acct [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
outerRetType) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            String
"Initial value is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
acct
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but stream's reduce lambda returns type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
outerRetType
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."
    StreamForm (Aliases rep)
Sequential -> () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  -- just get the dflow of lambda on the fakearg, which does not alias
  -- arr, so we can later check that aliases of arr are not used inside lam.
  let fake_lamarrs' :: [Arg]
fake_lamarrs' = (Type -> Arg) -> [Type] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Arg
forall b a. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Type -> Arg
forall b a. Monoid b => a -> (a, b)
asArg Type
forall shape u. TypeBase shape u
inttp Arg -> [Arg] -> [Arg]
forall a. a -> [a] -> [a]
: [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w Lambda (Aliases rep)
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
  -- Requirements:
  --
  --   0. @lambdaReturnType@ of @lam@ must be a list
  --      [index types..., value types, ...].
  --
  --   1. The number of index types and value types must be equal to the number
  --      of return values from @lam@.
  --
  --   2. Each index type must have the type i64.
  --
  --   3. Each array in @as@ and the value types must have the same type
  --
  --   4. Each array in @as@ is consumed.  This is not really a check, but more
  --      of a requirement, so that e.g. the source is not hoisted out of a
  --      loop, which will mean it cannot be consumed.
  --
  --   5. Each of ivs must be an array matching a corresponding lambda
  --      parameters.
  --
  -- Code:

  -- First check the input size.
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w

  -- 0.
  let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
      indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
      rts :: [Type]
rts = Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
      rtsI :: [Type]
rtsI = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
      rtsV :: [Type]
rtsV = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts

  -- 1.
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"Scatter: number of index types, value types and array outputs do not match."

  -- 2.
  [Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
rtI) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"Scatter: Index return type must be i64."

  [([Type], (Shape, Int, VName))]
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[Type]]
-> [(Shape, Int, VName)] -> [([Type], (Shape, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(Shape, Int, VName)]
as) ((([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ())
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
    -- All lengths must have type i64.
    (SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw

    -- 3.
    [Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtV -> [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a

    -- 4.
    Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
a

  -- 5.
  [Arg]
arrargs <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
ivs
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
len [HistOp (Aliases rep)]
ops Lambda (Aliases rep)
bucket_fun [VName]
imgs) = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
len

  -- Check the operators.
  [HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ())
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases rep)
op) -> do
    [Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
    [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dest_w
    [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf

    -- Operator type must match the type of neutral elements.
    Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
    let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
    Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
          String
"Operator has return type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
            String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t

    -- Arrays must have proper type.
    [(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
      [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
dest_w] VName
dest
      Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest

  -- Types of input arrays must equal parameter types for bucket function.
  [Arg]
img' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
len [VName]
imgs
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
bucket_fun [Arg]
img'

  -- Return type of bucket function must be an index for each
  -- operation followed by the values to write.
  [Type]
nes_ts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> TypeM rep [[Type]] -> TypeM rep [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp (Aliases rep) -> TypeM rep [Type])
-> [HistOp (Aliases rep)] -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType ([SubExp] -> TypeM rep [Type])
-> (HistOp (Aliases rep) -> [SubExp])
-> HistOp (Aliases rep)
-> TypeM rep [Type]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp (Aliases rep)]
ops
  let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases rep)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases rep)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
        String
"Bucket function has return type "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun)
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Aliases rep)]
scans [Reduce (Aliases rep)]
reds Lambda (Aliases rep)
map_lam)) = do
  [Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
  [Arg]
arrs' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
map_lam [Arg]
arrs'

  [Arg]
scan_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
    [Scan (Aliases rep)]
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases rep)]
scans ((Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases rep)
scan_lam [SubExp]
scan_nes) -> do
      [Arg]
scan_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
scan_nes
      let scan_t :: [Type]
scan_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
scan_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            String
"Scan function returns type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
scan_t
      [Arg] -> TypeM rep [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
scan_nes'

  [Arg]
red_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
    [Reduce (Aliases rep)]
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases rep)]
reds ((Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases rep)
red_lam [SubExp]
red_nes) -> do
      [Arg]
red_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
red_nes
      let red_t :: [Type]
red_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
      Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
red_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
      Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
        ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
          String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
            String
"Reduce function returns type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
red_t
      [Arg] -> TypeM rep [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
red_nes'

  let map_lam_ts :: [Type]
map_lam_ts = Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
map_lam

  Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    ( Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
        [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
    )
    (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
      String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
        String
"Map function return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
map_lam_ts
          String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" wrong for given scan and reduction functions."

instance OpMetrics (Op rep) => OpMetrics (SOAC rep) where
  opMetrics :: SOAC rep -> MetricsM ()
opMetrics (Stream SubExp
_ [VName]
_ StreamForm rep
_ [SubExp]
_ Lambda rep
lam) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Scatter SubExp
_len Lambda rep
lam [VName]
_ivs [(Shape, Int, VName)]
_as) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
  opMetrics (Hist SubExp
_len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
_imgs) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops MetricsM () -> MetricsM () -> MetricsM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
bucket_fun
  opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
      (Scan rep -> MetricsM ()) -> [Scan rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Scan rep -> Lambda rep) -> Scan rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
      (Reduce rep -> MetricsM ()) -> [Reduce rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Reduce rep -> Lambda rep) -> Reduce rep -> MetricsM ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
      Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
map_lam

instance PrettyRep rep => PP.Pretty (SOAC rep) where
  ppr :: SOAC rep -> Doc
ppr (Stream SubExp
size [VName]
arrs StreamForm rep
form [SubExp]
acc Lambda rep
lam) =
    case StreamForm rep
form of
      Parallel StreamOrd
o Commutativity
comm Lambda rep
lam0 ->
        let ord_str :: String
ord_str = if StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
== StreamOrd
Disorder then String
"Per" else String
""
            comm_str :: String
comm_str = case Commutativity
comm of
              Commutativity
Commutative -> String
"Comm"
              Commutativity
Noncommutative -> String
""
         in String -> Doc
text (String
"streamPar" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
ord_str String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
comm_str)
              Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
                ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam0 Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                    Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam
                )
      StreamForm rep
Sequential ->
        String -> Doc
text String
"streamSeq"
          Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
            ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
                Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam
            )
  ppr (Scatter SubExp
w Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
as) =
    Doc
"scatter"
      Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
        ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
            Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
            Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ([VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
ivs Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: ((Shape, Int, VName) -> Doc) -> [(Shape, Int, VName)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, Int, VName) -> Doc
forall a. Pretty a => a -> Doc
ppr [(Shape, Int, VName)]
as)
        )
  ppr (Hist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
imgs) =
    SubExp -> [HistOp rep] -> Lambda rep -> [VName] -> Doc
forall rep inp.
(PrettyRep rep, Pretty inp) =>
SubExp -> [HistOp rep] -> Lambda rep -> [inp] -> Doc
ppHist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
imgs
  ppr (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam))
    | [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans,
      [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
      String -> Doc
text String
"map"
        Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
          ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
          )
    | [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans =
      String -> Doc
text String
"redomap"
        Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
          ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc) -> [Reduce rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce rep]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
          )
    | [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
      String -> Doc
text String
"scanomap"
        Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
          ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc) -> [Scan rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan rep]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
              Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
          )
  ppr (Screma SubExp
w [VName]
arrs ScremaForm rep
form) = SubExp -> [VName] -> ScremaForm rep -> Doc
forall rep inp.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc
ppScrema SubExp
w [VName]
arrs ScremaForm rep
form

-- | Prettyprint the given Screma.
ppScrema ::
  (PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc
ppScrema :: SubExp -> [inp] -> ScremaForm rep -> Doc
ppScrema SubExp
w [inp]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
  String -> Doc
text String
"screma"
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
      ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> [inp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [inp]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc) -> [Scan rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan rep]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc) -> [Reduce rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce rep]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
      )

instance PrettyRep rep => Pretty (Scan rep) where
  ppr :: Scan rep -> Doc
ppr (Scan Lambda rep
scan_lam [SubExp]
scan_nes) =
    Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
scan_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
scan_nes)

ppComm :: Commutativity -> Doc
ppComm :: Commutativity -> Doc
ppComm Commutativity
Noncommutative = Doc
forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = String -> Doc
text String
"commutative "

instance PrettyRep rep => Pretty (Reduce rep) where
  ppr :: Reduce rep -> Doc
ppr (Reduce Commutativity
comm Lambda rep
red_lam [SubExp]
red_nes) =
    Commutativity -> Doc
ppComm Commutativity
comm Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
red_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
      Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
red_nes)

-- | Prettyprint the given histogram operation.
ppHist ::
  (PrettyRep rep, Pretty inp) =>
  SubExp ->
  [HistOp rep] ->
  Lambda rep ->
  [inp] ->
  Doc
ppHist :: SubExp -> [HistOp rep] -> Lambda rep -> [inp] -> Doc
ppHist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [inp]
imgs =
  String -> Doc
text String
"hist"
    Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
      ( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
len Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc) -> [HistOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc
forall rep. PrettyRep rep => HistOp rep -> Doc
ppOp [HistOp rep]
ops) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
bucket_fun Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
          Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ((inp -> Doc) -> [inp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc
forall a. Pretty a => a -> Doc
ppr [inp]
imgs)
      )
  where
    ppOp :: HistOp rep -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
      SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
        Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
        Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
op