module Futhark.CodeGen.ImpGen.Multicore.Base
  ( extractAllocations,
    compileThreadResult,
    Locks (..),
    HostEnv (..),
    AtomicBinOp,
    MulticoreGen,
    decideScheduling,
    decideScheduling',
    groupResultArrays,
    renameSegBinOp,
    freeParams,
    renameHistOpLambda,
    atomicUpdateLocking,
    AtomicUpdate (..),
    Locking (..),
    getSpace,
    getIterationDomain,
    getReturnParams,
    segOpString,
  )
where

import Control.Monad
import Data.Bifunctor
import Data.List (elemIndex, find)
import qualified Data.Map as M
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Error
import Futhark.IR.MCMem
import Futhark.Transform.Rename
import Futhark.Util (maybeNth)
import Prelude hiding (quot, rem)

-- | Is there an atomic t'BinOp' corresponding to this t'BinOp'?
type AtomicBinOp =
  BinOp ->
  Maybe (VName -> VName -> Imp.Count Imp.Elements (Imp.TExp Int32) -> Imp.Exp -> Imp.AtomicOp)

-- | Information about the locks available for accumulators.
data Locks = Locks
  { Locks -> VName
locksArray :: VName,
    Locks -> Int
locksCount :: Int
  }

data HostEnv = HostEnv
  { HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp,
    HostEnv -> Map VName Locks
hostLocks :: M.Map VName Locks
  }

type MulticoreGen = ImpM MCMem HostEnv Imp.Multicore

segOpString :: SegOp () MCMem -> MulticoreGen String
segOpString :: SegOp () MCMem -> MulticoreGen String
segOpString SegMap {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"segmap"
segOpString SegRed {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"segred"
segOpString SegScan {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"segscan"
segOpString SegHist {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"seghist"

arrParam :: VName -> MulticoreGen Imp.Param
arrParam :: VName -> MulticoreGen Param
arrParam VName
arr = do
  VarEntry MCMem
name_entry <- VName -> ImpM MCMem HostEnv Multicore (VarEntry MCMem)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
arr
  case VarEntry MCMem
name_entry of
    ArrayVar Maybe (Exp MCMem)
_ (ArrayEntry (MemLocation VName
mem [SubExp]
_ IxFun (TExp Int64)
_) PrimType
_) ->
      Param -> MulticoreGen Param
forall (m :: * -> *) a. Monad m => a -> m a
return (Param -> MulticoreGen Param) -> Param -> MulticoreGen Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
mem Space
DefaultSpace
    VarEntry MCMem
_ -> String -> MulticoreGen Param
forall a. HasCallStack => String -> a
error (String -> MulticoreGen Param) -> String -> MulticoreGen Param
forall a b. (a -> b) -> a -> b
$ String
"arrParam: could not handle array " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Show a => a -> String
show VName
arr

toParam :: VName -> TypeBase shape u -> MulticoreGen [Imp.Param]
toParam :: forall shape u. VName -> TypeBase shape u -> MulticoreGen [Param]
toParam VName
name (Prim PrimType
pt) = [Param] -> MulticoreGen [Param]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
pt]
toParam VName
name (Mem Space
space) = [Param] -> MulticoreGen [Param]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName -> Space -> Param
Imp.MemParam VName
name Space
space]
toParam VName
name Array {} = Param -> [Param]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param -> [Param]) -> MulticoreGen Param -> MulticoreGen [Param]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> MulticoreGen Param
arrParam VName
name
toParam VName
name Acc {} = String -> MulticoreGen [Param]
forall a. HasCallStack => String -> a
error (String -> MulticoreGen [Param]) -> String -> MulticoreGen [Param]
forall a b. (a -> b) -> a -> b
$ String
"toParam Acc: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

getSpace :: SegOp () MCMem -> SegSpace
getSpace :: SegOp () MCMem -> SegSpace
getSpace (SegHist ()
_ SegSpace
space [HistOp MCMem]
_ [Type]
_ KernelBody MCMem
_) = SegSpace
space
getSpace (SegRed ()
_ SegSpace
space [SegBinOp MCMem]
_ [Type]
_ KernelBody MCMem
_) = SegSpace
space
getSpace (SegScan ()
_ SegSpace
space [SegBinOp MCMem]
_ [Type]
_ KernelBody MCMem
_) = SegSpace
space
getSpace (SegMap ()
_ SegSpace
space [Type]
_ KernelBody MCMem
_) = SegSpace
space

getIterationDomain :: SegOp () MCMem -> SegSpace -> MulticoreGen (Imp.TExp Int64)
getIterationDomain :: SegOp () MCMem -> SegSpace -> MulticoreGen (TExp Int64)
getIterationDomain SegMap {} SegSpace
space = do
  let ns :: [SubExp]
ns = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  TExp Int64 -> MulticoreGen (TExp Int64)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64 -> MulticoreGen (TExp Int64))
-> TExp Int64 -> MulticoreGen (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64
getIterationDomain SegOp () MCMem
_ SegSpace
space = do
  let ns :: [SubExp]
ns = ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ns
  case SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space of
    [(VName, SubExp)
_] -> TExp Int64 -> MulticoreGen (TExp Int64)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64 -> MulticoreGen (TExp Int64))
-> TExp Int64 -> MulticoreGen (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64
    -- A segmented SegOp is over the segments
    -- so we drop the last dimension, which is
    -- executed sequentially
    [(VName, SubExp)]
_ -> TExp Int64 -> MulticoreGen (TExp Int64)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64 -> MulticoreGen (TExp Int64))
-> TExp Int64 -> MulticoreGen (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_64

-- When the SegRed's return value is a scalar
-- we perform a call by value-result in the segop function
getReturnParams :: Pattern MCMem -> SegOp () MCMem -> MulticoreGen [Imp.Param]
getReturnParams :: Pattern MCMem -> SegOp () MCMem -> MulticoreGen [Param]
getReturnParams Pattern MCMem
pat SegRed {} = do
  let retvals :: [VName]
retvals = (PatElemT LParamMem -> VName) -> [PatElemT LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName ([PatElemT LParamMem] -> [VName])
-> [PatElemT LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LParamMem
pat
  [Type]
retvals_ts <- (VName -> ImpM MCMem HostEnv Multicore Type)
-> [VName] -> ImpM MCMem HostEnv Multicore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ImpM MCMem HostEnv Multicore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
retvals
  [[Param]] -> [Param]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Param]] -> [Param])
