{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.Distribution
  ( Target,
    Targets,
    ppTargets,
    singleTarget,
    outerTarget,
    innerTarget,
    pushInnerTarget,
    popInnerTarget,
    targetsScope,
    LoopNesting (..),
    ppLoopNesting,
    scopeOfLoopNesting,
    Nesting (..),
    Nestings,
    ppNestings,
    letBindInInnerNesting,
    singleNesting,
    pushInnerNesting,
    KernelNest,
    ppKernelNest,
    newKernel,
    innermostKernelNesting,
    pushKernelNesting,
    pushInnerKernelNesting,
    scopeOfKernelNest,
    kernelNestLoops,
    kernelNestWidths,
    boundInKernelNest,
    boundInKernelNests,
    flatKernel,
    constructKernel,
    tryDistribute,
    tryDistributeStm,
  )
where

import Control.Monad.RWS.Strict
import Control.Monad.Trans.Maybe
import Data.Bifunctor (second)
import Data.Foldable
import Data.List (elemIndex, sortOn)
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
  ( DistRep,
    KernelInput (..),
    MkSegLevel,
    mapKernel,
    readKernelInput,
  )
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util
import Futhark.Util.Log

type Target = (Pat Type, Result)

-- | First pair element is the very innermost ("current") target.  In
-- the list, the outermost target comes first.  Invariant: Every
-- element of a pattern must be present as the result of the
-- immediately enclosing target.  This is ensured by 'pushInnerTarget'
-- by removing unused pattern elements.
data Targets = Targets
  { Targets -> Target
_innerTarget :: Target,
    Targets -> [Target]
_outerTargets :: [Target]
  }

ppTargets :: Targets -> String
ppTargets :: Targets -> [Char]
ppTargets (Targets Target
target [Target]
targets) =
  [[Char]] -> [Char]
unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {a} {a}. (Pretty a, Pretty a) => (a, a) -> [Char]
ppTarget forall a b. (a -> b) -> a -> b
$ [Target]
targets forall a. [a] -> [a] -> [a]
++ [Target
target]
  where
    ppTarget :: (a, a) -> [Char]
ppTarget (a
pat, a
res) = forall a. Pretty a => a -> [Char]
prettyString a
pat forall a. [a] -> [a] -> [a]
++ [Char]
" <- " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString a
res

singleTarget :: Target -> Targets
singleTarget :: Target -> Targets
singleTarget = forall a b c. (a -> b -> c) -> b -> a -> c
flip Target -> [Target] -> Targets
Targets []

outerTarget :: Targets -> Target
outerTarget :: Targets -> Target
outerTarget (Targets Target
inner_target []) = Target
inner_target
outerTarget (Targets Target
_ (Target
outer_target : [Target]
_)) = Target
outer_target

innerTarget :: Targets -> Target
innerTarget :: Targets -> Target
innerTarget (Targets Target
inner_target [Target]
_) = Target
inner_target

pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget Target
target (Targets Target
inner_target [Target]
targets) =
  Target -> [Target] -> Targets
Targets Target
inner_target (Target
target forall a. a -> [a] -> [a]
: [Target]
targets)

pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget (Pat Type
pat, Result
res) (Targets Target
inner_target [Target]
targets) =
  Target -> [Target] -> Targets
Targets (Pat Type
pat', Result
res') ([Target]
targets forall a. [a] -> [a] -> [a]
++ [Target
inner_target])
  where
    ([PatElem Type]
pes', Result
res') = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem Type -> Bool
used forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) Result
res
    pat' :: Pat Type
pat' = forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes'
    inner_used :: Names
inner_used = forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd Target
inner_target
    used :: PatElem Type -> Bool
used PatElem Type
pe = forall dec. PatElem dec -> VName
patElemName PatElem Type
pe VName -> Names -> Bool
`nameIn` Names
inner_used

popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget (Targets Target
t [Target]
ts) =
  case forall a. [a] -> [a]
reverse [Target]
ts of
    Target
x : [Target]
xs -> forall a. a -> Maybe a
Just (Target
t, Target -> [Target] -> Targets
Targets Target
x forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [Target]
xs)
    [] -> forall a. Maybe a
Nothing

targetScope :: DistRep rep => Target -> Scope rep
targetScope :: forall rep. DistRep rep => Target -> Scope rep
targetScope = forall {k} (rep :: k) dec.
(LetDec rep ~ dec) =>
Pat dec -> Scope rep
scopeOfPat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst

targetsScope :: DistRep rep => Targets -> Scope rep
targetsScope :: forall rep. DistRep rep => Targets -> Scope rep
targetsScope (Targets Target
t [Target]
ts) = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall rep. DistRep rep => Target -> Scope rep
targetScope forall a b. (a -> b) -> a -> b
$ Target
t forall a. a -> [a] -> [a]
: [Target]
ts

data LoopNesting = MapNesting
  { LoopNesting -> Pat Type
loopNestingPat :: Pat Type,
    LoopNesting -> StmAux ()
loopNestingAux :: StmAux (),
    LoopNesting -> SubExp
loopNestingWidth :: SubExp,
    LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs :: [(Param Type, VName)]
  }
  deriving (Int -> LoopNesting -> ShowS
[LoopNesting] -> ShowS
LoopNesting -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [LoopNesting] -> ShowS
$cshowList :: [LoopNesting] -> ShowS
show :: LoopNesting -> [Char]
$cshow :: LoopNesting -> [Char]
showsPrec :: Int -> LoopNesting -> ShowS
$cshowsPrec :: Int -> LoopNesting -> ShowS
Show)

scopeOfLoopNesting :: (LParamInfo rep ~ Type) => LoopNesting -> Scope rep
scopeOfLoopNesting :: forall {k} (rep :: k).
(LParamInfo rep ~ Type) =>
LoopNesting -> Scope rep
scopeOfLoopNesting = forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs

ppLoopNesting :: LoopNesting -> String
ppLoopNesting :: LoopNesting -> [Char]
ppLoopNesting (MapNesting Pat Type
_ StmAux ()
_ SubExp
_ [(Param Type, VName)]
params_and_arrs) =
  forall a. Pretty a => a -> [Char]
prettyString (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs)
    forall a. [a] -> [a] -> [a]
++ [Char]
" <- "
    forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Param Type, VName)]
