-- | Multicore code generation for 'SegMap'.
module Futhark.CodeGen.ImpGen.Multicore.SegMap
  ( compileSegMap,
  )
where

import Control.Monad
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.IR.MCMem
import Futhark.Transform.Rename

writeResult ::
  [VName] ->
  PatElemT dec ->
  KernelResult ->
  MulticoreGen ()
writeResult :: [VName] -> PatElemT dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is PatElemT dec
pe (Returns ResultManifest
_ Certs
_ SubExp
se) =
  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is) SubExp
se []
writeResult [VName]
_ PatElemT dec
pe (WriteReturns Certs
_ (Shape [SubExp]
rws) VName
_ [(Slice SubExp, SubExp)]
idx_vals) = do
  let ([Slice SubExp]
iss, [SubExp]
vs) = [(Slice SubExp, SubExp)] -> ([Slice SubExp], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Slice SubExp, SubExp)]
idx_vals
      rws' :: [TExp Int64]
rws' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
rws
  [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Slice SubExp] -> [SubExp] -> [(Slice SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Slice SubExp]
iss [SubExp]
vs) (((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ())
-> ((Slice SubExp, SubExp) -> MulticoreGen ()) -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
v) -> do
    let slice' :: Slice (TExp Int64)
slice' = (SubExp -> TExp Int64) -> Slice SubExp -> Slice (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Slice SubExp
slice
        when_in_bounds :: ImpM rep r op ()
when_in_bounds = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) (Slice (TExp Int64) -> [DimIndex (TExp Int64)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice') SubExp
v []
    TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
rws') MulticoreGen ()
forall rep r op. ImpM rep r op ()
when_in_bounds
writeResult [VName]
_ PatElemT dec
_ KernelResult
res =
  [Char] -> MulticoreGen ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> MulticoreGen ()) -> [Char] -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"writeResult: cannot handle " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ KernelResult -> [Char]
forall a. Pretty a => a -> [Char]
pretty KernelResult
res

compileSegMapBody ::
  TV Int64 ->
  Pat MCMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
compileSegMapBody :: TV Int64
-> Pat MCMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Code
compileSegMapBody TV Int64
flat_idx Pat MCMem
pat SegSpace
space (KernelBody BodyDec MCMem
_ Stms MCMem
kstms [KernelResult]
kres) = do
  let ([VName]
is, [SubExp]
ns) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns' :: [TExp Int64]
ns' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  Stms MCMem
kstms' <- (Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem))
-> Stms MCMem -> ImpM MCMem HostEnv Multicore (Stms MCMem)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm MCMem -> ImpM MCMem HostEnv Multicore (Stm MCMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Stm rep -> m (Stm rep)
renameStm Stms MCMem
kstms
  MulticoreGen () -> MulticoreGen Code
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    Code -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Maybe Exp -> Code
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"SegMap fbody" Maybe Exp
forall a. Maybe a
Nothing
    [(VName, TExp Int64)] -> TExp Int64 -> MulticoreGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TExp Int64]
ns') (TExp Int64 -> MulticoreGen ()) -> TExp Int64 -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
flat_idx
    Names -> Stms MCMem -> MulticoreGen () -> MulticoreGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms ([KernelResult] -> Names
forall a. FreeIn a => a -> Names
freeIn [KernelResult]
kres) Stms MCMem
kstms' (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      (PatElemT LetDecMem -> KernelResult -> MulticoreGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> MulticoreGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([VName] -> PatElemT LetDecMem -> KernelResult -> MulticoreGen ()
forall dec.
[VName] -> PatElemT dec -> KernelResult -> MulticoreGen ()
writeResult [VName]
is) (PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat MCMem
PatT LetDecMem
pat) [KernelResult]
kres

compileSegMap ::
  Pat MCMem ->
  SegSpace ->
  KernelBody MCMem ->
  MulticoreGen Imp.Code
compileSegMap :: Pat MCMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Code
compileSegMap Pat MCMem
pat SegSpace
space KernelBody MCMem
kbody =
  MulticoreGen () -> MulticoreGen Code
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (MulticoreGen () -> MulticoreGen Code)
-> MulticoreGen () -> MulticoreGen Code
forall a b. (a -> b) -> a -> b
$ do
    TV Int64
flat_par_idx <- [Char] -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int64)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"iter" PrimType
int64
    Code
body <- TV Int64
-> Pat MCMem -> SegSpace -> KernelBody MCMem -> MulticoreGen Code
compileSegMapBody TV Int64
flat_par_idx Pat MCMem
pat SegSpace
space KernelBody MCMem
kbody
    [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams Code
body [SegSpace -> VName
segFlat SegSpace
space, TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_par_idx]
    let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
    Code -> MulticoreGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code -> MulticoreGen ()) -> Code -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ [Char]
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop [Char]
"segmap" (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat_par_idx) Code
body_allocs Code
body' Code
forall a. Monoid a => a
mempty [Param]
free_params (VName -> Multicore) -> VName -> Multicore
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space