{-# LANGUAGE CPP                 #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.CodeGen.Permute
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.CodeGen.Permute (

  mkPermute,

) where

import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Permute
import Data.Array.Accelerate.LLVM.CodeGen.Ptr
import Data.Array.Accelerate.LLVM.CodeGen.Sugar

import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Loop
import Data.Array.Accelerate.LLVM.PTX.Target

import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.RMW                                as RMW
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation

import Foreign.CUDA.Analysis

import Control.Monad                                                ( void )
import Control.Monad.State                                          ( gets )
import Prelude


-- Forward permutation specified by an indexing mapping. The resulting array is
-- initialised with the given defaults, and any further values that are permuted
-- into the result array are added to the current value using the combination
-- function.
--
-- The combination function must be /associative/ and /commutative/. Elements
-- that are mapped to the magic index 'ignore' are dropped.
--
-- Parallel forward permutation has to take special care because different
-- threads could concurrently try to update the same memory location. Where
-- available we make use of special atomic instructions and other optimisations,
-- but in the general case each element of the output array has a lock which
-- must be obtained by the thread before it can update that memory location.
--
-- TODO: After too many failures to acquire the lock on an element, the thread
-- should back off and try a different element, adding this failed element to
-- a queue or some such.
--
mkPermute
    :: HasCallStack
    => Gamma            aenv
    -> ArrayR (Array sh e)
    -> ShapeR sh'
    -> IRPermuteFun PTX aenv (e -> e -> e)
    -> IRFun1       PTX aenv (sh -> PrimMaybe sh')
    -> MIRDelayed   PTX aenv (Array sh e)
    -> CodeGen      PTX      (IROpenAcc PTX aenv (Array sh' e))
mkPermute :: Gamma aenv
-> ArrayR (Array sh e)
-> ShapeR sh'
-> IRPermuteFun PTX aenv (e -> e -> e)
-> IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> MIRDelayed PTX aenv (Array sh e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
mkPermute Gamma aenv
aenv ArrayR (Array sh e)
repr ShapeR sh'
shr' IRPermuteFun{Maybe (RMWOperation, IRFun1 PTX aenv (e -> e))
IRFun2 PTX aenv (e -> e -> e)
combine :: forall arch aenv e.
IRPermuteFun arch aenv (e -> e -> e)
-> IRFun2 arch aenv (e -> e -> e)
atomicRMW :: forall arch aenv e.
IRPermuteFun arch aenv (e -> e -> e)
-> Maybe (RMWOperation, IRFun1 arch aenv (e -> e))
atomicRMW :: Maybe (RMWOperation, IRFun1 PTX aenv (e -> e))
combine :: IRFun2 PTX aenv (e -> e -> e)
..} IRFun1 PTX aenv (sh -> PrimMaybe sh')
project MIRDelayed PTX aenv (Array sh e)
arr =
  case Maybe (RMWOperation, IRFun1 PTX aenv (e -> e))
atomicRMW of
    Just (RMWOperation
rmw, IRFun1 PTX aenv (e -> e)
f) -> Gamma aenv
-> ArrayR (Array sh e)
-> ShapeR sh'
-> RMWOperation
-> IRFun1 PTX aenv (e -> e)
-> IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> MIRDelayed PTX aenv (Array sh e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
forall aenv sh e sh'.
HasCallStack =>
Gamma aenv
-> ArrayR (Array sh e)
-> ShapeR sh'
-> RMWOperation
-> IRFun1 PTX aenv (e -> e)
-> IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> MIRDelayed PTX aenv (Array sh e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
mkPermute_rmw   Gamma aenv
aenv ArrayR (Array sh e)
repr ShapeR sh'
shr' RMWOperation
rmw IRFun1 PTX aenv (e -> e)
IRFun1 PTX aenv (e -> e)
f   IRFun1 PTX aenv (sh -> PrimMaybe sh')
project MIRDelayed PTX aenv (Array sh e)
arr
    Maybe (RMWOperation, IRFun1 PTX aenv (e -> e))
_             -> Gamma aenv
-> ArrayR (Array sh e)
-> ShapeR sh'
-> IRFun2 PTX aenv (e -> e -> e)
-> IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> MIRDelayed PTX aenv (Array sh e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
forall aenv sh e sh'.
Gamma aenv
-> ArrayR (Array sh e)
-> ShapeR sh'
-> IRFun2 PTX aenv (e -> e -> e)
-> IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> MIRDelayed PTX aenv (Array sh e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
mkPermute_mutex Gamma aenv
aenv ArrayR (Array sh e)
repr ShapeR sh'
shr' IRFun2 PTX aenv (e -> e -> e)
IRFun2 PTX aenv (e -> e -> e)
combine IRFun1 PTX aenv (sh -> PrimMaybe sh')
project MIRDelayed PTX aenv (Array sh e)
arr


-- Parallel forward permutation function which uses atomic instructions to
-- implement lock-free array updates.
--
-- Atomic instruction support on CUDA devices is a bit patchy, so depending on
-- the element type and compute capability of the target hardware we may need to
-- emulate the operation using atomic compare-and-swap.
--
--              Int32    Int64    Float16    Float32    Float64
--           +-------------------------------------------------
--    (+)    |  2.0       2.0       7.0        2.0        6.0
--    (-)    |  2.0       2.0        x          x          x
--    (.&.)  |  2.0       3.2
--    (.|.)  |  2.0       3.2
--    xor    |  2.0       3.2
--    min    |  2.0       3.2        x          x          x
--    max    |  2.0       3.2        x          x          x
--    CAS    |  2.0       2.0
--
-- Note that NVPTX requires at least compute 2.0, so we can always implement the
-- lockfree update operations in terms of compare-and-swap.
--
mkPermute_rmw
    :: HasCallStack
    => Gamma aenv
    -> ArrayR (Array sh e)
    -> ShapeR sh'
    -> RMWOperation
    -> IRFun1     PTX aenv (e -> e)
    -> IRFun1     PTX aenv (sh -> PrimMaybe sh')
    -> MIRDelayed PTX aenv (Array sh e)
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Array sh' e))
mkPermute_rmw :: Gamma aenv
-> ArrayR (Array sh e)
-> ShapeR sh'
-> RMWOperation
-> IRFun1 PTX aenv (e -> e)
-> IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> MIRDelayed PTX aenv (Array sh e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
mkPermute_rmw Gamma aenv
aenv (ArrayR ShapeR sh
shr TypeR e
tp) ShapeR sh'
shr' RMWOperation
rmw IRFun1 PTX aenv (e -> e)
update IRFun1 PTX aenv (sh -> PrimMaybe sh')
project MIRDelayed PTX aenv (Array sh e)
marr = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  --
  let
      outR :: ArrayR (Array sh' e)
outR                = ShapeR sh' -> TypeR e -> ArrayR (Array sh' e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh'
shr' TypeR e
tp
      (IRArray (Array sh' e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array sh' e)
-> Name (Array sh' e) -> (IRArray (Array sh' e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh' e)
outR Name (Array sh' e)
"out"
      (IRDelayed PTX aenv (Array sh e)
arrIn,  [Parameter]
paramIn)   = Name (Array sh e)
-> MIRDelayed PTX aenv (Array sh e)
-> (IRDelayed PTX aenv (Array sh e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array sh e)
"in" MIRDelayed PTX aenv (Array sh e)
marr
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      start :: Operands Int
start               = Int -> Operands Int
liftInt Int
0
      --
      bytes :: Int
bytes               = TypeR e -> Int
forall e. TypeR e -> Int
bytesElt TypeR e
tp
      compute :: Compute
compute             = DeviceProperties -> Compute
computeCapability DeviceProperties
dev
      compute32 :: Compute
compute32           = Int -> Int -> Compute
Compute Int
3 Int
2
      compute60 :: Compute
compute60           = Int -> Int -> Compute
Compute Int
6 Int
0
      compute70 :: Compute
compute70           = Int -> Int -> Compute
Compute Int
7 Int
0
  --
  Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
forall aenv a.
Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc Label
"permute_rmw" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e)))
-> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
forall a b. (a -> b) -> a -> b
$ do

    Operands sh
shIn  <- IRDelayed PTX aenv (Array sh e) -> IRExp PTX aenv sh
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array sh e)
arrIn
    Operands Int
end   <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr Operands sh
Operands sh
shIn

    Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do

      Operands sh
ix  <- ShapeR sh
-> Operands sh -> Operands Int -> CodeGen PTX (Operands sh)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands Int -> CodeGen arch (Operands sh)
indexOfInt ShapeR sh
shr Operands sh
Operands sh
shIn Operands Int
i
      Operands (PrimMaybe sh')
ix' <- IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> Operands sh -> IROpenExp PTX ((), sh) aenv (PrimMaybe sh')
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 PTX aenv (sh -> PrimMaybe sh')
project Operands sh
Operands sh
ix

      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (Operands (PrimMaybe sh') -> CodeGen PTX (Operands Bool)
forall a arch.
Operands (PrimMaybe a) -> CodeGen arch (Operands Bool)
isJust Operands (PrimMaybe sh')
ix') (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
        Operands Int
j <- ShapeR sh'
-> Operands sh' -> Operands sh' -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands sh -> CodeGen arch (Operands Int)
intOfIndex ShapeR sh'
shr' (IRArray (Array sh' e) -> Operands sh'
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh' e)
arrOut) (Operands sh' -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands sh') -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands (PrimMaybe sh') -> CodeGen PTX (Operands sh')
forall a arch. Operands (PrimMaybe a) -> CodeGen arch (Operands a)
fromJust Operands (PrimMaybe sh')
ix'
        Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IROpenExp PTX ((), Int) aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array sh e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array sh e)
arrIn) Operands Int
i
        Operands e
r <- IRFun1 PTX aenv (e -> e)
-> Operands e -> IROpenExp PTX ((), Int) aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 PTX aenv (e -> e)
update Operands e
x

        case RMWOperation
rmw of
          RMWOperation
Exchange
            -> IntegralType Int
-> IRArray (Array sh' e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh' e)
arrOut Operands Int
j Operands e
Operands e
r
          --
          RMWOperation
_ | TupRsingle (SingleScalarType SingleType e
s)   <- TypeR e
tp
            , Operands e
adata                             <- IRArray (Array sh' e) -> Operands e
forall sh e. IRArray (Array sh e) -> Operands e
irArrayData IRArray (Array sh' e)
arrOut
            -> do
                  Operand (Ptr e)
addr <- Instruction (Ptr e) -> CodeGen PTX (Operand (Ptr e))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr e) -> CodeGen PTX (Operand (Ptr e)))
-> Instruction (Ptr e) -> CodeGen PTX (Operand (Ptr e))
forall a b. (a -> b) -> a -> b
$ Operand (Ptr e) -> [Operand Int] -> Instruction (Ptr e)
forall a1 i.
Operand (Ptr a1) -> [Operand i] -> Instruction (Ptr a1)
GetElementPtr (AddrSpace -> Operand e -> Operand (Ptr e)
forall t. HasCallStack => AddrSpace -> Operand t -> Operand (Ptr t)
asPtr AddrSpace
defaultAddrSpace (SingleType e -> Operands e -> Operand e
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op SingleType e
s Operands e
adata)) [IntegralType Int -> Operands Int -> Operand Int
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
j]
                  --
                  let
                      rmw_integral :: IntegralType t -> Operand (Ptr t) -> Operand t -> CodeGen PTX ()
                      rmw_integral :: IntegralType t -> Operand (Ptr t) -> Operand t -> CodeGen PTX ()
rmw_integral IntegralType t
t Operand (Ptr t)
ptr Operand t
val
                        | Bool
primOk    = CodeGen PTX (Operand t) -> CodeGen PTX ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen PTX (Operand t) -> CodeGen PTX ())
-> (Instruction t -> CodeGen PTX (Operand t))
-> Instruction t
-> CodeGen PTX ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Instruction t -> CodeGen PTX (Operand t)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction t -> CodeGen PTX ())
-> Instruction t -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ NumType t
-> Volatility
-> RMWOperation
-> Operand (Ptr t)
-> Operand t
-> Atomicity
-> Instruction t
forall a.
NumType a
-> Volatility
-> RMWOperation
-> Operand (Ptr a)
-> Operand a
-> Atomicity
-> Instruction a
AtomicRMW (IntegralType t -> NumType t
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType t
t) Volatility
NonVolatile RMWOperation
rmw Operand (Ptr t)
ptr Operand t
val (Synchronisation
CrossThread, MemoryOrdering
AcquireRelease)
                        | Bool
otherwise =
                            case RMWOperation
rmw of
                              RMWOperation
RMW.And -> SingleType t
-> (Operands t -> CodeGen PTX (Operands t))
-> Operand (Ptr t)
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (Operands e -> CodeGen arch (Operands e))
-> Operand (Ptr e)
-> CodeGen arch ()
atomicCAS_rmw SingleType t
s' (IntegralType t
-> Operands t -> Operands t -> CodeGen PTX (Operands t)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.band IntegralType t
t (IntegralType t -> Operand t -> Operands t
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType t
t Operand t
val)) Operand (Ptr t)
ptr
                              RMWOperation
RMW.Or  -> SingleType t
-> (Operands t -> CodeGen PTX (Operands t))
-> Operand (Ptr t)
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (Operands e -> CodeGen arch (Operands e))
-> Operand (Ptr e)
-> CodeGen arch ()
atomicCAS_rmw SingleType t
s' (IntegralType t
-> Operands t -> Operands t -> CodeGen PTX (Operands t)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.bor  IntegralType t
t (IntegralType t -> Operand t -> Operands t
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType t
t Operand t
val)) Operand (Ptr t)
ptr
                              RMWOperation
RMW.Xor -> SingleType t
-> (Operands t -> CodeGen PTX (Operands t))
-> Operand (Ptr t)
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (Operands e -> CodeGen arch (Operands e))
-> Operand (Ptr e)
-> CodeGen arch ()
atomicCAS_rmw SingleType t
s' (IntegralType t
-> Operands t -> Operands t -> CodeGen PTX (Operands t)
forall a arch.
IntegralType a
-> Operands a -> Operands a -> CodeGen arch (Operands a)
A.xor  IntegralType t
t (IntegralType t -> Operand t -> Operands t
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType t
t Operand t
val)) Operand (Ptr t)
ptr
                              RMWOperation
RMW.Min -> SingleType t
-> (SingleType t
    -> Operands t -> Operands t -> CodeGen PTX (Operands Bool))
-> Operand (Ptr t)
-> Operand t
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (SingleType e
    -> Operands e -> Operands e -> CodeGen arch (Operands Bool))
-> Operand (Ptr e)
-> Operand e
-> CodeGen arch ()
atomicCAS_cmp SingleType t
s' SingleType t
-> Operands t -> Operands t -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt Operand (Ptr t)
ptr Operand t
val
                              RMWOperation
RMW.Max -> SingleType t
-> (SingleType t
    -> Operands t -> Operands t -> CodeGen PTX (Operands Bool))
-> Operand (Ptr t)
-> Operand t
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (SingleType e
    -> Operands e -> Operands e -> CodeGen arch (Operands Bool))
-> Operand (Ptr e)
-> Operand e
-> CodeGen arch ()
atomicCAS_cmp SingleType t
s' SingleType t
-> Operands t -> Operands t -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gt Operand (Ptr t)
ptr Operand t
val
                              RMWOperation
_       -> String -> CodeGen PTX ()
forall a. HasCallStack => String -> a
internalError String
"unexpected transition"
                        where
                          s' :: SingleType t
s'      = NumType t -> SingleType t
forall a. NumType a -> SingleType a
NumSingleType (IntegralType t -> NumType t
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType t
t)
                          primOk :: Bool
primOk  = Compute
compute Compute -> Compute -> Bool
forall a. Ord a => a -> a -> Bool
>= Compute
compute32
                                 Bool -> Bool -> Bool
|| Int
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4
                                 Bool -> Bool -> Bool
|| case RMWOperation
rmw of
                                      RMWOperation
RMW.Add -> Bool
True
                                      RMWOperation
RMW.Sub -> Bool
True
                                      RMWOperation
_       -> Bool
False

                      rmw_floating :: FloatingType t -> Operand (Ptr t) -> Operand t -> CodeGen PTX ()
                      rmw_floating :: FloatingType t -> Operand (Ptr t) -> Operand t -> CodeGen PTX ()
rmw_floating FloatingType t
t Operand (Ptr t)
ptr Operand t
val =
                        case RMWOperation
rmw of
                          RMWOperation
RMW.Min       -> SingleType t
-> (SingleType t
    -> Operands t -> Operands t -> CodeGen PTX (Operands Bool))
-> Operand (Ptr t)
-> Operand t
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (SingleType e
    -> Operands e -> Operands e -> CodeGen arch (Operands Bool))
-> Operand (Ptr e)
-> Operand e
-> CodeGen arch ()
atomicCAS_cmp SingleType t
s' SingleType t
-> Operands t -> Operands t -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt Operand (Ptr t)
ptr Operand t
val
                          RMWOperation
RMW.Max       -> SingleType t
-> (SingleType t
    -> Operands t -> Operands t -> CodeGen PTX (Operands Bool))
-> Operand (Ptr t)
-> Operand t
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (SingleType e
    -> Operands e -> Operands e -> CodeGen arch (Operands Bool))
-> Operand (Ptr e)
-> Operand e
-> CodeGen arch ()
atomicCAS_cmp SingleType t
s' SingleType t
-> Operands t -> Operands t -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.gt Operand (Ptr t)
ptr Operand t
val
                          RMWOperation
RMW.Sub       -> SingleType t
-> (Operands t -> CodeGen PTX (Operands t))
-> Operand (Ptr t)
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (Operands e -> CodeGen arch (Operands e))
-> Operand (Ptr e)
-> CodeGen arch ()
atomicCAS_rmw SingleType t
s' (NumType t -> Operands t -> Operands t -> CodeGen PTX (Operands t)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.sub NumType t
n (FloatingType t -> Operand t -> Operands t
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir FloatingType t
t Operand t
val)) Operand (Ptr t)
ptr
                          RMWOperation
RMW.Add
                            | Bool
primAdd   -> FloatingType t -> Operand (Ptr t) -> Operand t -> CodeGen PTX ()
forall a.
HasCallStack =>
FloatingType a -> Operand (Ptr a) -> Operand a -> CodeGen PTX ()
atomicAdd_f FloatingType t
t Operand (Ptr t)
ptr Operand t
val
                            | Bool
otherwise -> SingleType t
-> (Operands t -> CodeGen PTX (Operands t))
-> Operand (Ptr t)
-> CodeGen PTX ()
forall arch e.
HasCallStack =>
SingleType e
-> (Operands e -> CodeGen arch (Operands e))
-> Operand (Ptr e)
-> CodeGen arch ()
atomicCAS_rmw SingleType t
s' (NumType t -> Operands t -> Operands t -> CodeGen PTX (Operands t)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType t
n (FloatingType t -> Operand t -> Operands t
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir FloatingType t
t Operand t
val)) Operand (Ptr t)
ptr
                          RMWOperation
_             -> String -> CodeGen PTX ()
forall a. HasCallStack => String -> a
internalError String
"unexpected transition"
                        where
                          n :: NumType t
n       = FloatingType t -> NumType t
forall a. FloatingType a -> NumType a
FloatingNumType FloatingType t
t
                          s' :: SingleType t
s'      = NumType t -> SingleType t
forall a. NumType a -> SingleType a
NumSingleType NumType t
n
                          primAdd :: Bool
primAdd =
                            case FloatingType t
t of
                              FloatingType t
TypeHalf   -> Compute
compute Compute -> Compute -> Bool
forall a. Ord a => a -> a -> Bool
>= Compute
compute70
                              FloatingType t
TypeFloat  -> Bool
True
                              FloatingType t
TypeDouble -> Compute
compute Compute -> Compute -> Bool
forall a. Ord a => a -> a -> Bool
>= Compute
compute60
                  case SingleType e
s of
                    NumSingleType (IntegralNumType IntegralType e
t) -> IntegralType e -> Operand (Ptr e) -> Operand e -> CodeGen PTX ()
forall t.
IntegralType t -> Operand (Ptr t) -> Operand t -> CodeGen PTX ()
rmw_integral IntegralType e
t Operand (Ptr e)
addr (IntegralType e -> Operands e -> Operand e
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType e
t Operands e
Operands e
r)
                    NumSingleType (FloatingNumType FloatingType e
t) -> FloatingType e -> Operand (Ptr e) -> Operand e -> CodeGen PTX ()
forall t.
FloatingType t -> Operand (Ptr t) -> Operand t -> CodeGen PTX ()
rmw_floating FloatingType e
t Operand (Ptr e)
addr (FloatingType e -> Operands e -> Operand e
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op FloatingType e
t Operands e
Operands e
r)
          --
          RMWOperation
_ -> String -> CodeGen PTX ()
forall a. HasCallStack => String -> a
internalError String
"unexpected transition"

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Parallel forward permutation function which uses a spinlock to acquire
-- a mutex before updating the value at that location.
--
mkPermute_mutex
    :: Gamma          aenv
    -> ArrayR (Array sh e)
    -> ShapeR sh'
    -> IRFun2     PTX aenv (e -> e -> e)
    -> IRFun1     PTX aenv (sh -> PrimMaybe sh')
    -> MIRDelayed PTX aenv (Array sh e)
    -> CodeGen    PTX      (IROpenAcc PTX aenv (Array sh' e))
mkPermute_mutex :: Gamma aenv
-> ArrayR (Array sh e)
-> ShapeR sh'
-> IRFun2 PTX aenv (e -> e -> e)
-> IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> MIRDelayed PTX aenv (Array sh e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
mkPermute_mutex Gamma aenv
aenv (ArrayR ShapeR sh
shr TypeR e
tp) ShapeR sh'
shr' IRFun2 PTX aenv (e -> e -> e)
combine IRFun1 PTX aenv (sh -> PrimMaybe sh')
project MIRDelayed PTX aenv (Array sh e)
marr =
  let
      outR :: ArrayR (Array sh' e)
outR                  = ShapeR sh' -> TypeR e -> ArrayR (Array sh' e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh'
shr' TypeR e
tp
      lockR :: ArrayR (Array ((), Int) Word32)
lockR                 = ShapeR ((), Int) -> TypeR Word32 -> ArrayR (Array ((), Int) Word32)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR (ShapeR () -> ShapeR ((), Int)
forall sh1. ShapeR sh1 -> ShapeR (sh1, Int)
ShapeRsnoc ShapeR ()
ShapeRz) (ScalarType Word32 -> TypeR Word32
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Word32
scalarTypeWord32)
      (IRArray (Array sh' e)
arrOut,  [Parameter]
paramOut)   = ArrayR (Array sh' e)
-> Name (Array sh' e) -> (IRArray (Array sh' e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh' e)
outR Name (Array sh' e)
"out"
      (IRArray (Array ((), Int) Word32)
arrLock, [Parameter]
paramLock)  = ArrayR (Array ((), Int) Word32)
-> Name (Array ((), Int) Word32)
-> (IRArray (Array ((), Int) Word32), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array ((), Int) Word32)
lockR Name (Array ((), Int) Word32)
"lock"
      (IRDelayed PTX aenv (Array sh e)
arrIn,   [Parameter]
paramIn)    = Name (Array sh e)
-> MIRDelayed PTX aenv (Array sh e)
-> (IRDelayed PTX aenv (Array sh e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array sh e)
"in" MIRDelayed PTX aenv (Array sh e)
marr
      paramEnv :: [Parameter]
paramEnv              = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      start :: Operands Int
start                 = Int -> Operands Int
liftInt Int
0
  in
  Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
forall aenv a.
Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc Label
"permute_mutex" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramLock [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e)))
-> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh' e))
forall a b. (a -> b) -> a -> b
$ do

    Operands sh
shIn  <- IRDelayed PTX aenv (Array sh e) -> IRExp PTX aenv sh
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRExp arch aenv sh
delayedExtent IRDelayed PTX aenv (Array sh e)
arrIn
    Operands Int
end   <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr Operands sh
Operands sh
shIn

    Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do

      Operands sh
ix  <- ShapeR sh
-> Operands sh -> Operands Int -> CodeGen PTX (Operands sh)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands Int -> CodeGen arch (Operands sh)
indexOfInt ShapeR sh
shr Operands sh
Operands sh
shIn Operands Int
i
      Operands (PrimMaybe sh')
ix' <- IRFun1 PTX aenv (sh -> PrimMaybe sh')
-> Operands sh -> IROpenExp PTX ((), sh) aenv (PrimMaybe sh')
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 PTX aenv (sh -> PrimMaybe sh')
project Operands sh
Operands sh
ix

      -- project element onto the destination array and (atomically) update
      CodeGen PTX (Operands Bool) -> CodeGen PTX () -> CodeGen PTX ()
forall arch.
CodeGen arch (Operands Bool) -> CodeGen arch () -> CodeGen arch ()
when (Operands (PrimMaybe sh') -> CodeGen PTX (Operands Bool)
forall a arch.
Operands (PrimMaybe a) -> CodeGen arch (Operands Bool)
isJust Operands (PrimMaybe sh')
ix') (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
        Operands Int
j <- ShapeR sh'
-> Operands sh' -> Operands sh' -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands sh -> CodeGen arch (Operands Int)
intOfIndex ShapeR sh'
shr' (IRArray (Array sh' e) -> Operands sh'
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh' e)
arrOut) (Operands sh' -> CodeGen PTX (Operands Int))
-> CodeGen PTX (Operands sh') -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Operands (PrimMaybe sh') -> CodeGen PTX (Operands sh')
forall a arch. Operands (PrimMaybe a) -> CodeGen arch (Operands a)
fromJust Operands (PrimMaybe sh')
ix'
        Operands e
x <- IROpenFun1 PTX () aenv (Int -> e)
-> Operands Int -> IROpenExp PTX ((), Int) aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed PTX aenv (Array sh e)
-> IROpenFun1 PTX () aenv (Int -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed PTX aenv (Array sh e)
arrIn) Operands Int
i

        IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX () -> CodeGen PTX ()
forall a.
IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
atomically IRArray (Array ((), Int) Word32)
arrLock Operands Int
j (CodeGen PTX () -> CodeGen PTX ())
-> CodeGen PTX () -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ do
          Operands e
y <- IntegralType Int
-> IRArray (Array sh' e)
-> Operands Int
-> CodeGen PTX (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Array sh' e)
arrOut Operands Int
j
          Operands e
r <- IRFun2 PTX aenv (e -> e -> e)
-> Operands e -> Operands e -> IROpenExp PTX ((), Int) aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 PTX aenv (e -> e -> e)
combine Operands e
x Operands e
Operands e
y
          IntegralType Int
-> IRArray (Array sh' e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh' e)
arrOut Operands Int
j Operands e
Operands e
r

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_


-- Atomically execute the critical section only when the lock at the given
-- array indexed is obtained.
--
atomically
    :: IRArray (Vector Word32)
    -> Operands Int
    -> CodeGen PTX a
    -> CodeGen PTX a
atomically :: IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
atomically IRArray (Array ((), Int) Word32)
barriers Operands Int
i CodeGen PTX a
action = do
  DeviceProperties
dev <- LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall arch a. LLVM arch a -> CodeGen arch a
liftCodeGen (LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties)
-> LLVM PTX DeviceProperties -> CodeGen PTX DeviceProperties
forall a b. (a -> b) -> a -> b
$ (PTX -> DeviceProperties) -> LLVM PTX DeviceProperties
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets PTX -> DeviceProperties
ptxDeviceProperties
  if DeviceProperties -> Compute
computeCapability DeviceProperties
dev Compute -> Compute -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Int -> Compute
Compute Int
7 Int
0
     then IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
forall a.
IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
atomically_thread IRArray (Array ((), Int) Word32)
barriers Operands Int
i CodeGen PTX a
action
     else IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
forall a.
IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
atomically_warp   IRArray (Array ((), Int) Word32)
barriers Operands Int
i CodeGen PTX a
action


-- Atomically execute the critical section only when the lock at the given
-- array index is obtained. The thread spins waiting for the lock to be
-- released with exponential backoff on failure in case the lock is
-- contended.
--
-- > uint32_t ns = 8;
-- > while ( atomic_exchange(&lock[i], 1) == 1 ) {
-- >     __nanosleep(ns);
-- >     if ( ns < 256 ) {
-- >         ns *= 2;
-- >     }
-- > }
--
-- Requires independent thread scheduling features of SM7+.
--
atomically_thread
    :: IRArray (Vector Word32)
    -> Operands Int
    -> CodeGen PTX a
    -> CodeGen PTX a
atomically_thread :: IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
atomically_thread IRArray (Array ((), Int) Word32)
barriers Operands Int
i CodeGen PTX a
action = do
  let
      lock :: Operand Word32
lock    = IntegralType Word32 -> Word32 -> Operand Word32
forall a. IntegralType a -> a -> Operand a
integral IntegralType Word32
forall a. IsIntegral a => IntegralType a
integralType Word32
1
      unlock :: Operand Word32
unlock  = IntegralType Word32 -> Word32 -> Operand Word32
forall a. IntegralType a -> a -> Operand a
integral IntegralType Word32
forall a. IsIntegral a => IntegralType a
integralType Word32
0
      unlock' :: Operands Word32
unlock' = IntegralType Word32 -> Operand Word32 -> Operands Word32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType Word32
TypeWord32 Operand Word32
unlock
      i32 :: TupR ScalarType Int32
i32     = ScalarType Int32 -> TupR ScalarType Int32
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int32
scalarTypeInt32
  --
  Block
entry <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.entry"
  Block
sleep <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.backoff"
  Block
moar  <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.backoff-moar"
  Block
start <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.critical-start"
  Block
end   <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.critical-end"
  Block
exit  <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.exit"
  Operands Int32
ns    <- TupR ScalarType Int32 -> CodeGen PTX (Operands Int32)
forall a arch. TypeR a -> CodeGen arch (Operands a)
fresh TupR ScalarType Int32
i32

  Operand (Ptr Word32)
addr  <- Instruction (Ptr Word32) -> CodeGen PTX (Operand (Ptr Word32))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr Word32) -> CodeGen PTX (Operand (Ptr Word32)))
-> Instruction (Ptr Word32) -> CodeGen PTX (Operand (Ptr Word32))
forall a b. (a -> b) -> a -> b
$ Operand (Ptr Word32) -> [Operand Int] -> Instruction (Ptr Word32)
forall a1 i.
Operand (Ptr a1) -> [Operand i] -> Instruction (Ptr a1)
GetElementPtr (AddrSpace -> Operand Word32 -> Operand (Ptr Word32)
forall t. HasCallStack => AddrSpace -> Operand t -> Operand (Ptr t)
asPtr AddrSpace
defaultAddrSpace (IntegralType Word32 -> Operands Word32 -> Operand Word32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Word32
forall a. IsIntegral a => IntegralType a
integralType (IRArray (Array ((), Int) Word32) -> Operands Word32
forall sh e. IRArray (Array sh e) -> Operands e
irArrayData IRArray (Array ((), Int) Word32)
barriers))) [IntegralType Int -> Operands Int -> Operand Int
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
i]
  Block
top   <- Block -> CodeGen PTX Block
forall arch. HasCallStack => Block -> CodeGen arch Block
br Block
entry

  -- Loop until this thread has completed its critical section. If the slot
  -- was unlocked we just acquired the lock and the thread can perform its
  -- critical section, otherwise sleep the thread and try again later.
  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
entry
  Operands Word32
old   <- Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Instruction Word32 -> CodeGen PTX (Operands Word32))
-> Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a b. (a -> b) -> a -> b
$ NumType Word32
-> Volatility
-> RMWOperation
-> Operand (Ptr Word32)
-> Operand Word32
-> Atomicity
-> Instruction Word32
forall a.
NumType a
-> Volatility
-> RMWOperation
-> Operand (Ptr a)
-> Operand a
-> Atomicity
-> Instruction a
AtomicRMW NumType Word32
forall a. IsNum a => NumType a
numType Volatility
NonVolatile RMWOperation
Exchange Operand (Ptr Word32)
addr Operand Word32
lock (Synchronisation
CrossThread, MemoryOrdering
Acquire)
  Operands Bool
ok    <- SingleType Word32
-> Operands Word32
-> Operands Word32
-> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Word32
forall a. IsSingle a => SingleType a
singleType Operands Word32
old Operands Word32
unlock'
  Block
_     <- Operands Bool -> Block -> Block -> CodeGen PTX Block
forall arch.
HasCallStack =>
Operands Bool -> Block -> Block -> CodeGen arch Block
cbr Operands Bool
ok Block
start Block
sleep

  -- We did not acquire the lock. Sleep the thread for a small amount of
  -- time and (possibly) increase the sleep duration for the next round
  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
sleep
  ()
_     <- Operands Int32 -> CodeGen PTX ()
nanosleep Operands Int32
ns
  Operands Bool
p     <- SingleType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.lt SingleType Int32
forall a. IsSingle a => SingleType a
singleType Operands Int32
ns (IntegralType Int32 -> Operand Int32 -> Operands Int32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType Int32
TypeInt32 (IntegralType Int32 -> Int32 -> Operand Int32
forall a. IntegralType a -> a -> Operand a
integral IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Int32
256))
  Block
_     <- Operands Bool -> Block -> Block -> CodeGen PTX Block
forall arch.
HasCallStack =>
Operands Bool -> Block -> Block -> CodeGen arch Block
cbr Operands Bool
p Block
moar Block
entry

  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
moar
  Operands Int32
ns'   <- NumType Int32
-> Operands Int32 -> Operands Int32 -> CodeGen PTX (Operands Int32)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.mul NumType Int32
forall a. IsNum a => NumType a
numType Operands Int32
ns (IntegralType Int32 -> Operand Int32 -> Operands Int32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType Int32
TypeInt32 (IntegralType Int32 -> Int32 -> Operand Int32
forall a. IntegralType a -> a -> Operand a
integral IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType Int32
2))
  Operands Int32
_     <- TupR ScalarType Int32
-> Block
-> Operands Int32
-> [(Operands Int32, Block)]
-> CodeGen PTX (Operands Int32)
forall a arch.
HasCallStack =>
TypeR a
-> Block
-> Operands a
-> [(Operands a, Block)]
-> CodeGen arch (Operands a)
phi' TupR ScalarType Int32
i32 Block
entry Operands Int32
ns [(IntegralType Int32 -> Operand Int32 -> Operands Int32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType Int32
TypeInt32 (IntegralType Int32 -> Int32 -> Operand Int32
forall a. IntegralType a -> a -> Operand a
integral (IntegralType Int32
forall a. IsIntegral a => IntegralType a
integralType) Int32
8), Block
top), (Operands Int32
ns, Block
sleep), (Operands Int32
ns', Block
moar)]
  Block
_     <- Block -> CodeGen PTX Block
forall arch. HasCallStack => Block -> CodeGen arch Block
br Block
entry

  -- If we just acquired the lock, execute the critical section, then
  -- release the lock and continue with your day.
  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
start
  a
r     <- CodeGen PTX a
action
  Block
_     <- Block -> CodeGen PTX Block
forall arch. HasCallStack => Block -> CodeGen arch Block
br Block
end

  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
end
  Operands Word32
_     <- Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Instruction Word32 -> CodeGen PTX (Operands Word32))
-> Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a b. (a -> b) -> a -> b
$ NumType Word32
-> Volatility
-> RMWOperation
-> Operand (Ptr Word32)
-> Operand Word32
-> Atomicity
-> Instruction Word32
forall a.
NumType a
-> Volatility
-> RMWOperation
-> Operand (Ptr a)
-> Operand a
-> Atomicity
-> Instruction a
AtomicRMW NumType Word32
forall a. IsNum a => NumType a
numType Volatility
NonVolatile RMWOperation
Exchange Operand (Ptr Word32)
addr Operand Word32
unlock (Synchronisation
CrossThread, MemoryOrdering
AcquireRelease)
  ()
_     <- CodeGen PTX ()
__threadfence_grid   -- TODO: why is this required?
  Block
_     <- Block -> CodeGen PTX Block
forall arch. HasCallStack => Block -> CodeGen arch Block
br Block
exit

  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
exit
  a -> CodeGen PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r


-- Atomically execute the critical section only when the lock at the given array
-- index is obtained. The thread spins waiting for the lock to be released and
-- there is no backoff strategy in case the lock is contended.
--
-- The canonical implementation of a spin-lock looks like this:
--
-- > do {
-- >   old = atomic_exchange(&lock[i], 1);
-- > } while (old == 1);
-- >
-- > /* critical section */
-- >
-- > atomic_exchange(&lock[i], 0);
--
-- The initial loop repeatedly attempts to take the lock by writing a 1 (locked)
-- into the lock slot. Once the 'old' state of the lock returns 0 (unlocked),
-- then we just acquired the lock and the atomic section can be computed.
-- Finally, the lock is released by writing 0 back to the lock slot.
--
-- However, there is a complication with CUDA devices because all threads in
-- a warp must execute in lockstep (with predicated execution). In the above
-- setup, once a thread acquires a lock, then it will be disabled and stop
-- participating in the loop, waiting for all other threads (to acquire their
-- locks) before continuing program execution. If two threads in the same warp
-- attempt to acquire the same lock, then once the lock is acquired by one
-- thread then it will sit idle waiting while the second thread spins attempting
-- to grab a lock that will never be released because the first thread (which
-- holds the lock) can not make progress. DEADLOCK.
--
-- To prevent this situation we must invert the algorithm so that threads can
-- always make progress, until each warp in the thread has committed their
-- result.
--
-- > done = 0;
-- > do {
-- >   if ( atomic_exchange(&lock[i], 1) == 0 ) {
-- >
-- >     /* critical section */
-- >
-- >     done = 1;
-- >     atomic_exchange(&lock[i], 0);
-- >   }
-- > } while ( done == 0 );
--
atomically_warp
    :: IRArray (Vector Word32)
    -> Operands Int
    -> CodeGen PTX a
    -> CodeGen PTX a
atomically_warp :: IRArray (Array ((), Int) Word32)
-> Operands Int -> CodeGen PTX a -> CodeGen PTX a
atomically_warp IRArray (Array ((), Int) Word32)
barriers Operands Int
i CodeGen PTX a
action = do
  let
      lock :: Operand Word32
lock    = IntegralType Word32 -> Word32 -> Operand Word32
forall a. IntegralType a -> a -> Operand a
integral IntegralType Word32
forall a. IsIntegral a => IntegralType a
integralType Word32
1
      unlock :: Operand Word32
unlock  = IntegralType Word32 -> Word32 -> Operand Word32
forall a. IntegralType a -> a -> Operand a
integral IntegralType Word32
forall a. IsIntegral a => IntegralType a
integralType Word32
0
      unlock' :: Operands Word32
unlock' = IntegralType Word32 -> Operand Word32 -> Operands Word32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operand a -> Operands a
ir IntegralType Word32
TypeWord32 Operand Word32
unlock
  --
  Block
entry <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.entry"
  Block
start <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.critical-start"
  Block
end   <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.critical-end"
  Block
exit  <- String -> CodeGen PTX Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock String
"spinlock.exit"

  Operand (Ptr Word32)
addr <- Instruction (Ptr Word32) -> CodeGen PTX (Operand (Ptr Word32))
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operand a)
instr' (Instruction (Ptr Word32) -> CodeGen PTX (Operand (Ptr Word32)))
-> Instruction (Ptr Word32) -> CodeGen PTX (Operand (Ptr Word32))
forall a b. (a -> b) -> a -> b
$ Operand (Ptr Word32) -> [Operand Int] -> Instruction (Ptr Word32)
forall a1 i.
Operand (Ptr a1) -> [Operand i] -> Instruction (Ptr a1)
GetElementPtr (AddrSpace -> Operand Word32 -> Operand (Ptr Word32)
forall t. HasCallStack => AddrSpace -> Operand t -> Operand (Ptr t)
asPtr AddrSpace
defaultAddrSpace (IntegralType Word32 -> Operands Word32 -> Operand Word32
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Word32
forall a. IsIntegral a => IntegralType a
integralType (IRArray (Array ((), Int) Word32) -> Operands Word32
forall sh e. IRArray (Array sh e) -> Operands e
irArrayData IRArray (Array ((), Int) Word32)
barriers))) [IntegralType Int -> Operands Int -> Operand Int
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op IntegralType Int
forall a. IsIntegral a => IntegralType a
integralType Operands Int
i]
  Block
_    <- Block -> CodeGen PTX Block
forall arch. HasCallStack => Block -> CodeGen arch Block
br Block
entry

  -- Loop until this thread has completed its critical section. If the slot was
  -- unlocked then we just acquired the lock and the thread can perform the
  -- critical section, otherwise skip to the bottom of the critical section.
  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
entry
  Operands Word32
old  <- Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Instruction Word32 -> CodeGen PTX (Operands Word32))
-> Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a b. (a -> b) -> a -> b
$ NumType Word32
-> Volatility
-> RMWOperation
-> Operand (Ptr Word32)
-> Operand Word32
-> Atomicity
-> Instruction Word32
forall a.
NumType a
-> Volatility
-> RMWOperation
-> Operand (Ptr a)
-> Operand a
-> Atomicity
-> Instruction a
AtomicRMW NumType Word32
forall a. IsNum a => NumType a
numType Volatility
NonVolatile RMWOperation
Exchange Operand (Ptr Word32)
addr Operand Word32
lock   (Synchronisation
CrossThread, MemoryOrdering
Acquire)
  Operands Bool
ok   <- SingleType Word32
-> Operands Word32
-> Operands Word32
-> CodeGen PTX (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
A.eq SingleType Word32
forall a. IsSingle a => SingleType a
singleType Operands Word32
old Operands Word32
unlock'
  Block
no   <- Operands Bool -> Block -> Block -> CodeGen PTX Block
forall arch.
HasCallStack =>
Operands Bool -> Block -> Block -> CodeGen arch Block
cbr Operands Bool
ok Block
start Block
end

  -- If we just acquired the lock, execute the critical section
  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
start
  a
r    <- CodeGen PTX a
action
  Operands Word32
_    <- Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a arch.
HasCallStack =>
Instruction a -> CodeGen arch (Operands a)
instr (Instruction Word32 -> CodeGen PTX (Operands Word32))
-> Instruction Word32 -> CodeGen PTX (Operands Word32)
forall a b. (a -> b) -> a -> b
$ NumType Word32
-> Volatility
-> RMWOperation
-> Operand (Ptr Word32)
-> Operand Word32
-> Atomicity
-> Instruction Word32
forall a.
NumType a
-> Volatility
-> RMWOperation
-> Operand (Ptr a)
-> Operand a
-> Atomicity
-> Instruction a
AtomicRMW NumType Word32
forall a. IsNum a => NumType a
numType Volatility
NonVolatile RMWOperation
Exchange Operand (Ptr Word32)
addr Operand Word32
unlock (Synchronisation
CrossThread, MemoryOrdering
AcquireRelease)
  Block
yes  <- Block -> CodeGen PTX Block
forall arch. HasCallStack => Block -> CodeGen arch Block
br Block
end

  -- At the base of the critical section, threads participate in a memory fence
  -- to ensure the lock state is committed to memory. Depending on which
  -- incoming edge the thread arrived at this block from determines whether they
  -- have completed their critical section.
  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
end
  Name Bool
res  <- CodeGen PTX (Name Bool)
forall arch a. CodeGen arch (Name a)
freshName
  Operand Bool
done <- Block
-> Name Bool
-> [(Operand Bool, Block)]
-> CodeGen PTX (Operand Bool)
forall a arch.
HasCallStack =>
Block -> Name a -> [(Operand a, Block)] -> CodeGen arch (Operand a)
phi1 Block
end Name Bool
res [(Bool -> Operand Bool
boolean Bool
True, Block
yes), (Bool -> Operand Bool
boolean Bool
False, Block
no)]

  CodeGen PTX ()
__syncthreads
  Block
_    <- Operands Bool -> Block -> Block -> CodeGen PTX Block
forall arch.
HasCallStack =>
Operands Bool -> Block -> Block -> CodeGen arch Block
cbr (Operand Bool -> Operands Bool
OP_Bool Operand Bool
done) Block
exit Block
entry

  Block -> CodeGen PTX ()
forall arch. Block -> CodeGen arch ()
setBlock Block
exit
  a -> CodeGen PTX a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r