-> ImpM MCMem HostEnv Multicore [[Param]] -> MulticoreGen [Param]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Type -> MulticoreGen [Param])
-> [VName] -> [Type] -> ImpM MCMem HostEnv Multicore [[Param]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Type -> MulticoreGen [Param]
forall shape u. VName -> TypeBase shape u -> MulticoreGen [Param]
toParam [VName]
retvals [Type]
retvals_ts
getReturnParams Pattern MCMem
_ SegOp () MCMem
_ = [Param] -> MulticoreGen [Param]
forall (m :: * -> *) a. Monad m => a -> m a
return [Param]
forall a. Monoid a => a
mempty

renameSegBinOp :: [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp :: [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
segbinops =
  [SegBinOp MCMem]
-> (SegBinOp MCMem
    -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
-> MulticoreGen [SegBinOp MCMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
segbinops ((SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
 -> MulticoreGen [SegBinOp MCMem])
-> (SegBinOp MCMem
    -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
-> MulticoreGen [SegBinOp MCMem]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
comm Lambda MCMem
lam [SubExp]
ne Shape
shape) -> do
    Lambda MCMem
lam' <- Lambda MCMem -> ImpM MCMem HostEnv Multicore (Lambda MCMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda MCMem
lam
    SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
-> SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem)
forall a b. (a -> b) -> a -> b
$ Commutativity
-> Lambda MCMem -> [SubExp] -> Shape -> SegBinOp MCMem
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda MCMem
lam' [SubExp]
ne Shape
shape

compileThreadResult ::
  SegSpace ->
  PatElem MCMem ->
  KernelResult ->
  MulticoreGen ()
compileThreadResult :: SegSpace -> PatElem MCMem -> KernelResult -> MulticoreGen ()
compileThreadResult SegSpace
space PatElem MCMem
pe (Returns ResultManifest
_ SubExp
what) = do
  let is :: [TExp Int64]
is = ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElem MCMem
PatElemT LParamMem
pe) [TExp Int64]
is SubExp
what []
compileThreadResult SegSpace
_ PatElem MCMem
_ ConcatReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: ConcatReturn unhandled."
compileThreadResult SegSpace
_ PatElem MCMem
_ WriteReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: WriteReturns unhandled."
compileThreadResult SegSpace
_ PatElem MCMem
_ TileReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: TileReturns unhandled."
compileThreadResult SegSpace
_ PatElem MCMem
_ RegTileReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: RegTileReturns unhandled."

freeVariables :: Imp.Code -> [VName] -> [VName]
freeVariables :: Code -> [VName] -> [VName]
freeVariables Code
code [VName]
names =
  Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Code -> Names
forall a. FreeIn a => a -> Names
freeIn Code
code Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
names

freeParams :: Imp.Code -> [VName] -> MulticoreGen [Imp.Param]
freeParams :: Code -> [VName] -> MulticoreGen [Param]
freeParams Code
code [VName]
names = do
  let freeVars :: [VName]
freeVars = Code -> [VName] -> [VName]
freeVariables Code
code [VName]
names
  [Type]
ts <- (VName -> ImpM MCMem HostEnv Multicore Type)
-> [VName] -> ImpM MCMem HostEnv Multicore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ImpM MCMem HostEnv Multicore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
freeVars
  [[Param]] -> [Param]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Param]] -> [Param])
-> ImpM MCMem HostEnv Multicore [[Param]] -> MulticoreGen [Param]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Type -> MulticoreGen [Param])
-> [VName] -> [Type] -> ImpM MCMem HostEnv Multicore [[Param]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Type -> MulticoreGen [Param]
forall shape u. VName -> TypeBase shape u -> MulticoreGen [Param]
toParam [VName]
freeVars [Type]
ts

-- | Arrays for storing group results shared between threads
groupResultArrays ::
  String ->
  SubExp ->
  [SegBinOp MCMem] ->
  MulticoreGen [[VName]]
groupResultArrays :: String -> SubExp -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
groupResultArrays String
s SubExp
num_threads [SegBinOp MCMem]
reds =
  [SegBinOp MCMem]
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
reds ((SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
 -> MulticoreGen [[VName]])
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda MCMem
lam [SubExp]
_ Shape
shape) ->
    [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda MCMem
lam) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
      let full_shape :: Shape
full_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
      String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
s (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) Shape
full_shape Space
DefaultSpace

isLoadBalanced :: Imp.Code -> Bool
isLoadBalanced :: Code -> Bool
isLoadBalanced (Code
a Imp.:>>: Code
b) = Code -> Bool
isLoadBalanced Code
a Bool -> Bool -> Bool
&& Code -> Bool
isLoadBalanced Code
b
isLoadBalanced (Imp.For VName
_ Exp
_ Code
a) = Code -> Bool
isLoadBalanced Code
a
isLoadBalanced (Imp.If TExp Bool
_ Code
a Code
b) = Code -> Bool
isLoadBalanced Code
a Bool -> Bool -> Bool
&& Code -> Bool
isLoadBalanced Code
b
isLoadBalanced (Imp.Comment String
_ Code
a) = Code -> Bool
isLoadBalanced Code
a
isLoadBalanced Imp.While {} = Bool
False
isLoadBalanced (Imp.Op (Imp.ParLoop String
_ VName
_ Code
_ Code
code Code
_ [Param]
_ VName
_)) = Code -> Bool
isLoadBalanced Code
code
isLoadBalanced Code
_ = Bool
True

segBinOpComm' :: [SegBinOp lore] -> Commutativity
segBinOpComm' :: forall lore. [SegBinOp lore] -> Commutativity
segBinOpComm' = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOp lore] -> [Commutativity])
-> [SegBinOp lore]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOp lore -> Commutativity)
-> [SegBinOp lore] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm

