{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ConstraintKinds #-}
-- | Definition of /Second-Order Array Combinators/ (SOACs), which are
-- the main form of parallelism in the early stages of the compiler.
module Futhark.IR.SOACS.SOAC
       ( SOAC(..)
       , StreamOrd(..)
       , StreamForm(..)
       , ScremaForm(..)
       , HistOp(..)
       , Scan(..)
       , scanResults
       , singleScan
       , Reduce(..)
       , redResults
       , singleReduce

         -- * Utility
       , getStreamAccums
       , scremaType
       , soacType

       , typeCheckSOAC

       , mkIdentityLambda
       , isIdentityLambda
       , nilFn
       , scanomapSOAC
       , redomapSOAC
       , scanSOAC
       , reduceSOAC
       , mapSOAC
       , isScanomapSOAC
       , isRedomapSOAC
       , isScanSOAC
       , isReduceSOAC
       , isMapSOAC

       , ppScrema
       , ppHist

         -- * Generic traversal
       , SOACMapper(..)
       , identitySOACMapper
       , mapSOACM
       )
       where

import Control.Monad.State.Strict
import Control.Monad.Writer
import Control.Monad.Identity
import qualified Data.Map.Strict as M
import Data.Maybe
import Data.List (intersperse)

import Futhark.IR
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Util.Pretty as PP
import Futhark.Util.Pretty (ppr, Doc, Pretty, parens, comma, (</>), (<+>), commasep, text)
import Futhark.IR.Prop.Aliases
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Optimise.Simplify.Lore
import Futhark.IR.Aliases (Aliases, removeLambdaAliases)
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.TypeCheck as TC
import Futhark.Analysis.Metrics
import Futhark.Construct
import Futhark.Util (maybeNth, chunks)

-- | A second-order array combinator (SOAC).
data SOAC lore =
    Stream SubExp (StreamForm lore) (Lambda lore) [VName]
  | Scatter SubExp (Lambda lore) [VName] [(SubExp, Int, VName)]
    -- ^ @Scatter <cs> <length> <lambda> <original index and value arrays>@
    --
    -- <input/output arrays along with their sizes and number of
    -- values to write for that array>
    --
    -- <length> is the length of each index array and value array, since they
    -- all must be the same length for any fusion to make sense.  If you have a
    -- list of index-value array pairs of different sizes, you need to use
    -- multiple writes instead.
    --
    -- The lambda body returns the output in this manner:
    --
    --     [index_0, index_1, ..., index_n, value_0, value_1, ..., value_n]
    --
    -- This must be consistent along all Scatter-related optimisations.
  | Hist SubExp [HistOp lore] (Lambda lore) [VName]
    -- ^ @Hist <length> <dest-arrays-and-ops> <bucket fun> <input arrays>@
    --
    -- The first SubExp is the length of the input arrays. The first
    -- list describes the operations to perform.  The t'Lambda' is the
    -- bucket function.  Finally comes the input images.
  | Screma SubExp (ScremaForm lore) [VName]
    -- ^ A combination of scan, reduction, and map.  The first
    -- t'SubExp' is the size of the input arrays.
    deriving (SOAC lore -> SOAC lore -> Bool
(SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool) -> Eq (SOAC lore)
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC lore -> SOAC lore -> Bool
$c/= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
== :: SOAC lore -> SOAC lore -> Bool
$c== :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
Eq, Eq (SOAC lore)
Eq (SOAC lore)
-> (SOAC lore -> SOAC lore -> Ordering)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> Bool)
-> (SOAC lore -> SOAC lore -> SOAC lore)
-> (SOAC lore -> SOAC lore -> SOAC lore)
-> Ord (SOAC lore)
SOAC lore -> SOAC lore -> Bool
SOAC lore -> SOAC lore -> Ordering
SOAC lore -> SOAC lore -> SOAC lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (SOAC lore)
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
forall lore. Decorations lore => SOAC lore -> SOAC lore -> Ordering
forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
min :: SOAC lore -> SOAC lore -> SOAC lore
$cmin :: forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
max :: SOAC lore -> SOAC lore -> SOAC lore
$cmax :: forall lore.
Decorations lore =>
SOAC lore -> SOAC lore -> SOAC lore
>= :: SOAC lore -> SOAC lore -> Bool
$c>= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
> :: SOAC lore -> SOAC lore -> Bool
$c> :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
<= :: SOAC lore -> SOAC lore -> Bool
$c<= :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
< :: SOAC lore -> SOAC lore -> Bool
$c< :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Bool
compare :: SOAC lore -> SOAC lore -> Ordering
$ccompare :: forall lore. Decorations lore => SOAC lore -> SOAC lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (SOAC lore)
Ord, Int -> SOAC lore -> ShowS
[SOAC lore] -> ShowS
SOAC lore -> String
(Int -> SOAC lore -> ShowS)
-> (SOAC lore -> String)
-> ([SOAC lore] -> ShowS)
-> Show (SOAC lore)
forall lore. Decorations lore => Int -> SOAC lore -> ShowS
forall lore. Decorations lore => [SOAC lore] -> ShowS
forall lore. Decorations lore => SOAC lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [SOAC lore] -> ShowS
show :: SOAC lore -> String
$cshow :: forall lore. Decorations lore => SOAC lore -> String
showsPrec :: Int -> SOAC lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> SOAC lore -> ShowS
Show)

-- | Information about computing a single histogram.
data HistOp lore = HistOp { HistOp lore -> SubExp
histWidth :: SubExp
                          , HistOp lore -> SubExp
histRaceFactor :: SubExp
                          -- ^ Race factor @RF@ means that only @1/RF@
                          -- bins are used.
                          , HistOp lore -> [VName]
histDest :: [VName]
                          , HistOp lore -> [SubExp]
histNeutral :: [SubExp]
                          , HistOp lore -> Lambda lore
histOp :: Lambda lore
                          }
                      deriving (HistOp lore -> HistOp lore -> Bool
(HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool) -> Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp lore -> HistOp lore -> Bool
$c/= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
== :: HistOp lore -> HistOp lore -> Bool
$c== :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
Eq, Eq (HistOp lore)
Eq (HistOp lore)
-> (HistOp lore -> HistOp lore -> Ordering)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> Bool)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> (HistOp lore -> HistOp lore -> HistOp lore)
-> Ord (HistOp lore)
HistOp lore -> HistOp lore -> Bool
HistOp lore -> HistOp lore -> Ordering
HistOp lore -> HistOp lore -> HistOp lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (HistOp lore)
forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
min :: HistOp lore -> HistOp lore -> HistOp lore
$cmin :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
max :: HistOp lore -> HistOp lore -> HistOp lore
$cmax :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> HistOp lore
>= :: HistOp lore -> HistOp lore -> Bool
$c>= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
> :: HistOp lore -> HistOp lore -> Bool
$c> :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
<= :: HistOp lore -> HistOp lore -> Bool
$c<= :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
< :: HistOp lore -> HistOp lore -> Bool
$c< :: forall lore. Decorations lore => HistOp lore -> HistOp lore -> Bool
compare :: HistOp lore -> HistOp lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
HistOp lore -> HistOp lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (HistOp lore)
Ord, Int -> HistOp lore -> ShowS
[HistOp lore] -> ShowS
HistOp lore -> String
(Int -> HistOp lore -> ShowS)
-> (HistOp lore -> String)
-> ([HistOp lore] -> ShowS)
-> Show (HistOp lore)
forall lore. Decorations lore => Int -> HistOp lore -> ShowS
forall lore. Decorations lore => [HistOp lore] -> ShowS
forall lore. Decorations lore => HistOp lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [HistOp lore] -> ShowS
show :: HistOp lore -> String
$cshow :: forall lore. Decorations lore => HistOp lore -> String
showsPrec :: Int -> HistOp lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> HistOp lore -> ShowS
Show)

-- | Is the stream chunk required to correspond to a contiguous
-- subsequence of the original input ('InOrder') or not?  'Disorder'
-- streams can be more efficient, but not all algorithms work with
-- this.
data StreamOrd  = InOrder | Disorder
                deriving (StreamOrd -> StreamOrd -> Bool
(StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool) -> Eq StreamOrd
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamOrd -> StreamOrd -> Bool
$c/= :: StreamOrd -> StreamOrd -> Bool
== :: StreamOrd -> StreamOrd -> Bool
$c== :: StreamOrd -> StreamOrd -> Bool
Eq, Eq StreamOrd
Eq StreamOrd
-> (StreamOrd -> StreamOrd -> Ordering)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> Ord StreamOrd
StreamOrd -> StreamOrd -> Bool
StreamOrd -> StreamOrd -> Ordering
StreamOrd -> StreamOrd -> StreamOrd
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: StreamOrd -> StreamOrd -> StreamOrd
$cmin :: StreamOrd -> StreamOrd -> StreamOrd
max :: StreamOrd -> StreamOrd -> StreamOrd
$cmax :: StreamOrd -> StreamOrd -> StreamOrd
>= :: StreamOrd -> StreamOrd -> Bool
$c>= :: StreamOrd -> StreamOrd -> Bool
> :: StreamOrd -> StreamOrd -> Bool
$c> :: StreamOrd -> StreamOrd -> Bool
<= :: StreamOrd -> StreamOrd -> Bool
$c<= :: StreamOrd -> StreamOrd -> Bool
< :: StreamOrd -> StreamOrd -> Bool
$c< :: StreamOrd -> StreamOrd -> Bool
compare :: StreamOrd -> StreamOrd -> Ordering
$ccompare :: StreamOrd -> StreamOrd -> Ordering
$cp1Ord :: Eq StreamOrd
Ord, Int -> StreamOrd -> ShowS
[StreamOrd] -> ShowS
StreamOrd -> String
(Int -> StreamOrd -> ShowS)
-> (StreamOrd -> String)
-> ([StreamOrd] -> ShowS)
-> Show StreamOrd
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamOrd] -> ShowS
$cshowList :: [StreamOrd] -> ShowS
show :: StreamOrd -> String
$cshow :: StreamOrd -> String
showsPrec :: Int -> StreamOrd -> ShowS
$cshowsPrec :: Int -> StreamOrd -> ShowS
Show)