params_and_arrs)

loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams :: LoopNesting -> [Param Type]
loopNestingParams = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs

instance FreeIn LoopNesting where
  freeIn' :: LoopNesting -> FV
freeIn' (MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) =
    forall a. FreeIn a => a -> FV
freeIn' Pat Type
pat forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' StmAux ()
aux forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' SubExp
w forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' [(Param Type, VName)]
params_and_arrs

data Nesting = Nesting
  { Nesting -> Names
nestingLetBound :: Names,
    Nesting -> LoopNesting
nestingLoop :: LoopNesting
  }
  deriving (Int -> Nesting -> ShowS
[Nesting] -> ShowS
Nesting -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Nesting] -> ShowS
$cshowList :: [Nesting] -> ShowS
show :: Nesting -> [Char]
$cshow :: Nesting -> [Char]
showsPrec :: Int -> Nesting -> ShowS
$cshowsPrec :: Int -> Nesting -> ShowS
Show)

letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting Names
newnames (Nesting Names
oldnames LoopNesting
loop) =
  Names -> LoopNesting -> Nesting
Nesting (Names
oldnames forall a. Semigroup a => a -> a -> a
<> Names
newnames) LoopNesting
loop
-- ^ First pair element is the very innermost ("current") nest.  In
-- the list, the outermost nest comes first.

type Nestings = (Nesting, [Nesting])

ppNestings :: Nestings -> String
ppNestings :: Nestings -> [Char]
ppNestings (Nesting
nesting, [Nesting]
nestings) =
  [[Char]] -> [Char]
unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Nesting -> [Char]
ppNesting forall a b. (a -> b) -> a -> b
$ [Nesting]
nestings forall a. [a] -> [a] -> [a]
++ [Nesting
nesting]
  where
    ppNesting :: Nesting -> [Char]
ppNesting (Nesting Names
_ LoopNesting
loop) = LoopNesting -> [Char]
ppLoopNesting LoopNesting
loop

singleNesting :: Nesting -> Nestings
singleNesting :: Nesting -> Nestings
singleNesting = (,[])

pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting Nesting
nesting (Nesting
inner_nesting, [Nesting]
nestings) =
  (Nesting
nesting, [Nesting]
nestings forall a. [a] -> [a] -> [a]
++ [Nesting
inner_nesting])

-- | Both parameters and let-bound.
boundInNesting :: Nesting -> Names
boundInNesting :: Nesting -> Names
boundInNesting Nesting
nesting =
  [VName] -> Names
