{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Pass.ExtractKernels.Distribution
       (
         Target
       , Targets
       , ppTargets
       , singleTarget
       , outerTarget
       , innerTarget
       , pushInnerTarget
       , popInnerTarget
       , targetsScope

       , LoopNesting (..)
       , ppLoopNesting
       , scopeOfLoopNesting

       , Nesting (..)
       , Nestings
       , ppNestings
       , letBindInInnerNesting
       , singleNesting
       , pushInnerNesting

       , KernelNest
       , ppKernelNest
       , newKernel
       , pushKernelNesting
       , pushInnerKernelNesting
       , kernelNestLoops
       , kernelNestWidths
       , boundInKernelNest
       , boundInKernelNests
       , flatKernel
       , constructKernel

       , tryDistribute
       , tryDistributeStm
       )
       where

import Control.Monad.RWS.Strict
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.Maybe
import Data.List (elemIndex, sortOn)

import Futhark.Representation.AST
import Futhark.Representation.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util
import Futhark.Transform.Rename
import Futhark.Util.Log
import Futhark.Pass.ExtractKernels.BlockedKernel
  (DistLore, mapKernel, KernelInput(..), readKernelInput, MkSegLevel)

type Target = (PatternT Type, Result)

-- | First pair element is the very innermost ("current") target.  In
-- the list, the outermost target comes first.  Invariant: Every
-- element of a pattern must be present as the result of the
-- immediately enclosing target.  This is ensured by 'pushInnerTarget'
-- by removing unused pattern elements.
data Targets = Targets { Targets -> Target
_innerTarget :: Target
                       , Targets -> [Target]
_outerTargets :: [Target]
                       }

ppTargets :: Targets -> String
ppTargets :: Targets -> String
ppTargets (Targets Target
target [Target]
targets) =
  [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Target -> String) -> [Target] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Target -> String
forall a a. (Pretty a, Pretty a) => (a, a) -> String
ppTarget ([Target] -> [String]) -> [Target] -> [String]
forall a b. (a -> b) -> a -> b
$ [Target]
targets [Target] -> [Target] -> [Target]
forall a. [a] -> [a] -> [a]
++ [Target
target]
  where ppTarget :: (a, a) -> String
ppTarget (a
pat, a
res) =
          a -> String
forall a. Pretty a => a -> String
pretty a
pat String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" <- " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
res

singleTarget :: Target -> Targets
singleTarget :: Target -> Targets
singleTarget = (Target -> [Target] -> Targets) -> [Target] -> Target -> Targets
forall a b c. (a -> b -> c) -> b -> a -> c
flip Target -> [Target] -> Targets
Targets []

outerTarget :: Targets -> Target
outerTarget :: Targets -> Target
outerTarget (Targets Target
inner_target []) = Target
inner_target
outerTarget (Targets Target
_ (Target
outer_target : [Target]
_)) = Target
outer_target

innerTarget :: Targets -> Target
innerTarget :: Targets -> Target
innerTarget (Targets Target
inner_target [Target]
_) = Target
inner_target

pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget Target
target (Targets Target
inner_target [Target]
targets) =
  Target -> [Target] -> Targets
Targets Target
inner_target (Target
target Target -> [Target] -> [Target]
forall a. a -> [a] -> [a]
: [Target]
targets)

pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget (PatternT Type
pat, Result
res) (Targets Target
inner_target [Target]
targets) =
  Target -> [Target] -> Targets
Targets (PatternT Type
pat', Result
res') ([Target]
targets [Target] -> [Target] -> [Target]
forall a. [a] -> [a] -> [a]
++ [Target
inner_target])
  where ([PatElemT Type]
pes', Result
res') = [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT Type, SubExp)] -> ([PatElemT Type], Result))
-> [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElemT Type, SubExp) -> Bool)
-> [(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT Type -> Bool
used (PatElemT Type -> Bool)
-> ((PatElemT Type, SubExp) -> PatElemT Type)
-> (PatElemT Type, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT Type, SubExp) -> PatElemT Type
forall a b. (a, b) -> a
fst) ([(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)])
-> [(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)]
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> Result -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
pat) Result
res
        pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type]
pes'
        inner_used :: Names
inner_used = Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ Target -> Result
forall a b. (a, b) -> b
snd Target
inner_target
        used :: PatElemT Type -> Bool
used PatElemT Type
pe = PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT Type
pe VName -> Names -> Bool
`nameIn` Names
inner_used

popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget (Targets Target
t [Target]
ts) =
  case [Target] -> [Target]
forall a. [a] -> [a]
reverse [Target]
ts of
    Target
x:[Target]
xs -> (Target, Targets) -> Maybe (Target, Targets)
forall a. a -> Maybe a
Just (Target
t, Target -> [Target] -> Targets
Targets Target
x ([Target] -> Targets) -> [Target] -> Targets
forall a b. (a -> b) -> a -> b
$ [Target] -> [Target]
forall a. [a] -> [a]
reverse [Target]
xs)
    []   -> Maybe (Target, Targets)
forall a. Maybe a
Nothing

