{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Futhark.IR.SOACS.SOAC
( SOAC (..),
StreamOrd (..),
StreamForm (..),
ScremaForm (..),
HistOp (..),
Scan (..),
scanResults,
singleScan,
Reduce (..),
redResults,
singleReduce,
scremaType,
soacType,
typeCheckSOAC,
mkIdentityLambda,
isIdentityLambda,
nilFn,
scanomapSOAC,
redomapSOAC,
scanSOAC,
reduceSOAC,
mapSOAC,
isScanomapSOAC,
isRedomapSOAC,
isScanSOAC,
isReduceSOAC,
isMapSOAC,
ppScrema,
ppHist,
groupScatterResults,
groupScatterResults',
splitScatterResults,
SOACMapper (..),
identitySOACMapper,
mapSOACM,
)
where
import Control.Category
import Control.Monad.Identity
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Function ((&))
import Data.List (intersperse)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.Metrics
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Aliases (Aliases, removeLambdaAliases)
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rep
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import qualified Futhark.TypeCheck as TC
import Futhark.Util (chunks, maybeNth)
import Futhark.Util.Pretty (Doc, Pretty, comma, commasep, parens, ppr, text, (<+>), (</>))
import qualified Futhark.Util.Pretty as PP
import Prelude hiding (id, (.))
data SOAC rep
= Stream SubExp [VName] (StreamForm rep) [SubExp] (Lambda rep)
|
Scatter SubExp (Lambda rep) [VName] [(Shape, Int, VName)]
|
Hist SubExp [HistOp rep] (Lambda rep) [VName]
|
Screma SubExp [VName] (ScremaForm rep)
deriving (SOAC rep -> SOAC rep -> Bool
(SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool) -> Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
Eq, Eq (SOAC rep)
Eq (SOAC rep)
-> (SOAC rep -> SOAC rep -> Ordering)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> Bool)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> (SOAC rep -> SOAC rep -> SOAC rep)
-> Ord (SOAC rep)
SOAC rep -> SOAC rep -> Bool
SOAC rep -> SOAC rep -> Ordering
SOAC rep -> SOAC rep -> SOAC rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (SOAC rep)
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
min :: SOAC rep -> SOAC rep -> SOAC rep
$cmin :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
max :: SOAC rep -> SOAC rep -> SOAC rep
$cmax :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> SOAC rep
>= :: SOAC rep -> SOAC rep -> Bool
$c>= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
> :: SOAC rep -> SOAC rep -> Bool
$c> :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
<= :: SOAC rep -> SOAC rep -> Bool
$c<= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
< :: SOAC rep -> SOAC rep -> Bool
$c< :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
compare :: SOAC rep -> SOAC rep -> Ordering
$ccompare :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Ordering
Ord, Int -> SOAC rep -> ShowS
[SOAC rep] -> ShowS
SOAC rep -> String
(Int -> SOAC rep -> ShowS)
-> (SOAC rep -> String) -> ([SOAC rep] -> ShowS) -> Show (SOAC rep)
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
show :: SOAC rep -> String
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
showsPrec :: Int -> SOAC rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
Show)
data HistOp rep = HistOp
{ forall rep. HistOp rep -> SubExp
histWidth :: SubExp,
forall rep. HistOp rep -> SubExp
histRaceFactor :: SubExp,
forall rep. HistOp rep -> [VName]
histDest :: [VName],
forall rep. HistOp rep -> [SubExp]
histNeutral :: [SubExp],
forall rep. HistOp rep -> Lambda rep
histOp :: Lambda rep
}
deriving (HistOp rep -> HistOp rep -> Bool
(HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool) -> Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HistOp rep -> HistOp rep -> Bool
$c/= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
== :: HistOp rep -> HistOp rep -> Bool
$c== :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
Eq, Eq (HistOp rep)
Eq (HistOp rep)
-> (HistOp rep -> HistOp rep -> Ordering)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> Bool)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> (HistOp rep -> HistOp rep -> HistOp rep)
-> Ord (HistOp rep)
HistOp rep -> HistOp rep -> Bool
HistOp rep -> HistOp rep -> Ordering
HistOp rep -> HistOp rep -> HistOp rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (HistOp rep)
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
min :: HistOp rep -> HistOp rep -> HistOp rep
$cmin :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
max :: HistOp rep -> HistOp rep -> HistOp rep
$cmax :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> HistOp rep
>= :: HistOp rep -> HistOp rep -> Bool
$c>= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
> :: HistOp rep -> HistOp rep -> Bool
$c> :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
<= :: HistOp rep -> HistOp rep -> Bool
$c<= :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
< :: HistOp rep -> HistOp rep -> Bool
$c< :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Bool
compare :: HistOp rep -> HistOp rep -> Ordering
$ccompare :: forall rep. RepTypes rep => HistOp rep -> HistOp rep -> Ordering
Ord, Int -> HistOp rep -> ShowS
[HistOp rep] -> ShowS
HistOp rep -> String
(Int -> HistOp rep -> ShowS)
-> (HistOp rep -> String)
-> ([HistOp rep] -> ShowS)
-> Show (HistOp rep)
forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
forall rep. RepTypes rep => [HistOp rep] -> ShowS
forall rep. RepTypes rep => HistOp rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HistOp rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [HistOp rep] -> ShowS
show :: HistOp rep -> String
$cshow :: forall rep. RepTypes rep => HistOp rep -> String
showsPrec :: Int -> HistOp rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> HistOp rep -> ShowS
Show)
data StreamOrd = InOrder | Disorder
deriving (StreamOrd -> StreamOrd -> Bool
(StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool) -> Eq StreamOrd
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamOrd -> StreamOrd -> Bool
$c/= :: StreamOrd -> StreamOrd -> Bool
== :: StreamOrd -> StreamOrd -> Bool
$c== :: StreamOrd -> StreamOrd -> Bool
Eq, Eq StreamOrd
Eq StreamOrd
-> (StreamOrd -> StreamOrd -> Ordering)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> Bool)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> (StreamOrd -> StreamOrd -> StreamOrd)
-> Ord StreamOrd
StreamOrd -> StreamOrd -> Bool
StreamOrd -> StreamOrd -> Ordering
StreamOrd -> StreamOrd -> StreamOrd
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: StreamOrd -> StreamOrd -> StreamOrd
$cmin :: StreamOrd -> StreamOrd -> StreamOrd
max :: StreamOrd -> StreamOrd -> StreamOrd
$cmax :: StreamOrd -> StreamOrd -> StreamOrd
>= :: StreamOrd -> StreamOrd -> Bool
$c>= :: StreamOrd -> StreamOrd -> Bool
> :: StreamOrd -> StreamOrd -> Bool
$c> :: StreamOrd -> StreamOrd -> Bool
<= :: StreamOrd -> StreamOrd -> Bool
$c<= :: StreamOrd -> StreamOrd -> Bool
< :: StreamOrd -> StreamOrd -> Bool
$c< :: StreamOrd -> StreamOrd -> Bool
compare :: StreamOrd -> StreamOrd -> Ordering
$ccompare :: StreamOrd -> StreamOrd -> Ordering
Ord, Int -> StreamOrd -> ShowS
[StreamOrd] -> ShowS
StreamOrd -> String
(Int -> StreamOrd -> ShowS)
-> (StreamOrd -> String)
-> ([StreamOrd] -> ShowS)
-> Show StreamOrd
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamOrd] -> ShowS
$cshowList :: [StreamOrd] -> ShowS
show :: StreamOrd -> String
$cshow :: StreamOrd -> String
showsPrec :: Int -> StreamOrd -> ShowS
$cshowsPrec :: Int -> StreamOrd -> ShowS
Show)
data StreamForm rep
= Parallel StreamOrd Commutativity (Lambda rep)
| Sequential
deriving (StreamForm rep -> StreamForm rep -> Bool
(StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> Eq (StreamForm rep)
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamForm rep -> StreamForm rep -> Bool
$c/= :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
== :: StreamForm rep -> StreamForm rep -> Bool
$c== :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
Eq, Eq (StreamForm rep)
Eq (StreamForm rep)
-> (StreamForm rep -> StreamForm rep -> Ordering)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> Bool)
-> (StreamForm rep -> StreamForm rep -> StreamForm rep)
-> (StreamForm rep -> StreamForm rep -> StreamForm rep)
-> Ord (StreamForm rep)
StreamForm rep -> StreamForm rep -> Bool
StreamForm rep -> StreamForm rep -> Ordering
StreamForm rep -> StreamForm rep -> StreamForm rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (StreamForm rep)
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Ordering
forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> StreamForm rep
min :: StreamForm rep -> StreamForm rep -> StreamForm rep
$cmin :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> StreamForm rep
max :: StreamForm rep -> StreamForm rep -> StreamForm rep
$cmax :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> StreamForm rep
>= :: StreamForm rep -> StreamForm rep -> Bool
$c>= :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
> :: StreamForm rep -> StreamForm rep -> Bool
$c> :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
<= :: StreamForm rep -> StreamForm rep -> Bool
$c<= :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
< :: StreamForm rep -> StreamForm rep -> Bool
$c< :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Bool
compare :: StreamForm rep -> StreamForm rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
StreamForm rep -> StreamForm rep -> Ordering
Ord, Int -> StreamForm rep -> ShowS
[StreamForm rep] -> ShowS
StreamForm rep -> String
(Int -> StreamForm rep -> ShowS)
-> (StreamForm rep -> String)
-> ([StreamForm rep] -> ShowS)
-> Show (StreamForm rep)
forall rep. RepTypes rep => Int -> StreamForm rep -> ShowS
forall rep. RepTypes rep => [StreamForm rep] -> ShowS
forall rep. RepTypes rep => StreamForm rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamForm rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [StreamForm rep] -> ShowS
show :: StreamForm rep -> String
$cshow :: forall rep. RepTypes rep => StreamForm rep -> String
showsPrec :: Int -> StreamForm rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> StreamForm rep -> ShowS
Show)
data ScremaForm rep
= ScremaForm
[Scan rep]
[Reduce rep]
(Lambda rep)
deriving (ScremaForm rep -> ScremaForm rep -> Bool
(ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScremaForm rep -> ScremaForm rep -> Bool
$c/= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
== :: ScremaForm rep -> ScremaForm rep -> Bool
$c== :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
Eq, Eq (ScremaForm rep)
Eq (ScremaForm rep)
-> (ScremaForm rep -> ScremaForm rep -> Ordering)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> Bool)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> (ScremaForm rep -> ScremaForm rep -> ScremaForm rep)
-> Ord (ScremaForm rep)
ScremaForm rep -> ScremaForm rep -> Bool
ScremaForm rep -> ScremaForm rep -> Ordering
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (ScremaForm rep)
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
min :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmin :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
max :: ScremaForm rep -> ScremaForm rep -> ScremaForm rep
$cmax :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> ScremaForm rep
>= :: ScremaForm rep -> ScremaForm rep -> Bool
$c>= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
> :: ScremaForm rep -> ScremaForm rep -> Bool
$c> :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
<= :: ScremaForm rep -> ScremaForm rep -> Bool
$c<= :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
< :: ScremaForm rep -> ScremaForm rep -> Bool
$c< :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Bool
compare :: ScremaForm rep -> ScremaForm rep -> Ordering
$ccompare :: forall rep.
RepTypes rep =>
ScremaForm rep -> ScremaForm rep -> Ordering
Ord, Int -> ScremaForm rep -> ShowS
[ScremaForm rep] -> ShowS
ScremaForm rep -> String
(Int -> ScremaForm rep -> ShowS)
-> (ScremaForm rep -> String)
-> ([ScremaForm rep] -> ShowS)
-> Show (ScremaForm rep)
forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
forall rep. RepTypes rep => ScremaForm rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScremaForm rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [ScremaForm rep] -> ShowS
show :: ScremaForm rep -> String
$cshow :: forall rep. RepTypes rep => ScremaForm rep -> String
showsPrec :: Int -> ScremaForm rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> ScremaForm rep -> ShowS
Show)
singleBinOp :: Bindable rep => [Lambda rep] -> Lambda rep
singleBinOp :: forall rep. Bindable rep => [Lambda rep] -> Lambda rep
singleBinOp [Lambda rep]
lams =
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam rep]
lambdaParams = (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
xParams [Lambda rep]
lams [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Lambda rep -> [Param Type]) -> [Lambda rep] -> [Param Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Param Type]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
yParams [Lambda rep]
lams,
lambdaReturnType :: [Type]
lambdaReturnType = (Lambda rep -> [Type]) -> [Lambda rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType [Lambda rep]
lams,
lambdaBody :: BodyT rep
lambdaBody =
Stms rep -> [SubExp] -> BodyT rep
forall rep. Bindable rep => Stms rep -> [SubExp] -> Body rep
mkBody
([Stms rep] -> Stms rep
forall a. Monoid a => [a] -> a
mconcat ((Lambda rep -> Stms rep) -> [Lambda rep] -> [Stms rep]
forall a b. (a -> b) -> [a] -> [b]
map (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep)
-> (Lambda rep -> BodyT rep) -> Lambda rep -> Stms rep
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda rep]
lams))
((Lambda rep -> [SubExp]) -> [Lambda rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT rep -> [SubExp])
-> (Lambda rep -> BodyT rep) -> Lambda rep -> [SubExp]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody) [Lambda rep]
lams)
}
where
xParams :: LambdaT rep -> [Param (LParamInfo rep)]
xParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams LambdaT rep
lam)
yParams :: LambdaT rep -> [Param (LParamInfo rep)]
yParams LambdaT rep
lam = Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LambdaT rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType LambdaT rep
lam)) (LambdaT rep -> [Param (LParamInfo rep)]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams LambdaT rep
lam)
data Scan rep = Scan
{ forall rep. Scan rep -> Lambda rep
scanLambda :: Lambda rep,
forall rep. Scan rep -> [SubExp]
scanNeutral :: [SubExp]
}
deriving (Scan rep -> Scan rep -> Bool
(Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool) -> Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Scan rep -> Scan rep -> Bool
$c/= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
== :: Scan rep -> Scan rep -> Bool
$c== :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
Eq, Eq (Scan rep)
Eq (Scan rep)
-> (Scan rep -> Scan rep -> Ordering)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Bool)
-> (Scan rep -> Scan rep -> Scan rep)
-> (Scan rep -> Scan rep -> Scan rep)
-> Ord (Scan rep)
Scan rep -> Scan rep -> Bool
Scan rep -> Scan rep -> Ordering
Scan rep -> Scan rep -> Scan rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Scan rep)
forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
min :: Scan rep -> Scan rep -> Scan rep
$cmin :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
max :: Scan rep -> Scan rep -> Scan rep
$cmax :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Scan rep
>= :: Scan rep -> Scan rep -> Bool
$c>= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
> :: Scan rep -> Scan rep -> Bool
$c> :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
<= :: Scan rep -> Scan rep -> Bool
$c<= :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
< :: Scan rep -> Scan rep -> Bool
$c< :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Bool
compare :: Scan rep -> Scan rep -> Ordering
$ccompare :: forall rep. RepTypes rep => Scan rep -> Scan rep -> Ordering
Ord, Int -> Scan rep -> ShowS
[Scan rep] -> ShowS
Scan rep -> String
(Int -> Scan rep -> ShowS)
-> (Scan rep -> String) -> ([Scan rep] -> ShowS) -> Show (Scan rep)
forall rep. RepTypes rep => Int -> Scan rep -> ShowS
forall rep. RepTypes rep => [Scan rep] -> ShowS
forall rep. RepTypes rep => Scan rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scan rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [Scan rep] -> ShowS
show :: Scan rep -> String
$cshow :: forall rep. RepTypes rep => Scan rep -> String
showsPrec :: Int -> Scan rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> Scan rep -> ShowS
Show)
scanResults :: [Scan rep] -> Int
scanResults :: forall rep. [Scan rep] -> Int
scanResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Scan rep] -> [Int]) -> [Scan rep] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Scan rep -> Int) -> [Scan rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Scan rep -> [SubExp]) -> Scan rep -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral)
singleScan :: Bindable rep => [Scan rep] -> Scan rep
singleScan :: forall rep. Bindable rep => [Scan rep] -> Scan rep
singleScan [Scan rep]
scans =
let scan_nes :: [SubExp]
scan_nes = (Scan rep -> [SubExp]) -> [Scan rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan rep -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan rep]
scans
scan_lam :: Lambda rep
scan_lam = [Lambda rep] -> Lambda rep
forall rep. Bindable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Lambda rep) -> [Scan rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda [Scan rep]
scans
in Lambda rep -> [SubExp] -> Scan rep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda rep
scan_lam [SubExp]
scan_nes
data Reduce rep = Reduce
{ forall rep. Reduce rep -> Commutativity
redComm :: Commutativity,
forall rep. Reduce rep -> Lambda rep
redLambda :: Lambda rep,
forall rep. Reduce rep -> [SubExp]
redNeutral :: [SubExp]
}
deriving (Reduce rep -> Reduce rep -> Bool
(Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool) -> Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Reduce rep -> Reduce rep -> Bool
$c/= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
== :: Reduce rep -> Reduce rep -> Bool
$c== :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
Eq, Eq (Reduce rep)
Eq (Reduce rep)
-> (Reduce rep -> Reduce rep -> Ordering)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Bool)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> (Reduce rep -> Reduce rep -> Reduce rep)
-> Ord (Reduce rep)
Reduce rep -> Reduce rep -> Bool
Reduce rep -> Reduce rep -> Ordering
Reduce rep -> Reduce rep -> Reduce rep
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall rep. RepTypes rep => Eq (Reduce rep)
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
min :: Reduce rep -> Reduce rep -> Reduce rep
$cmin :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
max :: Reduce rep -> Reduce rep -> Reduce rep
$cmax :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Reduce rep
>= :: Reduce rep -> Reduce rep -> Bool
$c>= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
> :: Reduce rep -> Reduce rep -> Bool
$c> :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
<= :: Reduce rep -> Reduce rep -> Bool
$c<= :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
< :: Reduce rep -> Reduce rep -> Bool
$c< :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Bool
compare :: Reduce rep -> Reduce rep -> Ordering
$ccompare :: forall rep. RepTypes rep => Reduce rep -> Reduce rep -> Ordering
Ord, Int -> Reduce rep -> ShowS
[Reduce rep] -> ShowS
Reduce rep -> String
(Int -> Reduce rep -> ShowS)
-> (Reduce rep -> String)
-> ([Reduce rep] -> ShowS)
-> Show (Reduce rep)
forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
forall rep. RepTypes rep => [Reduce rep] -> ShowS
forall rep. RepTypes rep => Reduce rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Reduce rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [Reduce rep] -> ShowS
show :: Reduce rep -> String
$cshow :: forall rep. RepTypes rep => Reduce rep -> String
showsPrec :: Int -> Reduce rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> Reduce rep -> ShowS
Show)
redResults :: [Reduce rep] -> Int
redResults :: forall rep. [Reduce rep] -> Int
redResults = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> ([Reduce rep] -> [Int]) -> [Reduce rep] -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Reduce rep -> Int) -> [Reduce rep] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> (Reduce rep -> [SubExp]) -> Reduce rep -> Int
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral)
singleReduce :: Bindable rep => [Reduce rep] -> Reduce rep
singleReduce :: forall rep. Bindable rep => [Reduce rep] -> Reduce rep
singleReduce [Reduce rep]
reds =
let red_nes :: [SubExp]
red_nes = (Reduce rep -> [SubExp]) -> [Reduce rep] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce rep -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce rep]
reds
red_lam :: Lambda rep
red_lam = [Lambda rep] -> Lambda rep
forall rep. Bindable rep => [Lambda rep] -> Lambda rep
singleBinOp ([Lambda rep] -> Lambda rep) -> [Lambda rep] -> Lambda rep
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Lambda rep) -> [Reduce rep] -> [Lambda rep]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda [Reduce rep]
reds
in Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce ([Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ((Reduce rep -> Commutativity) -> [Reduce rep] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Commutativity
forall rep. Reduce rep -> Commutativity
redComm [Reduce rep]
reds)) Lambda rep
red_lam [SubExp]
red_nes
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType :: forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
[Type]
scan_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
red_tps [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) [Type]
map_tps
where
scan_tps :: [Type]
scan_tps =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
(Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
red_tps :: [Type]
red_tps = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
map_tps :: [Type]
map_tps = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_tps Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
red_tps) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
map_lam
mkIdentityLambda ::
(Bindable rep, MonadFreshNames m) =>
[Type] ->
m (Lambda rep)
mkIdentityLambda :: forall rep (m :: * -> *).
(Bindable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts = do
[Param Type]
params <- (Type -> m (Param Type)) -> [Type] -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"x") [Type]
ts
Lambda rep -> m (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam rep]
lambdaParams = [Param Type]
[LParam rep]
params,
lambdaBody :: BodyT rep
lambdaBody = Stms rep -> [SubExp] -> BodyT rep
forall rep. Bindable rep => Stms rep -> [SubExp] -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty ([SubExp] -> BodyT rep) -> [SubExp] -> BodyT rep
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
params,
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
ts
}
isIdentityLambda :: Lambda rep -> Bool
isIdentityLambda :: forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
lam =
BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam)
[SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== (Param (LParamInfo rep) -> SubExp)
-> [Param (LParamInfo rep)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (LParamInfo rep) -> VName)
-> Param (LParamInfo rep)
-> SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName) (Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam)
nilFn :: Bindable rep => Lambda rep
nilFn :: forall rep. Bindable rep => Lambda rep
nilFn = [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam rep]
forall a. Monoid a => a
mempty (Stms rep -> [SubExp] -> BodyT rep
forall rep. Bindable rep => Stms rep -> [SubExp] -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty [SubExp]
forall a. Monoid a => a
mempty) [Type]
forall a. Monoid a => a
mempty
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC :: forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scans []
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC :: forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm []
scanSOAC ::
(Bindable rep, MonadFreshNames m) =>
[Scan rep] ->
m (ScremaForm rep)
scanSOAC :: forall rep (m :: * -> *).
(Bindable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Scan rep]
scans = [Scan rep] -> Lambda rep -> ScremaForm rep
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan rep]
scans (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Bindable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
where
ts :: [Type]
ts = (Scan rep -> [Type]) -> [Scan rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Scan rep -> Lambda rep) -> Scan rep -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
reduceSOAC ::
(Bindable rep, MonadFreshNames m) =>
[Reduce rep] ->
m (ScremaForm rep)
reduceSOAC :: forall rep (m :: * -> *).
(Bindable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [Reduce rep]
reds = [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC [Reduce rep]
reds (Lambda rep -> ScremaForm rep)
-> m (Lambda rep) -> m (ScremaForm rep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Type] -> m (Lambda rep)
forall rep (m :: * -> *).
(Bindable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
ts
where
ts :: [Type]
ts = (Reduce rep -> [Type]) -> [Reduce rep] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type])
-> (Reduce rep -> Lambda rep) -> Reduce rep -> [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC :: forall rep. Lambda rep -> ScremaForm rep
mapSOAC = [Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] []
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
([Scan rep], Lambda rep) -> Maybe ([Scan rep], Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Scan rep]
scans, Lambda rep
map_lam)
isScanSOAC :: ScremaForm rep -> Maybe [Scan rep]
isScanSOAC :: forall rep. ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm rep
form = do
([Scan rep]
scans, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm rep
form
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
[Scan rep] -> Maybe [Scan rep]
forall (m :: * -> *) a. Monad m => a -> m a
return [Scan rep]
scans
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC :: forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
([Reduce rep], Lambda rep) -> Maybe ([Reduce rep], Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Reduce rep]
reds, Lambda rep
map_lam)
isReduceSOAC :: ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC :: forall rep. ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm rep
form = do
([Reduce rep]
reds, Lambda rep
map_lam) <- ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda rep
map_lam
[Reduce rep] -> Maybe [Reduce rep]
forall (m :: * -> *) a. Monad m => a -> m a
return [Reduce rep]
reds
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC :: forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds
Lambda rep -> Maybe (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda rep
map_lam
groupScatterResults :: [(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults :: forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, array)]
output_spec [a]
results =
let ([Shape]
shapes, [Int]
ns, [array]
arrays) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
in [(Shape, Int, array)] -> [a] -> [([a], a)]
forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results
[([a], a)] -> ([([a], a)] -> [[([a], a)]]) -> [[([a], a)]]
forall a b. a -> (a -> b) -> b
& [Int] -> [([a], a)] -> [[([a], a)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
ns
[[([a], a)]]
-> ([[([a], a)]] -> [(Shape, array, [([a], a)])])
-> [(Shape, array, [([a], a)])]
forall a b. a -> (a -> b) -> b
& [Shape] -> [array] -> [[([a], a)]] -> [(Shape, array, [([a], a)])]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Shape]
shapes [array]
arrays
groupScatterResults' :: [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' :: forall array a. [(Shape, Int, array)] -> [a] -> [([a], a)]
groupScatterResults' [(Shape, Int, array)]
output_spec [a]
results =
let ([a]
indices, [a]
values) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results
([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
chunk_sizes :: [Int]
chunk_sizes =
[[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Int]] -> [Int]) -> [[Int]] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int -> [Int]) -> [Shape] -> [Int] -> [[Int]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Shape
shp Int
n -> Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate Int
n (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
shp) [Shape]
shapes [Int]
ns
in [[a]] -> [a] -> [([a], a)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [a] -> [[a]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
chunk_sizes [a]
indices) [a]
values
splitScatterResults :: [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults :: forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
output_spec [a]
results =
let ([Shape]
shapes, [Int]
ns, [array]
_) = [(Shape, Int, array)] -> ([Shape], [Int], [array])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, array)]
output_spec
num_indices :: Int
num_indices = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
shapes
in Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_indices [a]
results
data SOACMapper frep trep m = SOACMapper
{ forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp :: SubExp -> m SubExp,
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda :: Lambda frep -> m (Lambda trep),
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName :: VName -> m VName
}
identitySOACMapper :: Monad m => SOACMapper rep rep m
identitySOACMapper :: forall (m :: * -> *) rep. Monad m => SOACMapper rep rep m
identitySOACMapper =
SOACMapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper
{ mapOnSOACSubExp :: SubExp -> m SubExp
mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSOACLambda :: Lambda rep -> m (Lambda rep)
mapOnSOACLambda = Lambda rep -> m (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return,
mapOnSOACVName :: VName -> m VName
mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
}
mapSOACM ::
(Applicative m, Monad m) =>
SOACMapper frep trep m ->
SOAC frep ->
m (SOAC trep)
mapSOACM :: forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper frep trep m
tv (Stream SubExp
size [VName]
arrs StreamForm frep
form [SubExp]
accs Lambda frep
lam) =
SubExp
-> [VName]
-> StreamForm trep
-> [SubExp]
-> Lambda trep
-> SOAC trep
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream (SubExp
-> [VName]
-> StreamForm trep
-> [SubExp]
-> Lambda trep
-> SOAC trep)
-> m SubExp
-> m ([VName]
-> StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
size
m ([VName]
-> StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
-> m [VName]
-> m (StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
m (StreamForm trep -> [SubExp] -> Lambda trep -> SOAC trep)
-> m (StreamForm trep) -> m ([SubExp] -> Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> StreamForm frep -> m (StreamForm trep)
mapOnStreamForm StreamForm frep
form
m ([SubExp] -> Lambda trep -> SOAC trep)
-> m [SubExp] -> m (Lambda trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
accs
m (Lambda trep -> SOAC trep) -> m (Lambda trep) -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
where
mapOnStreamForm :: StreamForm frep -> m (StreamForm trep)
mapOnStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda frep
lam0) =
StreamOrd -> Commutativity -> Lambda trep -> StreamForm trep
forall rep.
StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
Parallel StreamOrd
o Commutativity
comm (Lambda trep -> StreamForm trep)
-> m (Lambda trep) -> m (StreamForm trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam0
mapOnStreamForm StreamForm frep
Sequential =
StreamForm trep -> m (StreamForm trep)
forall (f :: * -> *) a. Applicative f => a -> f a
pure StreamForm trep
forall rep. StreamForm rep
Sequential
mapSOACM SOACMapper frep trep m
tv (Scatter SubExp
len Lambda frep
lam [VName]
ivs [(Shape, Int, VName)]
as) =
SubExp
-> Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter
(SubExp
-> Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep)
-> m SubExp
-> m (Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
len
m (Lambda trep -> [VName] -> [(Shape, Int, VName)] -> SOAC trep)
-> m (Lambda trep)
-> m ([VName] -> [(Shape, Int, VName)] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
lam
m ([VName] -> [(Shape, Int, VName)] -> SOAC trep)
-> m [VName] -> m ([(Shape, Int, VName)] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
ivs
m ([(Shape, Int, VName)] -> SOAC trep)
-> m [(Shape, Int, VName)] -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((Shape, Int, VName) -> m (Shape, Int, VName))
-> [(Shape, Int, VName)] -> m [(Shape, Int, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
( \(Shape
aw, Int
an, VName
a) ->
(,,) (Shape -> Int -> VName -> (Shape, Int, VName))
-> m Shape -> m (Int -> VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> Shape -> m Shape
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) Shape
aw
m (Int -> VName -> (Shape, Int, VName))
-> m Int -> m (VName -> (Shape, Int, VName))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> m Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
an
m (VName -> (Shape, Int, VName))
-> m VName -> m (Shape, Int, VName)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv VName
a
)
[(Shape, Int, VName)]
as
mapSOACM SOACMapper frep trep m
tv (Hist SubExp
len [HistOp frep]
ops Lambda frep
bucket_fun [VName]
imgs) =
SubExp -> [HistOp trep] -> Lambda trep -> [VName] -> SOAC trep
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [VName] -> SOAC rep
Hist
(SubExp -> [HistOp trep] -> Lambda trep -> [VName] -> SOAC trep)
-> m SubExp
-> m ([HistOp trep] -> Lambda trep -> [VName] -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
len
m ([HistOp trep] -> Lambda trep -> [VName] -> SOAC trep)
-> m [HistOp trep] -> m (Lambda trep -> [VName] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HistOp frep -> m (HistOp trep))
-> [HistOp frep] -> m [HistOp trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
( \(HistOp SubExp
e SubExp
rf [VName]
arrs [SubExp]
nes Lambda frep
op) ->
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
SubExp -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp (SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m SubExp
-> m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
e
m (SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m SubExp
-> m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
rf
m ([VName] -> [SubExp] -> Lambda trep -> HistOp trep)
-> m [VName] -> m ([SubExp] -> Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
m ([SubExp] -> Lambda trep -> HistOp trep)
-> m [SubExp] -> m (Lambda trep -> HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
nes
m (Lambda trep -> HistOp trep)
-> m (Lambda trep) -> m (HistOp trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
op
)
[HistOp frep]
ops
m (Lambda trep -> [VName] -> SOAC trep)
-> m (Lambda trep) -> m ([VName] -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
bucket_fun
m ([VName] -> SOAC trep) -> m [VName] -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
imgs
mapSOACM SOACMapper frep trep m
tv (Screma SubExp
w [VName]
arrs (ScremaForm [Scan frep]
scans [Reduce frep]
reds Lambda frep
map_lam)) =
SubExp -> [VName] -> ScremaForm trep -> SOAC trep
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (SubExp -> [VName] -> ScremaForm trep -> SOAC trep)
-> m SubExp -> m ([VName] -> ScremaForm trep -> SOAC trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv SubExp
w
m ([VName] -> ScremaForm trep -> SOAC trep)
-> m [VName] -> m (ScremaForm trep -> SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> VName -> m VName
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> VName -> m VName
mapOnSOACVName SOACMapper frep trep m
tv) [VName]
arrs
m (ScremaForm trep -> SOAC trep)
-> m (ScremaForm trep) -> m (SOAC trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ( [Scan trep] -> [Reduce trep] -> Lambda trep -> ScremaForm trep
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
([Scan trep] -> [Reduce trep] -> Lambda trep -> ScremaForm trep)
-> m [Scan trep]
-> m ([Reduce trep] -> Lambda trep -> ScremaForm trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Scan frep] -> (Scan frep -> m (Scan trep)) -> m [Scan trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
[Scan frep]
scans
( \(Scan Lambda frep
red_lam [SubExp]
red_nes) ->
Lambda trep -> [SubExp] -> Scan trep
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan (Lambda trep -> [SubExp] -> Scan trep)
-> m (Lambda trep) -> m ([SubExp] -> Scan trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
m ([SubExp] -> Scan trep) -> m [SubExp] -> m (Scan trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
)
m ([Reduce trep] -> Lambda trep -> ScremaForm trep)
-> m [Reduce trep] -> m (Lambda trep -> ScremaForm trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Reduce frep]
-> (Reduce frep -> m (Reduce trep)) -> m [Reduce trep]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM
[Reduce frep]
reds
( \(Reduce Commutativity
comm Lambda frep
red_lam [SubExp]
red_nes) ->
Commutativity -> Lambda trep -> [SubExp] -> Reduce trep
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm (Lambda trep -> [SubExp] -> Reduce trep)
-> m (Lambda trep) -> m ([SubExp] -> Reduce trep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
red_lam
m ([SubExp] -> Reduce trep) -> m [SubExp] -> m (Reduce trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SOACMapper frep trep m -> SubExp -> m SubExp
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> SubExp -> m SubExp
mapOnSOACSubExp SOACMapper frep trep m
tv) [SubExp]
red_nes
)
m (Lambda trep -> ScremaForm trep)
-> m (Lambda trep) -> m (ScremaForm trep)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
forall frep trep (m :: * -> *).
SOACMapper frep trep m -> Lambda frep -> m (Lambda trep)
mapOnSOACLambda SOACMapper frep trep m
tv Lambda frep
map_lam
)
instance ASTRep rep => FreeIn (SOAC rep) where
freeIn' :: SOAC rep -> FV
freeIn' = (State FV (SOAC rep) -> FV -> FV)
-> FV -> State FV (SOAC rep) -> FV
forall a b c. (a -> b -> c) -> b -> a -> c
flip State FV (SOAC rep) -> FV -> FV
forall s a. State s a -> s -> s
execState FV
forall a. Monoid a => a
mempty (State FV (SOAC rep) -> FV)
-> (SOAC rep -> State FV (SOAC rep)) -> SOAC rep -> FV
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep (StateT FV Identity)
-> SOAC rep -> State FV (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep (StateT FV Identity)
free
where
walk :: (b -> s) -> b -> m b
walk b -> s
f b
x = (s -> s) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (s -> s -> s
forall a. Semigroup a => a -> a -> a
<> b -> s
f b
x) m () -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
x
free :: SOACMapper rep rep (StateT FV Identity)
free =
SOACMapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper
{ mapOnSOACSubExp :: SubExp -> StateT FV Identity SubExp
mapOnSOACSubExp = (SubExp -> FV) -> SubExp -> StateT FV Identity SubExp
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSOACLambda :: Lambda rep -> StateT FV Identity (Lambda rep)
mapOnSOACLambda = (Lambda rep -> FV) -> Lambda rep -> StateT FV Identity (Lambda rep)
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk Lambda rep -> FV
forall a. FreeIn a => a -> FV
freeIn',
mapOnSOACVName :: VName -> StateT FV Identity VName
mapOnSOACVName = (VName -> FV) -> VName -> StateT FV Identity VName
forall {m :: * -> *} {s} {b}.
(MonadState s m, Semigroup s) =>
(b -> s) -> b -> m b
walk VName -> FV
forall a. FreeIn a => a -> FV
freeIn'
}
instance ASTRep rep => Substitute (SOAC rep) where
substituteNames :: Map VName VName -> SOAC rep -> SOAC rep
substituteNames Map VName VName
subst =
Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC rep -> Identity (SOAC rep)) -> SOAC rep -> SOAC rep
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper rep rep Identity -> SOAC rep -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep Identity
substitute
where
substitute :: SOACMapper rep rep Identity
substitute =
SOACMapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper
{ mapOnSOACSubExp :: SubExp -> Identity SubExp
mapOnSOACSubExp = SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> SubExp -> SubExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSOACLambda :: Lambda rep -> Identity (Lambda rep)
mapOnSOACLambda = Lambda rep -> Identity (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep))
-> (Lambda rep -> Lambda rep)
-> Lambda rep
-> Identity (Lambda rep)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> Lambda rep -> Lambda rep
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst,
mapOnSOACVName :: VName -> Identity VName
mapOnSOACVName = VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Identity VName)
-> (VName -> VName) -> VName -> Identity VName
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst
}
instance ASTRep rep => Rename (SOAC rep) where
rename :: SOAC rep -> RenameM (SOAC rep)
rename = SOACMapper rep rep RenameM -> SOAC rep -> RenameM (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper rep rep RenameM
renamer
where
renamer :: SOACMapper rep rep RenameM
renamer = (SubExp -> RenameM SubExp)
-> (Lambda rep -> RenameM (Lambda rep))
-> (VName -> RenameM VName)
-> SOACMapper rep rep RenameM
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> RenameM SubExp
forall a. Rename a => a -> RenameM a
rename Lambda rep -> RenameM (Lambda rep)
forall a. Rename a => a -> RenameM a
rename VName -> RenameM VName
forall a. Rename a => a -> RenameM a
rename
soacType :: SOAC rep -> [Type]
soacType :: forall rep. SOAC rep -> [Type]
soacType (Stream SubExp
outersize [VName]
_ StreamForm rep
_ [SubExp]
accs Lambda rep
lam) =
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
substs) [Type]
rtp
where
nms :: [VName]
nms = (Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
take (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [Param (LParamInfo rep)]
params
substs :: Map VName SubExp
substs = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nms (SubExp
outersize SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
accs)
Lambda [Param (LParamInfo rep)]
params BodyT rep
_ [Type]
rtp = Lambda rep
lam
soacType (Scatter SubExp
_w Lambda rep
lam [VName]
_ivs [(Shape, Int, VName)]
as) =
(Type -> Shape -> Type) -> [Type] -> [Shape] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
where
indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
val_ts :: [Type]
val_ts = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda rep
lam
([Shape]
ws, [Int]
ns, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
soacType (Hist SubExp
_len [HistOp rep]
ops Lambda rep
_bucket_fun [VName]
_imgs) = do
HistOp rep
op <- [HistOp rep]
ops
(Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` HistOp rep -> SubExp
forall rep. HistOp rep -> SubExp
histWidth HistOp rep
op) (Lambda rep -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType (Lambda rep -> [Type]) -> Lambda rep -> [Type]
forall a b. (a -> b) -> a -> b
$ HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
soacType (Screma SubExp
w [VName]
_arrs ScremaForm rep
form) =
SubExp -> ScremaForm rep -> [Type]
forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form
instance TypedOp (SOAC rep) where
opType :: forall t (m :: * -> *). HasScope t m => SOAC rep -> m [ExtType]
opType = [ExtType] -> m [ExtType]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([ExtType] -> m [ExtType])
-> (SOAC rep -> [ExtType]) -> SOAC rep -> m [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType])
-> (SOAC rep -> [Type]) -> SOAC rep -> [ExtType]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. SOAC rep -> [Type]
soacType
instance (ASTRep rep, Aliased rep) => AliasedOp (SOAC rep) where
opAliases :: SOAC rep -> [Names]
opAliases = (Type -> Names) -> [Type] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Type -> Names
forall a b. a -> b -> a
const Names
forall a. Monoid a => a
mempty) ([Type] -> [Names]) -> (SOAC rep -> [Type]) -> SOAC rep -> [Names]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOAC rep -> [Type]
forall rep. SOAC rep -> [Type]
soacType
consumedInOp :: SOAC rep -> Names
consumedInOp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
map_lam)) =
(VName -> VName) -> Names -> Names
mapNames VName -> VName
consumedArray (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
map_lam
where
consumedArray :: VName -> VName
consumedArray VName
v = VName -> Maybe VName -> VName
forall a. a -> Maybe a -> a
fromMaybe VName
v (Maybe VName -> VName) -> Maybe VName -> VName
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, VName)] -> Maybe VName
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, VName)]
params_to_arrs
params_to_arrs :: [(VName, VName)]
params_to_arrs = [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
map_lam) [VName]
arrs
consumedInOp (Stream SubExp
_ [VName]
arrs StreamForm rep
form [SubExp]
accs Lambda rep
lam) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [VName]
subExpVars ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$
case StreamForm rep
form of
StreamForm rep
Sequential ->
(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
Parallel {} ->
(VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
consumedArray ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Names
forall rep. Aliased rep => Lambda rep -> Names
consumedByLambda Lambda rep
lam
where
consumedArray :: VName -> SubExp
consumedArray VName
v = SubExp -> Maybe SubExp -> SubExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> SubExp
Var VName
v) (Maybe SubExp -> SubExp) -> Maybe SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ VName -> [(VName, SubExp)] -> Maybe SubExp
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup VName
v [(VName, SubExp)]
paramsToInput
paramsToInput :: [(VName, SubExp)]
paramsToInput =
[VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo rep)] -> [VName])
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
1 ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda rep
lam) ([SubExp]
accs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs)
consumedInOp (Scatter SubExp
_ Lambda rep
_ [VName]
_ [(Shape, Int, VName)]
as) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
a) -> VName
a) [(Shape, Int, VName)]
as
consumedInOp (Hist SubExp
_ [HistOp rep]
ops Lambda rep
_ [VName]
_) =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> [VName]) -> [HistOp rep] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap HistOp rep -> [VName]
forall rep. HistOp rep -> [VName]
histDest [HistOp rep]
ops
mapHistOp ::
(Lambda frep -> Lambda trep) ->
HistOp frep ->
HistOp trep
mapHistOp :: forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp Lambda frep -> Lambda trep
f (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda frep
lam) =
SubExp
-> SubExp -> [VName] -> [SubExp] -> Lambda trep -> HistOp trep
forall rep.
SubExp -> SubExp -> [VName] -> [SubExp] -> Lambda rep -> HistOp rep
HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes (Lambda trep -> HistOp trep) -> Lambda trep -> HistOp trep
forall a b. (a -> b) -> a -> b
$ Lambda frep -> Lambda trep
f Lambda frep
lam
instance
( ASTRep rep,
ASTRep (Aliases rep),
CanBeAliased (Op rep)
) =>
CanBeAliased (SOAC rep)
where
type OpWithAliases (SOAC rep) = SOAC (Aliases rep)
addOpAliases :: AliasTable -> SOAC rep -> OpWithAliases (SOAC rep)
addOpAliases AliasTable
aliases (Stream SubExp
size [VName]
arr StreamForm rep
form [SubExp]
accs Lambda rep
lam) =
SubExp
-> [VName]
-> StreamForm (Aliases rep)
-> [SubExp]
-> Lambda (Aliases rep)
-> SOAC (Aliases rep)
forall rep.
SubExp
-> [VName] -> StreamForm rep -> [SubExp] -> Lambda rep -> SOAC rep
Stream SubExp
size [VName]
arr (StreamForm rep -> StreamForm (Aliases rep)
analyseStreamForm StreamForm rep
form) [SubExp]
accs (Lambda (Aliases rep) -> SOAC (Aliases rep))
-> Lambda (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$
AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam
where
analyseStreamForm :: StreamForm rep -> StreamForm (Aliases rep)
analyseStreamForm (Parallel StreamOrd
o Commutativity
comm Lambda rep
lam0) =
StreamOrd
-> Commutativity
-> Lambda (Aliases rep)
-> StreamForm (Aliases rep)
forall rep.
StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
Parallel StreamOrd
o Commutativity
comm (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam0)
analyseStreamForm StreamForm rep
Sequential = StreamForm (Aliases rep)
forall rep. StreamForm rep
Sequential
addOpAliases AliasTable
aliases (Scatter SubExp
len Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
as) =
SubExp
-> Lambda (Aliases rep)
-> [VName]
-> [(Shape, Int, VName)]
-> SOAC (Aliases rep)
forall rep.
SubExp
-> Lambda rep -> [VName] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
lam) [VName]
ivs [(Shape, Int, VName)]
as
addOpAliases AliasTable
aliases (Hist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
imgs) =
SubExp
-> [HistOp (Aliases rep)]
-> Lambda (Aliases rep)
-> [VName]
-> SOAC (Aliases rep)
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [VName] -> SOAC rep
Hist
SubExp
len
((HistOp rep -> HistOp (Aliases rep))
-> [HistOp rep] -> [HistOp (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map ((Lambda rep -> Lambda (Aliases rep))
-> HistOp rep -> HistOp (Aliases rep)
forall frep trep.
(Lambda frep -> Lambda trep) -> HistOp frep -> HistOp trep
mapHistOp (AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases)) [HistOp rep]
ops)
(AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
bucket_fun)
[VName]
imgs
addOpAliases AliasTable
aliases (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
SubExp -> [VName] -> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName]
arrs (ScremaForm (Aliases rep) -> SOAC (Aliases rep))
-> ScremaForm (Aliases rep) -> SOAC (Aliases rep)
forall a b. (a -> b) -> a -> b
$
[Scan (Aliases rep)]
-> [Reduce (Aliases rep)]
-> Lambda (Aliases rep)
-> ScremaForm (Aliases rep)
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm
((Scan rep -> Scan (Aliases rep))
-> [Scan rep] -> [Scan (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Scan (Aliases rep)
onScan [Scan rep]
scans)
((Reduce rep -> Reduce (Aliases rep))
-> [Reduce rep] -> [Reduce (Aliases rep)]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Reduce (Aliases rep)
onRed [Reduce rep]
reds)
(AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases Lambda rep
map_lam)
where
onRed :: Reduce rep -> Reduce (Aliases rep)
onRed Reduce rep
red = Reduce rep
red {redLambda :: Lambda (Aliases rep)
redLambda = AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep -> Lambda (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda Reduce rep
red}
onScan :: Scan rep -> Scan (Aliases rep)
onScan Scan rep
scan = Scan rep
scan {scanLambda :: Lambda (Aliases rep)
scanLambda = AliasTable -> Lambda rep -> Lambda (Aliases rep)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Lambda rep -> Lambda (Aliases rep)
Alias.analyseLambda AliasTable
aliases (Lambda rep -> Lambda (Aliases rep))
-> Lambda rep -> Lambda (Aliases rep)
forall a b. (a -> b) -> a -> b
$ Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda Scan rep
scan}
removeOpAliases :: OpWithAliases (SOAC rep) -> SOAC rep
removeOpAliases = Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC (Aliases rep) -> Identity (SOAC rep))
-> SOAC (Aliases rep)
-> SOAC rep
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Aliases rep) rep Identity
-> SOAC (Aliases rep) -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Aliases rep) rep Identity
remove
where
remove :: SOACMapper (Aliases rep) rep Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Aliases rep) -> Identity (Lambda rep))
-> (VName -> Identity VName)
-> SOACMapper (Aliases rep) rep Identity
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Aliases rep) -> Lambda rep)
-> Lambda (Aliases rep)
-> Identity (Lambda rep)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Aliases rep) -> Lambda rep
forall rep.
CanBeAliased (Op rep) =>
Lambda (Aliases rep) -> Lambda rep
removeLambdaAliases) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
instance ASTRep rep => IsOp (SOAC rep) where
safeOp :: SOAC rep -> Bool
safeOp SOAC rep
_ = Bool
False
cheapOp :: SOAC rep -> Bool
cheapOp SOAC rep
_ = Bool
True
substNamesInType :: M.Map VName SubExp -> Type -> Type
substNamesInType :: Map VName SubExp -> Type -> Type
substNamesInType Map VName SubExp
_ t :: Type
t@Prim {} = Type
t
substNamesInType Map VName SubExp
_ t :: Type
t@Acc {} = Type
t
substNamesInType Map VName SubExp
_ (Mem Space
space) = Space -> Type
forall shape u. Space -> TypeBase shape u
Mem Space
space
substNamesInType Map VName SubExp
subs (Array PrimType
btp Shape
shp NoUniqueness
u) =
let shp' :: Shape
shp' = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
subs) (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shp)
in PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
btp Shape
shp' NoUniqueness
u
substNamesInSubExp :: M.Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp :: Map VName SubExp -> SubExp -> SubExp
substNamesInSubExp Map VName SubExp
_ e :: SubExp
e@(Constant PrimValue
_) = SubExp
e
substNamesInSubExp Map VName SubExp
subs (Var VName
idd) =
SubExp -> VName -> Map VName SubExp -> SubExp
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> SubExp
Var VName
idd) VName
idd Map VName SubExp
subs
instance (ASTRep rep, CanBeWise (Op rep)) => CanBeWise (SOAC rep) where
type OpWithWisdom (SOAC rep) = SOAC (Wise rep)
removeOpWisdom :: OpWithWisdom (SOAC rep) -> SOAC rep
removeOpWisdom = Identity (SOAC rep) -> SOAC rep
forall a. Identity a -> a
runIdentity (Identity (SOAC rep) -> SOAC rep)
-> (SOAC (Wise rep) -> Identity (SOAC rep))
-> SOAC (Wise rep)
-> SOAC rep
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. SOACMapper (Wise rep) rep Identity
-> SOAC (Wise rep) -> Identity (SOAC rep)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper (Wise rep) rep Identity
remove
where
remove :: SOACMapper (Wise rep) rep Identity
remove = (SubExp -> Identity SubExp)
-> (Lambda (Wise rep) -> Identity (Lambda rep))
-> (VName -> Identity VName)
-> SOACMapper (Wise rep) rep Identity
forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOACMapper SubExp -> Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda rep -> Identity (Lambda rep))
-> (Lambda (Wise rep) -> Lambda rep)
-> Lambda (Wise rep)
-> Identity (Lambda rep)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Lambda (Wise rep) -> Lambda rep
forall rep. CanBeWise (Op rep) => Lambda (Wise rep) -> Lambda rep
removeLambdaWisdom) VName -> Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
instance RepTypes rep => ST.IndexOp (SOAC rep) where
indexOp :: forall rep.
(ASTRep rep, IndexOp (Op rep)) =>
SymbolTable rep
-> Int -> SOAC rep -> [TPrimExp Int64 VName] -> Maybe Indexed
indexOp SymbolTable rep
vtable Int
k SOAC rep
soac [TPrimExp Int64 VName
i] = do
(LambdaT rep
lam, SubExp
se, [Param (LParamInfo rep)]
arr_params, [VName]
arrs) <- SOAC rep
-> Maybe (LambdaT rep, SubExp, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp SOAC rep
soac
let arr_indexes :: Map VName (PrimExp VName, Certificates)
arr_indexes = [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates))
-> [(VName, (PrimExp VName, Certificates))]
-> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))])
-> [Maybe (VName, (PrimExp VName, Certificates))]
-> [(VName, (PrimExp VName, Certificates))]
forall a b. (a -> b) -> a -> b
$ (Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certificates)))
-> [Param (LParamInfo rep)]
-> [VName]
-> [Maybe (VName, (PrimExp VName, Certificates))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex [Param (LParamInfo rep)]
arr_params [VName]
arrs
arr_indexes' :: Map VName (PrimExp VName, Certificates)
arr_indexes' = (Map VName (PrimExp VName, Certificates)
-> Stm rep -> Map VName (PrimExp VName, Certificates))
-> Map VName (PrimExp VName, Certificates)
-> Seq (Stm rep)
-> Map VName (PrimExp VName, Certificates)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Map VName (PrimExp VName, Certificates)
-> Stm rep -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
arr_indexes (Seq (Stm rep) -> Map VName (PrimExp VName, Certificates))
-> Seq (Stm rep) -> Map VName (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Seq (Stm rep)
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Seq (Stm rep)) -> BodyT rep -> Seq (Stm rep)
forall a b. (a -> b) -> a -> b
$ LambdaT rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT rep
lam
case SubExp
se of
Var VName
v -> (PrimExp VName -> Certificates -> Indexed)
-> (PrimExp VName, Certificates) -> Indexed
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Certificates -> PrimExp VName -> Indexed)
-> PrimExp VName -> Certificates -> Indexed
forall a b c. (a -> b -> c) -> b -> a -> c
flip Certificates -> PrimExp VName -> Indexed
ST.Indexed) ((PrimExp VName, Certificates) -> Indexed)
-> Maybe (PrimExp VName, Certificates) -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
arr_indexes'
SubExp
_ -> Maybe Indexed
forall a. Maybe a
Nothing
where
lambdaAndSubExp :: SOAC rep
-> Maybe (LambdaT rep, SubExp, [Param (LParamInfo rep)], [VName])
lambdaAndSubExp (Screma SubExp
_ [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds LambdaT rep
map_lam)) =
Int
-> LambdaT rep
-> [VName]
-> Maybe (LambdaT rep, SubExp, [Param (LParamInfo rep)], [VName])
nthMapOut ([Scan rep] -> Int
forall rep. [Scan rep] -> Int
scanResults [Scan rep]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce rep] -> Int
forall rep. [Reduce rep] -> Int
redResults [Reduce rep]
reds) LambdaT rep
map_lam [VName]
arrs
lambdaAndSubExp SOAC rep
_ =
Maybe (LambdaT rep, SubExp, [Param (LParamInfo rep)], [VName])
forall a. Maybe a
Nothing
nthMapOut :: Int
-> LambdaT rep
-> [VName]
-> Maybe (LambdaT rep, SubExp, [Param (LParamInfo rep)], [VName])
nthMapOut Int
num_accs LambdaT rep
lam [VName]
arrs = do
SubExp
se <- Int -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
num_accs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) ([SubExp] -> Maybe SubExp) -> [SubExp] -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (BodyT rep -> [SubExp]) -> BodyT rep -> [SubExp]
forall a b. (a -> b) -> a -> b
$ LambdaT rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody LambdaT rep
lam
(LambdaT rep, SubExp, [Param (LParamInfo rep)], [VName])
-> Maybe (LambdaT rep, SubExp, [Param (LParamInfo rep)], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return (LambdaT rep
lam, SubExp
se, Int -> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a. Int -> [a] -> [a]
drop Int
num_accs ([Param (LParamInfo rep)] -> [Param (LParamInfo rep)])
-> [Param (LParamInfo rep)] -> [Param (LParamInfo rep)]
forall a b. (a -> b) -> a -> b
$ LambdaT rep -> [Param (LParamInfo rep)]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams LambdaT rep
lam, [VName]
arrs)
arrIndex :: Param (LParamInfo rep)
-> VName -> Maybe (VName, (PrimExp VName, Certificates))
arrIndex Param (LParamInfo rep)
p VName
arr = do
ST.Indexed Certificates
cs PrimExp VName
pe <- VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
forall rep.
VName -> [TPrimExp Int64 VName] -> SymbolTable rep -> Maybe Indexed
ST.index' VName
arr [TPrimExp Int64 VName
i] SymbolTable rep
vtable
(VName, (PrimExp VName, Certificates))
-> Maybe (VName, (PrimExp VName, Certificates))
forall (m :: * -> *) a. Monad m => a -> m a
return (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
p, (PrimExp VName
pe, Certificates
cs))
expandPrimExpTable :: Map VName (PrimExp VName, Certificates)
-> Stm rep -> Map VName (PrimExp VName, Certificates)
expandPrimExpTable Map VName (PrimExp VName, Certificates)
table Stm rep
stm
| [VName
v] <- PatternT (LetDec rep) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec rep) -> [VName])
-> PatternT (LetDec rep) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> PatternT (LetDec rep)
forall rep. Stm rep -> Pattern rep
stmPattern Stm rep
stm,
Just (PrimExp VName
pe, Certificates
cs) <-
WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates))
-> WriterT Certificates Maybe (PrimExp VName)
-> Maybe (PrimExp VName, Certificates)
forall a b. (a -> b) -> a -> b
$ (VName -> WriterT Certificates Maybe (PrimExp VName))
-> Exp rep -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp (Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table) (Exp rep -> WriterT Certificates Maybe (PrimExp VName))
-> Exp rep -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable rep
vtable) (Certificates -> [VName]
unCertificates (Certificates -> [VName]) -> Certificates -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm rep -> Certificates
forall rep. Stm rep -> Certificates
stmCerts Stm rep
stm) =
VName
-> (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
-> Map VName (PrimExp VName, Certificates)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v (PrimExp VName
pe, Stm rep -> Certificates
forall rep. Stm rep -> Certificates
stmCerts Stm rep
stm Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs) Map VName (PrimExp VName, Certificates)
table
| Bool
otherwise =
Map VName (PrimExp VName, Certificates)
table
asPrimExp :: Map VName (PrimExp VName, Certificates)
-> VName -> WriterT Certificates Maybe (PrimExp VName)
asPrimExp Map VName (PrimExp VName, Certificates)
table VName
v
| Just (PrimExp VName
e, Certificates
cs) <- VName
-> Map VName (PrimExp VName, Certificates)
-> Maybe (PrimExp VName, Certificates)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (PrimExp VName, Certificates)
table = Certificates -> WriterT Certificates Maybe ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Certificates
cs WriterT Certificates Maybe ()
-> WriterT Certificates Maybe (PrimExp VName)
-> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return PrimExp VName
e
| Just (Prim PrimType
pt) <- VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v SymbolTable rep
vtable =
PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> WriterT Certificates Maybe (PrimExp VName))
-> PrimExp VName -> WriterT Certificates Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt
| Bool
otherwise = Maybe (PrimExp VName) -> WriterT Certificates Maybe (PrimExp VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Maybe (PrimExp VName)
forall a. Maybe a
Nothing
indexOp SymbolTable rep
_ Int
_ SOAC rep
_ [TPrimExp Int64 VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing
typeCheckSOAC :: TC.Checkable rep => SOAC (Aliases rep) -> TC.TypeM rep ()
typeCheckSOAC :: forall rep. Checkable rep => SOAC (Aliases rep) -> TypeM rep ()
typeCheckSOAC (Stream SubExp
size [VName]
arrexps StreamForm (Aliases rep)
form [SubExp]
accexps Lambda (Aliases rep)
lam) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
size
[Arg]
accargs <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
accexps
[Type]
arrargs <- (VName -> TypeM rep Type) -> [VName] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> TypeM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrexps
[Arg]
_ <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
size [VName]
arrexps
let chunk :: Param (LParamInfo rep)
chunk = [Param (LParamInfo rep)] -> Param (LParamInfo rep)
forall a. [a] -> a
head ([Param (LParamInfo rep)] -> Param (LParamInfo rep))
-> [Param (LParamInfo rep)] -> Param (LParamInfo rep)
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases rep) -> [LParam (Aliases rep)]
forall {rep}. LambdaT rep -> [Param (LParamInfo rep)]
lambdaParams Lambda (Aliases rep)
lam
let asArg :: a -> (a, b)
asArg a
t = (a
t, b
forall a. Monoid a => a
mempty)
inttp :: TypeBase shape u
inttp = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
lamarrs' :: [Type]
lamarrs' = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` VName -> SubExp
Var (Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo rep)
chunk)) [Type]
arrargs
let acc_len :: Int
acc_len = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accexps
let lamrtp :: [Type]
lamrtp = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
acc_len ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
lamrtp) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"Stream with inconsistent accumulator type in lambda."
()
_ <- case StreamForm (Aliases rep)
form of
Parallel StreamOrd
_ Commutativity
_ Lambda (Aliases rep)
lam0 -> do
let acct :: [Type]
acct = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
accargs
outerRetType :: [Type]
outerRetType = Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam0
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam0 ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
accargs
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
acct [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== [Type]
outerRetType) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Initial value is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
acct
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", but stream's reduce lambda returns type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
outerRetType
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"."
StreamForm (Aliases rep)
Sequential -> () -> TypeM rep ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
let fake_lamarrs' :: [Arg]
fake_lamarrs' = (Type -> Arg) -> [Type] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg [Type]
lamarrs'
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ Type -> Arg
forall {b} {a}. Monoid b => a -> (a, b)
asArg Type
forall {shape} {u}. TypeBase shape u
inttp Arg -> [Arg] -> [Arg]
forall a. a -> [a] -> [a]
: [Arg]
accargs [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
fake_lamarrs'
typeCheckSOAC (Scatter SubExp
w Lambda (Aliases rep)
lam [VName]
ivs [(Shape, Int, VName)]
as) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
let ([Shape]
as_ws, [Int]
as_ns, [VName]
_as_vs) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
indexes :: Int
indexes = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws
rts :: [Type]
rts = Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
lam
rtsI :: [Type]
rtsI = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take Int
indexes [Type]
rts
rtsV :: [Type]
rtsV = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
indexes [Type]
rts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
rts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
as_ns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) [Int]
as_ns ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Shape -> Int) -> [Shape] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
as_ws)) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"Scatter: number of index types, value types and array outputs do not match."
[Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtsI ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtI ->
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64 Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
rtI) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError String
"Scatter: Index return type must be i64."
[([Type], (Shape, Int, VName))]
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[Type]]
-> [(Shape, Int, VName)] -> [([Type], (Shape, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [Type] -> [[Type]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
as_ns [Type]
rtsV) [(Shape, Int, VName)]
as) ((([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ())
-> (([Type], (Shape, Int, VName)) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \([Type]
rtVs, (Shape
aw, Int
_, VName
a)) -> do
(SubExp -> TypeM rep ()) -> Shape -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ([Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64]) Shape
aw
[Type] -> (Type -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Type]
rtVs ((Type -> TypeM rep ()) -> TypeM rep ())
-> (Type -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \Type
rtV -> [Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type -> Shape -> Type
arrayOfShape Type
rtV Shape
aw] VName
a
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
a
[Arg]
arrargs <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
ivs
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
lam [Arg]
arrargs
typeCheckSOAC (Hist SubExp
len [HistOp (Aliases rep)]
ops Lambda (Aliases rep)
bucket_fun [VName]
imgs) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
len
[HistOp (Aliases rep)]
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [HistOp (Aliases rep)]
ops ((HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ())
-> (HistOp (Aliases rep) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
dest_w SubExp
rf [VName]
dests [SubExp]
nes Lambda (Aliases rep)
op) -> do
[Arg]
nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
nes
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
dest_w
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
rf
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
op ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
nes'
let nes_t :: [Type]
nes_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
nes_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Operator has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
op)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
nes_t
[(Type, VName)] -> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Type] -> [VName] -> [(Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
nes_t [VName]
dests) (((Type, VName) -> TypeM rep ()) -> TypeM rep ())
-> ((Type, VName) -> TypeM rep ()) -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ \(Type
t, VName
dest) -> do
[Type] -> VName -> TypeM rep ()
forall rep. Checkable rep => [Type] -> VName -> TypeM rep ()
TC.requireI [Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
dest_w] VName
dest
Names -> TypeM rep ()
forall rep. Checkable rep => Names -> TypeM rep ()
TC.consume (Names -> TypeM rep ()) -> TypeM rep Names -> TypeM rep ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TypeM rep Names
forall rep. Checkable rep => VName -> TypeM rep Names
TC.lookupAliases VName
dest
[Arg]
img' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
len [VName]
imgs
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
bucket_fun [Arg]
img'
[Type]
nes_ts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type]) -> TypeM rep [[Type]] -> TypeM rep [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp (Aliases rep) -> TypeM rep [Type])
-> [HistOp (Aliases rep)] -> TypeM rep [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> TypeM rep Type) -> [SubExp] -> TypeM rep [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType ([SubExp] -> TypeM rep [Type])
-> (HistOp (Aliases rep) -> [SubExp])
-> HistOp (Aliases rep)
-> TypeM rep [Type]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp (Aliases rep) -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp (Aliases rep)]
ops
let bucket_ret_t :: [Type]
bucket_ret_t = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate ([HistOp (Aliases rep)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp (Aliases rep)]
ops) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
nes_ts
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
bucket_ret_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Bucket function has return type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
bucket_fun)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but should have type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
bucket_ret_t
typeCheckSOAC (Screma SubExp
w [VName]
arrs (ScremaForm [Scan (Aliases rep)]
scans [Reduce (Aliases rep)]
reds Lambda (Aliases rep)
map_lam)) = do
[Type] -> SubExp -> TypeM rep ()
forall rep. Checkable rep => [Type] -> SubExp -> TypeM rep ()
TC.require [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64] SubExp
w
[Arg]
arrs' <- SubExp -> [VName] -> TypeM rep [Arg]
forall rep. Checkable rep => SubExp -> [VName] -> TypeM rep [Arg]
TC.checkSOACArrayArgs SubExp
w [VName]
arrs
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
map_lam [Arg]
arrs'
[Arg]
scan_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
[Scan (Aliases rep)]
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan (Aliases rep)]
scans ((Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Scan (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda (Aliases rep)
scan_lam [SubExp]
scan_nes) -> do
[Arg]
scan_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
scan_nes
let scan_t :: [Type]
scan_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
scan_nes'
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
scan_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
scan_nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
scan_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Scan function returns type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
scan_lam)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
scan_t
[Arg] -> TypeM rep [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
scan_nes'
[Arg]
red_nes' <- ([[Arg]] -> [Arg]) -> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Arg]] -> [Arg]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (TypeM rep [[Arg]] -> TypeM rep [Arg])
-> TypeM rep [[Arg]] -> TypeM rep [Arg]
forall a b. (a -> b) -> a -> b
$
[Reduce (Aliases rep)]
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce (Aliases rep)]
reds ((Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]])
-> (Reduce (Aliases rep) -> TypeM rep [Arg]) -> TypeM rep [[Arg]]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
_ Lambda (Aliases rep)
red_lam [SubExp]
red_nes) -> do
[Arg]
red_nes' <- (SubExp -> TypeM rep Arg) -> [SubExp] -> TypeM rep [Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> TypeM rep Arg
forall rep. Checkable rep => SubExp -> TypeM rep Arg
TC.checkArg [SubExp]
red_nes
let red_t :: [Type]
red_t = (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType [Arg]
red_nes'
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
forall rep.
Checkable rep =>
Lambda (Aliases rep) -> [Arg] -> TypeM rep ()
TC.checkLambda Lambda (Aliases rep)
red_lam ([Arg] -> TypeM rep ()) -> [Arg] -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
TC.noArgAliases ([Arg] -> [Arg]) -> [Arg] -> [Arg]
forall a b. (a -> b) -> a -> b
$ [Arg]
red_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes'
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Type]
red_t [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam) (TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Reduce function returns type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple (Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
red_lam)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" but neutral element has type "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
red_t
[Arg] -> TypeM rep [Arg]
forall (m :: * -> *) a. Monad m => a -> m a
return [Arg]
red_nes'
let map_lam_ts :: [Type]
map_lam_ts = Lambda (Aliases rep) -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda (Aliases rep)
map_lam
Bool -> TypeM rep () -> TypeM rep ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
( Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
scan_nes' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Arg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Arg]
red_nes') [Type]
map_lam_ts
[Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== (Arg -> Type) -> [Arg] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Type
TC.argType ([Arg]
scan_nes' [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [Arg]
red_nes')
)
(TypeM rep () -> TypeM rep ()) -> TypeM rep () -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$ ErrorCase rep -> TypeM rep ()
forall rep a. ErrorCase rep -> TypeM rep a
TC.bad (ErrorCase rep -> TypeM rep ()) -> ErrorCase rep -> TypeM rep ()
forall a b. (a -> b) -> a -> b
$
String -> ErrorCase rep
forall rep. String -> ErrorCase rep
TC.TypeError (String -> ErrorCase rep) -> String -> ErrorCase rep
forall a b. (a -> b) -> a -> b
$
String
"Map function return type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Type] -> String
forall a. Pretty a => [a] -> String
prettyTuple [Type]
map_lam_ts
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" wrong for given scan and reduction functions."
instance OpMetrics (Op rep) => OpMetrics (SOAC rep) where
opMetrics :: SOAC rep -> MetricsM ()
opMetrics (Stream SubExp
_ [VName]
_ StreamForm rep
_ [SubExp]
_ Lambda rep
lam) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Stream" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
opMetrics (Scatter SubExp
_len Lambda rep
lam [VName]
_ivs [(Shape, Int, VName)]
_as) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Scatter" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
lam
opMetrics (Hist SubExp
_len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
_imgs) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Hist" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> MetricsM ()) -> [HistOp rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (HistOp rep -> Lambda rep) -> HistOp rep -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. HistOp rep -> Lambda rep
forall rep. HistOp rep -> Lambda rep
histOp) [HistOp rep]
ops MetricsM () -> MetricsM () -> MetricsM ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
bucket_fun
opMetrics (Screma SubExp
_ [VName]
_ (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam)) =
Text -> MetricsM () -> MetricsM ()
inside Text
"Screma" (MetricsM () -> MetricsM ()) -> MetricsM () -> MetricsM ()
forall a b. (a -> b) -> a -> b
$ do
(Scan rep -> MetricsM ()) -> [Scan rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Scan rep -> Lambda rep) -> Scan rep -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Scan rep -> Lambda rep
forall rep. Scan rep -> Lambda rep
scanLambda) [Scan rep]
scans
(Reduce rep -> MetricsM ()) -> [Reduce rep] -> MetricsM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics (Lambda rep -> MetricsM ())
-> (Reduce rep -> Lambda rep) -> Reduce rep -> MetricsM ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Reduce rep -> Lambda rep
forall rep. Reduce rep -> Lambda rep
redLambda) [Reduce rep]
reds
Lambda rep -> MetricsM ()
forall rep. OpMetrics (Op rep) => Lambda rep -> MetricsM ()
lambdaMetrics Lambda rep
map_lam
instance PrettyRep rep => PP.Pretty (SOAC rep) where
ppr :: SOAC rep -> Doc
ppr (Stream SubExp
size [VName]
arrs StreamForm rep
form [SubExp]
acc Lambda rep
lam) =
case StreamForm rep
form of
Parallel StreamOrd
o Commutativity
comm Lambda rep
lam0 ->
let ord_str :: String
ord_str = if StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
== StreamOrd
Disorder then String
"Per" else String
""
comm_str :: String
comm_str = case Commutativity
comm of
Commutativity
Commutative -> String
"Comm"
Commutativity
Noncommutative -> String
""
in String -> Doc
text (String
"streamPar" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
ord_str String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
comm_str)
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam0 Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam
)
StreamForm rep
Sequential ->
String -> Doc
text String
"streamSeq"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
size Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [SubExp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [SubExp]
acc Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam
)
ppr (Scatter SubExp
w Lambda rep
lam [VName]
ivs [(Shape, Int, VName)]
as) =
Doc
"scatter"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ([VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
ivs Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
: ((Shape, Int, VName) -> Doc) -> [(Shape, Int, VName)] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, Int, VName) -> Doc
forall a. Pretty a => a -> Doc
ppr [(Shape, Int, VName)]
as)
)
ppr (Hist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
imgs) =
SubExp -> [HistOp rep] -> Lambda rep -> [VName] -> Doc
forall rep inp.
(PrettyRep rep, Pretty inp) =>
SubExp -> [HistOp rep] -> Lambda rep -> [inp] -> Doc
ppHist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [VName]
imgs
ppr (Screma SubExp
w [VName]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam))
| [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans,
[Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
String -> Doc
text String
"map"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
)
| [Scan rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Scan rep]
scans =
String -> Doc
text String
"redomap"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc) -> [Reduce rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce rep]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
)
| [Reduce rep] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Reduce rep]
reds =
String -> Doc
text String
"scanomap"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [VName] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [VName]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc) -> [Scan rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan rep]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
)
ppr (Screma SubExp
w [VName]
arrs ScremaForm rep
form) = SubExp -> [VName] -> ScremaForm rep -> Doc
forall rep inp.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc
ppScrema SubExp
w [VName]
arrs ScremaForm rep
form
ppScrema ::
(PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc
ppScrema :: forall rep inp.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc
ppScrema SubExp
w [inp]
arrs (ScremaForm [Scan rep]
scans [Reduce rep]
reds Lambda rep
map_lam) =
String -> Doc
text String
"screma"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [inp] -> Doc
forall a. Pretty a => [a] -> Doc
ppTuple' [inp]
arrs Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Scan rep -> Doc) -> [Scan rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Scan rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Scan rep]
scans) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (Reduce rep -> Doc) -> [Reduce rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map Reduce rep -> Doc
forall a. Pretty a => a -> Doc
ppr [Reduce rep]
reds) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
map_lam
)
instance PrettyRep rep => Pretty (Scan rep) where
ppr :: Scan rep -> Doc
ppr (Scan Lambda rep
scan_lam [SubExp]
scan_nes) =
Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
scan_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
scan_nes)
ppComm :: Commutativity -> Doc
ppComm :: Commutativity -> Doc
ppComm Commutativity
Noncommutative = Doc
forall a. Monoid a => a
mempty
ppComm Commutativity
Commutative = String -> Doc
text String
"commutative "
instance PrettyRep rep => Pretty (Reduce rep) where
ppr :: Reduce rep -> Doc
ppr (Reduce Commutativity
comm Lambda rep
red_lam [SubExp]
red_nes) =
Commutativity -> Doc
ppComm Commutativity
comm Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
red_lam Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
red_nes)
ppHist ::
(PrettyRep rep, Pretty inp) =>
SubExp ->
[HistOp rep] ->
Lambda rep ->
[inp] ->
Doc
ppHist :: forall rep inp.
(PrettyRep rep, Pretty inp) =>
SubExp -> [HistOp rep] -> Lambda rep -> [inp] -> Doc
ppHist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [inp]
imgs =
String -> Doc
text String
"hist"
Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens
( SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
len Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
forall a. Monoid a => [a] -> a
mconcat ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> [Doc] -> [Doc]
forall a. a -> [a] -> [a]
intersperse (Doc
comma Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
PP.line) ([Doc] -> [Doc]) -> [Doc] -> [Doc]
forall a b. (a -> b) -> a -> b
$ (HistOp rep -> Doc) -> [HistOp rep] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map HistOp rep -> Doc
forall {rep}. PrettyRep rep => HistOp rep -> Doc
ppOp [HistOp rep]
ops) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
bucket_fun Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> [Doc] -> Doc
commasep ((inp -> Doc) -> [inp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map inp -> Doc
forall a. Pretty a => a -> Doc
ppr [inp]
imgs)
)
where
ppOp :: HistOp rep -> Doc
ppOp (HistOp SubExp
w SubExp
rf [VName]
dests [SubExp]
nes Lambda rep
op) =
SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
w Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr SubExp
rf Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma Doc -> Doc -> Doc
<+> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (VName -> Doc) -> [VName] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Doc
forall a. Pretty a => a -> Doc
ppr [VName]
dests) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Doc -> Doc
PP.braces ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (SubExp -> Doc) -> [SubExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Doc
forall a. Pretty a => a -> Doc
ppr [SubExp]
nes) Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
comma
Doc -> Doc -> Doc
</> Lambda rep -> Doc
forall a. Pretty a => a -> Doc
ppr Lambda rep
op