decideScheduling' :: SegOp () lore -> Imp.Code -> Imp.Scheduling
decideScheduling' :: forall lore. SegOp () lore -> Code -> Scheduling
decideScheduling' SegHist {} Code
_ = Scheduling
Imp.Static
decideScheduling' SegScan {} Code
_ = Scheduling
Imp.Static
decideScheduling' (SegRed ()
_ SegSpace
_ [SegBinOp lore]
reds [Type]
_ KernelBody lore
_) Code
code =
  case [SegBinOp lore] -> Commutativity
forall lore. [SegBinOp lore] -> Commutativity
segBinOpComm' [SegBinOp lore]
reds of
    Commutativity
Commutative -> Code -> Scheduling
decideScheduling Code
code
    Commutativity
Noncommutative -> Scheduling
Imp.Static
decideScheduling' SegMap {} Code
code = Code -> Scheduling
decideScheduling Code
code

decideScheduling :: Imp.Code -> Imp.Scheduling
decideScheduling :: Code -> Scheduling
decideScheduling Code
code =
  if Code -> Bool
isLoadBalanced Code
code
    then Scheduling
Imp.Static
    else Scheduling
Imp.Dynamic

-- | Try to extract invariant allocations.  If we assume that the
-- given 'Imp.Code' is the body of a 'SegOp', then it is always safe
-- to move the immediate allocations to the prebody.
extractAllocations :: Imp.Code -> (Imp.Code, Imp.Code)
extractAllocations :: Code -> (Code, Code)
extractAllocations Code
segop_code = Code -> (Code, Code)
forall {a}. Code -> (Code a, Code)
f Code
segop_code
  where
    declared :: Names
