{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Futhark.CodeGen.ImpGen.Kernels.SegMap
  ( compileSegMap )
where

import Control.Monad.Except

import Prelude hiding (quot, rem)

import Futhark.Representation.KernelsMem
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.CodeGen.ImpGen
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.Util.IntegralExp (quotRoundingUp)

-- | Compile 'SegMap' instance code.
compileSegMap :: Pattern KernelsMem
              -> SegLevel
              -> SegSpace
              -> KernelBody KernelsMem
              -> CallKernelGen ()

compileSegMap :: Pattern KernelsMem
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> CallKernelGen ()
compileSegMap Pattern KernelsMem
pat SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody = do
  let ([VName]
is, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  [Exp]
dims' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> [SubExp] -> ImpM KernelsMem HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims

  Count NumGroups Exp
num_groups' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (Count NumGroups SubExp
 -> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp))
-> Count NumGroups SubExp
-> ImpM KernelsMem HostEnv HostOp (Count NumGroups Exp)
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl
  Count GroupSize Exp
group_size' <- (SubExp -> ImpM KernelsMem HostEnv HostOp Exp)
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (Count GroupSize SubExp
 -> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp))
-> Count GroupSize SubExp
-> ImpM KernelsMem HostEnv HostOp (Count GroupSize Exp)
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl

  case SegLevel
lvl of
    SegThread{} -> do
      Code HostOp -> CallKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ String -> Maybe Exp -> Code HostOp
forall a. String -> Maybe Exp -> Code a
Imp.DebugPrint String
"\n# SegMap" Maybe Exp
forall a. Maybe a
Nothing
      let virt_num_groups :: Exp
virt_num_groups = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims' Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size'
      String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread String
"segmap" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
        SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) Exp
virt_num_groups ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id -> do
        Exp
local_tid <- KernelConstants -> Exp
kernelLocalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
        let global_tid :: Exp
global_tid = VName -> Exp
Imp.vi32 VName
group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count GroupSize Exp -> Exp
forall u e. Count u e -> e
unCount Count GroupSize Exp
group_size' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
local_tid

        (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' Exp
global_tid

        Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (PatElemT LetAttrMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetAttrMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem KernelsMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (PatternT LetAttrMem -> [PatElemT LetAttrMem]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern KernelsMem
PatternT LetAttrMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody

    SegGroup{} ->
      String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelGroup String
"segmap_intragroup" Count NumGroups Exp
num_groups' Count GroupSize Exp
group_size' (SegSpace -> VName
segFlat SegSpace
space) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let virt_num_groups :: Exp
virt_num_groups = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims'
      Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall a. Stms KernelsMem -> InKernelGen a -> InKernelGen a
precomputeSegOpIDs (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups (SegLevel -> SegVirt
segVirt SegLevel
lvl) Exp
virt_num_groups ((VName -> InKernelGen ()) -> InKernelGen ())
-> (VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \VName
group_id -> do

        (VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
is ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
group_id

        Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (PatElemT LetAttrMem -> KernelResult -> InKernelGen ())
-> [PatElemT LetAttrMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem KernelsMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
space) (PatternT LetAttrMem -> [PatElemT LetAttrMem]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern KernelsMem
PatternT LetAttrMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody