module Futhark.CodeGen.ImpGen.Multicore.Base
  ( extractAllocations,
    compileThreadResult,
    HostEnv (..),
    AtomicBinOp,
    MulticoreGen,
    decideScheduling,
    decideScheduling',
    groupResultArrays,
    renameSegBinOp,
    freeParams,
    renameHistOpLambda,
    atomicUpdateLocking,
    AtomicUpdate (..),
    Locking (..),
    getSpace,
    getIterationDomain,
    getReturnParams,
    segOpString,
  )
where

import Control.Monad
import Data.Bifunctor
import Data.List (elemIndex, find)
import Data.Maybe
import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Error
import Futhark.IR.MCMem
import Futhark.Transform.Rename
import Futhark.Util (maybeNth)
import Prelude hiding (quot, rem)

-- | Is there an atomic t'BinOp' corresponding to this t'BinOp'?
type AtomicBinOp =
  BinOp ->
  Maybe (VName -> VName -> Imp.Count Imp.Elements (Imp.TExp Int32) -> Imp.Exp -> Imp.AtomicOp)

newtype HostEnv = HostEnv
  {HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp}

type MulticoreGen = ImpM MCMem HostEnv Imp.Multicore

segOpString :: SegOp () MCMem -> MulticoreGen String
segOpString :: SegOp () MCMem -> MulticoreGen String
segOpString SegMap {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"segmap"
segOpString SegRed {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"segred"
segOpString SegScan {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"segscan"
segOpString SegHist {} = String -> MulticoreGen String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"seghist"

toParam :: VName -> TypeBase shape u -> MulticoreGen Imp.Param
toParam :: VName -> TypeBase shape u -> MulticoreGen Param
toParam VName
name (Prim PrimType
pt) = Param -> MulticoreGen Param
forall (m :: * -> *) a. Monad m => a -> m a
return (Param -> MulticoreGen Param) -> Param -> MulticoreGen Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
pt
toParam VName
name (Mem Space
space) = Param -> MulticoreGen Param
forall (m :: * -> *) a. Monad m => a -> m a
return (Param -> MulticoreGen Param) -> Param -> MulticoreGen Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
toParam VName
name Array {} = do
  VarEntry MCMem
name_entry <- VName -> ImpM MCMem HostEnv Multicore (VarEntry MCMem)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry MCMem
name_entry of
    ArrayVar Maybe (Exp MCMem)
_ (ArrayEntry (MemLocation VName
mem [DimSize]
_ IxFun (TExp Int64)
_) PrimType
_) ->
      Param -> MulticoreGen Param
forall (m :: * -> *) a. Monad m => a -> m a
return (Param -> MulticoreGen Param) -> Param -> MulticoreGen Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
mem Space
DefaultSpace
    VarEntry MCMem
_ -> String -> MulticoreGen Param
forall a. HasCallStack => String -> a
error (String -> MulticoreGen Param) -> String -> MulticoreGen Param
forall a b. (a -> b) -> a -> b
$ String
"[toParam] Could not handle array for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Show a => a -> String
show VName
name

getSpace :: SegOp () MCMem -> SegSpace
getSpace :: SegOp () MCMem -> SegSpace
getSpace (SegHist ()
_ SegSpace
space [HistOp MCMem]
_ [Type]
_ KernelBody MCMem
_) = SegSpace
space
getSpace (SegRed ()
_ SegSpace
space [SegBinOp MCMem]
_ [Type]
_ KernelBody MCMem
_) = SegSpace
space
getSpace (SegScan ()
_ SegSpace
space [SegBinOp MCMem]
_ [Type]
_ KernelBody MCMem
_) = SegSpace
space
getSpace (SegMap ()
_ SegSpace
space [Type]
_ KernelBody MCMem
_) = SegSpace
space

getIterationDomain :: SegOp () MCMem -> SegSpace -> MulticoreGen (Imp.TExp Int64)
getIterationDomain :: SegOp () MCMem -> SegSpace -> MulticoreGen (TExp Int64)
getIterationDomain SegMap {} SegSpace
space = do
  let ns :: [DimSize]
ns = ((VName, DimSize) -> DimSize) -> [(VName, DimSize)] -> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName, DimSize) -> DimSize
forall a b. (a, b) -> b
snd ([(VName, DimSize)] -> [DimSize])
-> [(VName, DimSize)] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, DimSize)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
ns
  TExp Int64 -> MulticoreGen (TExp Int64)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64 -> MulticoreGen (TExp Int64))
-> TExp Int64 -> MulticoreGen (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64
getIterationDomain SegOp () MCMem
_ SegSpace
space = do
  let ns :: [DimSize]
ns = ((VName, DimSize) -> DimSize) -> [(VName, DimSize)] -> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName, DimSize) -> DimSize
forall a b. (a, b) -> b
snd ([(VName, DimSize)] -> [DimSize])
-> [(VName, DimSize)] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, DimSize)]
unSegSpace SegSpace
space
      ns_64 :: [TExp Int64]