namesFromList (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (LoopNesting -> [Param Type]
loopNestingParams LoopNesting
loop))
    forall a. Semigroup a => a -> a -> a
<> Nesting -> Names
nestingLetBound Nesting
nesting
  where
    loop :: LoopNesting
loop = Nesting -> LoopNesting
nestingLoop Nesting
nesting

letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting Names
names (Nesting
nest, [Nesting]
nestings) =
  (Names -> Nesting -> Nesting
letBindInNesting Names
names Nesting
nest, [Nesting]
nestings)

-- | Note: first element is *outermost* nesting.  This is different
-- from the similar types elsewhere!
type KernelNest = (LoopNesting, [LoopNesting])

ppKernelNest :: KernelNest -> String
ppKernelNest :: KernelNest -> [Char]
ppKernelNest (LoopNesting
nesting, [LoopNesting]
nestings) =
  [[Char]] -> [Char]
unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> [Char]
ppLoopNesting forall a b. (a -> b) -> a -> b
$ LoopNesting
nesting forall a. a -> [a] -> [a]
: [LoopNesting]
nestings

-- | Retrieve the innermost kernel nesting.
innermostKernelNesting :: KernelNest -> LoopNesting
innermostKernelNesting :: KernelNest -> LoopNesting
innermostKernelNesting (LoopNesting
nest, [LoopNesting]
nests) =
  forall a. a -> Maybe a -> a
fromMaybe LoopNesting
nest forall a b. (a -> b) -> a -> b
$ forall a. [a] -> Maybe a
maybeHead forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [LoopNesting]
nests

-- | Add new outermost nesting, pushing the current outermost to the
-- list, also taking care to swap patterns if necessary.
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
  ( LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder LoopNesting
newnest Target
target (LoopNesting -> Pat Type
loopNestingPat LoopNesting
nest),
    LoopNesting
nest forall a. a -> [a] -> [a]
: [LoopNesting]
nests
  )

-- | Add new innermost nesting, pushing the current outermost to the
-- list.  It is important that the 'Target' has the right order
-- (non-permuted compared to what is expected by the outer nests).
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting Target
target LoopNesting
newnest (LoopNesting
nest, [LoopNesting]
nests) =
  (LoopNesting
nest, [LoopNesting]
nests forall a. [a] -> [a] -> [a]
++ [LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder LoopNesting
newnest Target
target (LoopNesting -> Pat Type
loopNestingPat LoopNesting
innermost)])
  where
    innermost :: LoopNesting
innermost = case forall a. [a] -> [a]
reverse [LoopNesting]
nests of
      [] -> LoopNesting
nest
      LoopNesting
n : [LoopNesting]
_ -> LoopNesting
n

fixNestingPatOrder :: LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder :: LoopNesting -> Target -> Pat Type -> LoopNesting
fixNestingPatOrder LoopNesting
nest (Pat Type
_, Result
res) Pat Type
inner_pat =
  LoopNesting
nest {loopNestingPat :: Pat Type
loopNestingPat = [Ident] -> Pat Type
basicPat [Ident]
pat'}
  where
    pat :: Pat Type
pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
nest
    pat' :: [Ident]
pat' = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(Ident, SubExpRes)]
fixed_target
    fixed_target :: [(Ident, SubExpRes)]
fixed_target = forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Ident, SubExpRes) -> Int
posInInnerPat forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Typed dec => Pat dec -> [Ident]
patIdents Pat Type
pat) Result
res
    posInInnerPat :: (Ident, SubExpRes) -> Int
posInInnerPat (Ident
_, SubExpRes Certs
_ (Var VName
v)) = forall a. a -> Maybe a -> a
fromMaybe Int
0 forall a b. (a -> b) -> a -> b
$ forall a. Eq a => a -> [a] -> Maybe Int
elemIndex VName
v forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
inner_pat
    posInInnerPat (Ident, SubExpRes)
_ = Int
0

newKernel :: LoopNesting -> KernelNest
newKernel :: LoopNesting -> KernelNest
newKernel LoopNesting
nest = (LoopNesting
nest, [])

kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops (LoopNesting
loop, [LoopNesting]
loops) = LoopNesting
loop forall a. a -> [a] -> [a]
: [LoopNesting]
loops