-- | What kind of stream is this?
data StreamForm lore  =
    Parallel StreamOrd Commutativity (Lambda lore) [SubExp]
  | Sequential [SubExp]
  deriving (StreamForm lore -> StreamForm lore -> Bool
(StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> Eq (StreamForm lore)
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamForm lore -> StreamForm lore -> Bool
$c/= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
== :: StreamForm lore -> StreamForm lore -> Bool
$c== :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
Eq, Eq (StreamForm lore)
Eq (StreamForm lore)
-> (StreamForm lore -> StreamForm lore -> Ordering)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> Bool)
-> (StreamForm lore -> StreamForm lore -> StreamForm lore)
-> (StreamForm lore -> StreamForm lore -> StreamForm lore)
-> Ord (StreamForm lore)
StreamForm lore -> StreamForm lore -> Bool
StreamForm lore -> StreamForm lore -> Ordering
StreamForm lore -> StreamForm lore -> StreamForm lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (StreamForm lore)
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Ordering
forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
min :: StreamForm lore -> StreamForm lore -> StreamForm lore
$cmin :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
max :: StreamForm lore -> StreamForm lore -> StreamForm lore
$cmax :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> StreamForm lore
>= :: StreamForm lore -> StreamForm lore -> Bool
$c>= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
> :: StreamForm lore -> StreamForm lore -> Bool
$c> :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
<= :: StreamForm lore -> StreamForm lore -> Bool
$c<= :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
< :: StreamForm lore -> StreamForm lore -> Bool
$c< :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Bool
compare :: StreamForm lore -> StreamForm lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
StreamForm lore -> StreamForm lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (StreamForm lore)
Ord, Int -> StreamForm lore -> ShowS
[StreamForm lore] -> ShowS
StreamForm lore -> String
(Int -> StreamForm lore -> ShowS)
-> (StreamForm lore -> String)
-> ([StreamForm lore] -> ShowS)
-> Show (StreamForm lore)
forall lore. Decorations lore => Int -> StreamForm lore -> ShowS
forall lore. Decorations lore => [StreamForm lore] -> ShowS
forall lore. Decorations lore => StreamForm lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamForm lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [StreamForm lore] -> ShowS
show :: StreamForm lore -> String
$cshow :: forall lore. Decorations lore => StreamForm lore -> String
showsPrec :: Int -> StreamForm lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> StreamForm lore -> ShowS
Show)

-- | The essential parts of a 'Screma' factored out (everything
-- except the input arrays).
data ScremaForm lore = ScremaForm
                         [Scan lore]
                         [Reduce lore]
                         (Lambda lore)
  deriving (ScremaForm lore -> ScremaForm lore -> Bool
(ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> Eq (ScremaForm lore)
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScremaForm lore -> ScremaForm lore -> Bool
$c/= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
== :: ScremaForm lore -> ScremaForm lore -> Bool
$c== :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
Eq, Eq (ScremaForm lore)
Eq (ScremaForm lore)
-> (ScremaForm lore -> ScremaForm lore -> Ordering)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> Bool)
-> (ScremaForm lore -> ScremaForm lore -> ScremaForm lore)
-> (ScremaForm lore -> ScremaForm lore -> ScremaForm lore)
-> Ord (ScremaForm lore)
ScremaForm lore -> ScremaForm lore -> Bool
ScremaForm lore -> ScremaForm lore -> Ordering
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (ScremaForm lore)
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Ordering
forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
min :: ScremaForm lore -> ScremaForm lore -> ScremaForm lore
$cmin :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
max :: ScremaForm lore -> ScremaForm lore -> ScremaForm lore
$cmax :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> ScremaForm lore
>= :: ScremaForm lore -> ScremaForm lore -> Bool
$c>= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
> :: ScremaForm lore -> ScremaForm lore -> Bool
$c> :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
<= :: ScremaForm lore -> ScremaForm lore -> Bool
$c<= :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
< :: ScremaForm lore -> ScremaForm lore -> Bool
$c< :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Bool
compare :: ScremaForm lore -> ScremaForm lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
ScremaForm lore -> ScremaForm lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (ScremaForm lore)
Ord, Int -> ScremaForm lore -> ShowS
[ScremaForm lore] -> ShowS
ScremaForm lore -> String
(Int -> ScremaForm lore -> ShowS)
-> (ScremaForm lore -> String)
-> ([ScremaForm lore] -> ShowS)
-> Show (ScremaForm lore)
forall lore. Decorations lore => Int -> ScremaForm lore -> ShowS
forall lore. Decorations lore => [ScremaForm lore] -> ShowS
forall lore. Decorations lore => ScremaForm lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScremaForm lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [ScremaForm lore] -> ShowS
show :: ScremaForm lore -> String
$cshow :: forall lore. Decorations lore => ScremaForm lore -> String
showsPrec :: Int -> ScremaForm lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> ScremaForm lore -> ShowS
Show)

singleBinOp :: Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp :: [Lambda lore] -> Lambda lore
singleBinOp [Lambda lore]
lams =
  Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [LParam lore]
lambdaParams = (Lambda lore -> [Param Type]) -> [Lambda lore] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Param Type]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
xParams [Lambda lore]
lams [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Lambda lore -> [Param Type]) -> [Lambda lore] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Param Type]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
yParams [Lambda lore]
lams
         , lambdaReturnType :: [Type]
lambdaReturnType = (Lambda lore -> [Type]) -> [Lambda lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType [Lambda lore]
lams
         , lambdaBody :: BodyT lore
lambdaBody = Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody ([Stms lore] -> Stms lore
forall a. Monoid a => [a] -> a
mconcat ((Lambda lore -> Stms lore) -> [Lambda lore] -> [Stms lore]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore)
-> (Lambda lore -> BodyT lore) -> Lambda lore -> Stms lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda lore]
lams))
                        ((Lambda lore -> [SubExp]) -> [Lambda lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp])
-> (Lambda lore -> BodyT lore) -> Lambda lore -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody) [Lambda lore]
lams)
         }
  where xParams :: LambdaT lore -> [Param (LParamInfo lore)]
xParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam)
        yParams :: LambdaT lore -> [Param (LParamInfo lore)]
yParams LambdaT lore
lam = Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT lore
lam)) (LambdaT lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam)

-- | How to compute a single scan result.
data Scan lore = Scan { Scan lore -> Lambda lore
scanLambda :: Lambda lore
                      , Scan lore -> [SubExp]
scanNeutral :: [SubExp]
                      }
               deriving (Scan lore -> Scan lore -> Bool
(Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool) -> Eq (Scan lore)
forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scan lore -> Scan lore -> Bool
$c/= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
== :: Scan lore -> Scan lore -> Bool
$c== :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
Eq, Eq (Scan lore)
Eq (Scan lore)
-> (Scan lore -> Scan lore -> Ordering)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Bool)
-> (Scan lore -> Scan lore -> Scan lore)
-> (Scan lore -> Scan lore -> Scan lore)
-> Ord (Scan lore)
Scan lore -> Scan lore -> Bool
Scan lore -> Scan lore -> Ordering
Scan lore -> Scan lore -> Scan lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (Scan lore)
forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
forall lore. Decorations lore => Scan lore -> Scan lore -> Ordering
forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
min :: Scan lore -> Scan lore -> Scan lore
$cmin :: forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
max :: Scan lore -> Scan lore -> Scan lore
$cmax :: forall lore.
Decorations lore =>
Scan lore -> Scan lore -> Scan lore
>= :: Scan lore -> Scan lore -> Bool
$c>= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
> :: Scan lore -> Scan lore -> Bool
$c> :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
<= :: Scan lore -> Scan lore -> Bool
$c<= :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
< :: Scan lore -> Scan lore -> Bool
$c< :: forall lore. Decorations lore => Scan lore -> Scan lore -> Bool
compare :: Scan lore -> Scan lore -> Ordering
$ccompare :: forall lore. Decorations lore => Scan lore -> Scan lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (Scan lore)
Ord, Int -> Scan lore -> ShowS
[Scan lore] -> ShowS
Scan lore -> String
(Int -> Scan lore -> ShowS)
-> (Scan lore -> String)
-> ([Scan lore] -> ShowS)
-> Show (Scan lore)
forall lore. Decorations lore => Int -> Scan lore -> ShowS
forall lore. Decorations lore => [Scan lore] -> ShowS
forall lore. Decorations lore => Scan lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scan lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [Scan lore] -> ShowS
show :: Scan lore -> String
$cshow :: forall lore. Decorations lore => Scan lore -> String
showsPrec :: Int -> Scan lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> Scan lore -> ShowS
Show)

-- | How many reduction results are produced by these 'Scan's?
scanResults :: [Scan lore] -> Int
scanResults :: [Scan lore] -> Int
scanResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Scan lore] -> [Int]) -> [Scan lore] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scan lore -> Int) -> [Scan lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Scan lore -> [SubExp]) -> Scan lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scan lore -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral)

-- | Combine multiple scan operators to a single operator.
singleScan :: Bindable lore => [Scan lore] -> Scan lore
singleScan :: [Scan lore] -> Scan lore
singleScan [Scan lore]
scans =
  let scan_nes :: [SubExp]