targetScope :: DistLore lore => Target -> Scope lore
targetScope :: Target -> Scope lore
targetScope = PatternT Type -> Scope lore
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern (PatternT Type -> Scope lore)
-> (Target -> PatternT Type) -> Target -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target -> PatternT Type
forall a b. (a, b) -> a
fst

targetsScope :: DistLore lore => Targets -> Scope lore
targetsScope :: Targets -> Scope lore
targetsScope (Targets Target
t [Target]
ts) = [Scope lore] -> Scope lore
forall a. Monoid a => [a] -> a
mconcat ([Scope lore] -> Scope lore) -> [Scope lore] -> Scope lore
forall a b. (a -> b) -> a -> b
$ (Target -> Scope lore) -> [Target] -> [Scope lore]
forall a b. (a -> b) -> [a] -> [b]
map Target -> Scope lore
forall lore. DistLore lore => Target -> Scope lore
targetScope ([Target] -> [Scope lore]) -> [Target] -> [Scope lore]
forall a b. (a -> b) -> a -> b
$ Target
t Target -> [Target] -> [Target]
forall a. a -> [a] -> [a]
: [Target]
ts

data LoopNesting = MapNesting { LoopNesting -> PatternT Type
loopNestingPattern :: PatternT Type
                              , LoopNesting -> Certificates
loopNestingCertificates :: Certificates
                              , LoopNesting -> SubExp
loopNestingWidth :: SubExp
                              , LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs :: [(Param Type, VName)]
                              }
                 deriving (Int -> LoopNesting -> String -> String
[LoopNesting] -> String -> String
LoopNesting -> String
(Int -> LoopNesting -> String -> String)
-> (LoopNesting -> String)
-> ([LoopNesting] -> String -> String)
-> Show LoopNesting
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [LoopNesting] -> String -> String
$cshowList :: [LoopNesting] -> String -> String
show :: LoopNesting -> String
$cshow :: LoopNesting -> String
showsPrec :: Int -> LoopNesting -> String -> String
$cshowsPrec :: Int -> LoopNesting -> String -> String
Show)

scopeOfLoopNesting :: DistLore lore => LoopNesting -> Scope lore
scopeOfLoopNesting :: LoopNesting -> Scope lore
scopeOfLoopNesting = [Param Type] -> Scope lore
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param Type] -> Scope lore)
-> (LoopNesting -> [Param Type]) -> LoopNesting -> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs

ppLoopNesting :: LoopNesting -> String
ppLoopNesting :: LoopNesting -> String
ppLoopNesting (MapNesting PatternT Type
_ Certificates
_ SubExp
_ [(Param Type, VName)]
params_and_arrs) =
  [Param Type] -> String
