{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExplicitAllocations.SegOp
       ( allocInKernelBody
       , allocInBinOpLambda
       )
where

import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.KernelsMem
import Futhark.Pass.ExplicitAllocations

allocInKernelBody :: Allocable fromlore tolore =>
                     KernelBody fromlore
                  -> AllocM fromlore tolore (KernelBody tolore)
allocInKernelBody :: KernelBody fromlore -> AllocM fromlore tolore (KernelBody tolore)
allocInKernelBody (KernelBody () Stms fromlore
stms [KernelResult]
res) =
  Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore (KernelBody tolore))
-> AllocM fromlore tolore (KernelBody tolore)
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms Stms fromlore
stms ((Stms tolore -> AllocM fromlore tolore (KernelBody tolore))
 -> AllocM fromlore tolore (KernelBody tolore))
-> (Stms tolore -> AllocM fromlore tolore (KernelBody tolore))
-> AllocM fromlore tolore (KernelBody tolore)
forall a b. (a -> b) -> a -> b
$ \Stms tolore
stms' -> KernelBody tolore -> AllocM fromlore tolore (KernelBody tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody tolore -> AllocM fromlore tolore (KernelBody tolore))
-> KernelBody tolore -> AllocM fromlore tolore (KernelBody tolore)
forall a b. (a -> b) -> a -> b
$ BodyDec tolore
-> Stms tolore -> [KernelResult] -> KernelBody tolore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms tolore
stms' [KernelResult]
res

allocInLambda :: Allocable fromlore tolore =>
                 [LParam tolore] -> Body fromlore -> [Type]
              -> AllocM fromlore tolore (Lambda tolore)
allocInLambda :: [LParam tolore]
-> Body fromlore
-> [Type]
-> AllocM fromlore tolore (Lambda tolore)
allocInLambda [LParam tolore]
params Body fromlore
body [Type]
rettype = do
  BodyT tolore
body' <- Scope tolore
-> AllocM fromlore tolore (BodyT tolore)
-> AllocM fromlore tolore (BodyT tolore)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param LParamMem] -> Scope tolore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [LParam tolore]
[Param LParamMem]
params) (AllocM fromlore tolore (BodyT tolore)
 -> AllocM fromlore tolore (BodyT tolore))
-> AllocM fromlore tolore (BodyT tolore)
-> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$
           Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore (BodyT tolore))
-> AllocM fromlore tolore (BodyT tolore)
forall fromlore tolore a.
Allocable fromlore tolore =>
Stms fromlore
-> (Stms tolore -> AllocM fromlore tolore a)
-> AllocM fromlore tolore a
allocInStms (Body fromlore -> Stms fromlore
forall lore. BodyT lore -> Stms lore
bodyStms Body fromlore
body) ((Stms tolore -> AllocM fromlore tolore (BodyT tolore))
 -> AllocM fromlore tolore (BodyT tolore))
-> (Stms tolore -> AllocM fromlore tolore (BodyT tolore))
-> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ \Stms tolore
bnds' ->
           BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT tolore -> AllocM fromlore tolore (BodyT tolore))
-> BodyT tolore -> AllocM fromlore tolore (BodyT tolore)
forall a b. (a -> b) -> a -> b
$ BodyDec tolore -> Stms tolore -> Result -> BodyT tolore
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms tolore
bnds' (Result -> BodyT tolore) -> Result -> BodyT tolore
forall a b. (a -> b) -> a -> b
$ Body fromlore -> Result
forall lore. BodyT lore -> Result
bodyResult Body fromlore
body
  Lambda tolore -> AllocM fromlore tolore (Lambda tolore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda tolore -> AllocM fromlore tolore (Lambda tolore))
-> Lambda tolore -> AllocM fromlore tolore (Lambda tolore)
forall a b. (a -> b) -> a -> b
$ [LParam tolore] -> BodyT tolore -> [Type] -> Lambda tolore
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam tolore]
params BodyT tolore
body' [Type]
rettype

allocInBinOpParams :: Allocable fromlore tolore =>
                      SubExp
                   -> PrimExp VName -> PrimExp VName
                   -> [LParam fromlore]
                   -> [LParam fromlore]
                   -> AllocM fromlore tolore ([LParam tolore], [LParam tolore])
allocInBinOpParams :: SubExp
-> PrimExp VName
-> PrimExp VName
-> [LParam fromlore]
-> [LParam fromlore]
-> AllocM fromlore tolore ([LParam tolore], [LParam tolore])
allocInBinOpParams SubExp
num_threads PrimExp VName
my_id PrimExp VName
other_id [LParam fromlore]
xs [LParam fromlore]
ys = [(Param LParamMem, Param LParamMem)]
-> ([Param LParamMem], [Param LParamMem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param LParamMem, Param LParamMem)]
 -> ([Param LParamMem], [Param LParamMem]))
-> AllocM fromlore tolore [(Param LParamMem, Param LParamMem)]
-> AllocM fromlore tolore ([Param LParamMem], [Param LParamMem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param Type
 -> Param Type
 -> AllocM fromlore tolore (Param LParamMem, Param LParamMem))
-> [Param Type]
-> [Param Type]
-> AllocM fromlore tolore [(Param LParamMem, Param LParamMem)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param Type
-> Param Type
-> AllocM fromlore tolore (Param LParamMem, Param LParamMem)
alloc [Param Type]
[LParam fromlore]
xs [Param Type]
[LParam fromlore]
ys
  where alloc :: Param Type
-> Param Type
-> AllocM fromlore tolore (Param LParamMem, Param LParamMem)
alloc Param Type
x Param Type
y =
          case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
x of
            Array PrimType
bt Shape
shape NoUniqueness
u -> do
              SubExp
twice_num_threads <-
                String
-> Exp (Lore (AllocM fromlore tolore))
-> AllocM fromlore tolore SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"twice_num_threads" (Exp (Lore (AllocM fromlore tolore))
 -> AllocM fromlore tolore SubExp)
-> Exp (Lore (AllocM fromlore tolore))
-> AllocM fromlore tolore SubExp
forall a b. (a -> b) -> a -> b
$
                BasicOp -> ExpT tolore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT tolore) -> BasicOp -> ExpT tolore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef) SubExp