scopeOfKernelNest :: LParamInfo rep ~ Type => KernelNest -> Scope rep
scopeOfKernelNest :: forall {k} (rep :: k).
(LParamInfo rep ~ Type) =>
KernelNest -> Scope rep
scopeOfKernelNest = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} (rep :: k).
(LParamInfo rep ~ Type) =>
LoopNesting -> Scope rep
scopeOfLoopNesting forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops

boundInKernelNest :: KernelNest -> Names
boundInKernelNest :: KernelNest -> Names
boundInKernelNest = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [Names]
boundInKernelNests

boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests =
  forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Names
namesFromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall b c a. (b -> c) -> (a -> b) -> a -> c
. LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs)
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops

kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths = forall a b. (a -> b) -> [a] -> [b]
map LoopNesting -> SubExp
loopNestingWidth forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelNest -> [LoopNesting]
kernelNestLoops

constructKernel ::
  (DistRep rep, MonadFreshNames m, LocalScope rep m) =>
  MkSegLevel rep m ->
  KernelNest ->
  Body rep ->
  m (Stm rep, Stms rep)
constructKernel :: forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel MkSegLevel rep m
mk_lvl KernelNest
kernel_nest Body rep
inner_body = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
kernel_nest
  let aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest
      ispace_scope :: Map VName (NameInfo rep)
ispace_scope = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, SubExp)]
ispace) forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). IntType -> NameInfo rep
IndexName IntType
Int64
      pat :: Pat Type
pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
first_nest
      rts :: [Type]
rts = forall a b. (a -> b) -> [a] -> [b]
map (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace)) forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat

  KernelBody rep
inner_body' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ()))) forall a b. (a -> b) -> a -> b
$
    forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Map VName (NameInfo rep)
ispace_scope forall a b. (a -> b) -> a -> b
$ do
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
      Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body rep
inner_body
      forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
cs SubExp
se) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se

  (SegOp (SegOpLevel rep) rep
segop, Stms rep
aux_stms) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel MkSegLevel rep m
mk_lvl [(VName, SubExp)]
ispace [] [Type]
rts KernelBody rep
inner_body'

  forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
aux_stms

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp SegOp (SegOpLevel rep) rep
segop
  where
    first_nest :: LoopNesting
first_nest = forall a b. (a, b) -> a
fst KernelNest
kernel_nest
    inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Body rep
inner_body

-- | Flatten a kernel nesting to:
--
--  (1) The index space.
--
--  (2) The kernel inputs - note that some of these may be unused.
flatKernel ::
  MonadFreshNames m =>
  KernelNest ->
  m ([(VName, SubExp)], [KernelInput])
flatKernel :: forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (MapNesting Pat Type
_ StmAux ()
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, []) = do
  VName
i <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid"
  let inps :: [KernelInput]
inps =
        [ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i]
          | (Param Attrs
_ VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs
        ]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(VName
i, SubExp
nesting_w)], [KernelInput]
inps)
flatKernel (MapNesting Pat Type
_ StmAux ()
_ SubExp
nesting_w [(Param Type, VName)]
params_and_arrs, LoopNesting
nest : [LoopNesting]
nests) = do
  VName
i <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"gtid"
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel (LoopNesting
nest, [LoopNesting]
nests)

  let inps' :: [KernelInput]
inps' = forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> KernelInput
fixupInput [KernelInput]
inps
      isParam :: KernelInput -> Maybe VName
isParam KernelInput
inp =
        forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== KernelInput -> VName
kernelInputArray KernelInput
inp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
      fixupInput :: KernelInput -> KernelInput
fixupInput KernelInput
inp
        | Just VName
arr <- KernelInput -> Maybe VName
isParam KernelInput
inp =
            KernelInput
inp
              { kernelInputArray :: VName
kernelInputArray = VName
arr,
                kernelInputIndices :: [SubExp]
kernelInputIndices = VName -> SubExp
Var VName
i forall a. a -> [a] -> [a]
: KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp
              }
        | Bool
otherwise =
            KernelInput
inp

  forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName
i, SubExp
nesting_w) forall a. a -> [a] -> [a]
: [(VName, SubExp)]
ispace, VName -> [KernelInput]
extra_inps VName
i forall a. Semigroup a => a -> a -> a
<> [KernelInput]
inps')
  where
    extra_inps :: VName -> [KernelInput]
