{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.Pass.ExplicitAllocations.SegOp
  ( allocInKernelBody,
    allocInBinOpLambda,
  )
where

import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Pass.ExplicitAllocations

instance SizeSubst (SegOp lvl rep) where
  opSizeSubst :: PatT dec -> SegOp lvl rep -> ChunkMap
opSizeSubst PatT dec
_ SegOp lvl rep
_ = ChunkMap
forall a. Monoid a => a
mempty

allocInKernelBody ::
  Allocable fromrep torep inner =>
  KernelBody fromrep ->
  AllocM fromrep torep (KernelBody torep)
allocInKernelBody :: KernelBody fromrep -> AllocM fromrep torep (KernelBody torep)
allocInKernelBody (KernelBody () Stms fromrep
stms [KernelResult]
res) =
  ([KernelResult] -> Stms torep -> KernelBody torep)
-> ([KernelResult], Stms torep) -> KernelBody torep
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms torep -> [KernelResult] -> KernelBody torep)
-> [KernelResult] -> Stms torep -> KernelBody torep
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec torep -> Stms torep -> [KernelResult] -> KernelBody torep
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ()))
    (([KernelResult], Stms torep) -> KernelBody torep)
-> AllocM fromrep torep ([KernelResult], Stms torep)
-> AllocM fromrep torep (KernelBody torep)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AllocM fromrep torep [KernelResult]
-> AllocM
     fromrep torep ([KernelResult], Stms (Rep (AllocM fromrep torep)))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (Stms fromrep
-> AllocM fromrep torep [KernelResult]
-> AllocM fromrep torep [KernelResult]
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms Stms fromrep
stms ([KernelResult] -> AllocM fromrep torep [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
res))

allocInLambda ::
  Allocable fromrep torep inner =>
  [LParam torep] ->
  Body fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInLambda :: [LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda [LParam torep]
params Body fromrep
body =
  [LParam (Rep (AllocM fromrep torep))]
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda (Rep (AllocM fromrep torep)))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam torep]
[LParam (Rep (AllocM fromrep torep))]
params (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> (AllocM fromrep torep Result -> AllocM fromrep torep Result)
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms fromrep
-> AllocM fromrep torep Result -> AllocM fromrep torep Result
forall fromrep torep inner a.
Allocable fromrep torep inner =>
Stms fromrep -> AllocM fromrep torep a -> AllocM fromrep torep a
allocInStms (Body fromrep -> Stms fromrep
forall rep. BodyT rep -> Stms rep
bodyStms Body fromrep
body) (AllocM fromrep torep Result
 -> AllocM fromrep torep (Lambda torep))
-> AllocM fromrep torep Result
-> AllocM fromrep torep (Lambda torep)
forall a b. (a -> b) -> a -> b
$
    Result -> AllocM fromrep torep Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> AllocM fromrep torep Result)
-> Result -> AllocM fromrep torep Result
forall a b. (a -> b) -> a -> b
$ Body fromrep -> Result
forall rep. BodyT rep -> Result
bodyResult Body fromrep
body

allocInBinOpParams ::
  Allocable fromrep torep inner =>
  SubExp ->
  TPrimExp Int64 VName ->
  TPrimExp Int64 VName ->
  [LParam fromrep] ->
  [LParam fromrep] ->
  AllocM fromrep torep ([LParam torep], [LParam torep])
allocInBinOpParams :: SubExp
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [LParam fromrep]
-> [LParam fromrep]
-> AllocM fromrep torep ([LParam torep], [LParam torep])
allocInBinOpParams SubExp
num_threads TPrimExp Int64 VName
my_id TPrimExp Int64 VName
other_id [LParam fromrep]
xs [LParam fromrep]
ys = [(Param (MemInfo SubExp NoUniqueness MemBind),
  Param (MemInfo SubExp NoUniqueness MemBind))]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
    [Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param (MemInfo SubExp NoUniqueness MemBind),
   Param (MemInfo SubExp NoUniqueness MemBind))]
 -> ([Param (MemInfo SubExp NoUniqueness MemBind)],
     [Param (MemInfo SubExp NoUniqueness MemBind)]))
-> AllocM
     fromrep
     torep
     [(Param (MemInfo SubExp NoUniqueness MemBind),
       Param (MemInfo SubExp NoUniqueness MemBind))]
