{-# LANGUAGE TypeFamilies #-}

-- | Code generation for 'SegMap' is quite straightforward.  The only
-- trick is virtualisation in case the physical number of threads is
-- not sufficient to cover the logical thread space.  This is handled
-- by having actual workgroups run a loop to imitate multiple workgroups.
module Futhark.CodeGen.ImpGen.GPU.SegMap (compileSegMap) where

import Control.Monad
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.Group
import Futhark.IR.GPUMem
import Futhark.Util.IntegralExp (divUp)
import Prelude hiding (quot, rem)

-- | Compile 'SegMap' instance code.
compileSegMap ::
  Pat LetDecMem ->
  SegLevel ->
  SegSpace ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegMap :: Pat LetDecMem
-> SegLevel -> SegSpace -> KernelBody GPUMem -> CallKernelGen ()
compileSegMap Pat LetDecMem
pat SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody = do
  KernelAttrs
attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl

  let ([VName]
is, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> Count GroupSize SubExp -> Count GroupSize (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegMap" Maybe Exp
forall a. Maybe a
Nothing
  case SegLevel
lvl of
    SegThread {} -> do
      TExp Int32
virt_num_groups <- String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_num_groups" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
      String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segmap" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
        SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TExp Int32
virt_num_groups ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
          TExp Int32
local_tid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

          TPrimExp Int64 VName
global_tid <-
            String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size')
                TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid

          [(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
global_tid

          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              (PatElem LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElem LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
    SegGroup {} -> do
      Precomputed
pc <- Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size' (Stms GPUMem -> CallKernelGen Precomputed)
-> Stms GPUMem -> CallKernelGen Precomputed
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody
      TExp Int32
virt_num_groups <- String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_num_groups" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelGroup String
"segmap_intragroup" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        Precomputed -> InKernelGen () -> InKernelGen ()
forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pc (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TExp Int32
virt_num_groups ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
            [(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TPrimExp Int64 VName]
dims') (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id

            Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              (PatElem LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElem LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
space) (Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
    SegThreadInGroup {} ->
      String -> CallKernelGen ()
forall a. HasCallStack => String -> a
error String
"compileSegMap: SegThreadInGroup"
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing