-- | Multicore code generation for 'SegMap'.
module Futhark.CodeGen.ImpGen.Multicore.SegMap
  ( compileSegMap,
  )
where

import Control.Monad
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Transform.Rename

writeResult ::
  [VName] ->
  PatElem dec ->
  KernelResult ->
  MulticoreGen ()
writeResult :: [VName] -> PatElem dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is PatElem dec
pe (Returns ResultManifest
_ Certs
_ SubExp
se) =
  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
se []
writeResult [VName]
_ PatElem dec
pe (WriteReturns Certs
_ (Shape [SubExp]
rws) VName
_ [(Slice SubExp, SubExp)]
idx_vals) = do
  let ([Slice SubExp]
iss, [SubExp]
vs) = [(Slice SubExp, SubExp)] -> ([Slice SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Slice SubExp, SubExp)]
idx_vals
      rws' :: [TExp Int64]
rws' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
rws
  [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Slice SubExp] -> [SubExp] -> [(Slice SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Slice SubExp]
iss [SubExp]
vs) (((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
v) -> do
    let slice' :: Slice (TExp Int64)
slice' = (SubExp -> TExp Int64) -> Slice SubExp -> Slice (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Slice SubExp
slice
    TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
rws') (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> MulticoreGen ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) (Slice (TExp Int64) -> [DimIndex (TExp Int64)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice') SubExp
v []
writeResult [VName]
_ PatElem dec
_ KernelResult
res =
  [Char] -> MulticoreGen ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> MulticoreGen ()) -> [Char] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"writeResult: cannot handle " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ KernelResult -> [Char]
forall a. Pretty a => a -> [Char]
pretty KernelResult
res

compileSegMapBody ::
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
compileSegMapBody :: Pat LetDecMem
-> SegSpace -> KernelBody MCMem -> MulticoreGen MCCode
compileSegMapBody Pat LetDecMem
pat SegSpace
space (KernelBody BodyDec MCMem
_ Stms MCMem
kstms [KernelResult]
kres) = MulticoreGen () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen MCCode)
-> MulticoreGen () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$ do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  VName -> PrimType -> MulticoreGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
  Multicore -> MulticoreGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
  Stms MCMem
kstms' <- (Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem))
-> Stms MCMem -> ImpM MCMem HostEnv Multicore (Stms MCMem)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm Stms MCMem
kstms
  [Char] -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
generateChunkLoop [Char]
"SegMap" ((TExp Int64 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    [(VName, TExp Int64)] -> TExp Int64 -> MulticoreGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TExp Int64]
ns') TExp Int64
i
    Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms ([KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
kres) Stms MCMem
kstms' (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      (PatElem LetDecMem -> KernelResult -> MulticoreGen ())
-> [PatElem LetDecMem] -> [KernelResult] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([VName] -> PatElem LetDecMem -> KernelResult -> MulticoreGen ()
forall dec.
[VName] -> PatElem dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is) (Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) [KernelResult]
kres

compileSegMap ::
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
compileSegMap :: Pat LetDecMem
-> SegSpace -> KernelBody MCMem -> MulticoreGen MCCode
compileSegMap Pat LetDecMem
pat SegSpace
space KernelBody MCMem
kbody = MulticoreGen () -> MulticoreGen MCCode
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen MCCode)
-> MulticoreGen () -> MulticoreGen MCCode
forall a b. (a -> b) -> a -> b
$ do
  MCCode
body <- Pat LetDecMem
-> SegSpace -> KernelBody MCMem -> MulticoreGen MCCode
compileSegMapBody Pat LetDecMem
pat SegSpace
space KernelBody MCMem
kbody
  [Param]
free_params <- MCCode -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
  MCCode -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (MCCode -> MulticoreGen ()) -> MCCode -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> MCCode
forall a. a -> Code a
Imp.Op (Multicore -> MCCode) -> Multicore -> MCCode
forall a b. (a -> b) -> a -> b
$ [Char] -> MCCode -> [Param] -> Multicore
Imp.ParLoop [Char]
"segmap" MCCode
body [Param]
free_params