ns_64 = (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
ns
  case SegSpace -> [(VName, DimSize)]
unSegSpace SegSpace
space of
    [(VName, DimSize)
_] -> TExp Int64 -> MulticoreGen (TExp Int64)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64 -> MulticoreGen (TExp Int64))
-> TExp Int64 -> MulticoreGen (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ns_64
    -- A segmented SegOp is over the segments
    -- so we drop the last dimension, which is
    -- executed sequentially
    [(VName, DimSize)]
_ -> TExp Int64 -> MulticoreGen (TExp Int64)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp Int64 -> MulticoreGen (TExp Int64))
-> TExp Int64 -> MulticoreGen (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
ns_64

-- When the SegRed's return value is a scalar
-- we perform a call by value-result in the segop function
getReturnParams :: Pattern MCMem -> SegOp () MCMem -> MulticoreGen [Imp.Param]
getReturnParams :: Pattern MCMem -> SegOp () MCMem -> MulticoreGen [Param]
getReturnParams Pattern MCMem
pat SegRed {} = do
  let retvals :: [VName]
retvals = (PatElemT LetDecMem -> VName) -> [PatElemT LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName ([PatElemT LetDecMem] -> [VName])
-> [PatElemT LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ PatternT LetDecMem -> [PatElemT LetDecMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern MCMem
PatternT LetDecMem
pat
  [Type]
retvals_ts <- (VName -> ImpM MCMem HostEnv Multicore Type)
-> [VName] -> ImpM MCMem HostEnv Multicore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ImpM MCMem HostEnv Multicore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
retvals
  (VName -> Type -> MulticoreGen Param)
-> [VName] -> [Type] -> MulticoreGen [Param]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Type -> MulticoreGen Param
forall shape u. VName -> TypeBase shape u -> MulticoreGen Param
toParam [VName]
retvals [Type]
retvals_ts
getReturnParams Pattern MCMem
_ SegOp () MCMem
_ = [Param] -> MulticoreGen [Param]
forall (m :: * -> *) a. Monad m => a -> m a
return [Param]
forall a. Monoid a => a
mempty

renameSegBinOp :: [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp :: [SegBinOp MCMem] -> MulticoreGen [SegBinOp MCMem]
renameSegBinOp [SegBinOp MCMem]
segbinops =
  [SegBinOp MCMem]
-> (SegBinOp MCMem
    -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
-> MulticoreGen [SegBinOp MCMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
segbinops ((SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
 -> MulticoreGen [SegBinOp MCMem])
-> (SegBinOp MCMem
    -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
-> MulticoreGen [SegBinOp MCMem]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
comm Lambda MCMem
lam [DimSize]
ne Shape
shape) -> do
    Lambda MCMem
lam' <- Lambda MCMem -> ImpM MCMem HostEnv Multicore (Lambda MCMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda MCMem
lam
    SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem))
-> SegBinOp MCMem -> ImpM MCMem HostEnv Multicore (SegBinOp MCMem)
forall a b. (a -> b) -> a -> b
$ Commutativity
-> Lambda MCMem -> [DimSize] -> Shape -> SegBinOp MCMem
forall lore.
Commutativity -> Lambda lore -> [DimSize] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda MCMem
lam' [DimSize]
ne Shape
shape

compileThreadResult ::
  SegSpace ->
  PatElem MCMem ->
  KernelResult ->
  MulticoreGen ()
compileThreadResult :: SegSpace -> PatElem MCMem -> KernelResult -> MulticoreGen ()
compileThreadResult SegSpace
space PatElem MCMem
pe (Returns ResultManifest
_ DimSize
what) = do
  let is :: [TExp Int64]
is = ((VName, DimSize) -> TExp Int64)
-> [(VName, DimSize)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64)
-> ((VName, DimSize) -> VName) -> (VName, DimSize) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, DimSize) -> VName
forall a b. (a, b) -> a
fst) ([(VName, DimSize)] -> [TExp Int64])
-> [(VName, DimSize)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, DimSize)]
unSegSpace SegSpace
space
  VName -> [TExp Int64] -> DimSize -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElem MCMem
PatElemT LetDecMem
pe) [TExp Int64]
is DimSize
what []
compileThreadResult SegSpace
_ PatElem MCMem
_ ConcatReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: ConcatReturn unhandled."
compileThreadResult SegSpace
_ PatElem MCMem
_ WriteReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: WriteReturns unhandled."
compileThreadResult SegSpace
_ PatElem MCMem
_ TileReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: TileReturns unhandled."
compileThreadResult SegSpace
_ PatElem MCMem
_ RegTileReturns {} =
  String -> MulticoreGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: RegTileReturns unhandled."

freeVariables :: Imp.Code -> [VName] -> [VName]
freeVariables :: Code -> [VName] -> [VName]
freeVariables Code
code [VName]
names =
  Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Code -> Names
forall a. FreeIn a => a -> Names
freeIn Code
code Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
names

freeParams :: Imp.Code -> [VName] -> MulticoreGen [Imp.Param]
freeParams :: Code -> [VName] -> MulticoreGen [Param]
freeParams Code
code [VName]
names = do
  let freeVars :: [VName]
freeVars = Code -> [VName] -> [VName]
freeVariables Code
code [VName]
names
  [Type]
ts <- (VName -> ImpM MCMem HostEnv Multicore Type)
-> [VName] -> ImpM MCMem HostEnv Multicore [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ImpM MCMem HostEnv Multicore Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType [VName]
freeVars
  (VName -> Type -> MulticoreGen Param)
-> [VName] -> [Type] -> MulticoreGen [Param]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Type -> MulticoreGen Param
forall shape u. VName -> TypeBase shape u -> MulticoreGen Param
toParam [VName]
freeVars [Type]
ts

-- | Arrays for storing group results shared between threads
groupResultArrays ::
  String ->
  SubExp ->
  [SegBinOp MCMem] ->
  MulticoreGen [[VName]]
groupResultArrays :: String -> DimSize -> [SegBinOp MCMem] -> MulticoreGen [[VName]]
groupResultArrays String
s DimSize
num_threads [SegBinOp MCMem]
reds =
  [SegBinOp MCMem]
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SegBinOp MCMem]
reds ((SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
 -> MulticoreGen [[VName]])
-> (SegBinOp MCMem -> ImpM MCMem HostEnv Multicore [VName])
-> MulticoreGen [[VName]]
forall a b. (a -> b) -> a -> b
$ \(SegBinOp Commutativity
_ Lambda MCMem
lam [DimSize]
_ Shape
shape) ->
    [Type]
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Lambda MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda MCMem
lam) ((Type -> ImpM MCMem HostEnv Multicore VName)
 -> ImpM MCMem HostEnv Multicore [VName])
-> (Type -> ImpM MCMem HostEnv Multicore VName)
-> ImpM MCMem HostEnv Multicore [VName]
forall a b. (a -> b) -> a -> b
$ \Type
t -> do
      let pt :: PrimType
pt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
          full_shape :: Shape
full_shape = [DimSize] -> Shape
forall d. [d] -> ShapeBase d
Shape [DimSize
num_threads] Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
      String
-> PrimType -> Shape -> Space -> ImpM MCMem HostEnv Multicore VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
s PrimType
pt Shape
full_shape Space
DefaultSpace

isLoadBalanced :: Imp.Code -> Bool
isLoadBalanced :: Code -> Bool
isLoadBalanced (Code
a Imp.:>>: Code
b) = Code -> Bool
isLoadBalanced Code
a Bool -> Bool -> Bool
&& Code -> Bool
isLoadBalanced Code
b
isLoadBalanced (Imp.For VName
_ Exp
_ Code
a) = Code -> Bool
isLoadBalanced Code
a
isLoadBalanced (Imp.If TExp Bool
_ Code
a Code
b) = Code -> Bool
isLoadBalanced Code
a Bool -> Bool -> Bool
&& Code -> Bool
isLoadBalanced Code
b
isLoadBalanced (Imp.Comment String
_ Code
a) = Code -> Bool
isLoadBalanced Code
a
isLoadBalanced Imp.While {} = Bool
False
isLoadBalanced (Imp.Op (Imp.ParLoop String
_ VName
_ Code
_ Code
code Code
_ [Param]
_ VName
_)) = Code -> Bool
isLoadBalanced Code
code
isLoadBalanced Code
_ = Bool
True

segBinOpComm' :: [SegBinOp lore] -> Commutativity
segBinOpComm' :: [SegBinOp lore] -> Commutativity
segBinOpComm' = [Commutativity] -> Commutativity
forall a. Monoid a => [a] -> a
mconcat ([Commutativity] -> Commutativity)
-> ([SegBinOp lore] -> [Commutativity])
-> [SegBinOp lore]
-> Commutativity
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SegBinOp lore -> Commutativity)
-> [SegBinOp lore] -> [Commutativity]
forall a b. (a -> b) -> [a] -> [b]
map SegBinOp lore -> Commutativity
forall lore. SegBinOp lore -> Commutativity
segBinOpComm

decideScheduling' :: SegOp () lore -> Imp.Code -> Imp.Scheduling
decideScheduling' :: SegOp () lore -> Code -> Scheduling
decideScheduling' SegHist {} Code
_ = Scheduling
Imp.Static
decideScheduling' SegScan {} Code
_ = Scheduling
Imp.Static
decideScheduling' (SegRed ()
_ SegSpace
_ [SegBinOp lore]
reds [Type]
_ KernelBody lore
_) Code
code =
  case [SegBinOp lore] -> Commutativity
