{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Code generation for 'SegMap' is quite straightforward.  The only
-- trick is virtualisation in case the physical number of threads is
-- not sufficient to cover the logical thread space.  This is handled
-- by having actual workgroups run a loop to imitate multiple workgroups.
module Futhark.CodeGen.ImpGen.GPU.SegMap (compileSegMap) where

import Control.Monad.Except
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.IR.GPUMem
import Futhark.Util.IntegralExp (divUp)
import Prelude hiding (quot, rem)

-- | Compile 'SegMap' instance code.
compileSegMap ::
  Pat GPUMem ->
  SegLevel ->
  SegSpace ->
  KernelBody GPUMem ->
  CallKernelGen ()
compileSegMap :: Pat GPUMem
-> SegLevel -> SegSpace -> KernelBody GPUMem -> CallKernelGen ()
compileSegMap Pat GPUMem
pat SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody = do
  let ([VName]
is, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
      num_groups' :: Count NumGroups (TExp Int64)
num_groups' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count NumGroups SubExp -> Count NumGroups (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
      group_size' :: Count GroupSize (TExp Int64)
group_size' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> Count GroupSize SubExp -> Count GroupSize (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegMap" Maybe Exp
forall a. Maybe a
Nothing
  case SegLevel
lvl of
    SegThread {} -> do
      let virt_num_groups :: TPrimExp Int32 VName
virt_num_groups =
            TExp Int64 -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 VName)
-> TExp Int64 -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims' TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size'
      String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segmap" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
        SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TPrimExp Int32 VName
virt_num_groups ((TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
group_id -> do
          TPrimExp Int32 VName
local_tid <- KernelConstants -> TPrimExp Int32 VName
kernelLocalThreadId (KernelConstants -> TPrimExp Int32 VName)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> TPrimExp Int32 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int32 VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int32 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

          TExp Int64
global_tid <-
            String -> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" (TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM GPUMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$
              TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (Count GroupSize (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count GroupSize (TExp Int64)
group_size')
                TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
local_tid

          [(VName, TExp Int64)] -> TExp Int64 -> InKernelGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TExp Int64]
dims') TExp Int64
global_tid

          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              (PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem GPUMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat GPUMem
PatT LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
    SegGroup {} ->
      String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelGroup String
"segmap_intragroup" Count NumGroups (TExp Int64)
num_groups' Count GroupSize (TExp Int64)
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        let virt_num_groups :: TPrimExp Int32 VName
virt_num_groups = TExp Int64 -> TPrimExp Int32 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TPrimExp Int32 VName)
-> TExp Int64 -> TPrimExp Int32 VName
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims'
        Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall a. Stms GPUMem -> InKernelGen a -> InKernelGen a
precomputeSegOpIDs (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          SegVirt
-> TPrimExp Int32 VName
-> (TPrimExp Int32 VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) TPrimExp Int32 VName
virt_num_groups ((TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int32 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int32 VName
group_id -> do
            [(VName, TExp Int64)] -> TExp Int64 -> InKernelGen ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [TExp Int64]
dims') (TExp Int64 -> InKernelGen ()) -> TExp Int64 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Int32 VName -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int32 VName
group_id

            Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              (PatElemT LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem GPUMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
space) (PatT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat GPUMem
PatT LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"" Maybe Exp
forall a. Maybe a
Nothing