extra_inps VName
i =
      [ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput VName
pname Type
ptype VName
arr [VName -> SubExp
Var VName
i]
        | (Param Attrs
_ VName
pname Type
ptype, VName
arr) <- [(Param Type, VName)]
params_and_arrs
      ]

-- | Description of distribution to do.
data DistributionBody = DistributionBody
  { DistributionBody -> Targets
distributionTarget :: Targets,
    DistributionBody -> Names
distributionFreeInBody :: Names,
    DistributionBody -> Map VName Ident
distributionIdentityMap :: M.Map VName Ident,
    -- | Also related to avoiding identity mapping.
    DistributionBody -> Target -> Target
distributionExpandTarget :: Target -> Target
  }

distributionInnerPat :: DistributionBody -> Pat Type
distributionInnerPat :: DistributionBody -> Pat Type
distributionInnerPat = forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. Targets -> Target
innerTarget forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Targets
distributionTarget

distributionBodyFromStms ::
  ASTRep rep =>
  Targets ->
  Stms rep ->
  (DistributionBody, Result)
distributionBodyFromStms :: forall {k} (rep :: k).
ASTRep rep =>
Targets -> Stms rep -> (DistributionBody, Result)
distributionBodyFromStms (Targets (Pat Type
inner_pat, Result
inner_res) [Target]
targets) Stms rep
stms =
  let bound_by_stms :: Names
bound_by_stms = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms
      (Pat Type
inner_pat', Result
inner_res', Map VName Ident
inner_identity_map, Target -> Target
inner_expand_target) =
        Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_by_stms Pat Type
inner_pat Result
inner_res
   in ( DistributionBody
          { distributionTarget :: Targets
distributionTarget = Target -> [Target] -> Targets
Targets (Pat Type
inner_pat', Result
inner_res') [Target]
targets,
            distributionFreeInBody :: Names
distributionFreeInBody = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> Names
freeIn Stms rep
stms Names -> Names -> Names
`namesSubtract` Names
bound_by_stms,
            distributionIdentityMap :: Map VName Ident
distributionIdentityMap = Map VName Ident
inner_identity_map,
            distributionExpandTarget :: Target -> Target
distributionExpandTarget = Target -> Target
inner_expand_target
          },
        Result
inner_res'
      )

distributionBodyFromStm ::
  ASTRep rep =>
  Targets ->
  Stm rep ->
  (DistributionBody, Result)
distributionBodyFromStm :: forall {k} (rep :: k).
ASTRep rep =>
Targets -> Stm rep -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm rep
stm =
  forall {k} (rep :: k).
ASTRep rep =>
Targets -> Stms rep -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm rep
stm

createKernelNest ::
  forall rep m.
  (MonadFreshNames m, HasScope rep m) =>
  Nestings ->
  DistributionBody ->
  m (Maybe (Targets, KernelNest))
createKernelNest :: forall {k} (rep :: k) (m :: * -> *).
(MonadFreshNames m, HasScope rep m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest (Nesting
inner_nest, [Nesting]
nests) DistributionBody
distrib_body = do
  let Targets Target
target [Target]
targets = DistributionBody -> Targets
distributionTarget DistributionBody
distrib_body
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting]
nests forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [Target]
targets) forall a b. (a -> b) -> a -> b
$
    forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
      [Char]
"Nests and targets do not match!\n"
        forall a. [a] -> [a] -> [a]
++ [Char]
"nests: "
        forall a. [a] -> [a] -> [a]
++ Nestings -> [Char]
ppNestings (Nesting
inner_nest, [Nesting]
nests)
        forall a. [a] -> [a] -> [a]
++ [Char]
"\ntargets:"
        forall a. [a] -> [a] -> [a]
++ Targets -> [Char]
ppTargets (Target -> [Target] -> Targets
Targets Target
target [Target]
targets)
  forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {b} {b} {a}. (b, b, a) -> (a, b)
prepare forall a b. (a -> b) -> a -> b
$ [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Nesting]
nests [Target]
targets
  where
    prepare :: (b, b, a) -> (a, b)
prepare (b
x, b
_, a
z) = (a
z, b
x)
    bound_in_nest :: Names
bound_in_nest = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Nesting -> Names
boundInNesting forall a b. (a -> b) -> a -> b
$ Nesting
inner_nest forall a. a -> [a] -> [a]
: [Nesting]
nests
    distributableType :: Type -> Bool