num_threads (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
2
              let t :: Type
t = Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
x Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
twice_num_threads
              VName
mem <- Type -> Space -> AllocM fromlore tolore VName
forall lore (m :: * -> *).
Allocator lore m =>
Type -> Space -> m VName
allocForArray Type
t Space
DefaultSpace
              -- XXX: this iota ixfun is a bit inefficient; leading to
              -- uncoalesced access.
              let base_dims :: [PrimExp VName]
base_dims = (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
t)
                  ixfun_base :: IxFun (PrimExp VName)
ixfun_base = [PrimExp VName] -> IxFun (PrimExp VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [PrimExp VName]
base_dims
                  ixfun_x :: IxFun (PrimExp VName)
ixfun_x = IxFun (PrimExp VName)
-> Slice (PrimExp VName) -> IxFun (PrimExp VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (PrimExp VName)
ixfun_base (Slice (PrimExp VName) -> IxFun (PrimExp VName))
-> Slice (PrimExp VName) -> IxFun (PrimExp VName)
forall a b. (a -> b) -> a -> b
$
                            [PrimExp VName] -> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [PrimExp VName]
base_dims [PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> DimIndex d
DimFix PrimExp VName
my_id]
                  ixfun_y :: IxFun (PrimExp VName)
ixfun_y = IxFun (PrimExp VName)
-> Slice (PrimExp VName) -> IxFun (PrimExp VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (PrimExp VName)
ixfun_base (Slice (PrimExp VName) -> IxFun (PrimExp VName))
-> Slice (PrimExp VName) -> IxFun (PrimExp VName)
forall a b. (a -> b) -> a -> b
$
                            [PrimExp VName] -> Slice (PrimExp VName) -> Slice (PrimExp VName)
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [PrimExp VName]
base_dims [PrimExp VName -> DimIndex (PrimExp VName)
forall d. d -> DimIndex d
DimFix PrimExp VName
other_id]
              (Param LParamMem, Param LParamMem)
-> AllocM fromlore tolore (Param LParamMem, Param LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
x { paramDec :: LParamMem
paramDec = PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (PrimExp VName) -> MemBind
ArrayIn VName
mem IxFun (PrimExp VName)
ixfun_x },
                      Param Type
y { paramDec :: LParamMem
paramDec = PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
u (MemBind -> LParamMem) -> MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (PrimExp VName) -> MemBind
ArrayIn VName
mem IxFun (PrimExp VName)
ixfun_y })
            Prim PrimType
bt ->
              (Param LParamMem, Param LParamMem)
-> AllocM fromlore tolore (Param LParamMem, Param LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
x { paramDec :: LParamMem
paramDec = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt },
                      Param Type
y { paramDec :: LParamMem
paramDec = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt })
            Mem Space
space ->
              (Param LParamMem, Param LParamMem)
-> AllocM fromlore tolore (Param LParamMem, Param LParamMem)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param Type
x { paramDec :: LParamMem
paramDec = Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space },
                      Param Type
y { paramDec :: LParamMem
paramDec = Space -> LParamMem
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space })

allocInBinOpLambda :: Allocable fromlore tolore =>
                      SubExp -> SegSpace -> Lambda fromlore
                   -> AllocM fromlore tolore (Lambda tolore)
allocInBinOpLambda :: SubExp
-> SegSpace
-> Lambda fromlore
-> AllocM fromlore tolore (Lambda tolore)
allocInBinOpLambda SubExp
num_threads (SegSpace VName
flat [(VName, SubExp)]
_) Lambda fromlore
lam = do
  let ([Param Type]
acc_params, [Param Type]
arr_params) =
        Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda fromlore -> [LParam fromlore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda fromlore
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda fromlore -> [LParam fromlore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda fromlore
lam
      index_x :: PrimExp VName
index_x = VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
flat PrimType
int32
      index_y :: PrimExp VName
index_y = PrimExp VName
index_x PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
num_threads
  ([Param LParamMem]
acc_params', [Param LParamMem]
arr_params') <-
    SubExp
-> PrimExp VName
-> PrimExp VName
-> [LParam fromlore]
-> [LParam fromlore]
-> AllocM fromlore tolore ([LParam tolore], [LParam tolore])
forall fromlore tolore.
Allocable fromlore tolore =>
SubExp
-> PrimExp VName
-> PrimExp VName
-> [LParam fromlore]
-> [LParam fromlore]
-> AllocM fromlore tolore ([LParam tolore], [LParam tolore])
allocInBinOpParams SubExp
num_threads PrimExp VName
index_x PrimExp VName
index_y [Param Type]
[LParam fromlore]
acc_params [Param Type]
[LParam fromlore]
arr_params

  [LParam tolore]
-> Body fromlore
-> [Type]
-> AllocM fromlore tolore (Lambda tolore)
forall fromlore tolore.
Allocable fromlore tolore =>
[LParam tolore]
-> Body fromlore
-> [Type]
-> AllocM fromlore tolore (Lambda tolore)
allocInLambda ([Param LParamMem]
acc_params' [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ [Param LParamMem]
arr_params')
    (Lambda fromlore -> Body fromlore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda fromlore
lam) (Lambda fromlore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda fromlore
lam)