scan_nes = (Scan lore -> [SubExp]) -> [Scan lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan lore -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan lore]
scans
      scan_lam :: Lambda lore
scan_lam = [Lambda lore] -> Lambda lore
forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp ([Lambda lore] -> Lambda lore) -> [Lambda lore] -> Lambda lore
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Lambda lore) -> [Scan lore] -> [Lambda lore]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda [Scan lore]
scans
  in Lambda lore -> [SubExp] -> Scan lore
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan Lambda lore
scan_lam [SubExp]
scan_nes

-- | How to compute a single reduction result.
data Reduce lore = Reduce { Reduce lore -> Commutativity
redComm :: Commutativity
                          , Reduce lore -> Lambda lore
redLambda :: Lambda lore
                          , Reduce lore -> [SubExp]
redNeutral :: [SubExp]
                          }
                   deriving (Reduce lore -> Reduce lore -> Bool
(Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool) -> Eq (Reduce lore)
forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reduce lore -> Reduce lore -> Bool
$c/= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
== :: Reduce lore -> Reduce lore -> Bool
$c== :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
Eq, Eq (Reduce lore)
Eq (Reduce lore)
-> (Reduce lore -> Reduce lore -> Ordering)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Bool)
-> (Reduce lore -> Reduce lore -> Reduce lore)
-> (Reduce lore -> Reduce lore -> Reduce lore)
-> Ord (Reduce lore)
Reduce lore -> Reduce lore -> Bool
Reduce lore -> Reduce lore -> Ordering
Reduce lore -> Reduce lore -> Reduce lore
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall lore. Decorations lore => Eq (Reduce lore)
forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Ordering
forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
min :: Reduce lore -> Reduce lore -> Reduce lore
$cmin :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
max :: Reduce lore -> Reduce lore -> Reduce lore
$cmax :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Reduce lore
>= :: Reduce lore -> Reduce lore -> Bool
$c>= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
> :: Reduce lore -> Reduce lore -> Bool
$c> :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
<= :: Reduce lore -> Reduce lore -> Bool
$c<= :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
< :: Reduce lore -> Reduce lore -> Bool
$c< :: forall lore. Decorations lore => Reduce lore -> Reduce lore -> Bool
compare :: Reduce lore -> Reduce lore -> Ordering
$ccompare :: forall lore.
Decorations lore =>
Reduce lore -> Reduce lore -> Ordering
$cp1Ord :: forall lore. Decorations lore => Eq (Reduce lore)
Ord, Int -> Reduce lore -> ShowS
[Reduce lore] -> ShowS
Reduce lore -> String
(Int -> Reduce lore -> ShowS)
-> (Reduce lore -> String)
-> ([Reduce lore] -> ShowS)
-> Show (Reduce lore)
forall lore. Decorations lore => Int -> Reduce lore -> ShowS
forall lore. Decorations lore => [Reduce lore] -> ShowS
forall lore. Decorations lore => Reduce lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reduce lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [Reduce lore] -> ShowS
show :: Reduce lore -> String
$cshow :: forall lore. Decorations lore => Reduce lore -> String
showsPrec :: Int -> Reduce lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> Reduce lore -> ShowS
Show)

-- | How many reduction results are produced by these 'Reduce's?
redResults :: [Reduce lore] -> Int
redResults :: [Reduce lore] -> Int
redResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Reduce lore] -> [Int]) -> [Reduce lore] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Reduce lore -> Int) -> [Reduce lore] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (Reduce lore -> [SubExp]) -> Reduce lore -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reduce lore -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral)

-- | Combine multiple reduction operators to a single operator.
singleReduce :: Bindable lore => [Reduce lore] -> Reduce lore
singleReduce :: [Reduce lore] -> Reduce lore
singleReduce [Reduce lore]
reds =
  let red_nes :: [SubExp]
red_nes = (Reduce lore -> [SubExp]) -> [Reduce lore] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce lore -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce lore]
reds
      red_lam :: Lambda lore
red_lam = [Lambda lore] -> Lambda lore
forall lore. Bindable lore => [Lambda lore] -> Lambda lore
singleBinOp ([Lambda lore] -> Lambda lore) -> [Lambda lore] -> Lambda lore
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Lambda lore) -> [Reduce lore] -> [Lambda lore]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda [Reduce lore]
reds
  in Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce ([Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((Reduce lore -> Commutativity) -> [Reduce lore] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Commutativity
forall lore. Reduce lore -> Commutativity
redComm [Reduce lore]
reds)) Lambda lore
red_lam [SubExp]
red_nes

-- | The types produced by a single 'Screma', given the size of the
-- input array.
scremaType :: SubExp -> ScremaForm lore -> [Type]
scremaType :: SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) =
  [Type]
scan_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
red_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
  where scan_tps :: [Type]
scan_tps = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
                   (Scan lore -> [Type]) -> [Scan lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Scan lore -> Lambda lore) -> Scan lore -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans
        red_tps :: [Type]
red_tps  = (Reduce lore -> [Type]) -> [Reduce lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Reduce lore -> Lambda lore) -> Reduce lore -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds
        map_tps :: [Type]
map_tps  = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
map_lam

-- | Construct a lambda that takes parameters of the given types and
-- simply returns them unchanged.
mkIdentityLambda :: (Bindable lore, MonadFreshNames m) =>
                    [Type] -> m (Lambda lore)
mkIdentityLambda :: [Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts = do
  [Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
  Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda { lambdaParams :: [LParam lore]
lambdaParams = [Param Type]
[LParam lore]
params
                , lambdaBody :: BodyT lore
lambdaBody = Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty ([SubExp] -> BodyT lore) -> [SubExp] -> BodyT lore
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
params
                , lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts }

-- | Is the given lambda an identity lambda?
isIdentityLambda :: Lambda lore -> Bool
isIdentityLambda :: Lambda lore -> Bool
isIdentityLambda Lambda lore
lam = BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
==
                       (Param (LParamInfo lore) -> SubExp)
-> [Param (LParamInfo lore)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (LParamInfo lore) -> VName)
-> Param (LParamInfo lore)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName) (Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
lam)

-- | A lambda with no parameters that returns no values.
nilFn :: Bindable lore => Lambda lore
nilFn :: Lambda lore
nilFn = [LParam lore] -> BodyT lore -> [Type] -> Lambda lore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam lore]
forall a. Monoid a => a
mempty (Stms lore -> [SubExp] -> BodyT lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty [SubExp]
forall a. Monoid a => a
mempty) [Type]
forall a. Monoid a => a
mempty