distributableType =
      (forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> Names -> Names
namesIntersection Names
bound_in_nest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> [SubExp]
arrayDims

    distributeAtNesting ::
      Nesting ->
      Pat Type ->
      (LoopNesting -> KernelNest, Names) ->
      M.Map VName Ident ->
      [Ident] ->
      (Target -> Targets) ->
      MaybeT m (KernelNest, Names, Targets)
    distributeAtNesting :: Nesting
-> Pat Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
      (Nesting Names
nest_let_bound LoopNesting
nest)
      Pat Type
pat
      (LoopNesting -> KernelNest
add_to_kernel, Names
free_in_kernel)
      Map VName Ident
identity_map
      [Ident]
inner_returned_arrs
      Target -> Targets
addTarget = do
        let nest' :: LoopNesting
nest'@(MapNesting Pat Type
_ StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) =
              Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel LoopNesting
nest
            ([Param Type]
params, [VName]
arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
            param_names :: Names
param_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
params
            free_in_kernel' :: Names
free_in_kernel' =
              (forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest' forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
param_names
            required_from_nest :: Names
required_from_nest =
              Names
free_in_kernel' Names -> Names -> Names
`namesIntersection` Names
nest_let_bound

        [Ident]
required_from_nest_idents <-
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
required_from_nest) forall a b. (a -> b) -> a -> b
$ \VName
name -> do
            Type
t <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
name
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t

        ([Param Type]
free_params, [Ident]
free_arrs, [Bool]
bind_in_target) <-
          forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Ident]
inner_returned_arrs forall a. [a] -> [a] -> [a]
++ [Ident]
required_from_nest_idents) forall a b. (a -> b) -> a -> b
$
              \(Ident VName
pname Type
ptype) ->
                case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pname Map VName Ident
identity_map of
                  Maybe Ident
Nothing -> do
                    Ident
arr <-
                      forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent (VName -> [Char]
baseString VName
pname forall a. [a] -> [a] -> [a]
++ [Char]
"_r") forall a b. (a -> b) -> a -> b
$ forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
ptype SubExp
w
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure
                      ( forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
pname Type
ptype,
                        Ident
arr,
                        Bool
True
                      )
                  Just Ident
arr ->
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure
                      ( forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
pname Type
ptype,
                        Ident
arr,
                        Bool
False
                      )

        let free_arrs_pat :: Pat Type
free_arrs_pat =
              [Ident] -> Pat Type
basicPat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Ident]
free_arrs
            free_params_pat :: [Param Type]
free_params_pat =
              forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bind_in_target [Param Type]
free_params

            ([Param Type]
actual_params, [VName]
actual_arrs) =
              ( [Param Type]
params forall a. [a] -> [a] -> [a]
++ [Param Type]
free_params,
                [VName]
arrs forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
free_arrs
              )
            actual_param_names :: Names