-> AllocM
     fromrep
     torep
     ([Param (MemInfo SubExp NoUniqueness MemBind)],
      [Param (MemInfo SubExp NoUniqueness MemBind)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param Type
 -> Param Type
 -> AllocM
      fromrep
      torep
      (Param (MemInfo SubExp NoUniqueness MemBind),
       Param (MemInfo SubExp NoUniqueness MemBind)))
-> [Param Type]
-> [Param Type]
-> AllocM
     fromrep
     torep
     [(Param (MemInfo SubExp NoUniqueness MemBind),
       Param (MemInfo SubExp NoUniqueness MemBind))]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Param Type
-> Param Type
-> AllocM
     fromrep
     torep
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
alloc [Param Type]
[LParam fromrep]
xs [Param Type]
[LParam fromrep]
ys
  where
    alloc :: Param Type
-> Param Type
-> AllocM
     fromrep
     torep
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
alloc Param Type
x Param Type
y =
      case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
x of
        Array PrimType
pt Shape
shape NoUniqueness
u -> do
          SubExp
twice_num_threads <-
            String
-> Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"twice_num_threads" (Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep SubExp)
-> Exp (Rep (AllocM fromrep torep)) -> AllocM fromrep torep SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> ExpT torep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT torep) -> BasicOp -> ExpT torep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) SubExp
num_threads (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
2
          let t :: Type
t = Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
x Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
twice_num_threads
          VName
mem <- Type -> Space -> AllocM fromrep torep VName
forall fromrep torep inner.
Allocable fromrep torep inner =>
Type -> Space -> AllocM fromrep torep VName
allocForArray Type
t Space
DefaultSpace
          -- XXX: this iota ixfun is a bit inefficient; leading to
          -- uncoalesced access.
          let base_dims :: [TPrimExp Int64 VName]
base_dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
              ixfun_base :: IxFun (TPrimExp Int64 VName)
ixfun_base = [TPrimExp Int64 VName] -> IxFun (TPrimExp Int64 VName)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [TPrimExp Int64 VName]
base_dims
              ixfun_x :: IxFun (TPrimExp Int64 VName)
ixfun_x =
                IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun_base (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                  [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
base_dims [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
my_id]
              ixfun_y :: IxFun (TPrimExp Int64 VName)
ixfun_y =
                IxFun (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TPrimExp Int64 VName)
ixfun_base (Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName))
-> Slice (TPrimExp Int64 VName) -> IxFun (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
                  [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [TPrimExp Int64 VName]
base_dims [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
other_id]
          (Param (MemInfo SubExp NoUniqueness MemBind),
 Param (MemInfo SubExp NoUniqueness MemBind))
-> AllocM
     fromrep
     torep
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Param Type
x {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun_x},
              Param Type
y {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = PrimType
-> Shape
-> NoUniqueness
-> MemBind
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
pt Shape
shape NoUniqueness
u (MemBind -> MemInfo SubExp NoUniqueness MemBind)
-> MemBind -> MemInfo SubExp NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ VName -> IxFun (TPrimExp Int64 VName) -> MemBind
ArrayIn VName
mem IxFun (TPrimExp Int64 VName)
ixfun_y}
            )
        Prim PrimType
bt ->
          (Param (MemInfo SubExp NoUniqueness MemBind),
 Param (MemInfo SubExp NoUniqueness MemBind))
-> AllocM
     fromrep
     torep
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Param Type
x {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt},
              Param Type
y {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = PrimType -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim PrimType
bt}
            )
        Mem Space
space ->
          (Param (MemInfo SubExp NoUniqueness MemBind),
 Param (MemInfo SubExp NoUniqueness MemBind))
-> AllocM
     fromrep
     torep
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Param Type
x {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space},
              Param Type
y {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = Space -> MemInfo SubExp NoUniqueness MemBind
forall d u ret. Space -> MemInfo d u ret
MemMem Space
space}
            )
        -- This next case will never happen.
        Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
u ->
          (Param (MemInfo SubExp NoUniqueness MemBind),
 Param (MemInfo SubExp NoUniqueness MemBind))
-> AllocM
     fromrep
     torep
     (Param (MemInfo SubExp NoUniqueness MemBind),
      Param (MemInfo SubExp NoUniqueness MemBind))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( Param Type
x {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = VName
-> Shape
-> [Type]
-> NoUniqueness
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u},
              Param Type
y {paramDec :: MemInfo SubExp NoUniqueness MemBind
paramDec = VName
-> Shape
-> [Type]
-> NoUniqueness
-> MemInfo SubExp NoUniqueness MemBind
forall d u ret. VName -> Shape -> [Type] -> u -> MemInfo d u ret
MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
u}
            )

allocInBinOpLambda ::
  Allocable fromrep torep inner =>
  SubExp ->
  SegSpace ->
  Lambda fromrep ->
  AllocM fromrep torep (Lambda torep)
allocInBinOpLambda :: SubExp
-> SegSpace
-> Lambda fromrep
-> AllocM fromrep torep (Lambda torep)
allocInBinOpLambda SubExp
num_threads (SegSpace VName
flat [(VName, SubExp)]
_) Lambda fromrep
lam = do
  let ([Param Type]
acc_params, [Param Type]
arr_params) =
        Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Lambda fromrep -> [LParam fromrep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda fromrep
lam) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda fromrep -> [LParam fromrep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda fromrep
lam
      index_x :: TPrimExp Int64 VName
index_x = PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
flat PrimType
int64
      index_y :: TPrimExp Int64 VName
index_y = TPrimExp Int64 VName
index_x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ SubExp -> TPrimExp Int64 VName
pe64 SubExp
num_threads
  ([Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params', [Param (MemInfo SubExp NoUniqueness MemBind)]
arr_params') <-
    SubExp
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [LParam fromrep]
-> [LParam fromrep]
-> AllocM fromrep torep ([LParam torep], [LParam torep])
forall fromrep torep inner.
Allocable fromrep torep inner =>
SubExp
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> [LParam fromrep]
-> [LParam fromrep]
-> AllocM fromrep torep ([LParam torep], [LParam torep])
allocInBinOpParams SubExp
num_threads TPrimExp Int64 VName
index_x TPrimExp Int64 VName
index_y [Param Type]
[LParam fromrep]
acc_params [Param Type]
[LParam fromrep]
arr_params

  [LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
forall fromrep torep inner.
Allocable fromrep torep inner =>
[LParam torep]
-> Body fromrep -> AllocM fromrep torep (Lambda torep)
allocInLambda ([Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params' [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++ [Param (MemInfo SubExp NoUniqueness MemBind)]
arr_params') (Lambda fromrep -> Body fromrep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda fromrep
lam)