-- | Multicore code generation for 'SegMap'.
module Futhark.CodeGen.ImpGen.Multicore.SegMap
  ( compileSegMap,
  )
where

import Control.Monad
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Transform.Rename

writeResult ::
  [VName] ->
  PatElemT dec ->
  KernelResult ->
  MulticoreGen ()
writeResult :: [VName] -> PatElemT dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is PatElemT dec
pe (Returns ResultManifest
_ Certs
_ SubExp
se) =
  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
se []
writeResult [VName]
_ PatElemT dec
pe (WriteReturns Certs
_ (Shape [SubExp]
rws) VName
_ [(Slice SubExp, SubExp)]
idx_vals) = do
  let ([Slice SubExp]
iss, [SubExp]
vs) = [(Slice SubExp, SubExp)] -> ([Slice SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Slice SubExp, SubExp)]
idx_vals
      rws' :: [TExp Int64]
rws' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
rws
  [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Slice SubExp] -> [SubExp] -> [(Slice SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Slice SubExp]
iss [SubExp]
vs) (((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
v) -> do
    let slice' :: Slice (TExp Int64)
slice' = (SubExp -> TExp Int64) -> Slice SubExp -> Slice (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Slice SubExp
slice
    TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
rws') (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> MulticoreGen ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) (Slice (TExp Int64) -> [DimIndex (TExp Int64)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice') SubExp
v []
writeResult [VName]
_ PatElemT dec
_ KernelResult
res =
  [Char] -> MulticoreGen ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> MulticoreGen ()) -> [Char] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"writeResult: cannot handle " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ KernelResult -> [Char]
forall a. Pretty a => a -> [Char]
pretty KernelResult
res

compileSegMapBody ::
  Pat MCMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
compileSegMapBody :: Pat MCMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Code
compileSegMapBody Pat MCMem
pat SegSpace
space (KernelBody BodyDec MCMem
_ Stms MCMem
kstms [KernelResult]
kres) = MulticoreGen () -> MulticoreGen Code
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  VName -> PrimType -> MulticoreGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ (SegSpace -> VName
segFlat SegSpace
space) PrimType
int64
  Multicore -> MulticoreGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Multicore
Imp.GetTaskId (SegSpace -> VName
segFlat SegSpace
space)
  Stms MCMem
kstms' <- (Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem))
-> Stms MCMem -> ImpM MCMem HostEnv Multicore (Stms MCMem)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm Stms MCMem
kstms
  [Char] -> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
generateChunkLoop [Char]
"SegMap" ((TExp Int64 -> MulticoreGen ()) -> MulticoreGen ())
-> (TExp Int64 -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    [(VName, TExp Int64)] -> TExp Int64 -> MulticoreGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TExp Int64]
ns') TExp Int64
i
    Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms ([KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
kres) Stms MCMem
kstms' (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      (PatElemT LetDecMem -> KernelResult -> MulticoreGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([VName] -> PatElemT LetDecMem -> KernelResult -> MulticoreGen ()
forall dec.
[VName] -> PatElemT dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is) (PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat MCMem
PatT LetDecMem
pat) [KernelResult]
kres

compileSegMap ::
  Pat MCMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
compileSegMap :: Pat MCMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Code
compileSegMap Pat MCMem
pat SegSpace
space KernelBody MCMem
kbody = MulticoreGen () -> MulticoreGen Code
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
  Code
body <- Pat MCMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Code
compileSegMapBody Pat MCMem
pat SegSpace
space KernelBody MCMem
kbody
  [Param]
free_params <- Code -> MulticoreGen [Param]
forall a. FreeIn a => a -> MulticoreGen [Param]
freeParams Code
body
  Code -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ [Char] -> Code -> [Param] -> Multicore
Imp.ParLoop [Char]
"segmap" Code
body [Param]
free_params