actual_param_names =
              [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
actual_params

            nest'' :: LoopNesting
nest'' =
              Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
free_in_kernel forall a b. (a -> b) -> a -> b
$
                Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w forall a b. (a -> b) -> a -> b
$
                  forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
actual_params [VName]
actual_arrs

            free_in_kernel'' :: Names
free_in_kernel'' =
              (forall a. FreeIn a => a -> Names
freeIn LoopNesting
nest'' forall a. Semigroup a => a -> a -> a
<> Names
free_in_kernel) Names -> Names -> Names
`namesSubtract` Names
actual_param_names

        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
          ( forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
distributableType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Typed dec => Param dec -> Type
paramType) forall a b. (a -> b) -> a -> b
$
              LoopNesting -> [Param Type]
loopNestingParams LoopNesting
nest''
          )
          forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Would induce irregular array"
        forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( LoopNesting -> KernelNest
add_to_kernel LoopNesting
nest'',
            Names
free_in_kernel'',
            Target -> Targets
addTarget (Pat Type
free_arrs_pat, [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
free_params_pat)
          )

    recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
    recurse :: [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [] =
      Nesting
-> Pat Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
        Nesting
inner_nest
        (DistributionBody -> Pat Type
distributionInnerPat DistributionBody
distrib_body)
        ( LoopNesting -> KernelNest
newKernel,
          DistributionBody -> Names
distributionFreeInBody DistributionBody
distrib_body Names -> Names -> Names
`namesIntersection` Names
bound_in_nest
        )
        (DistributionBody -> Map VName Ident
distributionIdentityMap DistributionBody
distrib_body)
        []
        forall a b. (a -> b) -> a -> b
$ Target -> Targets
singleTarget forall b c a. (b -> c) -> (a -> b) -> a -> c
. DistributionBody -> Target -> Target
distributionExpandTarget DistributionBody
distrib_body
    recurse ((Nesting
nest, (Pat Type
pat, Result
res)) : [(Nesting, Target)]
nests') = do
      (kernel :: KernelNest
kernel@(LoopNesting
outer, [LoopNesting]
_), Names
kernel_free, Targets
kernel_targets) <- [(Nesting, Target)] -> MaybeT m (KernelNest, Names, Targets)
recurse [(Nesting, Target)]
nests'

      let (Pat Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
            Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting
              ([VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames forall a b. (a -> b) -> a -> b
$ LoopNesting -> Pat Type
loopNestingPat LoopNesting
outer)
              Pat Type
pat
              Result
res

      Nesting
-> Pat Type
-> (LoopNesting -> KernelNest, Names)
-> Map VName Ident
-> [Ident]
-> (Target -> Targets)
-> MaybeT m (KernelNest, Names, Targets)
distributeAtNesting
        Nesting
nest
        Pat Type
pat'
        ( \LoopNesting
k -> Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting (Pat Type
pat', Result
res') LoopNesting
k KernelNest
kernel,
          Names
kernel_free
        )
        Map VName Ident
identity_map
        (forall dec. Typed dec => Pat dec -> [Ident]
patIdents forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ Targets -> Target
outerTarget Targets
kernel_targets)
        ((Target -> Targets -> Targets
`pushOuterTarget` Targets
kernel_targets) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Target -> Target
expand_target)

removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts Names
used (MapNesting Pat Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) =
  Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
used_params [VName]
used_arrs
  where
    ([Param Type]
params, [VName]
arrs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
    ([Param Type]
used_params, [VName]
used_arrs) =
      forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
used) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
params [VName]
arrs

removeIdentityMappingGeneral ::
  Names ->
  Pat Type ->
  Result ->
  ( Pat Type,
    Result,
    M.Map VName Ident,
    Target -> Target
  )
removeIdentityMappingGeneral :: Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound Pat Type
pat Result
res =
  let ([(PatElem Type, (Certs, VName))]
identities, [(PatElem Type, SubExpRes)]
not_identities) =
        forall a b c. (a -> Either b c) -> [a] -> ([b], [c])
mapEither (PatElem Type, SubExpRes)
-> Either (PatElem Type, (Certs, VName)) (PatElem Type, SubExpRes)
isIdentity forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) Result
res
      ([PatElem Type]
not_identity_patElems, Result
not_identity_res) = forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem Type, SubExpRes)]
not_identities
      ([PatElem Type]
identity_patElems, [(Certs, VName)]
identity_res) = forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem Type, (Certs, VName))]
identities
      expandTarget :: Target -> Target
expandTarget (Pat Type
tpat, Result
tres) =
        ( forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
tpat forall a. [a] -> [a] -> [a]
++ [PatElem Type]
identity_patElems,
          Result
tres forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Certs -> SubExp -> SubExpRes
SubExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second VName -> SubExp
Var) [(Certs, VName)]
identity_res
        )
      identity_map :: Map VName Ident
identity_map =
        forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Certs, VName)]
identity_res) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => PatElem dec -> Ident
patElemIdent [PatElem Type]
identity_patElems
   in ( forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
not_identity_patElems,
        Result
not_identity_res,
        Map VName Ident
identity_map,
        Target -> Target
expandTarget
      )
  where
    isIdentity :: (PatElem Type, SubExpRes)
-> Either (PatElem Type, (Certs, VName)) (PatElem Type, SubExpRes)
isIdentity (PatElem Type
patElem, SubExpRes Certs
cs (Var VName
v))
      | VName
v VName -> Names -> Bool
`notNameIn` Names
bound = forall a b. a -> Either a b
Left (PatElem Type
patElem, (Certs
cs, VName
v))
    isIdentity (PatElem Type, SubExpRes)
x = forall a b. b -> Either a b
Right (PatElem Type, SubExpRes)
x

removeIdentityMappingFromNesting ::
  Names ->
  Pat Type ->
  Result ->
  ( Pat Type,
    Result,
    M.Map VName Ident,
    Target -> Target
  )