forall lore. [SegBinOp lore] -> Commutativity
segBinOpComm' [SegBinOp lore]
reds of
    Commutativity
Commutative -> Code -> Scheduling
decideScheduling Code
code
    Commutativity
Noncommutative -> Scheduling
Imp.Static
decideScheduling' SegMap {} Code
code = Code -> Scheduling
decideScheduling Code
code

decideScheduling :: Imp.Code -> Imp.Scheduling
decideScheduling :: Code -> Scheduling
decideScheduling Code
code =
  if Code -> Bool
isLoadBalanced Code
code
    then Scheduling
Imp.Static
    else Scheduling
Imp.Dynamic

-- | Try to extract invariant allocations.  If we assume that the
-- given 'Imp.Code' is the body of a 'SegOp', then it is always safe
-- to move the immediate allocations to the prebody.
extractAllocations :: Imp.Code -> (Imp.Code, Imp.Code)
extractAllocations :: Code -> (Code, Code)
extractAllocations Code
segop_code = Code -> (Code, Code)
forall a. Code -> (Code a, Code)
f Code
segop_code
  where
    declared :: Names
declared = Code -> Names
forall a. Code a -> Names
Imp.declaredIn Code
segop_code
    f :: Code -> (Code a, Code)
f (Imp.DeclareMem VName
name Space
space) =
      -- Hoisting declarations out is always safe.
      (VName -> Space -> Code a
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name Space
space, Code
forall a. Monoid a => a
mempty)
    f (Imp.Allocate VName
name Count Bytes (TExp Int64)
size Space
space)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Count Bytes (TExp Int64) -> Names
forall a. FreeIn a => a -> Names
freeIn Count Bytes (TExp Int64)
size Names -> Names -> Bool
`namesIntersect` Names
declared =
        (VName -> Count Bytes (TExp Int64) -> Space -> Code a
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name Count Bytes (TExp Int64)
size Space
space, Code
forall a. Monoid a => a
mempty)
    f (Code
x Imp.:>>: Code
y) = Code -> (Code a, Code)
f Code
x (Code a, Code) -> (Code a, Code) -> (Code a, Code)
forall a. Semigroup a => a -> a -> a
<> Code -> (Code a, Code)
f Code
y
    f (Imp.While TExp Bool
cond Code
body) =
      (Code a
forall a. Monoid a => a
mempty, TExp Bool -> Code -> Code
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code
body)
    f (Imp.For VName
i Exp
bound Code
body) =
      (Code a
forall a. Monoid a => a
mempty, VName -> Exp -> Code -> Code
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code
body)
    f (Imp.Comment String
s Code
code) =
      (Code -> Code) -> (Code a, Code) -> (Code a, Code)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (String -> Code -> Code
forall a. String -> Code a -> Code a
Imp.Comment String
s) (Code -> (Code a, Code)
f Code
code)
    f Imp.Free {} =
      (Code a, Code)