forall a. Pretty a => a -> String
pretty (((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs) String -> String -> String
forall a. [a] -> [a] -> [a]
++
  String
" <- " String -> String -> String
forall a. [a] -> [a] -> [a]
++
  [VName] -> String
forall a. Pretty a => a -> String
pretty (((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd [(Param Type, VName)]
params_and_arrs)

loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams  = ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs

instance FreeIn LoopNesting where
  freeIn' :: LoopNesting -> FV
freeIn' (MapNesting PatternT Type
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
    PatternT Type -> FV
forall a. FreeIn a => a -> FV
freeIn' PatternT Type
pat FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
    Certificates -> FV
forall a. FreeIn a => a -> FV
freeIn' Certificates
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
    SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
    [(Param Type, VName)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Param Type, VName)]
params_and_arrs

data Nesting = Nesting { Nesting -> Names
nestingLetBound :: Names
                       , Nesting -> LoopNesting
nestingLoop :: LoopNesting
                       }
             deriving (Int -> Nesting -> String -> String
[Nesting] -> String -> String
Nesting -> String
(Int -> Nesting -> String -> String)
-> (Nesting -> String)
-> ([Nesting] -> String -> String)
-> Show Nesting
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [Nesting] -> String -> String
$cshowList :: [Nesting] -> String -> String
show :: Nesting -> String
$cshow :: Nesting -> String
showsPrec :: Int -> Nesting -> String -> String
$cshowsPrec :: Int -> Nesting -> String -> String
Show)

letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting Names
newnames (Nesting Names
oldnames LoopNesting
loop) =
  Names -> LoopNesting -> Nesting
Nesting (Names
oldnames Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
newnames) LoopNesting
loop

-- ^ First pair element is the very innermost ("current") nest.  In
-- the list, the outermost nest comes first.
type Nestings = (Nesting, [Nesting])

ppNestings :: Nestings -> String
ppNestings :: Nestings -> String
ppNestings (Nesting
nesting, [Nesting]
nestings) =
  [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Nesting -> String) -> [Nesting] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> String
ppNesting ([Nesting] -> [String]) -> [Nesting] -> [String]
forall a b. (a -> b) -> a -> b
$ [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
nesting]
  where ppNesting :: Nesting -> String
ppNesting (Nesting Names
_ LoopNesting
loop) =
          LoopNesting -> String
ppLoopNesting LoopNesting
loop

singleNesting :: Nesting -> Nestings
singleNesting :: Nesting -> Nestings
singleNesting = (,[])

pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nesting (Nesting
inner_nesting, [Nesting]
nestings) =
  (Nesting
nesting, [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
inner_nesting])

-- | Both parameters and let-bound.
boundInNesting :: Nesting -> Names
boundInNesting :: Nesting -> Names
boundInNesting Nesting
nesting =
  [VName] -> Names
namesFromList ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName (LoopNesting -> [Param Type]
loopNestingParams LoopNesting
loop)) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
  Nesting -> Names
nestingLetBound Nesting
nesting
  where loop :: LoopNesting
loop = Nesting -> LoopNesting
nestingLoop Nesting
nesting

letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting Names
names (Nesting
nest, [Nesting]
nestings) =
  (Names -> Nesting -> Nesting
letBindInNesting Names
names Nesting
nest, [Nesting]
nestings)


-- | Note: first element is *outermost* nesting.  This is different
-- from the similar types elsewhere!
type KernelNest = (LoopNesting, [LoopNesting])

ppKernelNest :: KernelNest -> String
ppKernelNest :: KernelNest -> String
ppKernelNest (LoopNesting
nesting, [LoopNesting]
nestings) =
  [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> String) -> [LoopNesting] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> String
ppLoopNesting ([LoopNesting] -> [String]) -> [LoopNesting] -> [String]
forall a b. (a -> b) -> a -> b
$ LoopNesting
nesting LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nestings

-- | Add new outermost nesting, pushing the current outermost to the
-- list, also taking care to swap patterns if necessary.
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
  (LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder LoopNesting
newnest Target
target (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest),
   LoopNesting
nest LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests)

-- | Add new innermost nesting, pushing the current outermost to the
-- list.  It is important that the 'Target' has the right order
-- (non-permuted compared to what is expected by the outer nests).
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
  (LoopNesting
nest, [LoopNesting]
nests [LoopNesting] -> [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a] -> [a]
++ [LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder LoopNesting
newnest Target
target (LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
innermost)])
  where innermost :: LoopNesting
innermost = case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
          []  -> LoopNesting
nest
          LoopNesting
n:[LoopNesting]
_ -> LoopNesting
n

fixNestingPatternOrder :: LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder :: LoopNesting -> Target -> PatternT Type -> LoopNesting
fixNestingPatternOrder LoopNesting
nest (PatternT Type
_,Result
res) PatternT Type
inner_pat =
  LoopNesting
nest { loopNestingPattern :: PatternT Type
loopNestingPattern = [Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident]
pat' }
  where pat :: PatternT Type
pat = LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
nest
        pat' :: [Ident]
pat' = ((Ident, SubExp) -> Ident) -> [(Ident, SubExp)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst [(Ident, SubExp)]
fixed_target
        fixed_target :: [(Ident, SubExp)]
fixed_target = ((Ident, SubExp) -> Int) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Ident, SubExp) -> Int
posInInnerPat ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> Result -> [(Ident, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternValueIdents PatternT Type
pat) Result
res
        posInInnerPat :: (Ident, SubExp) -> Int
posInInnerPat (Ident
_, Var VName
v) = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex VName
v ([VName] -> Maybe Int) -> [VName] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
inner_pat
        posInInnerPat (Ident, SubExp)
_          = Int
0

newKernel :: LoopNesting -> KernelNest
newKernel :: LoopNesting -> KernelNest
newKernel LoopNesting
nest = (LoopNesting
nest, [])

kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops (LoopNesting
loop, [LoopNesting]
loops) = LoopNesting
loop LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
loops

boundInKernelNest :: KernelNest -> Names
boundInKernelNest :: KernelNest -> Names
boundInKernelNest = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names)
-> (KernelNest -> [Names]) -> KernelNest -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [Names]
boundInKernelNests

boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests = (LoopNesting -> Names) -> [LoopNesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Names
namesFromList ([VName] -> Names)
-> (LoopNesting -> [VName]) -> LoopNesting -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                          ((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [VName])
-> (LoopNesting -> [(Param Type, VName)]) -> LoopNesting -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                          LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs) ([LoopNesting] -> [Names])
-> (KernelNest -> [LoopNesting]) -> KernelNest -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                     KernelNest -> [LoopNesting]
kernelNestLoops

kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths :: KernelNest -> Result
kernelNestWidths = (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth ([LoopNesting] -> Result)
-> (KernelNest -> [LoopNesting]) -> KernelNest -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops

constructKernel :: (DistLore lore, MonadFreshNames m, LocalScope lore m) =>
                   MkSegLevel lore m -> KernelNest -> Body lore
                -> m (Stm lore, Stms lore)
constructKernel :: MkSegLevel lore m
-> KernelNest -> Body lore -> m (Stm lore, Stms lore)
constructKernel MkSegLevel lore m
mk_lvl KernelNest
kernel_nest Body lore
inner_body = BinderT lore m (Stm lore) -> m (Stm lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT lore m (Stm lore) -> m (Stm lore, Stms lore))
-> BinderT lore m (Stm lore) -> m (Stm lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> BinderT lore m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
kernel_nest
  let cs :: Certificates
cs = LoopNesting -> Certificates
loopNestingCertificates LoopNesting
first_nest
      ispace_scope :: Map VName (NameInfo lore)
ispace_scope = [(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo lore)] -> Map VName (NameInfo lore))
-> [(VName, NameInfo lore)] -> Map VName (NameInfo lore)
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace) ([NameInfo lore] -> [(VName, NameInfo lore)])
-> [NameInfo lore] -> [(VName, NameInfo lore)]
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> [NameInfo lore]
forall a. a -> [a]
repeat (NameInfo lore -> [NameInfo lore])
-> NameInfo lore -> [NameInfo lore]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexInfo IntType
Int32
      pat :: PatternT Type
pat = LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
first_nest
      rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
pat

  KernelBody lore
inner_body' <- (([KernelResult], Stms lore) -> KernelBody lore)
-> BinderT lore m ([KernelResult], Stms lore)
-> BinderT lore m (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms lore -> KernelBody lore)
-> ([KernelResult], Stms lore) -> KernelBody lore
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms lore -> [KernelResult] -> KernelBody lore)
-> [KernelResult] -> Stms lore -> KernelBody lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (BinderT lore m ([KernelResult], Stms lore)
 -> BinderT lore m (KernelBody lore))
-> BinderT lore m ([KernelResult], Stms lore)
-> BinderT lore m (KernelBody lore)
forall a b. (a -> b) -> a -> b
$ Binder lore [KernelResult]
-> BinderT lore m ([KernelResult], Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore [KernelResult]
 -> BinderT lore m ([KernelResult], Stms lore))
-> Binder lore [KernelResult]
-> BinderT lore m ([KernelResult], Stms lore)
forall a b. (a -> b) -> a -> b
$
                 Map VName (NameInfo lore)
-> Binder lore [KernelResult] -> Binder lore [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Map VName (NameInfo lore)
ispace_scope (Binder lore [KernelResult] -> Binder lore [KernelResult])
-> Binder lore [KernelResult] -> Binder lore [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
    (KernelInput -> BinderT lore (State VNameSource) ())
-> [KernelInput] -> BinderT lore (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT lore (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput ([KernelInput] -> BinderT lore (State VNameSource) ())
-> [KernelInput] -> BinderT lore (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
    (SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) (Result -> [KernelResult])
-> BinderT lore (State VNameSource) Result
-> Binder lore [KernelResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind Body lore
Body (Lore (BinderT lore (State VNameSource)))
inner_body

  (SegOp (SegOpLevel lore) lore
segop, Stms lore
aux_stms) <- m (SegOp (SegOpLevel lore) lore, Stms lore)
-> BinderT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SegOp (SegOpLevel lore) lore, Stms lore)
 -> BinderT lore m (SegOp (SegOpLevel lore) lore, Stms lore))
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
-> BinderT lore m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel MkSegLevel lore m
mk_lvl [(VName, SubExp)]
ispace [] [Type]
rts KernelBody lore
inner_body'

  Stms (Lore (BinderT lore m)) -> BinderT lore m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms lore
Stms (Lore (BinderT lore m))
aux_stms

  Stm lore -> BinderT lore m (Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm lore -> BinderT lore m (Stm lore))
-> Stm lore -> BinderT lore m (Stm lore)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern lore
pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp lore -> Stm lore) -> Exp lore -> Stm lore
forall a b. (a -> b) -> a -> b
$ Op lore -> Exp lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> Exp lore) -> Op lore -> Exp lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp SegOp (SegOpLevel lore) lore
segop
  where
    first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
kernel_nest
    inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` Body lore -> Names
forall a. FreeIn a => a -> Names
freeIn Body lore
inner_body

-- | Flatten a kernel nesting to:
--
--  (1) The index space.
--
--  (2) The kernel inputs - note that some of these may be unused.
flatKernel :: MonadFreshNames m =>
              KernelNest
           -> m ([(VName, SubExp)],
                 [KernelInput])
flatKernel :: KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (MapNesting PatternT Type
_ Certificates
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, []) = do
  VName
i <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let inps :: [KernelInput]
inps = [ VName -> Type -> VName -> Result -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i] |
               (Param VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs ]
  ([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VName
i,SubExp
nesting_w)], [KernelInput]
inps)

flatKernel (MapNesting PatternT Type
_ Certificates
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, LoopNesting
nest : [LoopNesting]
nests) = do
  VName
i <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (LoopNesting
nest, [LoopNesting]
nests)

  let inps' :: [KernelInput]
inps' = (KernelInput -> KernelInput) -> [KernelInput] -> [KernelInput]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> KernelInput
fixupInput [KernelInput]
inps
      isParam :: KernelInput -> Maybe VName
isParam KernelInput
inp =
        (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==KernelInput -> VName
kernelInputArray KernelInput
inp) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
      fixupInput :: KernelInput -> KernelInput
fixupInput KernelInput
inp
        | Just VName
arr <- KernelInput -> Maybe VName
isParam KernelInput
inp =
            KernelInput
inp { kernelInputArray :: VName
kernelInputArray = VName
arr
                , kernelInputIndices :: Result
kernelInputIndices = VName -> SubExp
Var VName
i SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: KernelInput -> Result
kernelInputIndices KernelInput
inp }
        | Bool
otherwise =
            KernelInput
inp

  ([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => a -> m a
return ((VName
i, SubExp
nesting_w) (VName, SubExp) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. a -> [a] -> [a]
: [(VName, SubExp)]
ispace, VName -> [KernelInput]
extra_inps VName
i [KernelInput] -> [KernelInput] -> [KernelInput]
forall a. Semigroup a => a -> a -> a
<> [KernelInput]
inps')
  where extra_inps :: VName -> [KernelInput]
extra_inps VName
i =
          [ VName -> Type -> VName -> Result -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i] |
            (Param VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs ]

-- | Description of distribution to do.
data DistributionBody = DistributionBody {
    DistributionBody -> Targets
distributionTarget :: Targets
  , DistributionBody -> Names
distributionFreeInBody :: Names
  , DistributionBody -> Map VName Ident
distributionIdentityMap :: M.Map VName Ident
  , DistributionBody -> Target -> Target
distributionExpandTarget :: Target -> Target
    -- ^ Also related to avoiding identity mapping.
  }

distributionInnerPattern :: DistributionBody -> PatternT Type
distributionInnerPattern :: DistributionBody -> PatternT Type
distributionInnerPattern = Target -> PatternT Type
forall a b. (a, b) -> a
fst (Target -> PatternT Type)
-> (DistributionBody -> Target)
-> DistributionBody
-> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> Target
innerTarget (Targets -> Target)
-> (DistributionBody -> Targets) -> DistributionBody -> Target
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Targets
distributionTarget

distributionBodyFromStms :: Attributes lore =>
                            Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms :: Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms (Targets (PatternT Type
inner_pat, Result
inner_res) [Target]
targets) Stms lore
stms =
  let bound_by_stms :: Names
bound_by_stms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo lore) -> [VName])
-> Map VName (NameInfo lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms
      (PatternT Type
inner_pat', Result
inner_res', Map VName Ident
inner_identity_map, Target -> Target
inner_expand_target) =
        Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_by_stms PatternT Type
inner_pat Result
inner_res
  in (DistributionBody :: Targets
-> Names
-> Map VName Ident
-> (Target -> Target)
-> DistributionBody
DistributionBody
      { distributionTarget :: Targets
distributionTarget = Target -> [Target] -> Targets
Targets (PatternT Type
inner_pat', Result
inner_res') [Target]
targets
      , distributionFreeInBody :: Names
distributionFreeInBody = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms lore
stms Names -> Names -> Names
`namesSubtract` Names
bound_by_stms
      , distributionIdentityMap :: Map VName Ident
distributionIdentityMap = Map VName Ident
inner_identity_map
      , distributionExpandTarget :: Target -> Target
distributionExpandTarget = Target -> Target
inner_expand_target
      },
      Result
inner_res')

distributionBodyFromStm :: Attributes lore =>
                           Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm :: Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm lore
bnd =
  Targets -> Stms lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets (Stms lore -> (DistributionBody, Result))
-> Stms lore -> (DistributionBody, Result)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
bnd

createKernelNest :: (MonadFreshNames m, HasScope t m) =>
                    Nestings
                 -> DistributionBody
                 -> m (Maybe (Targets, KernelNest))
createKernelNest :: Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest (Nesting
inner_nest, [Nesting]
nests) DistributionBody
distrib_body = do
  let Targets Target
target [Target]
targets = DistributionBody -> Targets
distributionTarget DistributionBody
distrib_body
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Nesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting]
nests Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Target] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Target]
targets) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    String -> m ()
forall a. HasCallStack => String -> a
error (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"Nests and targets do not match!\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++
    String
"nests: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Nestings -> String
ppNestings (Nesting
inner_nest, [Nesting]
nests) String -> String -> String
forall a. [a] -> [a] -> [a]
++
    String
"\ntargets:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets (Target -> [Target] -> Targets
Targets Target
target [Target]
targets)
  MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest)))
-> MaybeT m (Targets, KernelNest)
-> m (Maybe (Targets, KernelNest))
forall a b. (a -> b) -> a -> b
$ ((KernelNest, Names, Targets) -> (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (KernelNest, Names, Targets) -> (Targets, KernelNest)
forall b b a. (b, b, a) -> (a, b)
prepare (MaybeT m (KernelNest, Names, Targets)
 -> MaybeT m (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall a b. (a -> b) -> a -> b
$ [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
[(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse ([(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets))
-> [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$ [Nesting] -> [Target] -> [(Nesting, Target)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Nesting]
nests [Target]
targets

  where prepare :: (b, b, a) -> (a, b)
prepare (b
x, b
_, a
z) = (a
z, b
x)
        bound_in_nest :: Names
bound_in_nest =
          [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Nesting -> Names) -> [Nesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> Names
boundInNesting ([Nesting] -> [Names]) -> [Nesting] -> [Names]
forall a b. (a -> b) -> a -> b
$ Nesting
inner_nest Nesting -> [Nesting] -> [Nesting]
forall a. a -> [a] -> [a]
: [Nesting]
nests
        -- | Can something of this type be taken outside the nest?
        -- I.e. are none of its dimensions bound inside the nest.
        distributableType :: Type -> Bool
distributableType =
          (Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
==Names
forall a. Monoid a => a
mempty) (Names -> Bool) -> (Type -> Names) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Names -> Names
namesIntersection Names
bound_in_nest (Names -> Names) -> (Type -> Names) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> (Type -> Result) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims

        distributeAtNesting :: (HasScope t m, MonadFreshNames m) =>
                               Nesting
                            -> PatternT Type
                            -> (LoopNesting -> KernelNest, Names)
                            -> M.Map VName Ident
                            -> [Ident]
                            -> (Target -> Targets)
                            -> MaybeT m (KernelNest, Names, Targets)
        distributeAtNesting :: Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
          (Nesting Names
nest_let_bound LoopNesting
nest)
          PatternT Type
pat
          (LoopNesting -> KernelNest
add_to_kernel, Names
free_in_kernel)
          Map VName Ident
identity_map
          [Ident]
inner_returned_arrs
          Target -> Targets
addTarget = do
          let nest' :: LoopNesting
nest'@(MapNesting PatternT Type
_ Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
                Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel LoopNesting
nest
              ([Param Type]
params,[VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
              param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName [Param Type]
params
              free_in_kernel' :: Names
free_in_kernel' =
                (LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
param_names
              required_from_nest :: Names
required_from_nest =
                Names
free_in_kernel' Names -> Names -> Names
`namesIntersection` Names
nest_let_bound

          [Ident]
required_from_nest_idents <-
            [VName] -> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
required_from_nest) ((VName -> MaybeT m Ident) -> MaybeT m [Ident])
-> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall a b. (a -> b) -> a -> b
$ \VName
name -> do
              Type
t <- m Type -> MaybeT m Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Type -> MaybeT m Type) -> m Type -> MaybeT m Type
forall a b. (a -> b) -> a -> b
$ VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
name
              Ident -> MaybeT m Ident
forall (m :: * -> *) a. Monad m => a -> m a
return (Ident -> MaybeT m Ident) -> Ident -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t

          ([Param Type]
free_params, [Ident]
free_arrs, [Bool]
bind_in_target) <-
            ([(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (MaybeT m [(Param Type, Ident, Bool)]
 -> MaybeT m ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall a b. (a -> b) -> a -> b
$
            [Ident]
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
inner_returned_arrs[Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++[Ident]
required_from_nest_idents) ((Ident -> MaybeT m (Param Type, Ident, Bool))
 -> MaybeT m [(Param Type, Ident, Bool)])
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall a b. (a -> b) -> a -> b
$
            \(Ident VName
pname Type
ptype) ->
              case VName -> Map VName Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pname Map VName Ident
identity_map of
                Maybe Ident
Nothing -> do
                  Ident
arr <- String -> Type -> MaybeT m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (VName -> String
baseString VName
pname String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_r") (Type -> MaybeT m Ident) -> Type -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$
                         Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
ptype SubExp
w
                  (Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param VName
pname Type
ptype,
                          Ident
arr,
                          Bool
True)
                Just Ident
arr ->
                  (Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param VName
pname Type
ptype,
                          Ident
arr,
                          Bool
False)

          let free_arrs_pat :: PatternT Type
free_arrs_pat =
                [Ident] -> [Ident] -> PatternT Type
basicPattern [] ([Ident] -> PatternT Type) -> [Ident] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ ((Bool, Ident) -> Ident) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Ident) -> Ident
forall a b. (a, b) -> b
snd ([(Bool, Ident)] -> [Ident]) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> a -> b
$
                ((Bool, Ident) -> Bool) -> [(Bool, Ident)] -> [(Bool, Ident)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Ident) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Ident)] -> [(Bool, Ident)])
-> [(Bool, Ident)] -> [(Bool, Ident)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Ident] -> [(Bool, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Ident]
free_arrs
              free_params_pat :: [Param Type]
free_params_pat =
                ((Bool, Param Type) -> Param Type)
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Param Type) -> Param Type
forall a b. (a, b) -> b
snd ([(Bool, Param Type)] -> [Param Type])
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ ((Bool, Param Type) -> Bool)
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Param Type) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Param Type)] -> [(Bool, Param Type)])
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Param Type] -> [(Bool, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Param Type]
free_params

              ([Param Type]
actual_params, [VName]
actual_arrs) =
                ([Param Type]
params[Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++[Param Type]
free_params,
                 [VName]
arrs[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++(Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
free_arrs)
              actual_param_names :: Names
actual_param_names =
                [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName [Param Type]
actual_params

              nest'' :: LoopNesting
nest'' =
                Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel (LoopNesting -> LoopNesting) -> LoopNesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
                PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
actual_params [VName]
actual_arrs

              free_in_kernel'' :: Names
free_in_kernel'' =
                (LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest'' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
actual_param_names

          Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
distributableType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$
                  LoopNesting -> [Param Type]
loopNestingParams LoopNesting
nest'') (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
            String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Would induce irregular array"
          (KernelNest, Names, Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting -> KernelNest
add_to_kernel LoopNesting
nest'',

                  Names
free_in_kernel'',

                  Target -> Targets
addTarget (PatternT Type
free_arrs_pat, (Param Type -> SubExp) -> [Param Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName) [Param Type]
free_params_pat))

        recurse :: (HasScope t m, MonadFreshNames m) =>
                   [(Nesting,Target)]
                -> MaybeT m (KernelNest, Names, Targets)
        recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [] =
          Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
          Nesting
inner_nest
          (DistributionBody -> PatternT Type
distributionInnerPattern DistributionBody
distrib_body)
          (LoopNesting -> KernelNest
newKernel,
           DistributionBody -> Names
distributionFreeInBody DistributionBody
distrib_body Names -> Names -> Names
`namesIntersection` Names
bound_in_nest)
          (DistributionBody -> Map VName Ident
distributionIdentityMap DistributionBody
distrib_body)
          [] ((Target -> Targets) -> MaybeT m (KernelNest, Names, Targets))
-> (Target -> Targets) -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$
          Target -> Targets
singleTarget (Target -> Targets) -> (Target -> Target) -> Target -> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Target -> Target
distributionExpandTarget DistributionBody
distrib_body

        recurse ((Nesting
nest, (PatternT Type
pat,Result
res)) : [(Nesting, Target)]
nests') = do
          (kernel :: KernelNest
kernel@(LoopNesting
outer, [LoopNesting]
_), Names
kernel_free, Targets
kernel_targets) <- [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
[(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [(Nesting, Target)]
nests'

          let (PatternT Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
                Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting
                ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
outer) PatternT Type
pat Result
res

          Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
Nesting
-> PatternT Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
            Nesting
nest
            PatternT Type
pat'
            (\LoopNesting
k -> Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting (PatternT Type
pat',Result
res') LoopNesting
k KernelNest
kernel,
             Names
kernel_free)
            Map VName Ident
identity_map
            (PatternT Type -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternIdents (PatternT Type -> [Ident]) -> PatternT Type -> [Ident]
forall a b. (a -> b) -> a -> b
$ Target -> PatternT Type
forall a b. (a, b) -> a
fst (Target -> PatternT Type) -> Target -> PatternT Type
forall a b. (a -> b) -> a -> b
$ Targets -> Target
outerTarget Targets
kernel_targets)
            ((Target -> Targets -> Targets
`pushOuterTarget` Targets
kernel_targets) (Target -> Targets) -> (Target -> Target) -> Target -> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target -> Target
expand_target)

removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
used (MapNesting PatternT Type
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
  PatternT Type
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
used_params [VName]
used_arrs
  where ([Param Type]
params,[VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
        ([Param Type]
used_params, [VName]
used_arrs) =
          [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
          ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [(Param Type, VName)])
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a b. (a -> b) -> a -> b
$
          [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
params [VName]
arrs

removeIdentityMappingGeneral :: Names -> PatternT Type -> Result
                             -> (PatternT Type,
                                 Result,
                                 M.Map VName Ident,
                                 Target -> Target)
removeIdentityMappingGeneral :: Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound PatternT Type
pat Result
res =
  let ([(PatElemT Type, VName)]
identities, [(PatElemT Type, SubExp)]
not_identities) =
        ((PatElemT Type, SubExp)
 -> Either (PatElemT Type, VName) (PatElemT Type, SubExp))
-> [(PatElemT Type, SubExp)]
-> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)])
forall a b c. (a -> Either b c) -> [a] -> ([b], [c])
mapEither (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
isIdentity ([(PatElemT Type, SubExp)]
 -> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)]))
-> [(PatElemT Type, SubExp)]
-> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)])
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> Result -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
pat) Result
res
      ([PatElemT Type]
not_identity_patElems, Result
not_identity_res) = [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT Type, SubExp)]
not_identities
      ([PatElemT Type]
identity_patElems, [VName]
identity_res) = [(PatElemT Type, VName)] -> ([PatElemT Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT Type, VName)]
identities
      expandTarget :: Target -> Target
expandTarget (PatternT Type
tpat, Result
tres) =
        ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
tpat [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
identity_patElems,
         Result
tres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
identity_res)
      identity_map :: Map VName Ident
identity_map = [(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Ident)] -> Map VName Ident)
-> [(VName, Ident)] -> Map VName Ident
forall a b. (a -> b) -> a -> b
$ [VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
identity_res ([Ident] -> [(VName, Ident)]) -> [Ident] -> [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$
                      (PatElemT Type -> Ident) -> [PatElemT Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> Ident
forall attr. Typed attr => PatElemT attr -> Ident
patElemIdent [PatElemT Type]
identity_patElems
  in ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type]
not_identity_patElems,
      Result
not_identity_res,
      Map VName Ident
identity_map,
      Target -> Target
expandTarget)
  where isIdentity :: (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
isIdentity (PatElemT Type
patElem, Var VName
v)
          | Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
bound) = (PatElemT Type, VName)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
forall a b. a -> Either a b
Left (PatElemT Type
patElem, VName
v)
        isIdentity (PatElemT Type, SubExp)
x               = (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
forall a b. b -> Either a b
Right (PatElemT Type, SubExp)
x

removeIdentityMappingFromNesting :: Names -> PatternT Type -> Result
                                 -> (PatternT Type,
                                     Result,
                                     M.Map VName Ident,
                                     Target -> Target)
removeIdentityMappingFromNesting :: Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting Names
bound_in_nesting PatternT Type
pat Result
res =
  let (PatternT Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
        Names
-> PatternT Type
-> Result
-> (PatternT Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_in_nesting PatternT Type
pat Result
res
  in (PatternT Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target)

tryDistribute :: (DistLore lore, MonadFreshNames m,
                  LocalScope lore m, MonadLogger m) =>
                 MkSegLevel lore m -> Nestings -> Targets -> Stms lore
              -> m (Maybe (Targets, Stms lore))
tryDistribute :: MkSegLevel lore m
-> Nestings
-> Targets
-> Stms lore
-> m (Maybe (Targets, Stms lore))
tryDistribute MkSegLevel lore m
_ Nestings
_ Targets
targets Stms lore
stms | Stms lore -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms lore
stms =
  -- No point in distributing an empty kernel.
  Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore)))
-> Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms lore) -> Maybe (Targets, Stms lore)
forall a. a -> Maybe a
Just (Targets
targets, Stms lore
forall a. Monoid a => a
mempty)
tryDistribute MkSegLevel lore m
mk_lvl Nestings
nest Targets
targets Stms lore
stms =
  Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) t.
(MonadFreshNames m, HasScope t m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body m (Maybe (Targets, KernelNest))
-> (Maybe (Targets, KernelNest) -> m (Maybe (Targets, Stms lore)))
-> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
  \case
    Just (Targets
targets', KernelNest
distributed) -> do
      (Stm lore
kernel_bnd, Stms lore
w_bnds) <-
        Scope lore -> m (Stm lore, Stms lore) -> m (Stm lore, Stms lore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Targets -> Scope lore
forall lore. DistLore lore => Targets -> Scope lore
targetsScope Targets
targets') (m (Stm lore, Stms lore) -> m (Stm lore, Stms lore))
-> m (Stm lore, Stms lore) -> m (Stm lore, Stms lore)
forall a b. (a -> b) -> a -> b
$
        MkSegLevel lore m
-> KernelNest -> Body lore -> m (Stm lore, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m) =>
MkSegLevel lore m
-> KernelNest -> Body lore -> m (Stm lore, Stms lore)
constructKernel MkSegLevel lore m
mk_lvl KernelNest
distributed (Body lore -> m (Stm lore, Stms lore))
-> Body lore -> m (Stm lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ Stms lore -> Result -> Body lore
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms lore
stms Result
inner_body_res
      Stm lore
distributed' <- Stm lore -> m (Stm lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm Stm lore
kernel_bnd
      String -> m ()
forall (m :: * -> *) a. (MonadLogger m, ToLog a) => a -> m ()
logMsg (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"distributing\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++
        [String] -> String
unlines ((Stm lore -> String) -> [Stm lore] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Stm lore -> String
forall a. Pretty a => a -> String
pretty ([Stm lore] -> [String]) -> [Stm lore] -> [String]
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms) String -> String -> String
forall a. [a] -> [a] -> [a]
++
        Result -> String
forall a. Pretty a => a -> String
pretty (Target -> Result
forall a b. (a, b) -> b
snd (Target -> Result) -> Target -> Result
forall a b. (a -> b) -> a -> b
$ Targets -> Target
innerTarget Targets
targets) String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
"\nas\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm lore -> String
forall a. Pretty a => a -> String
pretty Stm lore
distributed' String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
"\ndue to targets\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets Targets
targets String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
"\nand with new targets\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets Targets
targets'
      Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore)))
-> Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms lore) -> Maybe (Targets, Stms lore)
forall a. a -> Maybe a
Just (Targets
targets', Stms lore
w_bnds Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
distributed')
    Maybe (Targets, KernelNest)
Nothing ->
      Maybe (Targets, Stms lore) -> m (Maybe (Targets, Stms lore))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Targets, Stms lore)
forall a. Maybe a
Nothing
  where (DistributionBody
dist_body, Result
inner_body_res) = Targets -> Stms lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets Stms lore
stms

tryDistributeStm :: (MonadFreshNames m, HasScope t m, Attributes lore) =>
                    Nestings -> Targets -> Stm lore
                 -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm :: Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm lore
bnd =
  ((Targets, KernelNest) -> (Result, Targets, KernelNest))
-> Maybe (Targets, KernelNest)
-> Maybe (Result, Targets, KernelNest)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Maybe (Targets, KernelNest)
 -> Maybe (Result, Targets, KernelNest))
-> m (Maybe (Targets, KernelNest))
-> m (Maybe (Result, Targets, KernelNest))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) t.
(MonadFreshNames m, HasScope t m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body
  where (DistributionBody
dist_body, Result
res) = Targets -> Stm lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm lore
bnd
        addRes :: (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Targets
targets', KernelNest
kernel_nest) = (Result
res, Targets
targets', KernelNest
kernel_nest)