-- | Multicore code generation for 'SegMap'.
module Futhark.CodeGen.ImpGen.Multicore.SegMap
  ( compileSegMap,
  )
where

import Control.Monad
import Futhark.CodeGen.ImpCode.Multicore qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Transform.Rename

writeResult ::
  [VName] ->
  PatElem dec ->
  KernelResult ->
  MulticoreGen ()
writeResult :: forall dec.
[VName] -> PatElem dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is PatElem dec
pe (Returns ResultManifest
_ Certs
_ SubExp
se) =
  forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
se []
writeResult [VName]
_ PatElem dec
pe (WriteReturns Certs
_ (Shape [SubExp]
rws) VName
_ [(Slice SubExp, SubExp)]
idx_vals) = do
  let ([Slice SubExp]
iss, [SubExp]
vs) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Slice SubExp, SubExp)]
idx_vals
      rws' :: [TExp Int64]
rws' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
rws
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Slice SubExp]
iss [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
v) -> do
    let slice' :: Slice (TExp Int64)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Slice SubExp
slice
    forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
rws') forall a b. (a -> b) -> a -> b
$
      forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem dec
pe) (forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice') SubExp
v []
writeResult [VName]
_ PatElem dec
_ KernelResult
res =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"writeResult: cannot handle " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString KernelResult
res

compileSegMapBody ::
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
compileSegMapBody :: Pat LetDecMem
-> SegSpace -> KernelBody MCMem -> MulticoreGen MCCode
compileSegMapBody Pat LetDecMem
pat SegSpace
space (KernelBody BodyDec MCMem
_ Stms MCMem
kstms [KernelResult]
kres) = forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
  let ([VName]
is, [SubExp]
ns) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
ns
  forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
  forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
  Stms MCMem
kstms' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm Stms MCMem
kstms
  MulticoreGen () -> MulticoreGen ()
inISPC forall a b. (a -> b) -> a -> b
$
    [Char]
-> ChunkLoopVectorization
-> (TExp Int64 -> MulticoreGen ())
-> MulticoreGen ()
generateChunkLoop [Char]
"SegMap" ChunkLoopVectorization
Vectorized forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
      forall {k} (rep :: k) r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TExp Int64]
ns') TExp Int64
i
      forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn [KernelResult]
kres) Stms MCMem
kstms' forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (forall dec.
[VName] -> PatElem dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is) (forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) [KernelResult]
kres

compileSegMap ::
  Pat LetDecMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.MCCode
compileSegMap :: Pat LetDecMem
-> SegSpace -> KernelBody MCMem -> MulticoreGen MCCode
compileSegMap Pat LetDecMem
pat SegSpace
space KernelBody MCMem
kbody = forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
  MCCode
body <- Pat LetDecMem
-> SegSpace -> KernelBody MCMem -> MulticoreGen MCCode
compileSegMapBody Pat LetDecMem
pat SegSpace
space KernelBody MCMem
kbody
  [Param]
free_params <- forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams MCCode
body
  forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. a -> Code a
Imp.Op forall a b. (a -> b) -> a -> b
$ [Char] -> MCCode -> [Param] -> Multicore
Imp.ParLoop [Char]
"segmap" MCCode
body [Param]
free_params