declared = Code -> Names
forall a. Code a -> Names
Imp.declaredIn Code
segop_code
    f :: Code -> (Code a, Code)
f (Imp.DeclareMem VName
name Space
space) =
      -- Hoisting declarations out is always safe.
      (VName -> Space -> Code a
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name Space
space, Code
forall a. Monoid a => a
mempty)
    f (Imp.Allocate VName
name Count Bytes (TExp Int64)
size Space
space)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> Names
forall a. FreeIn a => a -> Names
freeIn Count Bytes (TExp Int64)
size Names -> Names -> Bool
`namesIntersect` Names
declared =
        (VName -> Count Bytes (TExp Int64) -> Space -> Code a
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name Count Bytes (TExp Int64)
size Space
space, Code
forall a. Monoid a => a
mempty)
    f (Code
x Imp.:>>: Code
y) = Code -> (Code a, Code)
f Code
x (Code a, Code) -> (Code a, Code) -> (Code a, Code)
forall a. Semigroup a => a -> a -> a
<> Code -> (Code a, Code)
f Code
y
    f (Imp.While TExp Bool
cond Code
body) =
      (Code a
forall a. Monoid a => a
mempty, TExp Bool -> Code -> Code
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code
body)
    f (Imp.For VName
i Exp
bound Code
body) =
      (Code a
forall a. Monoid a => a
mempty, VName -> Exp -> Code -> Code
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code
body)
    f (Imp.Comment String
s Code
code) =
      (Code -> Code) -> (Code a, Code) -> (Code a, Code)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (String -> Code -> Code
forall a. String -> Code a -> Code a
Imp.Comment String
s) (Code -> (Code a, Code)
f Code
code)
    f Imp.Free {} =
      (Code a, Code)