forall a. Monoid a => a
mempty
    f (Imp.If TExp Bool
cond Code
tcode Code
fcode) =
      let (Code a
ta, Code
tcode') = Code -> (Code a, Code)
f Code
tcode
          (Code a
fa, Code
fcode') = Code -> (Code a, Code)
f Code
fcode
       in (Code a
ta Code a -> Code a -> Code a
forall a. Semigroup a => a -> a -> a
<> Code a
fa, TExp Bool -> Code -> Code -> Code
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code
tcode' Code
fcode')
    f (Imp.Op (Imp.ParLoop String
s VName
i Code
prebody Code
body Code
postbody [Param]
free VName
info)) =
      let (Code
body_allocs, Code
body') = Code -> (Code, Code)
extractAllocations Code
body
          (Code a
free_allocs, Code
here_allocs) = Code -> (Code a, Code)
f Code
body_allocs
          free' :: [Param]
free' =
            (Param -> Bool) -> [Param] -> [Param]
forall a. (a -> Bool) -> [a] -> [a]
filter
              ( Bool -> Bool
not
                  (Bool -> Bool) -> (Param -> Bool) -> Param -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names -> Bool
`nameIn` Code -> Names
forall a. Code a -> Names
Imp.declaredIn Code
body_allocs)
                  (VName -> Bool) -> (Param -> VName) -> Param -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param -> VName
Imp.paramName
              )
              [Param]
free
       in ( Code a
free_allocs,
            Code
here_allocs
              Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Multicore -> Code
forall a. a -> Code a
Imp.Op (String
-> VName -> Code -> Code -> Code -> [Param] -> VName -> Multicore
Imp.ParLoop String
s VName
i Code
prebody Code
body' Code
postbody [Param]
free' VName
info)
          )
    f Code
code =
      (Code a
forall a. Monoid a => a
mempty, Code
code)

-------------------------------
------- SegHist helpers -------
-------------------------------
renameHistOpLambda :: [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda :: [HistOp MCMem] -> MulticoreGen [HistOp MCMem]
renameHistOpLambda [HistOp MCMem]
hist_ops =
  [HistOp MCMem]
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
-> MulticoreGen [HistOp MCMem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp MCMem]
hist_ops ((HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
 -> MulticoreGen [HistOp MCMem])
-> (HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
-> MulticoreGen [HistOp MCMem]
forall a b. (a -> b) -> a -> b
$ \(HistOp DimSize
w DimSize
rf [VName]
dest [DimSize]
neutral Shape
shape Lambda MCMem
lam) -> do
    Lambda MCMem
lam' <- Lambda MCMem -> ImpM MCMem HostEnv Multicore (Lambda MCMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda MCMem
lam
    HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem))
-> HistOp MCMem -> ImpM MCMem HostEnv Multicore (HistOp MCMem)
forall a b. (a -> b) -> a -> b
$ DimSize
-> DimSize
-> [VName]
-> [DimSize]
-> Shape
-> Lambda MCMem
-> HistOp MCMem
forall lore.
DimSize
-> DimSize
-> [VName]
-> [DimSize]
-> Shape
-> Lambda lore
-> HistOp lore
HistOp DimSize
w DimSize
rf [VName]
dest [DimSize]
neutral Shape
shape Lambda MCMem
lam'

-- | Locking strategy used for an atomic update.
data Locking = Locking
  { -- | Array containing the lock.
    Locking -> VName
lockingArray :: VName,
    -- | Value for us to consider the lock free.
    Locking -> TExp Int32
lockingIsUnlocked :: Imp.TExp Int32,
    -- | What to write when we lock it.
    Locking -> TExp Int32
lockingToLock :: Imp.TExp Int32,
    -- | What to write when we unlock it.
    Locking -> TExp Int32
lockingToUnlock :: Imp.TExp Int32,
    -- | A transformation from the logical lock index to the
    -- physical position in the array.  This can also be used
    -- to make the lock array smaller.
    Locking -> [TExp Int64] -> [TExp Int64]
lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64]
  }

-- | A function for generating code for an atomic update.  Assumes
-- that the bucket is in-bounds.
type DoAtomicUpdate lore r =
  [VName] -> [Imp.TExp Int64] -> MulticoreGen ()

-- | The mechanism that will be used for performing the atomic update.
-- Approximates how efficient it will be.  Ordered from most to least
-- efficient.
data AtomicUpdate lore r
  = AtomicPrim (DoAtomicUpdate lore r)
  | -- | Can be done by efficient swaps.
    AtomicCAS (DoAtomicUpdate lore r)
  | -- | Requires explicit locking.
    AtomicLocking (Locking -> DoAtomicUpdate lore r)

atomicUpdateLocking ::
  AtomicBinOp ->
  Lambda MCMem ->
  AtomicUpdate MCMem ()
atomicUpdateLocking :: AtomicBinOp -> Lambda MCMem -> AtomicUpdate MCMem ()
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda MCMem
lam
  | Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda MCMem -> Maybe [(BinOp, PrimType, VName, VName)]
forall lore.
ASTLore lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda MCMem
lam,
    ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> Int -> Bool
supportedPrims (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ PrimType -> Int
primBitSize PrimType
t) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
    [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall (t :: * -> *) b c d lore r.
Foldable t =>
t (BinOp, b, c, d)
-> DoAtomicUpdate MCMem () -> AtomicUpdate lore r
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ())
-> DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall a b. (a -> b) -> a -> b
$ \[VName]
arrs [TExp Int64]
bucket ->
      -- If the operator is a vectorised binary operator on 32-bit values,
      -- we can use a particularly efficient implementation. If the
      -- operator has an atomic implementation we use that, otherwise it
      -- is still a binary operator which can be implemented by atomic
      -- compare-and-swap if 32 bits.
      [(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> MulticoreGen ())
-> MulticoreGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> MulticoreGen ())
 -> MulticoreGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> MulticoreGen ())
-> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do
        -- Common variables.
        TV Any
old <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
t

        (VName
arr', Space
_a_space, Count Elements (TExp Int64)
bucket_offset) <- VName
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
a [TExp Int64]
bucket

        case VName
-> VName
-> Count Elements (TExp Int32)
-> BinOp
-> Maybe (Exp -> Multicore)
opHasAtomicSupport (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) VName
arr' (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
bucket_offset) BinOp
op of
          Just Exp -> Multicore
f -> Multicore -> MulticoreGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> Multicore
f (Exp -> Multicore) -> Exp -> Multicore
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
y PrimType
t
          Maybe (Exp -> Multicore)
Nothing ->
            PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> MulticoreGen ()
-> MulticoreGen ()
atomicUpdateCAS PrimType
t VName
a (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TExp Int64]
bucket VName
x (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
              VName
x VName -> Exp -> MulticoreGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t) (VName -> PrimType -> Exp
Imp.var VName
y PrimType
t)
  where
    opHasAtomicSupport :: VName
-> VName
-> Count Elements (TExp Int32)
-> BinOp
-> Maybe (Exp -> Multicore)
opHasAtomicSupport VName
old VName
arr' Count Elements (TExp Int32)
bucket' BinOp
bop = do
      let atomic :: (VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp)
-> a -> Multicore
atomic VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp
f = AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> (a -> AtomicOp) -> a -> Multicore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp
f VName
old VName
arr' Count Elements (TExp Int32)
bucket'
      (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Exp -> Multicore
forall a.
(VName -> VName -> Count Elements (TExp Int32) -> a -> AtomicOp)
-> a -> Multicore
atomic ((VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
 -> Exp -> Multicore)
-> Maybe
     (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Maybe (Exp -> Multicore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop

    primOrCas :: t (BinOp, b, c, d)
-> DoAtomicUpdate MCMem () -> AtomicUpdate lore r
primOrCas t (BinOp, b, c, d)
ops
      | ((BinOp, b, c, d) -> Bool) -> t (BinOp, b, c, d) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, b, c, d) -> Bool
forall b c d. (BinOp, b, c, d) -> Bool
isPrim t (BinOp, b, c, d)
ops = DoAtomicUpdate MCMem () -> AtomicUpdate lore r
forall lore r. DoAtomicUpdate MCMem () -> AtomicUpdate lore r
AtomicPrim
      | Bool
otherwise = DoAtomicUpdate MCMem () -> AtomicUpdate lore r
forall lore r. DoAtomicUpdate MCMem () -> AtomicUpdate lore r
AtomicCAS

    isPrim :: (BinOp, b, c, d) -> Bool
isPrim (BinOp
op, b
_, c
_, d
_) = Maybe
  (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe
   (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
 -> Bool)
-> Maybe
     (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op
atomicUpdateLocking AtomicBinOp
_ Lambda MCMem
op
  | [Prim PrimType
t] <- Lambda MCMem -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda MCMem
op,
    [LParam MCMem
xp, LParam MCMem
_] <- Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
op,
    Int -> Bool
supportedPrims (PrimType -> Int
primBitSize PrimType
t) = DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall lore r. DoAtomicUpdate MCMem () -> AtomicUpdate lore r
AtomicCAS (DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ())
-> DoAtomicUpdate MCMem () -> AtomicUpdate MCMem ()
forall a b. (a -> b) -> a -> b
$ \[VName
arr] [TExp Int64]
bucket -> do
    TV Any
old <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
t
    PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> MulticoreGen ()
-> MulticoreGen ()
atomicUpdateCAS PrimType
t VName
arr (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TExp Int64]
bucket (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName LParam MCMem
Param LetDecMem
xp) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      [Param LetDecMem] -> Body MCMem -> MulticoreGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [LParam MCMem
Param LetDecMem
xp] (Body MCMem -> MulticoreGen ()) -> Body MCMem -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda MCMem
op
atomicUpdateLocking AtomicBinOp
_ Lambda MCMem
op = (Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate MCMem ()
forall lore r.
(Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate lore r
AtomicLocking ((Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate MCMem ())
-> (Locking -> DoAtomicUpdate MCMem ()) -> AtomicUpdate MCMem ()
forall a b. (a -> b) -> a -> b
$ \Locking
locking [VName]
arrs [TExp Int64]
bucket -> do
  TV Int32
old <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
int32
  TV Int32
continue <- String
-> PrimType
-> TExp Int32
-> ImpM MCMem HostEnv Multicore (TV Int32)
forall t lore r op.
String -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol String
"continue" PrimType
int32 (TExp Int32
0 :: Imp.TExp Int32)

  -- Correctly index into locks.
  (VName
locks', Space
_locks_space, Count Elements (TExp Int64)
locks_offset) <-
    VName
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([TExp Int64]
 -> ImpM
      MCMem
      HostEnv
      Multicore
      (VName, Space, Count Elements (TExp Int64)))
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ Locking -> [TExp Int64] -> [TExp Int64]
lockingMapping Locking
locking [TExp Int64]
bucket

  -- Critical section
  let try_acquire_lock :: ImpM lore r Multicore ()
try_acquire_lock = do
        TV Int32
old TV Int32 -> TExp Int32 -> ImpM lore r Multicore ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- (TExp Int32
0 :: Imp.TExp Int32)
        Multicore -> ImpM lore r Multicore ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> ImpM lore r Multicore ())
-> Multicore -> ImpM lore r Multicore ()
forall a b. (a -> b) -> a -> b
$
          AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> AtomicOp -> Multicore
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TExp Int32)
-> VName
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
locks_offset)
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
continue)
              (TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (Locking -> TExp Int32
lockingToLock Locking
locking))
      lock_acquired :: TExp Int32
lock_acquired = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
continue
      -- Even the releasing is done with an atomic rather than a
      -- simple write, for memory coherency reasons.
      release_lock :: ImpM lore r Multicore ()
release_lock = do
        TV Int32
old TV Int32 -> TExp Int32 -> ImpM lore r Multicore ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- Locking -> TExp Int32
lockingToLock Locking
locking
        Multicore -> ImpM lore r Multicore ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> ImpM lore r Multicore ())
-> Multicore -> ImpM lore r Multicore ()
forall a b. (a -> b) -> a -> b
$
          AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> AtomicOp -> Multicore
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TExp Int32)
-> VName
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
locks_offset)
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
continue)
              (TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (Locking -> TExp Int32
lockingToUnlock Locking
locking))

  -- Preparing parameters. It is assumed that the caller has already
  -- filled the arr_params. We copy the current value to the
  -- accumulator parameters.
  let ([Param LetDecMem]
acc_params, [Param LetDecMem]
_arr_params) = Int -> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem]))
-> [Param LetDecMem] -> ([Param LetDecMem], [Param LetDecMem])
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> [LParam MCMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda MCMem
op
      bind_acc_params :: ImpM lore r op ()
bind_acc_params =
        ImpM lore r op () -> ImpM lore r op ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          String -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"bind lhs" (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            [(Param LetDecMem, VName)]
-> ((Param LetDecMem, VName) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [VName] -> [(Param LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
acc_params [VName]
arrs) (((Param LetDecMem, VName) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param LetDecMem, VName) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
acc_p, VName
arr) ->
              VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
acc_p) [] (VName -> DimSize
Var VName
arr) [TExp Int64]
bucket

  let op_body :: ImpM MCMem r op ()
op_body =
        String -> ImpM MCMem r op () -> ImpM MCMem r op ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"execute operation" (ImpM MCMem r op () -> ImpM MCMem r op ())
-> ImpM MCMem r op () -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$
          [Param LetDecMem] -> Body MCMem -> ImpM MCMem r op ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LetDecMem]
acc_params (Body MCMem -> ImpM MCMem r op ())
-> Body MCMem -> ImpM MCMem r op ()
forall a b. (a -> b) -> a -> b
$ Lambda MCMem -> Body MCMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda MCMem
op

      do_hist :: ImpM lore r op ()
do_hist =
        ImpM lore r op () -> ImpM lore r op ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          String -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"update global result" (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            (VName -> DimSize -> ImpM lore r op ())
-> [VName] -> [DimSize] -> ImpM lore r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([TExp Int64] -> VName -> DimSize -> ImpM lore r op ()
forall lore r op.
[TExp Int64] -> VName -> DimSize -> ImpM lore r op ()
writeArray [TExp Int64]
bucket) [VName]
arrs ([DimSize] -> ImpM lore r op ()) -> [DimSize] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (Param LetDecMem -> DimSize) -> [Param LetDecMem] -> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> DimSize
Var (VName -> DimSize)
-> (Param LetDecMem -> VName) -> Param LetDecMem -> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName) [Param LetDecMem]
acc_params

  -- While-loop: Try to insert your value
  TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
continue TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
    MulticoreGen ()
forall lore r. ImpM lore r Multicore ()
try_acquire_lock
    TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless (TExp Int32
lock_acquired TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
      [LParam MCMem] -> MulticoreGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams [LParam MCMem]
[Param LetDecMem]
acc_params
      MulticoreGen ()
forall lore r op. ImpM lore r op ()
bind_acc_params
      MulticoreGen ()
forall r op. ImpM MCMem r op ()
op_body
      MulticoreGen ()
forall lore r op. ImpM lore r op ()
do_hist
      MulticoreGen ()
forall lore r. ImpM lore r Multicore ()
release_lock
  where
    writeArray :: [TExp Int64] -> VName -> DimSize -> ImpM lore r op ()
writeArray [TExp Int64]
bucket VName
arr DimSize
val = VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int64]
bucket DimSize
val []

atomicUpdateCAS ::
  PrimType ->
  VName ->
  VName ->
  [Imp.TExp Int64] ->
  VName ->
  MulticoreGen () ->
  MulticoreGen ()
atomicUpdateCAS :: PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> MulticoreGen ()
-> MulticoreGen ()
atomicUpdateCAS PrimType
t VName
arr VName
old [TExp Int64]
bucket VName
x MulticoreGen ()
do_op = do
  -- Code generation target:
  --
  -- old = d_his[idx];
  -- do {
  --   assumed = old;
  --   x = do_op(assumed, y);
  --   old = atomicCAS(&d_his[idx], assumed, tmp);
  -- } while(assumed != old);
  TV Int32
run_loop <- String -> TExp Int32 -> ImpM MCMem HostEnv Multicore (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"run_loop" (TExp Int32
0 :: Imp.TExp Int32)
  MulticoreGen () -> MulticoreGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [TExp Int64] -> DimSize -> [TExp Int64] -> MulticoreGen ()
forall lore r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
old [] (VName -> DimSize
Var VName
arr) [TExp Int64]
bucket
  (VName
arr', Space
_a_space, Count Elements (TExp Int64)
bucket_offset) <- VName
-> [TExp Int64]
-> ImpM
     MCMem HostEnv Multicore (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
bucket

  PrimType
bytes <- Int -> MulticoreGen PrimType
toIntegral (Int -> MulticoreGen PrimType) -> Int -> MulticoreGen PrimType
forall a b. (a -> b) -> a -> b
$ PrimType -> Int
primBitSize PrimType
t
  (String
to, String
from) <- Int -> MulticoreGen (String, String)
getBitConvertFunc (Int -> MulticoreGen (String, String))
-> Int -> MulticoreGen (String, String)
forall a b. (a -> b) -> a -> b
$ PrimType -> Int
primBitSize PrimType
t
  -- While-loop: Try to insert your value
  let (PrimExp v -> PrimExp v
toBits, PrimExp v -> PrimExp v
_fromBits) =
        case PrimType
t of
          FloatType FloatType
_ ->
            ( \PrimExp v
v -> String -> [PrimExp v] -> PrimType -> PrimExp v
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
to [PrimExp v
v] PrimType
bytes,
              \PrimExp v
v -> String -> [PrimExp v] -> PrimType -> PrimExp v
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
from [PrimExp v
v] PrimType
t
            )
          PrimType
_ -> (PrimExp v -> PrimExp v
forall a. a -> a
id, PrimExp v -> PrimExp v
forall a. a -> a
id)

  TExp Bool -> MulticoreGen () -> MulticoreGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
run_loop TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (MulticoreGen () -> MulticoreGen ())
-> MulticoreGen () -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
x VName -> Exp -> MulticoreGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ VName -> PrimType -> Exp
Imp.var VName
old PrimType
t
    MulticoreGen ()
do_op -- Writes result into x
    Multicore -> MulticoreGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (Multicore -> MulticoreGen ()) -> Multicore -> MulticoreGen ()
forall a b. (a -> b) -> a -> b
$
      AtomicOp -> Multicore
Imp.Atomic (AtomicOp -> Multicore) -> AtomicOp -> Multicore
forall a b. (a -> b) -> a -> b
$
        PrimType
-> VName
-> VName
-> Count Elements (TExp Int32)
-> VName
-> Exp
-> AtomicOp
Imp.AtomicCmpXchg
          PrimType
bytes
          VName
old
          VName
arr'
          (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Count Elements (TExp Int64)
bucket_offset)
          (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
run_loop)
          (Exp -> Exp
forall v. PrimExp v -> PrimExp v
toBits (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t))

-- | Horizontally fission a lambda that models a binary operator.
splitOp :: ASTLore lore => Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp :: Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda lore
lam = (DimSize -> Maybe (BinOp, PrimType, VName, VName))
-> [DimSize] -> Maybe [(BinOp, PrimType, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> Maybe (BinOp, PrimType, VName, VName)
splitStm ([DimSize] -> Maybe [(BinOp, PrimType, VName, VName)])
-> [DimSize] -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [DimSize]
forall lore. BodyT lore -> [DimSize]
bodyResult (BodyT lore -> [DimSize]) -> BodyT lore -> [DimSize]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  where
    n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
    splitStm :: DimSize -> Maybe (BinOp, PrimType, VName, VName)
splitStm (Var VName
res) = do
      Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
        (Stm lore -> Bool) -> [Stm lore] -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res] [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
==) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm lore] -> Maybe (Stm lore)) -> [Stm lore] -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$
          Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
      Int
i <- VName -> DimSize
Var VName
res DimSize -> [DimSize] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` BodyT lore -> [DimSize]
forall lore. BodyT lore -> [DimSize]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
      Param (LParamInfo lore)
xp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
      Param (LParamInfo lore)
yp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
      Prim PrimType
t <- Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec lore)
pe
      (BinOp, PrimType, VName, VName)
-> Maybe (BinOp, PrimType, VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinOp
op, PrimType
t, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp)
    splitStm DimSize
_ = Maybe (BinOp, PrimType, VName, VName)
forall a. Maybe a
Nothing

-- TODO for supporting 8 and 16 bits (and 128)
-- we need a functions for converting to and from bits
getBitConvertFunc :: Int -> MulticoreGen (String, String)
-- getBitConvertFunc 8 = return $ ("to_bits8, from_bits8")
-- getBitConvertFunc 16 = return $ ("to_bits8, from_bits8")
getBitConvertFunc :: Int -> MulticoreGen (String, String)
getBitConvertFunc Int
32 = (String, String) -> MulticoreGen (String, String)
forall (m :: * -> *) a. Monad m => a -> m a
return (String
"to_bits32", String
"from_bits32")
getBitConvertFunc Int
64 = (String, String) -> MulticoreGen (String, String)
forall (m :: * -> *) a. Monad m => a -> m a
return (String
"to_bits64", String
"from_bits64")
getBitConvertFunc Int
b = String -> MulticoreGen (String, String)
forall a. HasCallStack => String -> a
error (String -> MulticoreGen (String, String))
-> String -> MulticoreGen (String, String)
forall a b. (a -> b) -> a -> b
$ String
"number of bytes is not supported " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
b

supportedPrims :: Int -> Bool
supportedPrims :: Int -> Bool
supportedPrims Int
8 = Bool
True
supportedPrims Int
16 = Bool
True
supportedPrims Int
32 = Bool
True
supportedPrims Int
64 = Bool
True
supportedPrims Int
_ = Bool
False

-- Supported bytes lengths by GCC (and clang) compiler
toIntegral :: Int -> MulticoreGen PrimType
toIntegral :: Int -> MulticoreGen PrimType
toIntegral Int
8 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int8
toIntegral Int
16 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int16
toIntegral Int
32 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int32
toIntegral Int
64 = PrimType -> MulticoreGen PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
int64
toIntegral Int
b = String -> MulticoreGen PrimType
forall a. HasCallStack => String -> a
error (String -> MulticoreGen PrimType)
-> String -> MulticoreGen PrimType
forall a b. (a -> b) -> a -> b
$ String
"number of bytes is not supported for CAS - " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
b