{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.GPU.SegMap (compileSegMap) where
import Control.Monad.Except
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.Group
import Futhark.IR.GPUMem
import Futhark.Util.IntegralExp (divUp)
import Prelude hiding (quot, rem)
compileSegMap ::
Pat LetDecMem ->
SegLevel ->
SegSpace ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegMap :: Pat LetDecMem
-> SegLevel -> SegSpace -> KernelBody GPUMem -> CallKernelGen ()
compileSegMap Pat LetDecMem
pat SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody = do
KernelAttrs
attrs <- SegLevel -> CallKernelGen KernelAttrs
lvlKernelAttrs SegLevel
lvl
let ([VName]
is, [SubExp]
dims) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TPrimExp Int64 VName]
dims' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
group_size' :: Count GroupSize (TPrimExp Int64 VName)
group_size' = SubExp -> TPrimExp Int64 VName
pe64 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegMap" forall a. Maybe a
Nothing
case SegLevel
lvl of
SegThread {} -> do
TExp Int32
virt_num_groups <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_num_groups" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims' forall e. IntegralExp e => e -> e -> e
`divUp` forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size'
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread String
"segmap" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs forall a b. (a -> b) -> a -> b
$
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TExp Int32
virt_num_groups forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
TExp Int32
local_tid <- KernelConstants -> TExp Int32
kernelLocalThreadId forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op r
askEnv
TPrimExp Int64 VName
global_tid <-
forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id forall a. Num a => a -> a -> a
* forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size')
forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid
forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TPrimExp Int64 VName]
dims') TPrimExp Int64 VName
global_tid
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
SegGroup {} -> do
Precomputed
pc <- Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody
TExp Int32
virt_num_groups <- forall {k1} {k2} (t :: k1) (rep :: k2) r op.
String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_num_groups" forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
String
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelGroup String
"segmap_intragroup" (SegSpace -> VName
segFlat SegSpace
space) KernelAttrs
attrs forall a b. (a -> b) -> a -> b
$ do
forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pc forall a b. (a -> b) -> a -> b
$
SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TExp Int32
virt_num_groups forall a b. (a -> b) -> a -> b
$ \TExp Int32
group_id -> do
forall {k} (rep :: k) r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TPrimExp Int64 VName]
dims') forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
group_id
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
space) (forall dec. Pat dec -> [PatElem dec]
patElems Pat LetDecMem
pat) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
SegThreadInGroup {} ->
forall a. HasCallStack => String -> a
error String
"compileSegMap: SegThreadInGroup"
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" forall a. Maybe a
Nothing