forall a. Monoid a => a
mempty
    f (Imp.If TExp Bool
cond Code
tcode Code
fcode) =
      let (Code a
ta, Code
tcode') = Code -> (Code a, Code)
f Code
tcode
          (Code a
fa, Code
fcode') = Code -> (Code a, Code)
f Code
fcode
       in (Code a
ta Code a -> Code a -> Code a
forall a. Semigroup a => a -> a -> a
<> Code a
fa, TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code
tcode' Code
fcode')
    f (Imp.Op (Imp.ParLoop String
s VName
i Code
prebody Code
body Code
postbody [Param]
free VName
info)) =
      let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
          (Code a
free_allocs, Code
here_allocs) = Code -> (Code a, Code)
f Code
body_allocs
          free' :: [Param]
free' =
            (Param -> Bool) -> [Param] -> [Param]
forall a. (a -> Bool) -> [a] -> [a]
filter
              ( Bool -> Bool
not
                  (Bool -> Bool) -> (Param -> Bool) -> Param -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names -> Bool
`nameIn` Code -> Names
forall a. Code a -> Names
Imp.declaredIn Code
body_allocs)
                  (VName -> Bool) -> (Param -> VName) -> Param -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param -> VName
Imp.paramName
              )
              [Param]
free
       in ( Code a
free_allocs,
            Code
here_allocs
              Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Multicore -> Code
forall a. a -> Code a
Imp.Op (String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
s VName
i Code
prebody Code
body' Code
postbody [Param]
free' VName
info)
          )
    f Code
code =
      (Code a
forall a. Monoid a => a
mempty, Code
code)

-------------------------------
------- SegHist helpers -------
-------------------------------
renameHistOpLambda :: [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda :: [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda [HistOp MCMem]
hist_ops =
  [HistOp MCMem]
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
-> MulticoreGen [HistOp MCMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp MCMem]
hist_ops ((HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
 -> MulticoreGen [HistOp MCMem])
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
-> MulticoreGen [HistOp MCMem]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
w SubExp
rf [VName]
dest [SubExp]
neutral Shape
shape Lambda MCMem
lam) -> do
    Lambda MCMem
lam' <- Lambda MCMem -> ImpM MCMem HostEnv Multicore (Lambda MCMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda MCMem
lam
    HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
-> HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda MCMem
-> HistOp MCMem
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
HistOp SubExp
w SubExp
rf [VName]
dest [SubExp]
neutral Shape
shape Lambda MCMem
lam'

-- | Locking strategy used for an atomic update.
data Locking = Locking
  { -- | Array containing the lock.
    Locking -> VName
lockingArray :: VName,
    -- | Value for us to consider the lock free.
    Locking -> TExp Int32
lockingIsUnlocked :: Imp.TExp Int32,
    -- | What to write when we lock it.
    Locking -> TExp Int32
lockingToLock :: Imp.TExp Int32,
    -- | What to write when we unlock it.
    Locking -> TExp Int32
lockingToUnlock :: Imp.TExp Int32,
    -- | A transformation from the logical lock index to the
    -- physical position in the array.  This can also be used
    -- to make the lock array smaller.
    Locking -> [TExp Int64] -> [TExp Int64]
lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64]
  }

-- | A function for generating code for an atomic update.  Assumes
-- that the bucket is in-bounds.
type DoAtomicUpdate lore r =
  [VName] -> [Imp.TExp Int64] -> MulticoreGen ()

-- | The mechanism that will be used for performing the atomic update.
-- Approximates how efficient it will be.  Ordered from most to least
-- efficient.
data AtomicUpdate lore r
  = AtomicPrim (DoAtomicUpdate lore r)
  | -- | Can be done by efficient swaps.
    AtomicCAS (DoAtomicUpdate lore r)
  | -- | Requires explicit locking.
    AtomicLocking (Locking -> DoAtomicUpdate lore r)

atomicUpdateLocking ::
  AtomicBinOp ->
  Lambda MCMem ->
  AtomicUpdate MCMem ()
atomicUpdateLocking :: AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem ()
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda MCMem
lam
  | Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda MCMem -> Maybe [(BinOp, PrimType, VName, VName)]
forall lore.
ASTLore lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda MCMem
lam,
    ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> Int -> Bool
supportedPrims (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ PrimType -> Int
primBitSize PrimType
t) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
    [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall {t :: * -> *} {b} {c} {d} {lore} {r}.
Foldable t =>
t (BinOp, b, c, d)
-> DoAtomicUpdate MCMem () -> AtomicUpdate lore r
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ())
-> DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall a b. (a -> b) -> a -> b
$ \[VName]
arrs [TExp Int64]
bucket ->
      -- If the operator is a vectorised binary operator on 32-bit values,
      -- we can use a particularly efficient implementation. If the
      -- operator has an atomic implementation we use that, otherwise it
      -- is still a binary operator which can be implemented by atomic
      -- compare-and-swap if 32 bits.
      [(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do
        -- Common variables.
        TV Any
old <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
t

        (VName
arr', Space
_a_space, Count Elements (TExp Int64)
bucket_offset) <- VName
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
a [TExp Int64]
bucket

        case VName
-> VName
-> Count Elements (TExp Int32)
-> BinOp
-> Maybe (Exp -> Multicore)
opHasAtomicSupport (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) VName
arr' (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
bucket_offset) BinOp
op of
          Just Exp -> Multicore
f -> Multicore -> MulticoreGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> Multicore
f (Exp -> Multicore) -> Exp -> Multicore
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
y PrimType
t
          Maybe (Exp -> Multicore)
Nothing ->
            PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> MulticoreGen ()
-> MulticoreGen ()
atomicUpdateCAS PrimType
t VName
a (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TExp Int64]
bucket VName
x (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              VName
x VName -> Exp -> MulticoreGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t) (VName -> PrimType -> Exp
Imp.var VName
y PrimType
t)
  where
    opHasAtomicSupport :: VName
-> VName
-> Count Elements (TExp Int32)
-> BinOp
-> Maybe (Exp -> Multicore)
opHasAtomicSupport VName
old VName
arr' Count Elements (TExp Int32)
bucket' BinOp
bop = do
      let atomic :: (VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp)
-> a -> Multicore
atomic VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp
f = AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> (a -> AtomicOp) -> a -> Multicore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp
f VName
old VName
arr' Count Elements (TExp Int32)
bucket'
      (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Exp -> Multicore
forall {a}.
(VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp)
-> a -> Multicore
atomic ((VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
 -> Exp -> Multicore)
-> Maybe
     (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Maybe (Exp -> Multicore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop

    primOrCas :: t (BinOp, b, c, d)
-> DoAtomicUpdate MCMem () -> AtomicUpdate lore r
primOrCas t (BinOp, b, c, d)
ops
      | ((BinOp, b, c, d) -> Bool) -> t (BinOp, b, c, d) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, b, c, d) -> Bool
forall {b} {c} {d}. (BinOp, b, c, d) -> Bool
isPrim t (BinOp, b, c, d)
ops = DoAtomicUpdate MCMem () -> AtomicUpdate lore r
forall lore r. DoAtomicUpdate MCMem () -> AtomicUpdate lore r
AtomicPrim
      | Bool
otherwise = DoAtomicUpdate MCMem () -> AtomicUpdate lore r
forall lore r. DoAtomicUpdate MCMem () -> AtomicUpdate lore r
AtomicCAS

    isPrim :: (BinOp, b, c, d) -> Bool
isPrim (BinOp
op, b
_, c
_, d
_) = Maybe
  (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe
   (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
 -> Bool)
-> Maybe
     (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op
atomicUpdateLocking AtomicBinOp
_ Lambda MCMem
op
  | [Prim PrimType
t] <- Lambda MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda MCMem
op,
    [LParam MCMem
xp, LParam MCMem
_] <- Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
op,
    Int -> Bool
supportedPrims (PrimType -> Int
primBitSize PrimType
t) = DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall lore r. DoAtomicUpdate MCMem () -> AtomicUpdate lore r
AtomicCAS (DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ())
-> DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall a b. (a -> b) -> a -> b
$ \[VName
arr] [TExp Int64]
bucket -> do
    TV Any
old <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
t
    PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> MulticoreGen ()
-> MulticoreGen ()
atomicUpdateCAS PrimType
t VName
arr (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TExp Int64]
bucket (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName LParam MCMem
Param LParamMem
xp) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      [Param LParamMem] -> Body MCMem -> MulticoreGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [LParam MCMem
Param LParamMem
xp] (Body MCMem -> MulticoreGen ()) -> Body MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda MCMem
op
atomicUpdateLocking AtomicBinOp
_ Lambda MCMem
op = (Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate MCMem ()
forall lore r.
(Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate lore r
AtomicLocking ((Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate MCMem ())
-> (Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate MCMem ()
forall a b. (a -> b) -> a -> b
$ \Locking
locking [VName]
arrs [TExp Int64]
bucket -> do
  TV Int32
old <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
int32
  TV Int32
continue <- String
-> PrimType
-> TExp Int32
-> ImpM MCMem HostEnv Multicore (TV Int32)
forall t lore r op.
String -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol String
"continue" PrimType
int32 (TExp Int32
0 :: Imp.TExp Int32)

  -- Correctly index into locks.
  (VName
locks', Space
_locks_space, Count Elements (TExp Int64)
locks_offset) <-
    VName
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([TExp Int64]
 -> ImpM
      MCMem
      HostEnv
      Multicore
      (VName, Space, Count Elements (TExp Int64)))
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ Locking -> [TExp Int64] -> [TExp Int64]
lockingMapping Locking
locking [TExp Int64]
bucket

  -- Critical section
  let try_acquire_lock :: ImpM lore r Multicore ()
try_acquire_lock = do
        TV Int32
old TV Int32 -> TExp Int32 -> ImpM lore r Multicore ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- (TExp Int32
0 :: Imp.TExp Int32)
        Multicore -> ImpM lore r Multicore ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> ImpM lore r Multicore ())
-> Multicore -> ImpM lore r Multicore ()
forall a b. (a -> b) -> a -> b
$
          AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> AtomicOp -> Multicore
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TExp Int32)
-> VName
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
locks_offset)
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
continue)
              (TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (Locking -> TExp Int32
lockingToLock Locking
locking))
      lock_acquired :: TExp Int32
lock_acquired = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
continue
      -- Even the releasing is done with an atomic rather than a
      -- simple write, for memory coherency reasons.
      release_lock :: ImpM lore r Multicore ()
release_lock = do
        TV Int32
old TV Int32 -> TExp Int32 -> ImpM lore r Multicore ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- Locking -> TExp Int32
lockingToLock Locking
locking
        Multicore -> ImpM lore r Multicore ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> ImpM lore r Multicore ())
-> Multicore -> ImpM lore r Multicore ()
forall a b. (a -> b) -> a -> b
$
          AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> AtomicOp -> Multicore
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TExp Int32)
-> VName
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
locks_offset)
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
continue)
              (TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (Locking -> TExp Int32
lockingToUnlock Locking
locking))

  -- Preparing parameters. It is assumed that the caller has already
  -- filled the arr_params. We copy the current value to the
  -- accumulator parameters.
  let ([Param LParamMem]
acc_params, [Param LParamMem]
_arr_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
op
      bind_acc_params :: ImpM lore r op ()
bind_acc_params =
        ImpM lore r op () -> ImpM lore r op ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          String -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"bind lhs" (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [VName]
arrs) (((Param LParamMem, VName) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param LParamMem, VName) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_p, VName
arr) ->
              VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
acc_p) [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket

  let op_body :: ImpM MCMem r op ()
op_body =
        String -> ImpM MCMem r op () -> ImpM MCMem r op ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"execute operation" (ImpM MCMem r op () -> ImpM MCMem r op ())
-> ImpM MCMem r op () -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$
          [Param LParamMem] -> Body MCMem -> ImpM MCMem r op ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
acc_params (Body MCMem -> ImpM MCMem r op ())
-> Body MCMem -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda MCMem
op

      do_hist :: ImpM lore r op ()
do_hist =
        ImpM lore r op () -> ImpM lore r op ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          String -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"update global result" (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            (VName -> SubExp -> ImpM lore r op ())
-> [VName] -> [SubExp] -> ImpM lore r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([TExp Int64] -> VName -> SubExp -> ImpM lore r op ()
forall {lore} {r} {op}.
[TExp Int64] -> VName -> SubExp -> ImpM lore r op ()
writeArray [TExp Int64]
bucket) [VName]
arrs ([SubExp] -> ImpM lore r op ()) -> [SubExp] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> SubExp) -> [Param LParamMem] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
acc_params

  -- While-loop: Try to insert your value
  TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
continue TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
    MulticoreGen ()
forall {lore} {r}. ImpM lore r Multicore ()
try_acquire_lock
    TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless (TExp Int32
lock_acquired TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
      [LParam MCMem] -> MulticoreGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams [LParam MCMem]
[Param LParamMem]
acc_params
      MulticoreGen ()
forall {lore} {r} {op}. ImpM lore r op ()
bind_acc_params
      MulticoreGen ()
forall {r} {op}. ImpM MCMem r op ()
op_body
      MulticoreGen ()
forall {lore} {r} {op}. ImpM lore r op ()
do_hist
      MulticoreGen ()
forall {lore} {r}. ImpM lore r Multicore ()
release_lock
  where
    writeArray :: [TExp Int64] -> VName -> SubExp -> ImpM lore r op ()
writeArray [TExp Int64]
bucket VName
arr SubExp
val = VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int64]
bucket SubExp
val []

atomicUpdateCAS ::
  PrimType ->
  VName ->
  VName ->
  [Imp.TExp Int64] ->
  VName ->
  MulticoreGen () ->
  MulticoreGen ()
atomicUpdateCAS :: PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> MulticoreGen ()
-> MulticoreGen ()
atomicUpdateCAS PrimType
t VName
arr VName
old [TExp Int64]
bucket VName
x MulticoreGen ()
do_op = do
  -- Code generation target:
  --
  -- old = d_his[idx];
  -- do {
  --   assumed = old;
  --   x = do_op(assumed, y);
  --   old = atomicCAS(&d_his[idx], assumed, tmp);
  -- } while(assumed != old);
  TV Int32
run_loop <- String -> TExp Int32 -> ImpM MCMem HostEnv Multicore (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"run_loop" (TExp Int32
0 :: Imp.TExp Int32)
  MulticoreGen () -> MulticoreGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket
  (VName
arr', Space
_a_space, Count Elements (TExp Int64)
bucket_offset) <- VName
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
bucket

  PrimType
bytes <- Int -> MulticoreGen PrimType
toIntegral (Int -> MulticoreGen PrimType) -> Int -> MulticoreGen PrimType
forall a b. (a -> b) -> a -> b
$ PrimType -> Int
primBitSize PrimType
t
  (String
to, String
from) <- Int -> MulticoreGen (String, String)
getBitConvertFunc (Int -> MulticoreGen (String, String))
-> Int -> MulticoreGen (String, String)
forall a b. (a -> b) -> a -> b
$ PrimType -> Int
primBitSize PrimType
t
  -- While-loop: Try to insert your value
  let (PrimExp v -> PrimExp v
toBits, PrimExp v -> PrimExp v
_fromBits) =
        case PrimType
t of
          FloatType FloatType
_ ->
            ( \PrimExp v
v -> String -> [PrimExp v] -> PrimType -> PrimExp v
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
to [PrimExp v
v] PrimType
bytes,
              \PrimExp v
v -> String -> [PrimExp v] -> PrimType -> PrimExp v
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
from [PrimExp v
v] PrimType
t
            )
          PrimType
_ -> (PrimExp v -> PrimExp v
forall a. a -> a
id, PrimExp v -> PrimExp v
forall a. a -> a
id)

  TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
run_loop TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
x VName -> Exp -> MulticoreGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ VName -> PrimType -> Exp
Imp.var VName
old PrimType
t
    MulticoreGen ()
do_op -- Writes result into x
    Multicore -> MulticoreGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> AtomicOp -> Multicore
forall a b. (a -> b) -> a -> b
$
        PrimType
-> VName
-> VName
-> Count Elements (TExp Int32)
-> VName
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
          PrimType
bytes
          VName
old
          VName
arr'
          (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
bucket_offset)
          (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
run_loop)
          (Exp -> Exp
forall {v}. PrimExp v -> PrimExp v
toBits (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t))

-- | Horizontally fission a lambda that models a binary operator.
splitOp :: ASTLore lore => Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp :: forall lore.
ASTLore lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda lore
lam = (SubExp -> Maybe (BinOp, PrimType, VName, VName))
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm ([SubExp] -> Maybe [(BinOp, PrimType, VName, VName)])
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  where
    n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
    splitStm :: SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm (Var VName
res) = do
      Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
        (Stm lore -> Bool) -> [Stm lore] -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res] [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
==) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm lore] -> Maybe (Stm lore)) -> [Stm lore] -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$
          Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
      Int
i <- VName -> SubExp
Var VName
res SubExp -> [SubExp] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
      Param (LParamInfo lore)
xp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
      Param (LParamInfo lore)
yp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
      Prim PrimType
t <- Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec lore)
pe
      (BinOp, PrimType, VName, VName)
-> Maybe (BinOp, PrimType, VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinOp
op, PrimType
t, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp)
    splitStm SubExp
_ = Maybe (BinOp, PrimType, VName, VName)
forall a. Maybe a
Nothing

-- TODO for supporting 8 and 16 bits (and 128)
-- we need a functions for converting to and from bits
getBitConvertFunc :: Int -> MulticoreGen (String, String)
-- getBitConvertFunc 8 = return $ ("to_bits8, from_bits8")
-- getBitConvertFunc 16 = return $ ("to_bits8, from_bits8")
getBitConvertFunc :: Int -> MulticoreGen (String, String)
getBitConvertFunc Int
32 = (String, String) -> MulticoreGen (String, String)
forall (m :: * -> *) a. Monad m => a -> m a
return (String
"to_bits32", String
"from_bits32")
getBitConvertFunc Int
64 = (String, String) -> MulticoreGen (String, String)
forall (m :: * -> *) a. Monad m => a -> m a
return (String
"to_bits64", String
"from_bits64")
getBitConvertFunc Int
b = String -> MulticoreGen (String, String)
forall a. HasCallStack => String -> a
error (String -> MulticoreGen (String, String))
-> String -> MulticoreGen (String, String)
forall a b. (a -> b) -> a -> b
$ String
"number of bytes is not supported " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
b

supportedPrims :: Int -> Bool
supportedPrims :: Int -> Bool
supportedPrims Int
8 = Bool
True
supportedPrims Int
16 = Bool
True
supportedPrims Int
32 = Bool
True
supportedPrims Int
64 = Bool
True
supportedPrims Int
_ = Bool
False

-- Supported bytes lengths by GCC (and clang) compiler
toIntegral :: Int -> MulticoreGen PrimType
toIntegral :: Int -> MulticoreGen PrimType
toIntegral Int
8 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int8
toIntegral Int
16 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int16
toIntegral Int
32 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int32
toIntegral Int
64 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int64
toIntegral Int
b = String -> MulticoreGen PrimType
forall a. HasCallStack => String -> a
error (String -> MulticoreGen PrimType)
-> String -> MulticoreGen PrimType
forall a b. (a -> b) -> a -> b
$ String
"number of bytes is not supported for CAS - " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
b