{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Pass.ExtractKernels.Distribution
(
Target
, Targets
, ppTargets
, singleTarget
, outerTarget
, innerTarget
, pushInnerTarget
, popInnerTarget
, targetsScope
, LoopNesting (..)
, ppLoopNesting
, scopeOfLoopNesting
, Nesting (..)
, Nestings
, ppNestings
, letBindInInnerNesting
, singleNesting
, pushInnerNesting
, KernelNest
, ppKernelNest
, newKernel
, pushKernelNesting
, pushInnerKernelNesting
, kernelNestLoops
, kernelNestWidths
, boundInKernelNest
, boundInKernelNests
, flatKernel
, constructKernel
, tryDistribute
, tryDistributeStm
)
where
import Control.Monad.RWS.Strict
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.Maybe
import Data.List (elemIndex, sortOn)
import Futhark.Representation.AST
import Futhark.Representation.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util
import Futhark.Transform.Rename
import Futhark.Util.Log
import Futhark.Pass.ExtractKernels.BlockedKernel
(DistLore, mapKernel, KernelInput(..), readKernelInput, MkSegLevel)
type Target = (PatternT Type, Result)
data Targets = Targets { Targets -> Target
_innerTarget :: Target
, Targets -> [Target]
_outerTargets :: [Target]
}
ppTargets :: Targets -> String
ppTargets :: Targets -> String
ppTargets (Targets Target
target [Target]
targets) =
[String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Target -> String) -> [Target] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Target -> String
forall a a. (Pretty a, Pretty a) => (a, a) -> String
ppTarget ([Target] -> [String]) -> [Target] -> [String]
forall a b. (a -> b) -> a -> b
$ [Target]
targets [Target] -> [Target] -> [Target]
forall a. [a] -> [a] -> [a]
++ [Target
target]
where ppTarget :: (a, a) -> String
ppTarget (a
pat, a
res) =
a -> String
forall a. Pretty a => a -> String
pretty a
pat String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" <- " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
res
singleTarget :: Target -> Targets
singleTarget :: Target -> Targets
singleTarget = (Target -> [Target] -> Targets) -> [Target] -> Target -> Targets
forall a b c. (a -> b -> c) -> b -> a -> c
flip Target -> [Target] -> Targets
Targets []
outerTarget :: Targets -> Target
outerTarget :: Targets -> Target
outerTarget (Targets Target
inner_target []) = Target
inner_target
outerTarget (Targets Target
_ (Target
outer_target : [Target]
_)) = Target
outer_target
innerTarget :: Targets -> Target
innerTarget :: Targets -> Target
innerTarget (Targets Target
inner_target [Target]
_) = Target
inner_target
pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget Target
target (Targets Target
inner_target [Target]
targets) =
Target -> [Target] -> Targets
Targets Target
inner_target (Target
target Target -> [Target] -> [Target]
forall a. a -> [a] -> [a]
: [Target]
targets)
pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget (PatternT Type
pat, Result
res) (Targets Target
inner_target [Target]
targets) =
Target -> [Target] -> Targets
Targets (PatternT Type
pat', Result
res') ([Target]
targets [Target] -> [Target] -> [Target]
forall a. [a] -> [a] -> [a]
++ [Target
inner_target])
where ([PatElemT Type]
pes', Result
res') = [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT Type, SubExp)] -> ([PatElemT Type], Result))
-> [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElemT Type, SubExp) -> Bool)
-> [(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT Type -> Bool
used (PatElemT Type -> Bool)
-> ((PatElemT Type, SubExp) -> PatElemT Type)
-> (PatElemT Type, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT Type, SubExp) -> PatElemT Type
forall a b. (a, b) -> a
fst) ([(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)])
-> [(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)]
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> Result -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
pat) Result
res
pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type]
pes'
inner_used :: Names
inner_used = Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ Target -> Result
forall a b. (a, b) -> b
snd Target
inner_target
used :: PatElemT Type -> Bool
used PatElemT Type
pe = PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT Type
pe VName -> Names -> Bool
`nameIn` Names
inner_used
popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget (Targets Target
t [Target]
ts) =
case [Target] -> [Target]
forall a. [a] -> [a]
reverse [Target]
ts of
Target
x:[Target]
xs -> (Target, Targets) -> Maybe (Target, Targets)
forall a. a -> Maybe a
Just (Target
t, Target -> [Target] -> Targets
Targets Target
x ([Target] -> Targets) -> [Target] -> Targets
forall a b. (a -> b) -> a -> b
$ [Target] -> [Target]
forall a. [a] -> [a]
reverse [Target]
xs)
[] -> Maybe (Target, Targets)
forall a. Maybe a
Nothing
targetScope :: DistLore lore => Target -> Scope lore
targetScope :: Target -> Scope lore
targetScope = PatternT Type -> Scope lore
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern (PatternT Type -> Scope lore)
-> (Target -> PatternT Type) -> Target -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target -> PatternT Type
forall a b. (a, b) -> a
fst
targetsScope :: DistLore lore => Targets -> Scope lore
targetsScope :: Targets -> Scope lore
targetsScope (Targets Target
t [Target]
ts) = [Scope lore] -> Scope lore
forall a. Monoid a => [a] -> a
mconcat ([Scope lore] -> Scope lore) -> [Scope lore] -> Scope lore
forall a b. (a -> b) -> a -> b
$ (Target -> Scope lore) -> [Target] -> [Scope lore]
forall a b. (a -> b) -> [a] -> [b]
map Target -> Scope lore
forall lore. DistLore lore => Target -> Scope lore
targetScope ([Target] -> [Scope lore]) -> [Target] -> [Scope lore]
forall a b. (a -> b) -> a -> b
$ Target
t Target -> [Target] -> [Target]
forall a. a -> [a] -> [a]
: [Target]
ts
data LoopNesting = MapNesting { LoopNesting -> PatternT Type
loopNestingPattern :: PatternT Type
, LoopNesting -> Certificates
loopNestingCertificates :: Certificates
, LoopNesting -> SubExp
loopNestingWidth :: SubExp
, LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs :: [(Param Type, VName)]
}
deriving (Int -> LoopNesting -> String -> String
[LoopNesting] -> String -> String
LoopNesting -> String
(Int -> LoopNesting -> String -> String)
-> (LoopNesting -> String)
-> ([LoopNesting] -> String -> String)
-> Show LoopNesting
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [LoopNesting] -> String -> String
$cshowList :: [LoopNesting] -> String -> String
show :: LoopNesting -> String
$cshow :: LoopNesting -> String
showsPrec :: Int -> LoopNesting -> String -> String
$cshowsPrec :: Int -> LoopNesting -> String -> String
Show)
scopeOfLoopNesting :: DistLore lore => LoopNesting -> Scope lore
scopeOfLoopNesting :: LoopNesting -> Scope lore
scopeOfLoopNesting = [Param Type] -> Scope lore
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param Type] -> Scope lore)
-> (LoopNesting -> [Param Type]) -> LoopNesting -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs
ppLoopNesting :: LoopNesting -> String
ppLoopNesting :: LoopNesting -> String
ppLoopNesting (MapNesting PatternT Type
_ Certificates
_ SubExp
_ [(Param Type, VName)]
params_and_arrs) =
[Param Type] -> String
forall a. Pretty a => a -> String
pretty (((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs) String -> String -> String
forall a. [a] -> [a] -> [a]
++
String
" <- " String -> String -> String
forall a. [a] -> [a] -> [a]
++
[VName] -> String
forall a. Pretty a => a -> String
pretty (((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd [(Param Type, VName)]
params_and_arrs)
loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams = ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs
instance FreeIn LoopNesting where
freeIn' :: LoopNesting -> FV
freeIn' (MapNesting PatternT Type
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
PatternT Type -> FV
forall a. FreeIn a => a -> FV
freeIn' PatternT Type
pat FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
Certificates -> FV
forall a. FreeIn a => a -> FV
freeIn' Certificates
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
[(Param Type, VName)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Param Type, VName)]
params_and_arrs
data Nesting = Nesting { Nesting -> Names
nestingLetBound :: Names
, Nesting -> LoopNesting
nestingLoop :: LoopNesting
}
deriving (Int -> Nesting -> String -> String
[Nesting] -> String -> String
Nesting -> String
(Int -> Nesting -> String -> String)
-> (Nesting -> String)
-> ([Nesting] -> String -> String)
-> Show Nesting
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [Nesting] -> String -> String
$cshowList :: [Nesting] -> String -> String
show :: Nesting -> String
$cshow :: Nesting -> String
showsPrec :: Int -> Nesting -> String -> String
$cshowsPrec :: Int -> Nesting -> String -> String
Show)
letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting Names
newnames (Nesting Names
oldnames LoopNesting
loop) =
Names -> LoopNesting -> Nesting
Nesting (Names
oldnames Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
newnames) LoopNesting
loop
type Nestings = (Nesting, [Nesting])
ppNestings :: Nestings -> String
ppNestings :: Nestings -> String
ppNestings (Nesting
nesting, [Nesting]
nestings) =
[String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Nesting -> String) -> [Nesting] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> String
ppNesting ([Nesting] -> [String]) -> [Nesting] -> [String]
forall a b. (a -> b) -> a -> b
$ [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
nesting]
where ppNesting :: Nesting -> String
ppNesting (Nesting Names
_ LoopNesting
loop) =
LoopNesting -> String
ppLoopNesting LoopNesting
loop
singleNesting :: Nesting -> Nestings
singleNesting :: Nesting -> Nestings
singleNesting = (,[])
pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nesting (Nesting
inner_nesting, [Nesting]
nestings) =
(Nesting
nesting, [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
inner_nesting])
boundInNesting :: Nesting -> Names
boundInNesting :: Nesting -> Names
boundInNesting Nesting
nesting =
[VName] -> Names
namesFromList ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName (LoopNesting -> [Param Type]
loopNestingParams LoopNesting
loop)) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
Nesting -> Names
nestingLetBound Nesting
nesting
where loop :: LoopNesting
loop = Nesting -> LoopNesting
nestingLoop Nesting
nesting
letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting Names
names (Nesting
nest, [Nesting]
nestings) =
(Names -> Nesting -> Nesting
letBindInNesting Names
names Nesting
nest, [Nesting]
nestings)
type KernelNest = (LoopNesting, [LoopNesting])
ppKernelNest :: KernelNest -> String
ppKernelNest :: KernelNest -> String
ppKernelNest (LoopNesting
nesting, [LoopNesting]
nestings) =
[String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> String) -> [LoopNesting] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> String
ppLoopNesting ([LoopNesting] -> [String]) -> [LoopNesting] -> [String]
forall a b. (a -> b) -> a -> b
$ LoopNesting
nesting LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nestings
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
(LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder LoopNesting
newnest Target
target (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest),
LoopNesting
nest LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests)
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
(LoopNesting
nest, [LoopNesting]
nests [LoopNesting] -> [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a] -> [a]
++ [LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder LoopNesting
newnest Target
target (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
innermost)])
where innermost :: LoopNesting
innermost = case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
[] -> LoopNesting
nest
LoopNesting
n:[LoopNesting]
_ -> LoopNesting
n
fixNestingPatternOrder :: LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder :: LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder LoopNesting
nest (PatternT Type
_,Result
res) PatternT Type
inner_pat =
LoopNesting
nest { loopNestingPattern :: PatternT Type
loopNestingPattern = [Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident]
pat' }
where pat :: PatternT Type
pat = LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest
pat' :: [Ident]
pat' = ((Ident, SubExp) -> Ident) -> [(Ident, SubExp)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst [(Ident, SubExp)]
fixed_target
fixed_target :: [(Ident, SubExp)]
fixed_target = ((Ident, SubExp) -> Int) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Ident, SubExp) -> Int
posInInnerPat ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> Result -> [(Ident, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternValueIdents PatternT Type
pat) Result
res
posInInnerPat :: (Ident, SubExp) -> Int
posInInnerPat (Ident
_, Var VName
v) = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex VName
v ([VName] -> Maybe Int) -> [VName] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
inner_pat
posInInnerPat (Ident, SubExp)
_ = Int
0
newKernel :: LoopNesting -> KernelNest
newKernel :: LoopNesting -> KernelNest
newKernel LoopNesting
nest = (LoopNesting
nest, [])
kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops (LoopNesting
loop, [LoopNesting]
loops) = LoopNesting
loop LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
loops
boundInKernelNest :: KernelNest -> Names
boundInKernelNest :: KernelNest -> Names
boundInKernelNest = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names)
-> (KernelNest -> [Names]) -> KernelNest -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [Names]
boundInKernelNests
boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests = (LoopNesting -> Names) -> [LoopNesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Names
namesFromList ([VName] -> Names)
-> (LoopNesting -> [VName]) -> LoopNesting -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [VName])
-> (LoopNesting -> [(Param Type, VName)]) -> LoopNesting -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs) ([LoopNesting] -> [Names])
-> (KernelNest -> [LoopNesting]) -> KernelNest -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
KernelNest -> [LoopNesting]
kernelNestLoops
kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths :: KernelNest -> Result
kernelNestWidths = (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth ([LoopNesting] -> Result)
-> (KernelNest -> [LoopNesting]) -> KernelNest -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops
constructKernel :: (DistLore lore, MonadFreshNames m, LocalScope lore m) =>
MkSegLevel lore m -> KernelNest -> Body lore
-> m (Stm lore, Stms lore)
constructKernel :: MkSegLevel lore m
-> KernelNest -> Body lore -> m (Stm lore, Stms lore)
constructKernel MkSegLevel lore m
mk_lvl KernelNest
kernel_nest Body lore
inner_body = BinderT lore m (Stm lore) -> m (Stm lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT lore m (Stm lore) -> m (Stm lore, Stms lore))
-> BinderT lore m (Stm lore) -> m (Stm lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ do
([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> BinderT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
kernel_nest
let cs :: Certificates
cs = LoopNesting -> Certificates
loopNestingCertificates LoopNesting
first_nest
ispace_scope :: Map VName (NameInfo lore)
ispace_scope = [(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Map VName (NameInfo lore))
-> [(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace) ([NameInfo lore] -> [(VName, NameInfo lore)])
-> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> [NameInfo lore]
forall a. a -> [a]
repeat (NameInfo lore -> [NameInfo lore])
-> NameInfo lore -> [NameInfo lore]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexInfo IntType
Int32
pat :: PatternT Type
pat = LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
first_nest
rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
pat
KernelBody lore
inner_body' <- (([KernelResult], Stms lore) -> KernelBody lore)
-> BinderT lore m ([KernelResult], Stms lore)
-> BinderT lore m (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms lore -> KernelBody lore)
-> ([KernelResult], Stms lore) -> KernelBody lore
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms lore -> [KernelResult] -> KernelBody lore)
-> [KernelResult] -> Stms lore -> KernelBody lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (BinderT lore m ([KernelResult], Stms lore)
-> BinderT lore m (KernelBody lore))
-> BinderT lore m ([KernelResult], Stms lore)
-> BinderT lore m (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ Binder lore [KernelResult]
-> BinderT lore m ([KernelResult], Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore [KernelResult]
-> BinderT lore m ([KernelResult], Stms lore))
-> Binder lore [KernelResult]
-> BinderT lore m ([KernelResult], Stms lore)
forall a b. (a -> b) -> a -> b
$
Map VName (NameInfo lore)
-> Binder lore [KernelResult] -> Binder lore [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Map VName (NameInfo lore)
ispace_scope (Binder lore [KernelResult] -> Binder lore [KernelResult])
-> Binder lore [KernelResult] -> Binder lore [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
(KernelInput -> BinderT lore (State VNameSource) ())
-> [KernelInput] -> BinderT lore (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT lore (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput ([KernelInput] -> BinderT lore (State VNameSource) ())
-> [KernelInput] -> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
(SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) (Result -> [KernelResult])
-> BinderT lore (State VNameSource) Result
-> Binder lore [KernelResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind Body lore
Body (Lore (BinderT lore (State VNameSource)))
inner_body
(SegOp (SegOpLevel lore) lore
segop, Stms lore
aux_stms) <- m (SegOp (SegOpLevel lore) lore, Stms lore)
-> BinderT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SegOp (SegOpLevel lore) lore, Stms lore)
-> BinderT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
-> BinderT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel MkSegLevel lore m
mk_lvl [(VName, SubExp)]
ispace [] [Type]
rts KernelBody lore
inner_body'
Stms (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore m))
aux_stms
Stm lore -> BinderT lore m (Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm lore -> BinderT lore m (Stm lore))
-> Stm lore -> BinderT lore m (Stm lore)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
segop
where
first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
kernel_nest
inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` Body lore -> Names
forall a. FreeIn a => a -> Names
freeIn Body lore
inner_body
flatKernel :: MonadFreshNames m =>
KernelNest
-> m ([(VName, SubExp)],
[KernelInput])
flatKernel :: KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (MapNesting PatternT Type
_ Certificates
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, []) = do
VName
i <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
let inps :: [KernelInput]
inps = [ VName -> Type -> VName -> Result -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i] |
(Param VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs ]
([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VName
i,SubExp
nesting_w)], [KernelInput]
inps)
flatKernel (MapNesting PatternT Type
_ Certificates
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, LoopNesting
nest : [LoopNesting]
nests) = do
VName
i <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (LoopNesting
nest, [LoopNesting]
nests)
let inps' :: [KernelInput]
inps' = (KernelInput -> KernelInput) -> [KernelInput] -> [KernelInput]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> KernelInput
fixupInput [KernelInput]
inps
isParam :: KernelInput -> Maybe VName
isParam KernelInput
inp =
(Param Type, VName) -> VName
forall a b. (a, b) -> b
snd ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==KernelInput -> VName
kernelInputArray KernelInput
inp) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
fixupInput :: KernelInput -> KernelInput
fixupInput KernelInput
inp
| Just VName
arr <- KernelInput -> Maybe VName
isParam KernelInput
inp =
KernelInput
inp { kernelInputArray :: VName
kernelInputArray = VName
arr
, kernelInputIndices :: Result
kernelInputIndices = VName -> SubExp
Var VName
i SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: KernelInput -> Result
kernelInputIndices KernelInput
inp }
| Bool
otherwise =
KernelInput
inp
([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => a -> m a
return ((VName
i, SubExp
nesting_w) (VName, SubExp) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. a -> [a] -> [a]
: [(VName, SubExp)]
ispace, VName -> [KernelInput]
extra_inps VName
i [KernelInput] -> [KernelInput] -> [KernelInput]
forall a. Semigroup a => a -> a -> a
<> [KernelInput]
inps')
where extra_inps :: VName -> [KernelInput]
extra_inps VName
i =
[ VName -> Type -> VName -> Result -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i] |
(Param VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs ]
data DistributionBody = DistributionBody {
DistributionBody -> Targets
distributionTarget :: Targets
, DistributionBody -> Names
distributionFreeInBody :: Names
, DistributionBody -> Map VName Ident
distributionIdentityMap :: M.Map VName Ident
, DistributionBody -> Target -> Target
distributionExpandTarget :: Target -> Target
}
distributionInnerPattern :: DistributionBody -> PatternT Type
distributionInnerPattern :: DistributionBody -> PatternT Type
distributionInnerPattern = Target -> PatternT Type
forall a b. (a, b) -> a
fst (Target -> PatternT Type)
-> (DistributionBody -> Target)
-> DistributionBody
-> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> Target
innerTarget (Targets -> Target)
-> (DistributionBody -> Targets) -> DistributionBody -> Target
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Targets
distributionTarget
distributionBodyFromStms :: Attributes lore =>
Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms :: Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms (Targets (PatternT Type
inner_pat, Result
inner_res) [Target]
targets) Stms lore
stms =
let bound_by_stms :: Names
bound_by_stms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo lore) -> [VName])
-> Map VName (NameInfo lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms
(PatternT Type
inner_pat', Result
inner_res', Map VName Ident
inner_identity_map, Target -> Target
inner_expand_target) =
Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_by_stms PatternT Type
inner_pat Result
inner_res
in (DistributionBody :: Targets
-> Names
-> Map VName Ident
-> (Target -> Target)
-> DistributionBody
DistributionBody
{ distributionTarget :: Targets
distributionTarget = Target -> [Target] -> Targets
Targets (PatternT Type
inner_pat', Result
inner_res') [Target]
targets
, distributionFreeInBody :: Names
distributionFreeInBody = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms lore
stms Names -> Names -> Names
`namesSubtract` Names
bound_by_stms
, distributionIdentityMap :: Map VName Ident
distributionIdentityMap = Map VName Ident
inner_identity_map
, distributionExpandTarget :: Target -> Target
distributionExpandTarget = Target -> Target
inner_expand_target
},
Result
inner_res')
distributionBodyFromStm :: Attributes lore =>
Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm :: Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm lore
bnd =
Targets -> Stms lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets (Stms lore -> (DistributionBody, Result))
-> Stms lore -> (DistributionBody, Result)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
bnd
createKernelNest :: (MonadFreshNames m, HasScope t m) =>
Nestings
-> DistributionBody
-> m (Maybe (Targets, KernelNest))
createKernelNest :: Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest (Nesting
inner_nest, [Nesting]
nests) DistributionBody
distrib_body = do
let Targets Target
target [Target]
targets = DistributionBody -> Targets
distributionTarget DistributionBody
distrib_body
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Nesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting]
nests Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Target] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Target]
targets) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
String -> m ()
forall a. HasCallStack => String -> a
error (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"Nests and targets do not match!\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++
String
"nests: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Nestings -> String
ppNestings (Nesting
inner_nest, [Nesting]
nests) String -> String -> String
forall a. [a] -> [a] -> [a]
++
String
"\ntargets:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets (Target -> [Target] -> Targets
Targets Target
target [Target]
targets)
MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest)))
-> MaybeT m (Targets, KernelNest)
-> m (Maybe (Targets, KernelNest))
forall a b. (a -> b) -> a -> b
$ ((KernelNest, Names, Targets) -> (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (KernelNest, Names, Targets) -> (Targets, KernelNest)
forall b b a. (b, b, a) -> (a, b)
prepare (MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall a b. (a -> b) -> a -> b
$ [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
[(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse ([(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets))
-> [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$ [Nesting] -> [Target] -> [(Nesting, Target)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Nesting]
nests [Target]
targets
where prepare :: (b, b, a) -> (a, b)
prepare (b
x, b
_, a
z) = (a
z, b
x)
bound_in_nest :: Names
bound_in_nest =
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Nesting -> Names) -> [Nesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> Names
boundInNesting ([Nesting] -> [Names]) -> [Nesting] -> [Names]
forall a b. (a -> b) -> a -> b
$ Nesting
inner_nest Nesting -> [Nesting] -> [Nesting]
forall a. a -> [a] -> [a]
: [Nesting]
nests
distributableType :: Type -> Bool
distributableType =
(Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
==Names
forall a. Monoid a => a
mempty) (Names -> Bool) -> (Type -> Names) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Names -> Names
namesIntersection Names
bound_in_nest (Names -> Names) -> (Type -> Names) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> (Type -> Result) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims
distributeAtNesting :: (HasScope t m, MonadFreshNames m) =>
Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> M.Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting :: Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
(Nesting Names
nest_let_bound LoopNesting
nest)
PatternT Type
pat
(LoopNesting -> KernelNest
add_to_kernel, Names
free_in_kernel)
Map VName Ident
identity_map
[Ident]
inner_returned_arrs
Target -> Targets
addTarget = do
let nest' :: LoopNesting
nest'@(MapNesting PatternT Type
_ Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel LoopNesting
nest
([Param Type]
params,[VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName [Param Type]
params
free_in_kernel' :: Names
free_in_kernel' =
(LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
param_names
required_from_nest :: Names
required_from_nest =
Names
free_in_kernel' Names -> Names -> Names
`namesIntersection` Names
nest_let_bound
[Ident]
required_from_nest_idents <-
[VName] -> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
required_from_nest) ((VName -> MaybeT m Ident) -> MaybeT m [Ident])
-> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall a b. (a -> b) -> a -> b
$ \VName
name -> do
Type
t <- m Type -> MaybeT m Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Type -> MaybeT m Type) -> m Type -> MaybeT m Type
forall a b. (a -> b) -> a -> b
$ VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
name
Ident -> MaybeT m Ident
forall (m :: * -> *) a. Monad m => a -> m a
return (Ident -> MaybeT m Ident) -> Ident -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
([Param Type]
free_params, [Ident]
free_arrs, [Bool]
bind_in_target) <-
([(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall a b. (a -> b) -> a -> b
$
[Ident]
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
inner_returned_arrs[Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++[Ident]
required_from_nest_idents) ((Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)])
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall a b. (a -> b) -> a -> b
$
\(Ident VName
pname Type
ptype) ->
case VName -> Map VName Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pname Map VName Ident
identity_map of
Maybe Ident
Nothing -> do
Ident
arr <- String -> Type -> MaybeT m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (VName -> String
baseString VName
pname String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_r") (Type -> MaybeT m Ident) -> Type -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$
Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
ptype SubExp
w
(Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param VName
pname Type
ptype,
Ident
arr,
Bool
True)
Just Ident
arr ->
(Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param VName
pname Type
ptype,
Ident
arr,
Bool
False)
let free_arrs_pat :: PatternT Type
free_arrs_pat =
[Ident] -> [Ident] -> PatternT Type
basicPattern [] ([Ident] -> PatternT Type) -> [Ident] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ ((Bool, Ident) -> Ident) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Ident) -> Ident
forall a b. (a, b) -> b
snd ([(Bool, Ident)] -> [Ident]) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> a -> b
$
((Bool, Ident) -> Bool) -> [(Bool, Ident)] -> [(Bool, Ident)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Ident) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Ident)] -> [(Bool, Ident)])
-> [(Bool, Ident)] -> [(Bool, Ident)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Ident] -> [(Bool, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Ident]
free_arrs
free_params_pat :: [Param Type]
free_params_pat =
((Bool, Param Type) -> Param Type)
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Param Type) -> Param Type
forall a b. (a, b) -> b
snd ([(Bool, Param Type)] -> [Param Type])
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ ((Bool, Param Type) -> Bool)
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Param Type) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Param Type)] -> [(Bool, Param Type)])
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Param Type] -> [(Bool, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Param Type]
free_params
([Param Type]
actual_params, [VName]
actual_arrs) =
([Param Type]
params[Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++[Param Type]
free_params,
[VName]
arrs[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++(Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
free_arrs)
actual_param_names :: Names
actual_param_names =
[VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName [Param Type]
actual_params
nest'' :: LoopNesting
nest'' =
Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel (LoopNesting -> LoopNesting) -> LoopNesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
actual_params [VName]
actual_arrs
free_in_kernel'' :: Names
free_in_kernel'' =
(LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest'' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
actual_param_names
Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
distributableType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$
LoopNesting -> [Param Type]
loopNestingParams LoopNesting
nest'') (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Would induce irregular array"
(KernelNest, Names, Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting -> KernelNest
add_to_kernel LoopNesting
nest'',
Names
free_in_kernel'',
Target -> Targets
addTarget (PatternT Type
free_arrs_pat, (Param Type -> SubExp) -> [Param Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName) [Param Type]
free_params_pat))
recurse :: (HasScope t m, MonadFreshNames m) =>
[(Nesting,Target)]
-> MaybeT m (KernelNest, Names, Targets)
recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [] =
Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
Nesting
inner_nest
(DistributionBody -> PatternT Type
distributionInnerPattern DistributionBody
distrib_body)
(LoopNesting -> KernelNest
newKernel,
DistributionBody -> Names
distributionFreeInBody DistributionBody
distrib_body Names -> Names -> Names
`namesIntersection` Names
bound_in_nest)
(DistributionBody -> Map VName Ident
distributionIdentityMap DistributionBody
distrib_body)
[] ((Target -> Targets) -> MaybeT m (KernelNest, Names, Targets))
-> (Target -> Targets) -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$
Target -> Targets
singleTarget (Target -> Targets) -> (Target -> Target) -> Target -> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Target -> Target
distributionExpandTarget DistributionBody
distrib_body
recurse ((Nesting
nest, (PatternT Type
pat,Result
res)) : [(Nesting, Target)]
nests') = do
(kernel :: KernelNest
kernel@(LoopNesting
outer, [LoopNesting]
_), Names
kernel_free, Targets
kernel_targets) <- [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
[(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [(Nesting, Target)]
nests'
let (PatternT Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting
([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
outer) PatternT Type
pat Result
res
Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
Nesting
nest
PatternT Type
pat'
(\LoopNesting
k -> Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting (PatternT Type
pat',Result
res') LoopNesting
k KernelNest
kernel,
Names
kernel_free)
Map VName Ident
identity_map
(PatternT Type -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternIdents (PatternT Type -> [Ident]) -> PatternT Type -> [Ident]
forall a b. (a -> b) -> a -> b
$ Target -> PatternT Type
forall a b. (a, b) -> a
fst (Target -> PatternT Type) -> Target -> PatternT Type
forall a b. (a -> b) -> a -> b
$ Targets -> Target
outerTarget Targets
kernel_targets)
((Target -> Targets -> Targets
`pushOuterTarget` Targets
kernel_targets) (Target -> Targets) -> (Target -> Target) -> Target -> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target -> Target
expand_target)
removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
used (MapNesting PatternT Type
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
used_params [VName]
used_arrs
where ([Param Type]
params,[VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
([Param Type]
used_params, [VName]
used_arrs) =
[(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [(Param Type, VName)])
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a b. (a -> b) -> a -> b
$
[Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
params [VName]
arrs
removeIdentityMappingGeneral :: Names -> PatternT Type -> Result
-> (PatternT Type,
Result,
M.Map VName Ident,
Target -> Target)
removeIdentityMappingGeneral :: Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound PatternT Type
pat Result
res =
let ([(PatElemT Type, VName)]
identities, [(PatElemT Type, SubExp)]
not_identities) =
((PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp))
-> [(PatElemT Type, SubExp)]
-> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)])
forall a b c. (a -> Either b c) -> [a] -> ([b], [c])
mapEither (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
isIdentity ([(PatElemT Type, SubExp)]
-> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)]))
-> [(PatElemT Type, SubExp)]
-> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)])
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> Result -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
pat) Result
res
([PatElemT Type]
not_identity_patElems, Result
not_identity_res) = [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT Type, SubExp)]
not_identities
([PatElemT Type]
identity_patElems, [VName]
identity_res) = [(PatElemT Type, VName)] -> ([PatElemT Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT Type, VName)]
identities
expandTarget :: Target -> Target
expandTarget (PatternT Type
tpat, Result
tres) =
([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
tpat [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
identity_patElems,
Result
tres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
identity_res)
identity_map :: Map VName Ident
identity_map = [(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Ident)] -> Map VName Ident)
-> [(VName, Ident)] -> Map VName Ident
forall a b. (a -> b) -> a -> b
$ [VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
identity_res ([Ident] -> [(VName, Ident)]) -> [Ident] -> [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$
(PatElemT Type -> Ident) -> [PatElemT Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> Ident
forall attr. Typed attr => PatElemT attr -> Ident
patElemIdent [PatElemT Type]
identity_patElems
in ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type]
not_identity_patElems,
Result
not_identity_res,
Map VName Ident
identity_map,
Target -> Target
expandTarget)
where isIdentity :: (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
isIdentity (PatElemT Type
patElem, Var VName
v)
| Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
bound) = (PatElemT Type, VName)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
forall a b. a -> Either a b
Left (PatElemT Type
patElem, VName
v)
isIdentity (PatElemT Type, SubExp)
x = (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
forall a b. b -> Either a b
Right (PatElemT Type, SubExp)
x
removeIdentityMappingFromNesting :: Names -> PatternT Type -> Result
-> (PatternT Type,
Result,
M.Map VName Ident,
Target -> Target)
removeIdentityMappingFromNesting :: Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting Names
bound_in_nesting PatternT Type
pat Result
res =
let (PatternT Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_in_nesting PatternT Type
pat Result
res
in (PatternT Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target)
tryDistribute :: (DistLore lore, MonadFreshNames m,
LocalScope lore m, MonadLogger m) =>
MkSegLevel lore m -> Nestings -> Targets -> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute :: MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute MkSegLevel lore m
_ Nestings
_ Targets
targets Stms lore
stms | Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
stms =
Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore)))
-> Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms lore) -> Maybe (Targets, Stms lore)
forall a. a -> Maybe a
Just (Targets
targets, Stms lore
forall a. Monoid a => a
mempty)
tryDistribute MkSegLevel lore m
mk_lvl Nestings
nest Targets
targets Stms lore
stms =
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) t.
(MonadFreshNames m, HasScope t m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body m (Maybe (Targets, KernelNest))
-> (Maybe (Targets, KernelNest) -> m (Maybe (Targets, Stms lore)))
-> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
\case
Just (Targets
targets', KernelNest
distributed) -> do
(Stm lore
kernel_bnd, Stms lore
w_bnds) <-
Scope lore -> m (Stm lore, Stms lore) -> m (Stm lore, Stms lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Targets -> Scope lore
forall lore. DistLore lore => Targets -> Scope lore
targetsScope Targets
targets') (m (Stm lore, Stms lore) -> m (Stm lore, Stms lore))
-> m (Stm lore, Stms lore) -> m (Stm lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
MkSegLevel lore m
-> KernelNest -> Body lore -> m (Stm lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m) =>
MkSegLevel lore m
-> KernelNest -> Body lore -> m (Stm lore, Stms lore)
constructKernel MkSegLevel lore m
mk_lvl KernelNest
distributed (Body lore -> m (Stm lore, Stms lore))
-> Body lore -> m (Stm lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Stms lore -> Result -> Body lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms lore
stms Result
inner_body_res
Stm lore
distributed' <- Stm lore -> m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm Stm lore
kernel_bnd
String -> m ()
forall (m :: * -> *) a. (MonadLogger m, ToLog a) => a -> m ()
logMsg (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"distributing\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++
[String] -> String
unlines ((Stm lore -> String) -> [Stm lore] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> String
forall a. Pretty a => a -> String
pretty ([Stm lore] -> [String]) -> [Stm lore] -> [String]
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) String -> String -> String
forall a. [a] -> [a] -> [a]
++
Result -> String
forall a. Pretty a => a -> String
pretty (Target -> Result
forall a b. (a, b) -> b
snd (Target -> Result) -> Target -> Result
forall a b. (a -> b) -> a -> b
$ Targets -> Target
innerTarget Targets
targets) String -> String -> String
forall a. [a] -> [a] -> [a]
++
String
"\nas\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm lore -> String
forall a. Pretty a => a -> String
pretty Stm lore
distributed' String -> String -> String
forall a. [a] -> [a] -> [a]
++
String
"\ndue to targets\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets Targets
targets String -> String -> String
forall a. [a] -> [a] -> [a]
++
String
"\nand with new targets\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets Targets
targets'
Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore)))
-> Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms lore) -> Maybe (Targets, Stms lore)
forall a. a -> Maybe a
Just (Targets
targets', Stms lore
w_bnds Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
distributed')
Maybe (Targets, KernelNest)
Nothing ->
Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Targets, Stms lore)
forall a. Maybe a
Nothing
where (DistributionBody
dist_body, Result
inner_body_res) = Targets -> Stms lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets Stms lore
stms
tryDistributeStm :: (MonadFreshNames m, HasScope t m, Attributes lore) =>
Nestings -> Targets -> Stm lore
-> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm :: Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm lore
bnd =
((Targets, KernelNest) -> (Result, Targets, KernelNest))
-> Maybe (Targets, KernelNest)
-> Maybe (Result, Targets, KernelNest)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Maybe (Targets, KernelNest)
-> Maybe (Result, Targets, KernelNest))
-> m (Maybe (Targets, KernelNest))
-> m (Maybe (Result, Targets, KernelNest))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) t.
(MonadFreshNames m, HasScope t m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body
where (DistributionBody
dist_body, Result
res) = Targets -> Stm lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm lore
bnd
addRes :: (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Targets
targets', KernelNest
kernel_nest) = (Result
res, Targets
targets', KernelNest
kernel_nest)