removeIdentityMappingFromNesting :: Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingFromNesting Names
bound_in_nesting Pat Type
pat Result
res =
  let (Pat Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target) =
        Names
-> Pat Type
-> Result
-> (Pat Type, Result, Map VName Ident, Target -> Target)
removeIdentityMappingGeneral Names
bound_in_nesting Pat Type
pat Result
res
   in (Pat Type
pat', Result
res', Map VName Ident
identity_map, Target -> Target
expand_target)

tryDistribute ::
  ( DistRep rep,
    MonadFreshNames m,
    LocalScope rep m,
    MonadLogger m
  ) =>
  MkSegLevel rep m ->
  Nestings ->
  Targets ->
  Stms rep ->
  m (Maybe (Targets, Stms rep))
tryDistribute :: forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m,
 MonadLogger m) =>
MkSegLevel rep m
-> Nestings -> Targets -> Stms rep -> m (Maybe (Targets, Stms rep))
tryDistribute MkSegLevel rep m
_ Nestings
_ Targets
targets Stms rep
stms
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null Stms rep
stms =
      -- No point in distributing an empty kernel.
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Targets
targets, forall a. Monoid a => a
mempty)
tryDistribute MkSegLevel rep m
mk_lvl Nestings
nest Targets
targets Stms rep
stms =
  forall {k} (rep :: k) (m :: * -> *).
(MonadFreshNames m, HasScope rep m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body
    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Just (Targets
targets', KernelNest
distributed) -> do
        (Stm rep
kernel_stm, Stms rep
w_stms) <-
          forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. DistRep rep => Targets -> Scope rep
targetsScope Targets
targets') forall a b. (a -> b) -> a -> b
$
            forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel MkSegLevel rep m
mk_lvl KernelNest
distributed forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms rep
stms Result
inner_body_res
        Stm rep
distributed' <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm Stm rep
kernel_stm
        forall (m :: * -> *) a. (MonadLogger m, ToLog a) => a -> m ()
logMsg forall a b. (a -> b) -> a -> b
$
          [Char]
"distributing\n"
            forall a. [a] -> [a] -> [a]
++ [[Char]] -> [Char]
unlines (forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> [Char]
prettyString forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms rep
stms)
            forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ Targets -> Target
innerTarget Targets
targets)
            forall a. [a] -> [a] -> [a]
++ [Char]
"\nas\n"
            forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Stm rep
distributed'
            forall a. [a] -> [a] -> [a]
++ [Char]
"\ndue to targets\n"
            forall a. [a] -> [a] -> [a]
++ Targets -> [Char]
ppTargets Targets
targets
            forall a. [a] -> [a] -> [a]
++ [Char]
"\nand with new targets\n"
            forall a. [a] -> [a] -> [a]
++ Targets -> [Char]
ppTargets Targets
targets'
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Targets
targets', Stms rep
w_stms forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm rep
distributed')
      Maybe (Targets, KernelNest)
Nothing ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
  where
    (DistributionBody
dist_body, Result
inner_body_res) = forall {k} (rep :: k).
ASTRep rep =>
Targets -> Stms rep -> (DistributionBody, Result)
distributionBodyFromStms Targets
targets Stms rep
stms

tryDistributeStm ::
  (MonadFreshNames m, HasScope t m, ASTRep rep) =>
  Nestings ->
  Targets ->
  Stm rep ->
  m (Maybe (Result, Targets, KernelNest))
tryDistributeStm :: forall {k} {k} (m :: * -> *) (t :: k) (rep :: k).
(MonadFreshNames m, HasScope t m, ASTRep rep) =>
Nestings
-> Targets -> Stm rep -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm Nestings
nest Targets
targets Stm rep
stm =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(MonadFreshNames m, HasScope rep m) =>
Nestings -> DistributionBody -> m (Maybe (Targets, KernelNest))
createKernelNest Nestings
nest DistributionBody
dist_body
  where
    (DistributionBody
dist_body, Result
res) = forall {k} (rep :: k).
ASTRep rep =>
Targets -> Stm rep -> (DistributionBody, Result)
distributionBodyFromStm Targets
targets Stm rep
stm
    addRes :: (Targets, KernelNest) -> (Result, Targets, KernelNest)
addRes (Targets
targets', KernelNest
kernel_nest) = (Result
res, Targets
targets', KernelNest
kernel_nest)