-- | Construct a Screma with possibly multiple scans, and
-- the given map function.
scanomapSOAC :: [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC :: [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC [Scan lore]
scans = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan lore]
scans []

-- | Construct a Screma with possibly multiple reductions, and
-- the given map function.
redomapSOAC :: [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC :: [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm []

-- | Construct a Screma with possibly multiple scans, and identity map
-- function.
scanSOAC :: (Bindable lore, MonadFreshNames m) =>
            [Scan lore] -> m (ScremaForm lore)
scanSOAC :: [Scan lore] -> m (ScremaForm lore)
scanSOAC [Scan lore]
scans = [Scan lore] -> Lambda lore -> ScremaForm lore
forall lore. [Scan lore] -> Lambda lore -> ScremaForm lore
scanomapSOAC [Scan lore]
scans (Lambda lore -> ScremaForm lore)
-> m (Lambda lore) -> m (ScremaForm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts
  where ts :: [Type]
ts = (Scan lore -> [Type]) -> [Scan lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Scan lore -> Lambda lore) -> Scan lore -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans

-- | Construct a Screma with possibly multiple reductions, and
-- identity map function.
reduceSOAC :: (Bindable lore, MonadFreshNames m) =>
              [Reduce lore] -> m (ScremaForm lore)
reduceSOAC :: [Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Reduce lore]
reds = [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore. [Reduce lore] -> Lambda lore -> ScremaForm lore
redomapSOAC [Reduce lore]
reds (Lambda lore -> ScremaForm lore)
-> m (Lambda lore) -> m (ScremaForm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda lore)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Type] -> m (Lambda lore)
mkIdentityLambda [Type]
ts
  where ts :: [Type]
ts = (Reduce lore -> [Type]) -> [Reduce lore] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type])
-> (Reduce lore -> Lambda lore) -> Reduce lore -> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds

-- | Construct a Screma corresponding to a map.
mapSOAC :: Lambda lore -> ScremaForm lore
mapSOAC :: Lambda lore -> ScremaForm lore
mapSOAC = [Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [] []

-- | Does this Screma correspond to a scan-map composition?
isScanomapSOAC :: ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC :: ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
  ([Scan lore], Lambda lore) -> Maybe ([Scan lore], Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Scan lore]
scans, Lambda lore
map_lam)

-- | Does this Screma correspond to pure scan?
isScanSOAC :: ScremaForm lore -> Maybe [Scan lore]
isScanSOAC :: ScremaForm lore -> Maybe [Scan lore]
isScanSOAC ScremaForm lore
form = do ([Scan lore]
scans, Lambda lore
map_lam) <- ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm lore
form
                     Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
map_lam
                     [Scan lore] -> Maybe [Scan lore]
forall (m :: * -> *) a. Monad m => a -> m a
return [Scan lore]
scans

-- | Does this Screma correspond to a reduce-map composition?
isRedomapSOAC :: ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC :: ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
  ([Reduce lore], Lambda lore) -> Maybe ([Reduce lore], Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Reduce lore]
reds, Lambda lore
map_lam)

-- | Does this Screma correspond to a pure reduce?
isReduceSOAC :: ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC :: ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm lore
form = do ([Reduce lore]
reds, Lambda lore
map_lam) <- ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm lore
form
                       Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Bool
forall lore. Lambda lore -> Bool
isIdentityLambda Lambda lore
map_lam
                       [Reduce lore] -> Maybe [Reduce lore]
forall (m :: * -> *) a. Monad m => a -> m a
return [Reduce lore]
reds

-- | Does this Screma correspond to a simple map, without any
-- reduction or scan results?
isMapSOAC :: ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC :: ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds
  Lambda lore -> Maybe (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda lore
map_lam

-- | Like 'Mapper', but just for 'SOAC's.
data SOACMapper flore tlore m = SOACMapper {
    SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp
  , SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda :: Lambda flore -> m (Lambda tlore)
  , SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
  }

-- | A mapper that simply returns the SOAC verbatim.
identitySOACMapper :: Monad m => SOACMapper lore lore m
identitySOACMapper :: SOACMapper lore lore m
identitySOACMapper = SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper { mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                                , mapOnSOACLambda :: Lambda lore -> m (Lambda lore)
mapOnSOACLambda = Lambda lore -> m (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return
                                , mapOnSOACVName :: VName -> m VName
mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
                                }

-- | Map a monadic action across the immediate children of a
-- SOAC.  The mapping does not descend recursively into subexpressions
-- and is done left-to-right.
mapSOACM :: (Applicative m, Monad m) =>
            SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM :: SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper flore tlore m
tv (Stream SubExp
size StreamForm flore
form Lambda flore
lam [VName]
arrs) =
  SubExp -> StreamForm tlore -> Lambda tlore -> [VName] -> SOAC tlore
forall lore.
SubExp -> StreamForm lore -> Lambda lore -> [VName] -> SOAC lore
Stream (SubExp
 -> StreamForm tlore -> Lambda tlore -> [VName] -> SOAC tlore)
-> m SubExp
-> m (StreamForm tlore -> Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
size m (StreamForm tlore -> Lambda tlore -> [VName] -> SOAC tlore)
-> m (StreamForm tlore)
-> m (Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
  StreamForm flore -> m (StreamForm tlore)
mapOnStreamForm StreamForm flore
form m (Lambda tlore -> [VName] -> SOAC tlore)
-> m (Lambda tlore) -> m ([VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam m ([VName] -> SOAC tlore) -> m [VName] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
  (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
  where mapOnStreamForm :: StreamForm flore -> m (StreamForm tlore)
mapOnStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda flore
lam0 [SubExp]
acc) =
            StreamOrd
-> Commutativity -> Lambda tlore -> [SubExp] -> StreamForm tlore
forall lore.
StreamOrd
-> Commutativity -> Lambda lore -> [SubExp] -> StreamForm lore
Parallel StreamOrd
o Commutativity
comm (Lambda tlore -> [SubExp] -> StreamForm tlore)
-> m (Lambda tlore) -> m ([SubExp] -> StreamForm tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
            SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam0 m ([SubExp] -> StreamForm tlore)
-> m [SubExp] -> m (StreamForm tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
acc
        mapOnStreamForm (Sequential [SubExp]
acc) =
            [SubExp] -> StreamForm tlore
forall lore. [SubExp] -> StreamForm lore
Sequential ([SubExp] -> StreamForm tlore)
-> m [SubExp] -> m (StreamForm tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
acc
mapSOACM SOACMapper flore tlore m
tv (Scatter SubExp
len Lambda flore
lam [VName]
ivs [(SubExp, Int, VName)]
as) =
  SubExp
-> Lambda tlore -> [VName] -> [(SubExp, Int, VName)] -> SOAC tlore
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(SubExp, Int, VName)] -> SOAC lore
Scatter
  (SubExp
 -> Lambda tlore -> [VName] -> [(SubExp, Int, VName)] -> SOAC tlore)
-> m SubExp
-> m (Lambda tlore
      -> [VName] -> [(SubExp, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
len
  m (Lambda tlore -> [VName] -> [(SubExp, Int, VName)] -> SOAC tlore)
-> m (Lambda tlore)
-> m ([VName] -> [(SubExp, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
lam
  m ([VName] -> [(SubExp, Int, VName)] -> SOAC tlore)
-> m [VName] -> m ([(SubExp, Int, VName)] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
ivs
  m ([(SubExp, Int, VName)] -> SOAC tlore)
-> m [(SubExp, Int, VName)] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((SubExp, Int, VName) -> m (SubExp, Int, VName))
-> [(SubExp, Int, VName)] -> m [(SubExp, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(SubExp
aw,Int
an,VName
a) -> (,,) (SubExp -> Int -> VName -> (SubExp, Int, VName))
-> m SubExp -> m (Int -> VName -> (SubExp, Int, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
aw m (Int -> VName -> (SubExp, Int, VName))
-> m Int -> m (VName -> (SubExp, Int, VName))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
                          Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an m (VName -> (SubExp, Int, VName))
-> m VName -> m (SubExp, Int, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv VName
a) [(SubExp, Int, VName)]
as
mapSOACM SOACMapper flore tlore m
tv (Hist SubExp
len [HistOp flore]
ops Lambda flore
bucket_fun [VName]
imgs) =
  SubExp -> [HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Hist
  (SubExp -> [HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
-> m SubExp
-> m ([HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
len
  m ([HistOp tlore] -> Lambda tlore -> [VName] -> SOAC tlore)
-> m [HistOp tlore] -> m (Lambda tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp flore -> m (HistOp tlore))
-> [HistOp flore] -> m [HistOp tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(HistOp SubExp
e SubExp
rf [VName]
arrs [SubExp]
nes Lambda flore
op) ->
              SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore
forall lore.
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda lore -> HistOp lore
HistOp (SubExp
 -> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m (SubExp
      -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
e
              m (SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m SubExp
-> m ([VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
rf
              m ([VName] -> [SubExp] -> Lambda tlore -> HistOp tlore)
-> m [VName] -> m ([SubExp] -> Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs
              m ([SubExp] -> Lambda tlore -> HistOp tlore)
-> m [SubExp] -> m (Lambda tlore -> HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
nes
              m (Lambda tlore -> HistOp tlore)
-> m (Lambda tlore) -> m (HistOp tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
op) [HistOp flore]
ops
  m (Lambda tlore -> [VName] -> SOAC tlore)
-> m (Lambda tlore) -> m ([VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
bucket_fun
  m ([VName] -> SOAC tlore) -> m [VName] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
imgs
mapSOACM SOACMapper flore tlore m
tv (Screma SubExp
w (ScremaForm [Scan flore]
scans [Reduce flore]
reds Lambda flore
map_lam) [VName]
arrs) =
  SubExp -> ScremaForm tlore -> [VName] -> SOAC tlore
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma (SubExp -> ScremaForm tlore -> [VName] -> SOAC tlore)
-> m SubExp -> m (ScremaForm tlore -> [VName] -> SOAC tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv SubExp
w m (ScremaForm tlore -> [VName] -> SOAC tlore)
-> m (ScremaForm tlore) -> m ([VName] -> SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
  ([Scan tlore] -> [Reduce tlore] -> Lambda tlore -> ScremaForm tlore
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm ([Scan tlore]
 -> [Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
-> m [Scan tlore]
-> m ([Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
   [Scan flore] -> (Scan flore -> m (Scan tlore)) -> m [Scan tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan flore]
scans (\(Scan Lambda flore
red_lam [SubExp]
red_nes) ->
                  Lambda tlore -> [SubExp] -> Scan tlore
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan (Lambda tlore -> [SubExp] -> Scan tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Scan tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
red_lam m ([SubExp] -> Scan tlore) -> m [SubExp] -> m (Scan tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
                  (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
red_nes) m ([Reduce tlore] -> Lambda tlore -> ScremaForm tlore)
-> m [Reduce tlore] -> m (Lambda tlore -> ScremaForm tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
   [Reduce flore]
-> (Reduce flore -> m (Reduce tlore)) -> m [Reduce tlore]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce flore]
reds (\(Reduce Commutativity
comm Lambda flore
red_lam [SubExp]
red_nes) ->
                 Commutativity -> Lambda tlore -> [SubExp] -> Reduce tlore
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm (Lambda tlore -> [SubExp] -> Reduce tlore)
-> m (Lambda tlore) -> m ([SubExp] -> Reduce tlore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
red_lam m ([SubExp] -> Reduce tlore) -> m [SubExp] -> m (Reduce tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
                 (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> SubExp -> m SubExp
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper flore tlore m
tv) [SubExp]
red_nes) m (Lambda tlore -> ScremaForm tlore)
-> m (Lambda tlore) -> m (ScremaForm tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
   SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> Lambda flore -> m (Lambda tlore)
mapOnSOACLambda SOACMapper flore tlore m
tv Lambda flore
map_lam)
  m ([VName] -> SOAC tlore) -> m [VName] -> m (SOAC tlore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper flore tlore m -> VName -> m VName
forall flore tlore (m :: * -> *).
SOACMapper flore tlore m -> VName -> m VName
mapOnSOACVName SOACMapper flore tlore m
tv) [VName]
arrs

instance ASTLore lore => FreeIn (SOAC lore) where
  freeIn' :: SOAC lore -> FV
freeIn' = (State FV (SOAC lore) -> FV -> FV)
-> FV -> State FV (SOAC lore) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SOAC lore) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SOAC lore) -> FV)
-> (SOAC lore -> State FV (SOAC lore)) -> SOAC lore -> FV
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper lore lore (StateT FV Identity)
-> SOAC lore -> State FV (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore (StateT FV Identity)
free
    where walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<>b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
          free :: SOACMapper lore lore (StateT FV Identity)
free = SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper { mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn'
                            , mapOnSOACLambda :: Lambda lore -> StateT FV Identity (Lambda lore)
mapOnSOACLambda = (Lambda lore -> FV)
-> Lambda lore -> StateT FV Identity (Lambda lore)
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda lore -> FV
forall a. FreeIn a => a -> FV
freeIn'
                            , mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall (m :: * -> *) s b.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
                            }

instance ASTLore lore => Substitute (SOAC lore) where
  substituteNames :: Map VName VName -> SOAC lore -> SOAC lore
substituteNames Map VName VName
subst =
    Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC lore -> Identity (SOAC lore)) -> SOAC lore -> SOAC lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper lore lore Identity -> SOAC lore -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore Identity
substitute
    where substitute :: SOACMapper lore lore Identity
substitute =
            SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper { mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                       , mapOnSOACLambda :: Lambda lore -> Identity (Lambda lore)
mapOnSOACLambda = Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda lore -> Lambda lore)
-> Lambda lore
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> Lambda lore -> Lambda lore
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                       , mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
                       }

instance ASTLore lore => Rename (SOAC lore) where
  rename :: SOAC lore -> RenameM (SOAC lore)
rename = SOACMapper lore lore RenameM -> SOAC lore -> RenameM (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper lore lore RenameM
renamer
    where renamer :: SOACMapper lore lore RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda lore -> RenameM (Lambda lore))
-> (VName -> RenameM VName)
-> SOACMapper lore lore RenameM
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda lore -> RenameM (Lambda lore)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename

-- | The type of a SOAC.
soacType :: SOAC lore -> [Type]
soacType :: SOAC lore -> [Type]
soacType (Stream SubExp
outersize StreamForm lore
form Lambda lore
lam [VName]
_) =
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
  where nms :: [VName]
nms = (Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [Param (LParamInfo lore)]
params
        substs :: Map VName SubExp
substs = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersizeSubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:[SubExp]
accs)
        Lambda [Param (LParamInfo lore)]
params BodyT lore
_ [Type]
rtp = Lambda lore
lam
        accs :: [SubExp]
accs = case StreamForm lore
form of
                Parallel StreamOrd
_ Commutativity
_ Lambda lore
_ [SubExp]
acc -> [SubExp]
acc
                Sequential  [SubExp]
acc -> [SubExp]
acc
soacType (Scatter SubExp
_w Lambda lore
lam [VName]
_ivs [(SubExp, Int, VName)]
as) =
  (Type -> SubExp -> Type) -> [Type] -> [SubExp] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow [Type]
val_ts [SubExp]
ws
  where val_ts :: [Type]
val_ts = ([Type] -> [Type]) -> [[Type]] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
1) ([[Type]] -> [Type]) -> [[Type]] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$
                 Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
ns) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
        ([SubExp]
ws, [Int]
ns, [VName]
_) = [(SubExp, Int, VName)] -> ([SubExp], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
as
soacType (Hist SubExp
_len [HistOp lore]
ops Lambda lore
_bucket_fun [VName]
_imgs) = do
  HistOp lore
op <- [HistOp lore]
ops
  (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp lore -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp lore
op) (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda lore -> [Type]) -> Lambda lore -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp HistOp lore
op)
soacType (Screma SubExp
w ScremaForm lore
form [VName]
_arrs) =
  SubExp -> ScremaForm lore -> [Type]
forall lore. SubExp -> ScremaForm lore -> [Type]
scremaType SubExp
w ScremaForm lore
form

instance TypedOp (SOAC lore) where
  opType :: SOAC lore -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SOAC lore -> [ExtType]) -> SOAC lore -> m [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SOAC lore -> [Type]) -> SOAC lore -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC lore -> [Type]
forall lore. SOAC lore -> [Type]
soacType

instance (ASTLore lore, Aliased lore) => AliasedOp (SOAC lore) where
  opAliases :: SOAC lore -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names])
-> (SOAC lore -> [Type]) -> SOAC lore -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC lore -> [Type]
forall lore. SOAC lore -> [Type]
soacType

  -- Only map functions can consume anything.  The operands to scan
  -- and reduce functions are always considered "fresh".
  consumedInOp :: SOAC lore -> Names
consumedInOp (Screma SubExp
_ (ScremaForm [Scan lore]
_ [Reduce lore]
_ Lambda lore
map_lam) [VName]
arrs) =
    (VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
map_lam
    where consumedArray :: VName -> VName
consumedArray VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
          params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
map_lam) [VName]
arrs
  consumedInOp (Stream SubExp
_ StreamForm lore
form Lambda lore
lam [VName]
arrs) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$
    case StreamForm lore
form of Sequential [SubExp]
accs ->
                   (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> VName -> SubExp
consumedArray [SubExp]
accs) ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
lam
                 Parallel StreamOrd
_ Commutativity
_ Lambda lore
_ [SubExp]
accs ->
                   (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> VName -> SubExp
consumedArray [SubExp]
accs) ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda lore
lam
    where consumedArray :: [SubExp] -> VName -> SubExp
consumedArray [SubExp]
accs VName
v = SubExp -> Maybe SubExp -> SubExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, SubExp)] -> Maybe SubExp
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v ([(VName, SubExp)] -> Maybe SubExp)
-> [(VName, SubExp)] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [(VName, SubExp)]
paramsToInput [SubExp]
accs
          -- Drop the chunk parameter, which cannot alias anything.
          paramsToInput :: [SubExp] -> [(VName, SubExp)]
paramsToInput [SubExp]
accs = [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip
                               ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo lore)] -> [VName])
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop Int
1 ([Param (LParamInfo lore)] -> [Param (LParamInfo lore)])
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda lore
lam)
                               ([SubExp]
accs[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
  consumedInOp (Scatter SubExp
_ Lambda lore
_ [VName]
_ [(SubExp, Int, VName)]
as) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((SubExp, Int, VName) -> VName)
-> [(SubExp, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
_, Int
_, VName
a) -> VName
a) [(SubExp, Int, VName)]
as
  consumedInOp (Hist SubExp
_ [HistOp lore]
ops Lambda lore
_ [VName]
_) =
    [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> [VName]) -> [HistOp lore] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp lore -> [VName]
forall lore. HistOp lore -> [VName]
histDest [HistOp lore]
ops

mapHistOp :: (Lambda flore -> Lambda tlore)
          -> HistOp flore -> HistOp tlore
mapHistOp :: (Lambda flore -> Lambda tlore) -> HistOp flore -> HistOp tlore
mapHistOp Lambda flore -> Lambda tlore
f (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda flore
lam) =
  SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda tlore -> HistOp tlore
forall lore.
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda lore -> HistOp lore
HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes (Lambda tlore -> HistOp tlore) -> Lambda tlore -> HistOp tlore
forall a b. (a -> b) -> a -> b
$ Lambda flore -> Lambda tlore
f Lambda flore
lam

instance (ASTLore lore,
          ASTLore (Aliases lore),
          CanBeAliased (Op lore)) => CanBeAliased (SOAC lore) where
  type OpWithAliases (SOAC lore) = SOAC (Aliases lore)

  addOpAliases :: SOAC lore -> OpWithAliases (SOAC lore)
addOpAliases (Stream SubExp
size StreamForm lore
form Lambda lore
lam [VName]
arr) =
    SubExp
-> StreamForm (Aliases lore)
-> Lambda (Aliases lore)
-> [VName]
-> SOAC (Aliases lore)
forall lore.
SubExp -> StreamForm lore -> Lambda lore -> [VName] -> SOAC lore
Stream SubExp
size (StreamForm lore -> StreamForm (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
StreamForm lore -> StreamForm (Aliases lore)
analyseStreamForm StreamForm lore
form)
    (Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda Lambda lore
lam) [VName]
arr
    where analyseStreamForm :: StreamForm lore -> StreamForm (Aliases lore)
analyseStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda lore
lam0 [SubExp]
acc) =
              StreamOrd
-> Commutativity
-> Lambda (Aliases lore)
-> [SubExp]
-> StreamForm (Aliases lore)
forall lore.
StreamOrd
-> Commutativity -> Lambda lore -> [SubExp] -> StreamForm lore
Parallel StreamOrd
o Commutativity
comm (Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda Lambda lore
lam0) [SubExp]
acc
          analyseStreamForm (Sequential [SubExp]
acc) = [SubExp] -> StreamForm (Aliases lore)
forall lore. [SubExp] -> StreamForm lore
Sequential [SubExp]
acc
  addOpAliases (Scatter SubExp
len Lambda lore
lam [VName]
ivs [(SubExp, Int, VName)]
as) =
    SubExp
-> Lambda (Aliases lore)
-> [VName]
-> [(SubExp, Int, VName)]
-> SOAC (Aliases lore)
forall lore.
SubExp
-> Lambda lore -> [VName] -> [(SubExp, Int, VName)] -> SOAC lore
Scatter SubExp
len (Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda Lambda lore
lam) [VName]
ivs [(SubExp, Int, VName)]
as
  addOpAliases (Hist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs) =
    SubExp
-> [HistOp (Aliases lore)]
-> Lambda (Aliases lore)
-> [VName]
-> SOAC (Aliases lore)
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> SOAC lore
Hist SubExp
len ((HistOp lore -> HistOp (Aliases lore))
-> [HistOp lore] -> [HistOp (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map ((Lambda lore -> Lambda (Aliases lore))
-> HistOp lore -> HistOp (Aliases lore)
forall flore tlore.
(Lambda flore -> Lambda tlore) -> HistOp flore -> HistOp tlore
mapHistOp Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda) [HistOp lore]
ops)
    (Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda Lambda lore
bucket_fun) [VName]
imgs
  addOpAliases (Screma SubExp
w (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) [VName]
arrs) =
    SubExp
-> ScremaForm (Aliases lore) -> [VName] -> SOAC (Aliases lore)
forall lore. SubExp -> ScremaForm lore -> [VName] -> SOAC lore
Screma SubExp
w ([Scan (Aliases lore)]
-> [Reduce (Aliases lore)]
-> Lambda (Aliases lore)
-> ScremaForm (Aliases lore)
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm
                ((Scan lore -> Scan (Aliases lore))
-> [Scan lore] -> [Scan (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Scan (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Scan lore -> Scan (Aliases lore)
onScan [Scan lore]
scans)
                ((Reduce lore -> Reduce (Aliases lore))
-> [Reduce lore] -> [Reduce (Aliases lore)]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Reduce (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Reduce lore -> Reduce (Aliases lore)
onRed [Reduce lore]
reds)
                (Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda Lambda lore
map_lam))
               [VName]
arrs
    where onRed :: Reduce lore -> Reduce (Aliases lore)
onRed Reduce lore
red = Reduce lore
red { redLambda :: Lambda (Aliases lore)
redLambda = Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore -> Lambda (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda Reduce lore
red }
          onScan :: Scan lore -> Scan (Aliases lore)
onScan Scan lore
scan = Scan lore
scan { scanLambda :: Lambda (Aliases lore)
scanLambda = Lambda lore -> Lambda (Aliases lore)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
Lambda lore -> Lambda (Aliases lore)
Alias.analyseLambda (Lambda lore -> Lambda (Aliases lore))
-> Lambda lore -> Lambda (Aliases lore)
forall a b. (a -> b) -> a -> b
$ Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda Scan lore
scan }

  removeOpAliases :: OpWithAliases (SOAC lore) -> SOAC lore
removeOpAliases = Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC (Aliases lore) -> Identity (SOAC lore))
-> SOAC (Aliases lore)
-> SOAC lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper (Aliases lore) lore Identity
-> SOAC (Aliases lore) -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper (Aliases lore) lore Identity
remove
    where remove :: SOACMapper (Aliases lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases lore) -> Identity (Lambda lore))
-> (VName -> Identity VName)
-> SOACMapper (Aliases lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Aliases lore) -> Lambda lore)
-> Lambda (Aliases lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Aliases lore) -> Lambda lore
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
removeLambdaAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance ASTLore lore => IsOp (SOAC lore) where
  safeOp :: SOAC lore -> Bool
safeOp SOAC lore
_ = Bool
False
  cheapOp :: SOAC lore -> Bool
cheapOp SOAC lore
_ = Bool
True

substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ tp :: Type
tp@(Prim PrimType
_) = Type
tp
substNamesInType Map VName SubExp
_ (Mem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
  let shp' :: Shape
shp' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
  in  PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u

substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
  SubExp -> VName -> Map VName SubExp -> SubExp
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs

instance (ASTLore lore, CanBeWise (Op lore)) => CanBeWise (SOAC lore) where
  type OpWithWisdom (SOAC lore) = SOAC (Wise lore)

  removeOpWisdom :: OpWithWisdom (SOAC lore) -> SOAC lore
removeOpWisdom = Identity (SOAC lore) -> SOAC lore
forall a. Identity a -> a
runIdentity (Identity (SOAC lore) -> SOAC lore)
-> (SOAC (Wise lore) -> Identity (SOAC lore))
-> SOAC (Wise lore)
-> SOAC lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOACMapper (Wise lore) lore Identity
-> SOAC (Wise lore) -> Identity (SOAC lore)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper (Wise lore) lore Identity
remove
    where remove :: SOACMapper (Wise lore) lore Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise lore) -> Identity (Lambda lore))
-> (VName -> Identity VName)
-> SOACMapper (Wise lore) lore Identity
forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda lore -> Identity (Lambda lore))
-> (Lambda (Wise lore) -> Lambda lore)
-> Lambda (Wise lore)
-> Identity (Lambda lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda (Wise lore) -> Lambda lore
forall lore.
CanBeWise (Op lore) =>
Lambda (Wise lore) -> Lambda lore
removeLambdaWisdom) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return

instance Decorations lore => ST.IndexOp (SOAC lore) where
  indexOp :: SymbolTable lore
-> Int -> SOAC lore -> [PrimExp VName] -> Maybe Indexed
indexOp SymbolTable lore
vtable Int
k SOAC lore
soac [PrimExp VName
i] = do
    (LambdaT lore
lam,SubExp
se,[Param (LParamInfo lore)]
arr_params,[VName]
arrs) <- SOAC lore
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
lambdaAndSubExp SOAC lore
soac
    let arr_indexes :: Map VName (PrimExp VName, Certificates)
arr_indexes = [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (PrimExp VName, Certificates))]
 -> Map VName (PrimExp VName, Certificates))
-> [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, (PrimExp VName, Certificates))]
 -> [(VName, (PrimExp VName, Certificates))])
-> [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo lore)
 -> VName -> Maybe (VName, (PrimExp VName, Certificates)))
-> [Param (LParamInfo lore)]
-> [VName]
-> [Maybe (VName, (PrimExp VName, Certificates))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo lore)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex [Param (LParamInfo lore)]
arr_params [VName]
arrs
        arr_indexes' :: Map VName (PrimExp VName, Certificates)
arr_indexes' = (Map VName (PrimExp VName, Certificates)
 -> Stm lore -> Map VName (PrimExp VName, Certificates))
-> Map VName (PrimExp VName, Certificates)
-> Seq (Stm lore)
-> Map VName (PrimExp VName, Certificates)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certificates)
-> Stm lore -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
arr_indexes (Seq (Stm lore) -> Map VName (PrimExp VName, Certificates))
-> Seq (Stm lore) -> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Seq (Stm lore)) -> BodyT lore -> Seq (Stm lore)
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT lore
lam
    case SubExp
se of
      Var VName
v -> (PrimExp VName -> Certificates -> Indexed)
-> (PrimExp VName, Certificates) -> Indexed
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Certificates -> PrimExp VName -> Indexed)
-> PrimExp VName -> Certificates -> Indexed
forall a b c. (a -> b -> c) -> b -> a -> c
flip Certificates -> PrimExp VName -> Indexed
ST.Indexed) ((PrimExp VName, Certificates) -> Indexed)
-> Maybe (PrimExp VName, Certificates) -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
arr_indexes'
      SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
      where lambdaAndSubExp :: SOAC lore
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
lambdaAndSubExp (Screma SubExp
_ (ScremaForm [Scan lore]
scans [Reduce lore]
reds LambdaT lore
map_lam) [VName]
arrs) =
              Int
-> LambdaT lore
-> [VName]
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
nthMapOut ([Scan lore] -> Int
forall lore. [Scan lore] -> Int
scanResults [Scan lore]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce lore] -> Int
forall lore. [Reduce lore] -> Int
redResults [Reduce lore]
reds) LambdaT lore
map_lam [VName]
arrs
            lambdaAndSubExp SOAC lore
_ =
              Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
forall a. Maybe a
Nothing

            nthMapOut :: Int
-> LambdaT lore
-> [VName]
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
nthMapOut Int
num_accs LambdaT lore
lam [VName]
arrs = do
              SubExp
se <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k) ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT lore
lam
              (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
-> Maybe (LambdaT lore, SubExp, [Param (LParamInfo lore)], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaT lore
lam, SubExp
se, Int -> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (LParamInfo lore)] -> [Param (LParamInfo lore)])
-> [Param (LParamInfo lore)] -> [Param (LParamInfo lore)]
forall a b. (a -> b) -> a -> b
$ LambdaT lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams LambdaT lore
lam, [VName]
arrs)

            arrIndex :: Param (LParamInfo lore)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex Param (LParamInfo lore)
p VName
arr = do
              ST.Indexed Certificates
cs PrimExp VName
pe <- VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
forall lore.
VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
ST.index' VName
arr [PrimExp VName
i] SymbolTable lore
vtable
              (VName, (PrimExp VName, Certificates))
-> Maybe (VName, (PrimExp VName, Certificates))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
p, (PrimExp VName
pe,Certificates
cs))

            expandPrimExpTable :: Map VName (PrimExp VName, Certificates)
-> Stm lore -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
table Stm lore
stm
              | [VName
v] <- PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm,
                Just (PrimExp VName
pe,Certificates
cs) <-
                  WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
 -> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table) (Exp lore -> WriterT Certificates Maybe (PrimExp VName))
-> Exp lore -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
                (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable lore -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable lore
vtable) (Certificates -> [VName]
unCertificates (Certificates -> [VName]) -> Certificates -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm) =
                  VName
-> (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm lore
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) Map VName (PrimExp VName, Certificates)
table
              | Bool
otherwise =
                  Map VName (PrimExp VName, Certificates)
table

            asPrimExp :: Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table VName
v
              | Just (PrimExp VName
e,Certificates
cs) <- VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
              | Just (Prim PrimType
pt) <- VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
v SymbolTable lore
vtable =
                  PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
              | Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
  indexOp SymbolTable lore
_ Int
_ SOAC lore
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

-- | Type-check a SOAC.
typeCheckSOAC :: TC.Checkable lore => SOAC (Aliases lore) -> TC.TypeM lore ()
typeCheckSOAC :: SOAC (Aliases lore) -> TypeM lore ()
typeCheckSOAC (Stream SubExp
size StreamForm (Aliases lore)
form Lambda (Aliases lore)
lam [VName]
arrexps) = do
  let accexps :: [SubExp]
accexps = StreamForm (Aliases lore) -> [SubExp]
forall lore. StreamForm lore -> [SubExp]
getStreamAccums StreamForm (Aliases lore)
form
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
size
  [Arg]
accargs <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
accexps
  [Type]
arrargs <- (VName -> TypeM lore Type) -> [VName] -> TypeM lore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> TypeM lore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
arrexps
  [Arg]
_ <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
  let chunk :: Param (LParamInfo lore)
chunk = [Param (LParamInfo lore)] -> Param (LParamInfo lore)
forall a. [a] -> a
head ([Param (LParamInfo lore)] -> Param (LParamInfo lore))
-> [Param (LParamInfo lore)] -> Param (LParamInfo lore)
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases lore) -> [LParam (Aliases lore)]
forall lore. LambdaT lore -> [Param (LParamInfo lore)]
lambdaParams Lambda (Aliases lore)
lam
  let asArg :: a -> (a, b)
asArg a
t = (a
t, b
forall a. Monoid a => a
mempty)
      inttp :: TypeBase shape u
inttp   = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32
      lamarrs' :: [Type]
lamarrs'= (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
chunk)) [Type]
arrargs
  let acc_len :: Int
acc_len= [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
  let lamrtp :: [Type]
lamrtp = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
acc_len ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam
  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Stream with inconsistent accumulator type in lambda."
  -- check reduce's lambda, if any
  ()
_ <- case StreamForm (Aliases lore)
form of
        Parallel StreamOrd
_ Commutativity
_ Lambda (Aliases lore)
lam0 [SubExp]
_ -> do
            let acct :: [Type]
acct = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs
                outerRetType :: [Type]
outerRetType = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam0
            Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam0 ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
accargs
            Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
acct [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
outerRetType) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
                ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$
                String
"Initial value is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
acct String -> ShowS
forall a. [a] -> [a] -> [a]
++
                String
", but stream's reduce lambda returns type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
outerRetType String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."
        StreamForm (Aliases lore)
_ -> () -> TypeM lore ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  -- just get the dflow of lambda on the fakearg, which does not alias
  -- arr, so we can later check that aliases of arr are not used inside lam.
  let fake_lamarrs' :: [Arg]
fake_lamarrs' = (Type -> Arg) -> [Type] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Arg
forall b a. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ Type -> Arg
forall b a. Monoid b => a -> (a, b)
asArg Type
forall shape u. TypeBase shape u
inttp Arg -> [Arg] -> [Arg]
forall a. a -> [a] -> [a]
: [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'

typeCheckSOAC (Scatter SubExp
w Lambda (Aliases lore)
lam [VName]
ivs [(SubExp, Int, VName)]
as) = do
  -- Requirements:
  --
  --   0. @lambdaReturnType@ of @lam@ must be a list
  --      [index types..., value types].
  --
  --   1. The number of index types must be equal to the number of value types
  --      and the number of writes to arrays in @as@.
  --
  --   2. Each index type must have the type i32.
  --
  --   3. Each array in @as@ and the value types must have the same type
  --
  --   4. Each array in @as@ is consumed.  This is not really a check, but more
  --      of a requirement, so that e.g. the source is not hoisted out of a
  --      loop, which will mean it cannot be consumed.
  --
  --   5. Each of ivs must be an array matching a corresponding lambda
  --      parameters.
  --
  -- Code:

  -- First check the input size.
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
w

  -- 0.
  let ([SubExp]
_as_ws, [Int]
as_ns, [VName]
_as_vs) = [(SubExp, Int, VName)] -> ([SubExp], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, Int, VName)]
as
      rts :: [Type]
rts = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
lam
      rtsLen :: Int
rtsLen = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
      rtsI :: [Type]
rtsI = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
rtsLen [Type]
rts
      rtsV :: [Type]
rtsV = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
rtsLen [Type]
rts

  -- 1.
  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
rtsLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns)
    (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Scatter: Uneven number of index types, value types, and arrays outputs."

  -- 2.
  [Type] -> (Type -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI ((Type -> TypeM lore ()) -> TypeM lore ())
-> (Type -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \Type
rtI -> Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
rtI) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError String
"Scatter: Index return type must be i32."

  [([Type], (SubExp, Int, VName))]
-> (([Type], (SubExp, Int, VName)) -> TypeM lore ())
-> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[Type]]
-> [(SubExp, Int, VName)] -> [([Type], (SubExp, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(SubExp, Int, VName)]
as) ((([Type], (SubExp, Int, VName)) -> TypeM lore ())
 -> TypeM lore ())
-> (([Type], (SubExp, Int, VName)) -> TypeM lore ())
-> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (SubExp
aw, Int
_, VName
a)) -> do
    -- All lengths must have type i32.
    [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
aw

    -- 3.
    [Type] -> (Type -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs ((Type -> TypeM lore ()) -> TypeM lore ())
-> (Type -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \Type
rtV -> [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
rtV Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
aw] VName
a

    -- 4.
    Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
a

  -- 5.
  [Arg]
arrargs <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
ivs
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
lam [Arg]
arrargs

typeCheckSOAC (Hist SubExp
len [HistOp (Aliases lore)]
ops Lambda (Aliases lore)
bucket_fun [VName]
imgs) = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
len

  -- Check the operators.
  [HistOp (Aliases lore)]
-> (HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases lore)]
ops ((HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ())
-> (HistOp (Aliases lore) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases lore)
op) -> do
    [Arg]
nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
nes
    [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
dest_w
    [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
rf

    -- Operator type must match the type of neutral elements.
    Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
op ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
    let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Operator has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
op) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t

    -- Arrays must have proper type.
    [(Type, VName)]
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM lore ()) -> TypeM lore ())
-> ((Type, VName) -> TypeM lore ()) -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
      [Type] -> VName -> TypeM lore ()
forall lore. Checkable lore => [Type] -> VName -> TypeM lore ()
TC.requireI [Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
dest_w] VName
dest
      Names -> TypeM lore ()
forall lore. Checkable lore => Names -> TypeM lore ()
TC.consume (Names -> TypeM lore ()) -> TypeM lore Names -> TypeM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM lore Names
forall lore. Checkable lore => VName -> TypeM lore Names
TC.lookupAliases VName
dest

  -- Types of input arrays must equal parameter types for bucket function.
  [Arg]
img' <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
len [VName]
imgs
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
bucket_fun [Arg]
img'

  -- Return type of bucket function must be an index for each
  -- operation followed by the values to write.
  [Type]
nes_ts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> TypeM lore [[Type]] -> TypeM lore [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp (Aliases lore) -> TypeM lore [Type])
-> [HistOp (Aliases lore)] -> TypeM lore [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> TypeM lore Type) -> [SubExp] -> TypeM lore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType ([SubExp] -> TypeM lore [Type])
-> (HistOp (Aliases lore) -> [SubExp])
-> HistOp (Aliases lore)
-> TypeM lore [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp (Aliases lore) -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp (Aliases lore)]
ops
  let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases lore)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases lore)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
bucket_fun) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Bucket function has return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
    [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
bucket_fun) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
    [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t

typeCheckSOAC (Screma SubExp
w (ScremaForm [Scan (Aliases lore)]
scans [Reduce (Aliases lore)]
reds Lambda (Aliases lore)
map_lam) [VName]
arrs) = do
  [Type] -> SubExp -> TypeM lore ()
forall lore. Checkable lore => [Type] -> SubExp -> TypeM lore ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32] SubExp
w
  [Arg]
arrs' <- SubExp -> [VName] -> TypeM lore [Arg]
forall lore.
Checkable lore =>
SubExp -> [VName] -> TypeM lore [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
  Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
map_lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases [Arg]
arrs'

  [Arg]
scan_nes' <- ([[Arg]] -> [Arg]) -> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM lore [[Arg]] -> TypeM lore [Arg])
-> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall a b. (a -> b) -> a -> b
$ [Scan (Aliases lore)]
-> (Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases lore)]
scans ((Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]])
-> (Scan (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases lore)
scan_lam [SubExp]
scan_nes) -> do
    [Arg]
scan_nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
scan_nes
    let scan_t :: [Type]
scan_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
    Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
scan_lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
scan_lam) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Scan function returns type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
scan_lam) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
scan_t
    [Arg] -> TypeM lore [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
scan_nes'

  [Arg]
red_nes' <- ([[Arg]] -> [Arg]) -> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM lore [[Arg]] -> TypeM lore [Arg])
-> TypeM lore [[Arg]] -> TypeM lore [Arg]
forall a b. (a -> b) -> a -> b
$ [Reduce (Aliases lore)]
-> (Reduce (Aliases lore) -> TypeM lore [Arg])
-> TypeM lore [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases lore)]
reds ((Reduce (Aliases lore) -> TypeM lore [Arg]) -> TypeM lore [[Arg]])
-> (Reduce (Aliases lore) -> TypeM lore [Arg])
-> TypeM lore [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases lore)
red_lam [SubExp]
red_nes) -> do
    [Arg]
red_nes' <- (SubExp -> TypeM lore Arg) -> [SubExp] -> TypeM lore [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM lore Arg
forall lore. Checkable lore => SubExp -> TypeM lore Arg
TC.checkArg [SubExp]
red_nes
    let red_t :: [Type]
red_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
    Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
forall lore.
Checkable lore =>
Lambda (Aliases lore) -> [Arg] -> TypeM lore ()
TC.checkLambda Lambda (Aliases lore)
red_lam ([Arg] -> TypeM lore ()) -> [Arg] -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
    Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
red_lam) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
      ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Reduce function returns type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
red_lam) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type " String -> ShowS
forall a. [a] -> [a] -> [a]
++
      [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
red_t
    [Arg] -> TypeM lore [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
red_nes'

  let map_lam_ts :: [Type]
map_lam_ts = Lambda (Aliases lore) -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda (Aliases lore)
map_lam

  Bool -> TypeM lore () -> TypeM lore ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
==
          (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes'[Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')) (TypeM lore () -> TypeM lore ()) -> TypeM lore () -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$
    ErrorCase lore -> TypeM lore ()
forall lore a. ErrorCase lore -> TypeM lore a
TC.bad (ErrorCase lore -> TypeM lore ())
-> ErrorCase lore -> TypeM lore ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase lore
forall lore. String -> ErrorCase lore
TC.TypeError (String -> ErrorCase lore) -> String -> ErrorCase lore
forall a b. (a -> b) -> a -> b
$ String
"Map function return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
map_lam_ts String -> ShowS
forall a. [a] -> [a] -> [a]
++
    String
" wrong for given scan and reduction functions."

-- | Get Stream's accumulators as a sub-expression list
getStreamAccums :: StreamForm lore -> [SubExp]
getStreamAccums :: StreamForm lore -> [SubExp]
getStreamAccums (Parallel StreamOrd
_ Commutativity
_ Lambda lore
_ [SubExp]
accs) = [SubExp]
accs
getStreamAccums (Sequential  [SubExp]
accs) = [SubExp]
accs

instance OpMetrics (Op lore) => OpMetrics (SOAC lore) where
  opMetrics :: SOAC lore -> MetricsM ()
opMetrics (Stream SubExp
_ StreamForm lore
_ Lambda lore
lam [VName]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
lam
  opMetrics (Scatter SubExp
_len Lambda lore
lam [VName]
_ivs [(SubExp, Int, VName)]
_as) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
lam
  opMetrics (Hist SubExp
_len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
_imgs) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> MetricsM ()) -> [HistOp lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (HistOp lore -> Lambda lore) -> HistOp lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp lore -> Lambda lore
forall lore. HistOp lore -> Lambda lore
histOp) [HistOp lore]
ops MetricsM () -> MetricsM () -> MetricsM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
bucket_fun
  opMetrics (Screma SubExp
_ (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) [VName]
_) =
    Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do (Scan lore -> MetricsM ()) -> [Scan lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (Scan lore -> Lambda lore) -> Scan lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scan lore -> Lambda lore
forall lore. Scan lore -> Lambda lore
scanLambda) [Scan lore]
scans
                         (Reduce lore -> MetricsM ()) -> [Reduce lore] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics (Lambda lore -> MetricsM ())
-> (Reduce lore -> Lambda lore) -> Reduce lore -> MetricsM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Reduce lore -> Lambda lore
forall lore. Reduce lore -> Lambda lore
redLambda) [Reduce lore]
reds
                         Lambda lore -> MetricsM ()
forall lore. OpMetrics (Op lore) => Lambda lore -> MetricsM ()
lambdaMetrics Lambda lore
map_lam

instance PrettyLore lore => PP.Pretty (SOAC lore) where
  ppr :: SOAC lore -> Doc
ppr (Stream SubExp
size StreamForm lore
form Lambda lore
lam [VName]
arrs) =
    case StreamForm lore
form of
       Parallel StreamOrd
o Commutativity
comm Lambda lore
lam0 [SubExp]
acc ->
         let ord_str :: String
ord_str = if StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
== StreamOrd
Disorder then String
"Per" else String
""
             comm_str :: String
comm_str = case Commutativity
comm of Commutativity
Commutative -> String
"Comm"
                                     Commutativity
Noncommutative -> String
""
         in  String -> Doc
text (String
"streamPar"String -> ShowS
forall a. [a] -> [a] -> [a]
++String
ord_strString -> ShowS
forall a. [a] -> [a] -> [a]
++String
comm_str) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
             Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam0 Doc -> Doc -> Doc
</> Doc
comma Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam Doc -> Doc -> Doc
</>
                        [Doc] -> Doc
commasep ( Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
acc) Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
arrs ))
       Sequential [SubExp]
acc ->
             String -> Doc
text String
"streamSeq" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
             Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                        [Doc] -> Doc
commasep ( Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
acc) Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
arrs ))
  ppr (Scatter SubExp
len Lambda lore
lam [VName]
ivs [(SubExp, Int, VName)]
as) =
    String
-> SubExp
-> [Lambda lore]
-> Maybe [SubExp]
-> [(Int, VName)]
-> Doc
forall fn v.
(Pretty fn, Pretty v) =>
String -> SubExp -> [fn] -> Maybe [SubExp] -> [v] -> Doc
ppSOAC String
"scatter" SubExp
len [Lambda lore
lam] ([SubExp] -> Maybe [SubExp]
forall a. a -> Maybe a
Just ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
ivs)) (((SubExp, Int, VName) -> (Int, VName))
-> [(SubExp, Int, VName)] -> [(Int, VName)]
forall a b. (a -> b) -> [a] -> [b]
map (\(SubExp
_,Int
n,VName
a) -> (Int
n,VName
a)) [(SubExp, Int, VName)]
as)
  ppr (Hist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs) =
    SubExp -> [HistOp lore] -> Lambda lore -> [VName] -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
ppHist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [VName]
imgs
  ppr (Screma SubExp
w (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) [VName]
arrs)
    | [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans, [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds =
        String -> Doc
text String
"map" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
        Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                [Doc] -> Doc
commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
arrs))

    | [Scan lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan lore]
scans =
        String -> Doc
text String
"redomap" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
        Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Doc) -> [Reduce lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce lore]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                [Doc] -> Doc
commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
arrs))

    | [Reduce lore] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce lore]
reds =
        String -> Doc
text String
"scanomap" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
        Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Doc) -> [Scan lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan lore]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                [Doc] -> Doc
commasep ((VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
arrs))

  ppr (Screma SubExp
w ScremaForm lore
form [VName]
arrs) = SubExp -> ScremaForm lore -> [VName] -> Doc
forall lore inp.
(PrettyLore lore, Pretty inp) =>
SubExp -> ScremaForm lore -> [inp] -> Doc
ppScrema SubExp
w ScremaForm lore
form [VName]
arrs

-- | Prettyprint the given Screma.
ppScrema :: (PrettyLore lore, Pretty inp) =>
            SubExp -> ScremaForm lore -> [inp] -> Doc
ppScrema :: SubExp -> ScremaForm lore -> [inp] -> Doc
ppScrema SubExp
w (ScremaForm [Scan lore]
scans [Reduce lore]
reds Lambda lore
map_lam) [inp]
arrs =
  String -> Doc
text String
"screma" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
  Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan lore -> Doc) -> [Scan lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan lore]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce lore -> Doc) -> [Reduce lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce lore -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce lore]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
