{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Futhark.Pass.ExtractKernels.Distribution
       (
         Target
       , Targets
       , ppTargets
       , singleTarget
       , outerTarget
       , innerTarget
       , pushInnerTarget
       , popInnerTarget
       , targetsScope

       , LoopNesting (..)
       , ppLoopNesting

       , Nesting (..)
       , Nestings
       , ppNestings
       , letBindInInnerNesting
       , singleNesting
       , pushInnerNesting

       , KernelNest
       , ppKernelNest
       , newKernel
       , pushKernelNesting
       , pushInnerKernelNesting
       , kernelNestLoops
       , kernelNestWidths
       , boundInKernelNest
       , boundInKernelNests
       , flatKernel
       , constructKernel

       , tryDistribute
       , tryDistributeStm
       )
       where

import Control.Monad.RWS.Strict
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.Maybe
import Data.List (elemIndex, sortOn)

import Futhark.Representation.Kernels
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util
import Futhark.Transform.Rename
import Futhark.Util.Log
import Futhark.Pass.ExtractKernels.BlockedKernel
  (mapKernel, KernelInput(..), readKernelInput, MkSegLevel)

type Target = (Pattern Kernels, Result)

-- | First pair element is the very innermost ("current") target.  In
-- the list, the outermost target comes first.  Invariant: Every
-- element of a pattern must be present as the result of the
-- immediately enclosing target.  This is ensured by 'pushInnerTarget'
-- by removing unused pattern elements.
data Targets = Targets { Targets -> Target
_innerTarget :: Target
                       , Targets -> [Target]
_outerTargets :: [Target]
                       }

ppTargets :: Targets -> String
ppTargets :: Targets -> String
ppTargets (Targets Target
target [Target]
targets) =
  [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ ((PatternT Type, Result) -> String)
-> [(PatternT Type, Result)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (PatternT Type, Result) -> String
forall a a. (Pretty a, Pretty a) => (a, a) -> String
ppTarget ([(PatternT Type, Result)] -> [String])
-> [(PatternT Type, Result)] -> [String]
forall a b. (a -> b) -> a -> b
$ [(PatternT Type, Result)]
[Target]
targets [(PatternT Type, Result)]
-> [(PatternT Type, Result)] -> [(PatternT Type, Result)]
forall a. [a] -> [a] -> [a]
++ [(PatternT Type, Result)
Target
target]
  where ppTarget :: (a, a) -> String
ppTarget (a
pat, a
res) =
          a -> String
forall a. Pretty a => a -> String
pretty a
pat String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" <- " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Pretty a => a -> String
pretty a
res

singleTarget :: Target -> Targets
singleTarget :: Target -> Targets
singleTarget = ((PatternT Type, Result) -> [(PatternT Type, Result)] -> Targets)
-> [(PatternT Type, Result)] -> (PatternT Type, Result) -> Targets
forall a b c. (a -> b -> c) -> b -> a -> c
flip (PatternT Type, Result) -> [(PatternT Type, Result)] -> Targets
Target -> [Target] -> Targets
Targets []

outerTarget :: Targets -> Target
outerTarget :: Targets -> Target
outerTarget (Targets Target
inner_target []) = Target
inner_target
outerTarget (Targets Target
_ (Target
outer_target : [Target]
_)) = Target
outer_target

innerTarget :: Targets -> Target
innerTarget :: Targets -> Target
innerTarget (Targets Target
inner_target [Target]
_) = Target
inner_target

pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget Target
target (Targets Target
inner_target [Target]
targets) =
  Target -> [Target] -> Targets
Targets Target
inner_target ((PatternT Type, Result)
Target
target (PatternT Type, Result)
-> [(PatternT Type, Result)] -> [(PatternT Type, Result)]
forall a. a -> [a] -> [a]
: [(PatternT Type, Result)]
[Target]
targets)

pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget (Pattern Kernels
pat, Result
res) (Targets Target
inner_target [Target]
targets) =
  Target -> [Target] -> Targets
Targets (PatternT Type
Pattern Kernels
pat', Result
res') ([(PatternT Type, Result)]
[Target]
targets [(PatternT Type, Result)]
-> [(PatternT Type, Result)] -> [(PatternT Type, Result)]
forall a. [a] -> [a] -> [a]
++ [(PatternT Type, Result)
Target
inner_target])
  where ([PatElemT Type]
pes', Result
res') = [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PatElemT Type, SubExp)] -> ([PatElemT Type], Result))
-> [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. (a -> b) -> a -> b
$ ((PatElemT Type, SubExp) -> Bool)
-> [(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (PatElemT Type -> Bool
used (PatElemT Type -> Bool)
-> ((PatElemT Type, SubExp) -> PatElemT Type)
-> (PatElemT Type, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT Type, SubExp) -> PatElemT Type
forall a b. (a, b) -> a
fst) ([(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)])
-> [(PatElemT Type, SubExp)] -> [(PatElemT Type, SubExp)]
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> Result -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
Pattern Kernels
pat) Result
res
        pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type]
pes'
        inner_used :: Names
inner_used = Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> Result -> Names
forall a b. (a -> b) -> a -> b
$ (PatternT Type, Result) -> Result
forall a b. (a, b) -> b
snd (PatternT Type, Result)
Target
inner_target
        used :: PatElemT Type -> Bool
used PatElemT Type
pe = PatElemT Type -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT Type
pe VName -> Names -> Bool
`nameIn` Names
inner_used

popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget (Targets Target
t [Target]
ts) =
  case [(PatternT Type, Result)] -> [(PatternT Type, Result)]
forall a. [a] -> [a]
reverse [(PatternT Type, Result)]
[Target]
ts of
    (PatternT Type, Result)
x:[(PatternT Type, Result)]
xs -> ((PatternT Type, Result), Targets)
-> Maybe ((PatternT Type, Result), Targets)
forall a. a -> Maybe a
Just ((PatternT Type, Result)
Target
t, Target -> [Target] -> Targets
Targets (PatternT Type, Result)
Target
x ([Target] -> Targets) -> [Target] -> Targets
forall a b. (a -> b) -> a -> b
$ [(PatternT Type, Result)] -> [(PatternT Type, Result)]
forall a. [a] -> [a]
reverse [(PatternT Type, Result)]
xs)
    []   -> Maybe (Target, Targets)
forall a. Maybe a
Nothing

targetScope :: Target -> Scope Kernels
targetScope :: Target -> Scope Kernels
targetScope = PatternT Type -> Scope Kernels
forall lore attr.
(LetAttr lore ~ attr) =>
PatternT attr -> Scope lore
scopeOfPattern (PatternT Type -> Scope Kernels)
-> ((PatternT Type, Result) -> PatternT Type)
-> (PatternT Type, Result)
-> Scope Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst

targetsScope :: Targets -> Scope Kernels
targetsScope :: Targets -> Scope Kernels
targetsScope (Targets Target
t [Target]
ts) = [Scope Kernels] -> Scope Kernels
forall a. Monoid a => [a] -> a
mconcat ([Scope Kernels] -> Scope Kernels)
-> [Scope Kernels] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((PatternT Type, Result) -> Scope Kernels)
-> [(PatternT Type, Result)] -> [Scope Kernels]
forall a b. (a -> b) -> [a] -> [b]
map (PatternT Type, Result) -> Scope Kernels
Target -> Scope Kernels
targetScope ([(PatternT Type, Result)] -> [Scope Kernels])
-> [(PatternT Type, Result)] -> [Scope Kernels]
forall a b. (a -> b) -> a -> b
$ (PatternT Type, Result)
Target
t (PatternT Type, Result)
-> [(PatternT Type, Result)] -> [(PatternT Type, Result)]
forall a. a -> [a] -> [a]
: [(PatternT Type, Result)]
[Target]
ts

data LoopNesting = MapNesting { LoopNesting -> Pattern Kernels
loopNestingPattern :: Pattern Kernels
                              , LoopNesting -> Certificates
loopNestingCertificates :: Certificates
                              , LoopNesting -> SubExp
loopNestingWidth :: SubExp
                              , LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs :: [(Param Type, VName)]
                              }
                 deriving (Int -> LoopNesting -> String -> String
[LoopNesting] -> String -> String
LoopNesting -> String
(Int -> LoopNesting -> String -> String)
-> (LoopNesting -> String)
-> ([LoopNesting] -> String -> String)
-> Show LoopNesting
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [LoopNesting] -> String -> String
$cshowList :: [LoopNesting] -> String -> String
show :: LoopNesting -> String
$cshow :: LoopNesting -> String
showsPrec :: Int -> LoopNesting -> String -> String
$cshowsPrec :: Int -> LoopNesting -> String -> String
Show)

instance Scoped Kernels LoopNesting where
  scopeOf :: LoopNesting -> Scope Kernels
scopeOf = [Param Type] -> Scope Kernels
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels)
-> (LoopNesting -> [Param Type]) -> LoopNesting -> Scope Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs

ppLoopNesting :: LoopNesting -> String
ppLoopNesting :: LoopNesting -> String
ppLoopNesting (MapNesting Pattern Kernels
_ Certificates
_ SubExp
_ [(Param Type, VName)]
params_and_arrs) =
  [Param Type] -> String
forall a. Pretty a => a -> String
pretty (((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs) String -> String -> String
forall a. [a] -> [a] -> [a]
++
  String
" <- " String -> String -> String
forall a. [a] -> [a] -> [a]
++
  [VName] -> String
forall a. Pretty a => a -> String
pretty (((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd [(Param Type, VName)]
params_and_arrs)

loopNestingParams :: LoopNesting -> [LParam Kernels]
loopNestingParams :: LoopNesting -> [LParam Kernels]
loopNestingParams  = ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst ([(Param Type, VName)] -> [Param Type])
-> (LoopNesting -> [(Param Type, VName)])
-> LoopNesting
-> [Param Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs

instance FreeIn LoopNesting where
  freeIn' :: LoopNesting -> FV
freeIn' (MapNesting Pattern Kernels
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
    PatternT Type -> FV
forall a. FreeIn a => a -> FV
freeIn' PatternT Type
Pattern Kernels
pat FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
    Certificates -> FV
forall a. FreeIn a => a -> FV
freeIn' Certificates
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
    SubExp -> FV
forall a. FreeIn a => a -> FV
freeIn' SubExp
w FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<>
    [(Param Type, VName)] -> FV
forall a. FreeIn a => a -> FV
freeIn' [(Param Type, VName)]
params_and_arrs

data Nesting = Nesting { Nesting -> Names
nestingLetBound :: Names
                       , Nesting -> LoopNesting
nestingLoop :: LoopNesting
                       }
             deriving (Int -> Nesting -> String -> String
[Nesting] -> String -> String
Nesting -> String
(Int -> Nesting -> String -> String)
-> (Nesting -> String)
-> ([Nesting] -> String -> String)
-> Show Nesting
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [Nesting] -> String -> String
$cshowList :: [Nesting] -> String -> String
show :: Nesting -> String
$cshow :: Nesting -> String
showsPrec :: Int -> Nesting -> String -> String
$cshowsPrec :: Int -> Nesting -> String -> String
Show)

letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting Names
newnames (Nesting Names
oldnames LoopNesting
loop) =
  Names -> LoopNesting -> Nesting
Nesting (Names
oldnames Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
newnames) LoopNesting
loop

-- ^ First pair element is the very innermost ("current") nest.  In
-- the list, the outermost nest comes first.
type Nestings = (Nesting, [Nesting])

ppNestings :: Nestings -> String
ppNestings :: Nestings -> String
ppNestings (Nesting
nesting, [Nesting]
nestings) =
  [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (Nesting -> String) -> [Nesting] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> String
ppNesting ([Nesting] -> [String]) -> [Nesting] -> [String]
forall a b. (a -> b) -> a -> b
$ [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
nesting]
  where ppNesting :: Nesting -> String
ppNesting (Nesting Names
_ LoopNesting
loop) =
          LoopNesting -> String
ppLoopNesting LoopNesting
loop

singleNesting :: Nesting -> Nestings
singleNesting :: Nesting -> Nestings
singleNesting = (,[])

pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nesting (Nesting
inner_nesting, [Nesting]
nestings) =
  (Nesting
nesting, [Nesting]
nestings [Nesting] -> [Nesting] -> [Nesting]
forall a. [a] -> [a] -> [a]
++ [Nesting
inner_nesting])

-- | Both parameters and let-bound.
boundInNesting :: Nesting -> Names
boundInNesting :: Nesting -> Names
boundInNesting Nesting
nesting =
  [VName] -> Names
namesFromList ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName (LoopNesting -> [LParam Kernels]
loopNestingParams LoopNesting
loop)) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
  Nesting -> Names
nestingLetBound Nesting
nesting
  where loop :: LoopNesting
loop = Nesting -> LoopNesting
nestingLoop Nesting
nesting

letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting Names
names (Nesting
nest, [Nesting]
nestings) =
  (Names -> Nesting -> Nesting
letBindInNesting Names
names Nesting
nest, [Nesting]
nestings)


-- | Note: first element is *outermost* nesting.  This is different
-- from the similar types elsewhere!
type KernelNest = (LoopNesting, [LoopNesting])

ppKernelNest :: KernelNest -> String
ppKernelNest :: KernelNest -> String
ppKernelNest (LoopNesting
nesting, [LoopNesting]
nestings) =
  [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (LoopNesting -> String) -> [LoopNesting] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> String
ppLoopNesting ([LoopNesting] -> [String]) -> [LoopNesting] -> [String]
forall a b. (a -> b) -> a -> b
$ LoopNesting
nesting LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nestings

-- | Add new outermost nesting, pushing the current outermost to the
-- list, also taking care to swap patterns if necessary.
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
  (LoopNesting -> Target -> Pattern Kernels -> LoopNesting
fixNestingPatternOrder LoopNesting
newnest Target
target (LoopNesting -> Pattern Kernels
loopNestingPattern LoopNesting
nest),
   LoopNesting
nest LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
nests)

-- | Add new innermost nesting, pushing the current outermost to the
-- list.  It is important that the 'Target' has the right order
-- (non-permuted compared to what is expected by the outer nests).
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
  (LoopNesting
nest, [LoopNesting]
nests [LoopNesting] -> [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a] -> [a]
++ [LoopNesting -> Target -> Pattern Kernels -> LoopNesting
fixNestingPatternOrder LoopNesting
newnest Target
target (LoopNesting -> Pattern Kernels
loopNestingPattern LoopNesting
innermost)])
  where innermost :: LoopNesting
innermost = case [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse [LoopNesting]
nests of
          []  -> LoopNesting
nest
          LoopNesting
n:[LoopNesting]
_ -> LoopNesting
n

fixNestingPatternOrder :: LoopNesting -> Target -> Pattern Kernels -> LoopNesting
fixNestingPatternOrder :: LoopNesting -> Target -> Pattern Kernels -> LoopNesting
fixNestingPatternOrder LoopNesting
nest (Pattern Kernels
_,Result
res) Pattern Kernels
inner_pat =
  LoopNesting
nest { loopNestingPattern :: Pattern Kernels
loopNestingPattern = [Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident]
pat' }
  where pat :: Pattern Kernels
pat = LoopNesting -> Pattern Kernels
loopNestingPattern LoopNesting
nest
        pat' :: [Ident]
pat' = ((Ident, SubExp) -> Ident) -> [(Ident, SubExp)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Ident, SubExp) -> Ident
forall a b. (a, b) -> a
fst [(Ident, SubExp)]
fixed_target
        fixed_target :: [(Ident, SubExp)]
fixed_target = ((Ident, SubExp) -> Int) -> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Ident, SubExp) -> Int
posInInnerPat ([(Ident, SubExp)] -> [(Ident, SubExp)])
-> [(Ident, SubExp)] -> [(Ident, SubExp)]
forall a b. (a -> b) -> a -> b
$ [Ident] -> Result -> [(Ident, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternValueIdents PatternT Type
Pattern Kernels
pat) Result
res
        posInInnerPat :: (Ident, SubExp) -> Int
posInInnerPat (Ident
_, Var VName
v) = Int -> Maybe Int -> Int
forall a. a -> Maybe a -> a
fromMaybe Int
0 (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex VName
v ([VName] -> Maybe Int) -> [VName] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames PatternT Type
Pattern Kernels
inner_pat
        posInInnerPat (Ident, SubExp)
_          = Int
0

newKernel :: LoopNesting -> KernelNest
newKernel :: LoopNesting -> KernelNest
newKernel LoopNesting
nest = (LoopNesting
nest, [])

kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops (LoopNesting
loop, [LoopNesting]
loops) = LoopNesting
loop LoopNesting -> [LoopNesting] -> [LoopNesting]
forall a. a -> [a] -> [a]
: [LoopNesting]
loops

boundInKernelNest :: KernelNest -> Names
boundInKernelNest :: KernelNest -> Names
boundInKernelNest = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names)
-> (KernelNest -> [Names]) -> KernelNest -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [Names]
boundInKernelNests

boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests = (LoopNesting -> Names) -> [LoopNesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Names
namesFromList ([VName] -> Names)
-> (LoopNesting -> [VName]) -> LoopNesting -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                          ((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [VName])
-> (LoopNesting -> [(Param Type, VName)]) -> LoopNesting -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                          LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs) ([LoopNesting] -> [Names])
-> (KernelNest -> [LoopNesting]) -> KernelNest -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                     KernelNest -> [LoopNesting]
kernelNestLoops

kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths :: KernelNest -> Result
kernelNestWidths = (LoopNesting -> SubExp) -> [LoopNesting] -> Result
forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth ([LoopNesting] -> Result)
-> (KernelNest -> [LoopNesting]) -> KernelNest -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops

constructKernel :: (MonadFreshNames m, LocalScope Kernels m) =>
                   MkSegLevel m -> KernelNest -> Body Kernels
                -> m (Stm Kernels, Stms Kernels)
constructKernel :: MkSegLevel m
-> KernelNest -> Body Kernels -> m (Stm Kernels, Stms Kernels)
constructKernel MkSegLevel m
mk_lvl KernelNest
kernel_nest Body Kernels
inner_body = BinderT Kernels m (Stm Kernels) -> m (Stm Kernels, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT Kernels m (Stm Kernels) -> m (Stm Kernels, Stms Kernels))
-> BinderT Kernels m (Stm Kernels) -> m (Stm Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> BinderT Kernels m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
kernel_nest
  let cs :: Certificates
cs = LoopNesting -> Certificates
loopNestingCertificates LoopNesting
first_nest
      ispace_scope :: Scope Kernels
ispace_scope = [(VName, NameInfo Kernels)] -> Scope Kernels
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, NameInfo Kernels)] -> Scope Kernels)
-> [(VName, NameInfo Kernels)] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ [VName] -> [NameInfo Kernels] -> [(VName, NameInfo Kernels)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace) ([NameInfo Kernels] -> [(VName, NameInfo Kernels)])
-> [NameInfo Kernels] -> [(VName, NameInfo Kernels)]
forall a b. (a -> b) -> a -> b
$ NameInfo Kernels -> [NameInfo Kernels]
forall a. a -> [a]
repeat (NameInfo Kernels -> [NameInfo Kernels])
-> NameInfo Kernels -> [NameInfo Kernels]
forall a b. (a -> b) -> a -> b
$ IntType -> NameInfo Kernels
forall lore. IntType -> NameInfo lore
IndexInfo IntType
Int32
      pat :: Pattern Kernels
pat = LoopNesting -> Pattern Kernels
loopNestingPattern LoopNesting
first_nest
      rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace)) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall attr. Typed attr => PatternT attr -> [Type]
patternTypes PatternT Type
Pattern Kernels
pat

  KernelBody Kernels
inner_body' <- (([KernelResult], Stms Kernels) -> KernelBody Kernels)
-> BinderT Kernels m ([KernelResult], Stms Kernels)
-> BinderT Kernels m (KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms Kernels -> KernelBody Kernels)
-> ([KernelResult], Stms Kernels) -> KernelBody Kernels
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> Stms Kernels -> KernelBody Kernels
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyAttr Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyAttr lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (BinderT Kernels m ([KernelResult], Stms Kernels)
 -> BinderT Kernels m (KernelBody Kernels))
-> BinderT Kernels m ([KernelResult], Stms Kernels)
-> BinderT Kernels m (KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$ Binder Kernels [KernelResult]
-> BinderT Kernels m ([KernelResult], Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels [KernelResult]
 -> BinderT Kernels m ([KernelResult], Stms Kernels))
-> Binder Kernels [KernelResult]
-> BinderT Kernels m ([KernelResult], Stms Kernels)
forall a b. (a -> b) -> a -> b
$
                 Scope Kernels
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
ispace_scope (Binder Kernels [KernelResult] -> Binder Kernels [KernelResult])
-> Binder Kernels [KernelResult] -> Binder Kernels [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
    (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(MonadBinder m, Lore m ~ Kernels) =>
KernelInput -> m ()
readKernelInput ([KernelInput] -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
    (SubExp -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) (Result -> [KernelResult])
-> BinderT Kernels (State VNameSource) Result
-> Binder Kernels [KernelResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind Body (Lore (BinderT Kernels (State VNameSource)))
Body Kernels
inner_body

  (SegOp Kernels
segop, Stms Kernels
aux_stms) <- m (SegOp Kernels, Stms Kernels)
-> BinderT Kernels m (SegOp Kernels, Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SegOp Kernels, Stms Kernels)
 -> BinderT Kernels m (SegOp Kernels, Stms Kernels))
-> m (SegOp Kernels, Stms Kernels)
-> BinderT Kernels m (SegOp Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ MkSegLevel m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> m (SegOp Kernels, Stms Kernels)
forall (m :: * -> *).
(HasScope Kernels m, MonadFreshNames m) =>
MkSegLevel m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> m (SegOp Kernels, Stms Kernels)
mapKernel MkSegLevel m
mk_lvl [(VName, SubExp)]
ispace [] [Type]
rts KernelBody Kernels
inner_body'

  Stms (Lore (BinderT Kernels m)) -> BinderT Kernels m ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels m))
Stms Kernels
aux_stms

  Stm Kernels -> BinderT Kernels m (Stm Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm Kernels -> BinderT Kernels m (Stm Kernels))
-> Stm Kernels -> BinderT Kernels m (Stm Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpAttr Kernels) -> Exp Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpAttr lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat (Certificates -> () -> StmAux ()
forall attr. Certificates -> attr -> StmAux attr
StmAux Certificates
cs ()) (Exp Kernels -> Stm Kernels) -> Exp Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> Exp Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> Exp Kernels) -> Op Kernels -> Exp Kernels
forall a b. (a -> b) -> a -> b
$ SegOp Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp lore -> HostOp lore op
SegOp SegOp Kernels
segop
  where
    first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
kernel_nest
    inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` Body Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Body Kernels
inner_body

-- | Flatten a kernel nesting to:
--
--  (1) The index space.
--
--  (2) The kernel inputs - note that some of these may be unused.
flatKernel :: MonadFreshNames m =>
              KernelNest
           -> m ([(VName, SubExp)],
                 [KernelInput])
flatKernel :: KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (MapNesting Pattern Kernels
_ Certificates
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, []) = do
  VName
i <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let inps :: [KernelInput]
inps = [ VName -> Type -> VName -> Result -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i] |
               (Param VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs ]
  ([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => a -> m a
return ([(VName
i,SubExp
nesting_w)], [KernelInput]
inps)

flatKernel (MapNesting Pattern Kernels
_ Certificates
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, LoopNesting
nest : [LoopNesting]
nests) = do
  VName
i <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (LoopNesting
nest, [LoopNesting]
nests)

  let inps' :: [KernelInput]
inps' = (KernelInput -> KernelInput) -> [KernelInput] -> [KernelInput]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> KernelInput
fixupInput [KernelInput]
inps
      isParam :: KernelInput -> Maybe VName
isParam KernelInput
inp =
        (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==KernelInput -> VName
kernelInputArray KernelInput
inp) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
      fixupInput :: KernelInput -> KernelInput
fixupInput KernelInput
inp
        | Just VName
arr <- KernelInput -> Maybe VName
isParam KernelInput
inp =
            KernelInput
inp { kernelInputArray :: VName
kernelInputArray = VName
arr
                , kernelInputIndices :: Result
kernelInputIndices = VName -> SubExp
Var VName
i SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: KernelInput -> Result
kernelInputIndices KernelInput
inp }
        | Bool
otherwise =
            KernelInput
inp

  ([(VName, SubExp)], [KernelInput])
-> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => a -> m a
return ((VName
i, SubExp
nesting_w) (VName, SubExp) -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. a -> [a] -> [a]
: [(VName, SubExp)]
ispace, VName -> [KernelInput]
extra_inps VName
i [KernelInput] -> [KernelInput] -> [KernelInput]
forall a. Semigroup a => a -> a -> a
<> [KernelInput]
inps')
  where extra_inps :: VName -> [KernelInput]
extra_inps VName
i =
          [ VName -> Type -> VName -> Result -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i] |
            (Param VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs ]

-- | Description of distribution to do.
data DistributionBody = DistributionBody {
    DistributionBody -> Targets
distributionTarget :: Targets
  , DistributionBody -> Names
distributionFreeInBody :: Names
  , DistributionBody -> Map VName Ident
distributionIdentityMap :: M.Map VName Ident
  , DistributionBody -> Target -> Target
distributionExpandTarget :: Target -> Target
    -- ^ Also related to avoiding identity mapping.
  }

distributionInnerPattern :: DistributionBody -> Pattern Kernels
distributionInnerPattern :: DistributionBody -> Pattern Kernels
distributionInnerPattern = (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst ((PatternT Type, Result) -> PatternT Type)
-> (DistributionBody -> (PatternT Type, Result))
-> DistributionBody
-> PatternT Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> (PatternT Type, Result)
Targets -> Target
innerTarget (Targets -> (PatternT Type, Result))
-> (DistributionBody -> Targets)
-> DistributionBody
-> (PatternT Type, Result)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Targets
distributionTarget

distributionBodyFromStms :: Attributes lore =>
                            Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms :: Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms (Targets (Pattern Kernels
inner_pat, Result
inner_res) [Target]
targets) Stms lore
stms =
  let bound_by_stms :: Names
bound_by_stms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo lore) -> [VName])
-> Map VName (NameInfo lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms
      (PatternT Type
inner_pat', Result
inner_res', Map VName Ident
inner_identity_map, (PatternT Type, Result) -> (PatternT Type, Result)
inner_expand_target) =
        Names
-> Pattern Kernels
-> Result
-> (Pattern Kernels, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_by_stms Pattern Kernels
inner_pat Result
inner_res
  in (DistributionBody :: Targets
-> Names
-> Map VName Ident
-> (Target -> Target)
-> DistributionBody
DistributionBody
      { distributionTarget :: Targets
distributionTarget = Target -> [Target] -> Targets
Targets (PatternT Type
Pattern Kernels
inner_pat', Result
inner_res') [Target]
targets
      , distributionFreeInBody :: Names
distributionFreeInBody = (Stm lore -> Names) -> Stms lore -> Names
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Names
forall a. FreeIn a => a -> Names
freeIn Stms lore
stms Names -> Names -> Names
`namesSubtract` Names
bound_by_stms
      , distributionIdentityMap :: Map VName Ident
distributionIdentityMap = Map VName Ident
inner_identity_map
      , distributionExpandTarget :: Target -> Target
distributionExpandTarget = (PatternT Type, Result) -> (PatternT Type, Result)
Target -> Target
inner_expand_target
      },
      Result
inner_res')

distributionBodyFromStm :: Attributes lore =>
                           Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm :: Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm lore
bnd =
  Targets -> Stms lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets (Stms lore -> (DistributionBody, Result))
-> Stms lore -> (DistributionBody, Result)
forall a b. (a -> b) -> a -> b
$ Stm lore -> Stms lore
forall lore. Stm lore -> Stms lore
oneStm Stm lore
bnd

createKernelNest :: (MonadFreshNames m, HasScope t m) =>
                    Nestings
                 -> DistributionBody
                 -> m (Maybe (Targets, KernelNest))
createKernelNest :: Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest (Nesting
inner_nest, [Nesting]
nests) DistributionBody
distrib_body = do
  let Targets Target
target [Target]
targets = DistributionBody -> Targets
distributionTarget DistributionBody
distrib_body
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Nesting] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting]
nests Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [(PatternT Type, Result)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(PatternT Type, Result)]
[Target]
targets) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    String -> m ()
forall a. HasCallStack => String -> a
error (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"Nests and targets do not match!\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++
    String
"nests: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Nestings -> String
ppNestings (Nesting
inner_nest, [Nesting]
nests) String -> String -> String
forall a. [a] -> [a] -> [a]
++
    String
"\ntargets:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets (Target -> [Target] -> Targets
Targets Target
target [Target]
targets)
  MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Targets, KernelNest) -> m (Maybe (Targets, KernelNest)))
-> MaybeT m (Targets, KernelNest)
-> m (Maybe (Targets, KernelNest))
forall a b. (a -> b) -> a -> b
$ ((KernelNest, Names, Targets) -> (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (KernelNest, Names, Targets) -> (Targets, KernelNest)
forall b b a. (b, b, a) -> (a, b)
prepare (MaybeT m (KernelNest, Names, Targets)
 -> MaybeT m (Targets, KernelNest))
-> MaybeT m (KernelNest, Names, Targets)
-> MaybeT m (Targets, KernelNest)
forall a b. (a -> b) -> a -> b
$ [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
[(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse ([(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets))
-> [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$ [Nesting]
-> [(PatternT Type, Result)]
-> [(Nesting, (PatternT Type, Result))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Nesting]
nests [(PatternT Type, Result)]
[Target]
targets

  where prepare :: (b, b, a) -> (a, b)
prepare (b
x, b
_, a
z) = (a
z, b
x)
        bound_in_nest :: Names
bound_in_nest =
          [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (Nesting -> Names) -> [Nesting] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map Nesting -> Names
boundInNesting ([Nesting] -> [Names]) -> [Nesting] -> [Names]
forall a b. (a -> b) -> a -> b
$ Nesting
inner_nest Nesting -> [Nesting] -> [Nesting]
forall a. a -> [a] -> [a]
: [Nesting]
nests
        -- | Can something of this type be taken outside the nest?
        -- I.e. are none of its dimensions bound inside the nest.
        distributableType :: Type -> Bool
distributableType =
          (Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
==Names
forall a. Monoid a => a
mempty) (Names -> Bool) -> (Type -> Names) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Names -> Names
namesIntersection Names
bound_in_nest (Names -> Names) -> (Type -> Names) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Result -> Names
forall a. FreeIn a => a -> Names
freeIn (Result -> Names) -> (Type -> Result) -> Type -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims

        distributeAtNesting :: (HasScope t m, MonadFreshNames m) =>
                               Nesting
                            -> Pattern Kernels
                            -> (LoopNesting -> KernelNest, Names)
                            -> M.Map VName Ident
                            -> [Ident]
                            -> (Target -> Targets)
                            -> MaybeT m (KernelNest, Names, Targets)
        distributeAtNesting :: Nesting
-> Pattern Kernels
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
          (Nesting Names
nest_let_bound LoopNesting
nest)
          Pattern Kernels
pat
          (LoopNesting -> KernelNest
add_to_kernel, Names
free_in_kernel)
          Map VName Ident
identity_map
          [Ident]
inner_returned_arrs
          Target -> Targets
addTarget = do
          let nest' :: LoopNesting
nest'@(MapNesting Pattern Kernels
_ Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
                Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel LoopNesting
nest
              ([Param Type]
params,[VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
              param_names :: Names
param_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName [Param Type]
params
              free_in_kernel' :: Names
free_in_kernel' =
                (LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
param_names
              required_from_nest :: Names
required_from_nest =
                Names
free_in_kernel' Names -> Names -> Names
`namesIntersection` Names
nest_let_bound

          [Ident]
required_from_nest_idents <-
            [VName] -> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
required_from_nest) ((VName -> MaybeT m Ident) -> MaybeT m [Ident])
-> (VName -> MaybeT m Ident) -> MaybeT m [Ident]
forall a b. (a -> b) -> a -> b
$ \VName
name -> do
              Type
t <- m Type -> MaybeT m Type
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Type -> MaybeT m Type) -> m Type -> MaybeT m Type
forall a b. (a -> b) -> a -> b
$ VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
name
              Ident -> MaybeT m Ident
forall (m :: * -> *) a. Monad m => a -> m a
return (Ident -> MaybeT m Ident) -> Ident -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t

          ([Param Type]
free_params, [Ident]
free_arrs, [Bool]
bind_in_target) <-
            ([(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(Param Type, Ident, Bool)] -> ([Param Type], [Ident], [Bool])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 (MaybeT m [(Param Type, Ident, Bool)]
 -> MaybeT m ([Param Type], [Ident], [Bool]))
-> MaybeT m [(Param Type, Ident, Bool)]
-> MaybeT m ([Param Type], [Ident], [Bool])
forall a b. (a -> b) -> a -> b
$
            [Ident]
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
inner_returned_arrs[Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++[Ident]
required_from_nest_idents) ((Ident -> MaybeT m (Param Type, Ident, Bool))
 -> MaybeT m [(Param Type, Ident, Bool)])
-> (Ident -> MaybeT m (Param Type, Ident, Bool))
-> MaybeT m [(Param Type, Ident, Bool)]
forall a b. (a -> b) -> a -> b
$
            \(Ident VName
pname Type
ptype) ->
              case VName -> Map VName Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pname Map VName Ident
identity_map of
                Maybe Ident
Nothing -> do
                  Ident
arr <- String -> Type -> MaybeT m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (VName -> String
baseString VName
pname String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_r") (Type -> MaybeT m Ident) -> Type -> MaybeT m Ident
forall a b. (a -> b) -> a -> b
$
                         Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
ptype SubExp
w
                  (Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param VName
pname Type
ptype,
                          Ident
arr,
                          Bool
True)
                Just Ident
arr ->
                  (Param Type, Ident, Bool) -> MaybeT m (Param Type, Ident, Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Type -> Param Type
forall attr. VName -> attr -> Param attr
Param VName
pname Type
ptype,
                          Ident
arr,
                          Bool
False)

          let free_arrs_pat :: PatternT Type
free_arrs_pat =
                [Ident] -> [Ident] -> PatternT Type
basicPattern [] ([Ident] -> PatternT Type) -> [Ident] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ ((Bool, Ident) -> Ident) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Ident) -> Ident
forall a b. (a, b) -> b
snd ([(Bool, Ident)] -> [Ident]) -> [(Bool, Ident)] -> [Ident]
forall a b. (a -> b) -> a -> b
$
                ((Bool, Ident) -> Bool) -> [(Bool, Ident)] -> [(Bool, Ident)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Ident) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Ident)] -> [(Bool, Ident)])
-> [(Bool, Ident)] -> [(Bool, Ident)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Ident] -> [(Bool, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Ident]
free_arrs
              free_params_pat :: [Param Type]
free_params_pat =
                ((Bool, Param Type) -> Param Type)
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Param Type) -> Param Type
forall a b. (a, b) -> b
snd ([(Bool, Param Type)] -> [Param Type])
-> [(Bool, Param Type)] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ ((Bool, Param Type) -> Bool)
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Param Type) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Param Type)] -> [(Bool, Param Type)])
-> [(Bool, Param Type)] -> [(Bool, Param Type)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Param Type] -> [(Bool, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Param Type]
free_params

              ([Param Type]
actual_params, [VName]
actual_arrs) =
                ([Param Type]
params[Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++[Param Type]
free_params,
                 [VName]
arrs[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++(Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
free_arrs)
              actual_param_names :: Names
actual_param_names =
                [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall attr. Param attr -> VName
paramName [Param Type]
actual_params

              nest'' :: LoopNesting
nest'' =
                Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel (LoopNesting -> LoopNesting) -> LoopNesting -> LoopNesting
forall a b. (a -> b) -> a -> b
$
                Pattern Kernels
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pattern Kernels
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
actual_params [VName]
actual_arrs

              free_in_kernel'' :: Names
free_in_kernel'' =
                (LoopNesting -> Names
forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest'' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
actual_param_names

          Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
distributableType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall attr. Typed attr => Param attr -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$
                  LoopNesting -> [LParam Kernels]
loopNestingParams LoopNesting
nest'') (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
            String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Would induce irregular array"
          (KernelNest, Names, Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall (m :: * -> *) a. Monad m => a -> m a
return (LoopNesting -> KernelNest
add_to_kernel LoopNesting
nest'',

                  Names
free_in_kernel'',

                  Target -> Targets
addTarget (PatternT Type
Pattern Kernels
free_arrs_pat, (Param Type -> SubExp) -> [Param Type] -> Result
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName) [Param Type]
free_params_pat))

        recurse :: (HasScope t m, MonadFreshNames m) =>
                   [(Nesting,Target)]
                -> MaybeT m (KernelNest, Names, Targets)
        recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [] =
          Nesting
-> Pattern Kernels
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
Nesting
-> Pattern Kernels
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
          Nesting
inner_nest
          (DistributionBody -> Pattern Kernels
distributionInnerPattern DistributionBody
distrib_body)
          (LoopNesting -> KernelNest
newKernel,
           DistributionBody -> Names
distributionFreeInBody DistributionBody
distrib_body Names -> Names -> Names
`namesIntersection` Names
bound_in_nest)
          (DistributionBody -> Map VName Ident
distributionIdentityMap DistributionBody
distrib_body)
          [] ((Target -> Targets) -> MaybeT m (KernelNest, Names, Targets))
-> (Target -> Targets) -> MaybeT m (KernelNest, Names, Targets)
forall a b. (a -> b) -> a -> b
$
          (PatternT Type, Result) -> Targets
Target -> Targets
singleTarget ((PatternT Type, Result) -> Targets)
-> ((PatternT Type, Result) -> (PatternT Type, Result))
-> (PatternT Type, Result)
-> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Target -> Target
distributionExpandTarget DistributionBody
distrib_body

        recurse ((Nesting
nest, (Pattern Kernels
pat,Result
res)) : [(Nesting, Target)]
nests') = do
          (kernel :: KernelNest
kernel@(LoopNesting
outer, [LoopNesting]
_), Names
kernel_free, Targets
kernel_targets) <- [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
[(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [(Nesting, Target)]
nests'

          let (PatternT Type
pat', Result
res', Map VName Ident
identity_map, (PatternT Type, Result) -> (PatternT Type, Result)
expand_target) =
                Names
-> Pattern Kernels
-> Result
-> (Pattern Kernels, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting
                ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pattern Kernels
loopNestingPattern LoopNesting
outer) Pattern Kernels
pat Result
res

          Nesting
-> Pattern Kernels
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
forall t (m :: * -> *).
(HasScope t m, MonadFreshNames m) =>
Nesting
-> Pattern Kernels
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
            Nesting
nest
            PatternT Type
Pattern Kernels
pat'
            (\LoopNesting
k -> Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting (PatternT Type
Pattern Kernels
pat',Result
res') LoopNesting
k KernelNest
kernel,
             Names
kernel_free)
            Map VName Ident
identity_map
            (PatternT Type -> [Ident]
forall attr. Typed attr => PatternT attr -> [Ident]
patternIdents (PatternT Type -> [Ident]) -> PatternT Type -> [Ident]
forall a b. (a -> b) -> a -> b
$ (PatternT Type, Result) -> PatternT Type
forall a b. (a, b) -> a
fst ((PatternT Type, Result) -> PatternT Type)
-> (PatternT Type, Result) -> PatternT Type
forall a b. (a -> b) -> a -> b
$ Targets -> Target
outerTarget Targets
kernel_targets)
            ((Target -> Targets -> Targets
`pushOuterTarget` Targets
kernel_targets) ((PatternT Type, Result) -> Targets)
-> ((PatternT Type, Result) -> (PatternT Type, Result))
-> (PatternT Type, Result)
-> Targets
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatternT Type, Result) -> (PatternT Type, Result)
expand_target)

removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
used (MapNesting Pattern Kernels
pat Certificates
cs SubExp
w [(Param Type, VName)]
params_and_arrs) =
  Pattern Kernels
-> Certificates -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pattern Kernels
pat Certificates
cs SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
used_params [VName]
used_arrs
  where ([Param Type]
params,[VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
        ([Param Type]
used_params, [VName]
used_arrs) =
          [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
          ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall attr. Param attr -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> [(Param Type, VName)])
-> [(Param Type, VName)] -> [(Param Type, VName)]
forall a b. (a -> b) -> a -> b
$
          [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
params [VName]
arrs

removeIdentityMappingGeneral :: Names -> Pattern Kernels -> Result
                             -> (Pattern Kernels,
                                 Result,
                                 M.Map VName Ident,
                                 Target -> Target)
removeIdentityMappingGeneral :: Names
-> Pattern Kernels
-> Result
-> (Pattern Kernels, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound Pattern Kernels
pat Result
res =
  let ([(PatElemT Type, VName)]
identities, [(PatElemT Type, SubExp)]
not_identities) =
        ((PatElemT Type, SubExp)
 -> Either (PatElemT Type, VName) (PatElemT Type, SubExp))
-> [(PatElemT Type, SubExp)]
-> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)])
forall a b c. (a -> Either b c) -> [a] -> ([b], [c])
mapEither (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
isIdentity ([(PatElemT Type, SubExp)]
 -> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)]))
-> [(PatElemT Type, SubExp)]
-> ([(PatElemT Type, VName)], [(PatElemT Type, SubExp)])
forall a b. (a -> b) -> a -> b
$ [PatElemT Type] -> Result -> [(PatElemT Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
Pattern Kernels
pat) Result
res
      ([PatElemT Type]
not_identity_patElems, Result
not_identity_res) = [(PatElemT Type, SubExp)] -> ([PatElemT Type], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT Type, SubExp)]
not_identities
      ([PatElemT Type]
identity_patElems, [VName]
identity_res) = [(PatElemT Type, VName)] -> ([PatElemT Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElemT Type, VName)]
identities
      expandTarget :: (PatternT Type, Result) -> (PatternT Type, Result)
expandTarget (PatternT Type
tpat, Result
tres) =
        ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall attr. PatternT attr -> [PatElemT attr]
patternElements PatternT Type
tpat [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. [a] -> [a] -> [a]
++ [PatElemT Type]
identity_patElems,
         Result
tres Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
identity_res)
      identity_map :: Map VName Ident
identity_map = [(VName, Ident)] -> Map VName Ident
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Ident)] -> Map VName Ident)
-> [(VName, Ident)] -> Map VName Ident
forall a b. (a -> b) -> a -> b
$ [VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
identity_res ([Ident] -> [(VName, Ident)]) -> [Ident] -> [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$
                      (PatElemT Type -> Ident) -> [PatElemT Type] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> Ident
forall attr. Typed attr => PatElemT attr -> Ident
patElemIdent [PatElemT Type]
identity_patElems
  in ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall attr. [PatElemT attr] -> [PatElemT attr] -> PatternT attr
Pattern [] [PatElemT Type]
not_identity_patElems,
      Result
not_identity_res,
      Map VName Ident
identity_map,
      (PatternT Type, Result) -> (PatternT Type, Result)
Target -> Target
expandTarget)
  where isIdentity :: (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
isIdentity (PatElemT Type
patElem, Var VName
v)
          | Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
bound) = (PatElemT Type, VName)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
forall a b. a -> Either a b
Left (PatElemT Type
patElem, VName
v)
        isIdentity (PatElemT Type, SubExp)
x               = (PatElemT Type, SubExp)
-> Either (PatElemT Type, VName) (PatElemT Type, SubExp)
forall a b. b -> Either a b
Right (PatElemT Type, SubExp)
x

removeIdentityMappingFromNesting :: Names -> Pattern Kernels -> Result
                                 -> (Pattern Kernels,
                                     Result,
                                     M.Map VName Ident,
                                     Target -> Target)
removeIdentityMappingFromNesting :: Names
-> Pattern Kernels
-> Result
-> (Pattern Kernels, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting Names
bound_in_nesting Pattern Kernels
pat Result
res =
  let (PatternT Type
pat', Result
res', Map VName Ident
identity_map, (PatternT Type, Result) -> (PatternT Type, Result)
expand_target) =
        Names
-> Pattern Kernels
-> Result
-> (Pattern Kernels, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_in_nesting Pattern Kernels
pat Result
res
  in (PatternT Type
Pattern Kernels
pat', Result
res', Map VName Ident
identity_map, (PatternT Type, Result) -> (PatternT Type, Result)
Target -> Target
expand_target)

tryDistribute :: (MonadFreshNames m, LocalScope Kernels m, MonadLogger m) =>
                 MkSegLevel m -> Nestings -> Targets -> Stms Kernels
              -> m (Maybe (Targets, Stms Kernels))
tryDistribute :: MkSegLevel m
-> Nestings
-> Targets
-> Stms Kernels
-> m (Maybe (Targets, Stms Kernels))
tryDistribute MkSegLevel m
_ Nestings
_ Targets
targets Stms Kernels
stms | Stms Kernels -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms Kernels
stms =
  -- No point in distributing an empty kernel.
  Maybe (Targets, Stms Kernels) -> m (Maybe (Targets, Stms Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Targets, Stms Kernels)
 -> m (Maybe (Targets, Stms Kernels)))
-> Maybe (Targets, Stms Kernels)
-> m (Maybe (Targets, Stms Kernels))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms Kernels) -> Maybe (Targets, Stms Kernels)
forall a. a -> Maybe a
Just (Targets
targets, Stms Kernels
forall a. Monoid a => a
mempty)
tryDistribute MkSegLevel m
mk_lvl Nestings
nest Targets
targets Stms Kernels
stms =
  Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) t.
(MonadFreshNames m, HasScope t m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body m (Maybe (Targets, KernelNest))
-> (Maybe (Targets, KernelNest)
    -> m (Maybe (Targets, Stms Kernels)))
-> m (Maybe (Targets, Stms Kernels))
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
  \case
    Just (Targets
targets', KernelNest
distributed) -> do
      (Stm Kernels
kernel_bnd, Stms Kernels
w_bnds) <-
        Scope Kernels
-> m (Stm Kernels, Stms Kernels) -> m (Stm Kernels, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Targets -> Scope Kernels
targetsScope Targets
targets') (m (Stm Kernels, Stms Kernels) -> m (Stm Kernels, Stms Kernels))
-> m (Stm Kernels, Stms Kernels) -> m (Stm Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
        MkSegLevel m
-> KernelNest -> Body Kernels -> m (Stm Kernels, Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, LocalScope Kernels m) =>
MkSegLevel m
-> KernelNest -> Body Kernels -> m (Stm Kernels, Stms Kernels)
constructKernel MkSegLevel m
mk_lvl KernelNest
distributed (Body Kernels -> m (Stm Kernels, Stms Kernels))
-> Body Kernels -> m (Stm Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
stms Result
inner_body_res
      Stm Kernels
distributed' <- Stm Kernels -> m (Stm Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Stm lore -> m (Stm lore)
renameStm Stm Kernels
kernel_bnd
      String -> m ()
forall (m :: * -> *) a. (MonadLogger m, ToLog a) => a -> m ()
logMsg (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ String
"distributing\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++
        [String] -> String
unlines ((Stm Kernels -> String) -> [Stm Kernels] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Stm Kernels -> String
forall a. Pretty a => a -> String
pretty ([Stm Kernels] -> [String]) -> [Stm Kernels] -> [String]
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms) String -> String -> String
forall a. [a] -> [a] -> [a]
++
        Result -> String
forall a. Pretty a => a -> String
pretty ((PatternT Type, Result) -> Result
forall a b. (a, b) -> b
snd ((PatternT Type, Result) -> Result)
-> (PatternT Type, Result) -> Result
forall a b. (a -> b) -> a -> b
$ Targets -> Target
innerTarget Targets
targets) String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
"\nas\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Stm Kernels -> String
forall a. Pretty a => a -> String
pretty Stm Kernels
distributed' String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
"\ndue to targets\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets Targets
targets String -> String -> String
forall a. [a] -> [a] -> [a]
++
        String
"\nand with new targets\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Targets -> String
ppTargets Targets
targets'
      Maybe (Targets, Stms Kernels) -> m (Maybe (Targets, Stms Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Targets, Stms Kernels)
 -> m (Maybe (Targets, Stms Kernels)))
-> Maybe (Targets, Stms Kernels)
-> m (Maybe (Targets, Stms Kernels))
forall a b. (a -> b) -> a -> b
$ (Targets, Stms Kernels) -> Maybe (Targets, Stms Kernels)
forall a. a -> Maybe a
Just (Targets
targets', Stms Kernels
w_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
distributed')
    Maybe (Targets, KernelNest)
Nothing ->
      Maybe (Targets, Stms Kernels) -> m (Maybe (Targets, Stms Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Targets, Stms Kernels)
forall a. Maybe a
Nothing
  where (DistributionBody
dist_body, Result
inner_body_res) = Targets -> Stms Kernels -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets Stms Kernels
stms

tryDistributeStm :: (MonadFreshNames m, HasScope t m, Attributes lore) =>
                    Nestings -> Targets -> Stm lore
                 -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm :: Nestings
-> Targets -> Stm lore -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm lore
bnd =
  ((Targets, KernelNest) -> (Result, Targets, KernelNest))
-> Maybe (Targets, KernelNest)
-> Maybe (Result, Targets, KernelNest)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Maybe (Targets, KernelNest)
 -> Maybe (Result, Targets, KernelNest))
-> m (Maybe (Targets, KernelNest))
-> m (Maybe (Result, Targets, KernelNest))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
forall (m :: * -> *) t.
(MonadFreshNames m, HasScope t m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body
  where (DistributionBody
dist_body, Result
res) = Targets -> Stm lore -> (DistributionBody, Result)
forall lore.
Attributes lore =>
Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm lore
bnd
        addRes :: (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Targets
targets', KernelNest
kernel_nest) = (Result
res, Targets
targets', KernelNest
kernel_nest)