{-# LANGUAGE TypeFamilies #-}

-- | Code generation for 'SegMap' is quite straightforward.  The only
-- trick is virtualisation in case the physical number of threads is
-- not sufficient to cover the logical thread space.  This is handled
-- by having actual workgroups run a loop to imitate multiple workgroups.
module Futhark.CodeGen.ImpGen.GPU.SegMap (compileSegMap) where

import Control.Monad
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.Group
import Futhark.IR.GPUMem
import Futhark.Util.IntegralExp (divUp)
import Prelude hiding (quot, rem)

-- | Compile 'SegMap' instance code.
compileSegMap ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegMap :: Pat LetDecMem
-> SegLevel -> SegSpace -> KernelBody GPUMem -> CallKernelGen ()
compileSegMap Pat LetDecMem
pat SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody = do
  KernelAttrs
attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl

  let ([VName]
is, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = SubExp -> TPrimExp Int64 VName
pe64 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs

  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegMap" forall a. Maybe a
Nothing
  case SegLevel
lvl of
    SegThread {} -> do
      TExp Int32
virt_num_groups <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_num_groups" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims' forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
      String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segmap" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs forall a b. (a -> b) -> a -> b
$
        SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TExp Int32
virt_num_groups forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
          TExp Int32
local_tid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op r
askEnv

          TPrimExp Int64 VName
global_tid <-
            forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" forall a b. (a -> b) -> a -> b
$
              forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size')
                forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid

          forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
global_tid

          forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
            forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) forall a b. (a -> b) -> a -> b
$
                forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
    SegGroup {} -> do
      Precomputed
pc <- Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size' forall a b. (a -> b) -> a -> b
$ forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody
      TExp Int32
virt_num_groups <- forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_num_groups" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelGroup String
"segmap_intragroup" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs forall a b. (a -> b) -> a -> b
$ do
        forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pc forall a b. (a -> b) -> a -> b
$
          SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TExp Int32
virt_num_groups forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
            forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id

            forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$
              forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
space) (forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) forall a b. (a -> b) -> a -> b
$
                forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
    SegThreadInGroup {} ->
      forall a. HasCallStack => String -> a
error String
"compileSegMap: SegThreadInGroup"
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" forall a. Maybe a
Nothing