map_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          [Doc] -> Doc
commasep ((inp -> Doc) -> [inp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc
forall a. Pretty a => a -> Doc
ppr [inp]
arrs))

instance PrettyLore lore => Pretty (Scan lore) where
  ppr :: Scan lore -> Doc
ppr (Scan Lambda lore
scan_lam [SubExp]
scan_nes) =
    Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
scan_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
scan_nes)

ppComm :: Commutativity -> Doc
ppComm :: Commutativity -> Doc
ppComm Commutativity
Noncommutative = Doc
forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = String -> Doc
text String
"commutative "

instance PrettyLore lore => Pretty (Reduce lore) where
  ppr :: Reduce lore -> Doc
ppr (Reduce Commutativity
comm Lambda lore
red_lam [SubExp]
red_nes) =
    Commutativity -> Doc
ppComm Commutativity
comm Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
red_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
    Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
red_nes)

-- | Prettyprint the given histogram operation.
ppHist :: (PrettyLore lore, Pretty inp) =>
          SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
ppHist :: SubExp -> [HistOp lore] -> Lambda lore -> [inp] -> Doc
ppHist SubExp
len [HistOp lore]
ops Lambda lore
bucket_fun [inp]
imgs =
  String -> Doc
text String
"hist" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<>
  Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
len Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp lore -> Doc) -> [HistOp lore] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp lore -> Doc
forall lore. PrettyLore lore => HistOp lore -> Doc
ppOp [HistOp lore]
ops) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
bucket_fun Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          [Doc] -> Doc
commasep ((inp -> Doc) -> [inp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc
forall a. Pretty a => a -> Doc
ppr [inp]
imgs))
  where ppOp :: HistOp lore -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda lore
op) =
          SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
          Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Lambda lore -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda lore
op

ppSOAC :: (Pretty fn, Pretty v) =>
          String -> SubExp -> [fn] -> Maybe [SubExp] -> [v] -> Doc
ppSOAC :: String -> SubExp -> [fn] -> Maybe [SubExp] -> [v] -> Doc
ppSOAC String
name SubExp
size [fn]
funs Maybe [SubExp]
es [v]
as =
  String -> Doc
text String
name Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens (SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</>
                       [fn] -> Doc
forall a. Pretty a => [a] -> Doc
ppList [fn]
funs Doc -> Doc -> Doc
</>
                       [Doc] -> Doc
commasep ([Doc]
es' [Doc] -> [Doc] -> [Doc]
forall a. [a] -> [a] -> [a]
++ (v -> Doc) -> [v] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map v -> Doc
forall a. Pretty a => a -> Doc
ppr [v]
as))
  where es' :: [Doc]
es' = [Doc] -> ([SubExp] -> [Doc]) -> Maybe [SubExp] -> [Doc]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] ((Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
:[]) (Doc -> [Doc]) -> ([SubExp] -> Doc) -> [SubExp] -> [Doc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple') Maybe [SubExp]
es

ppList :: Pretty a => [a] -> Doc
ppList :: [a] -> Doc
ppList [a]
as = case (a -> Doc) -> [a] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map a -> Doc
forall a. Pretty a => a -> Doc
ppr [a]
as of
              []     -> Doc
forall a. Monoid a => a
mempty
              Doc
a':[Doc]
as' -> (Doc -> Doc -> Doc) -> Doc -> [Doc] -> Doc
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Doc -> Doc -> Doc
(</>) (Doc
a' Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma) ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (Doc -> Doc) -> [Doc] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma) [Doc]
as'