{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Execute
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Execute (

  executeAcc,
  executeOpenAcc,

) where

import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Lifetime
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.Execute

import Data.Array.Accelerate.LLVM.PTX.Analysis.Launch           ( multipleOf )
import Data.Array.Accelerate.LLVM.PTX.Array.Data
import Data.Array.Accelerate.LLVM.PTX.Array.Prim                ( memsetArrayAsync )
import Data.Array.Accelerate.LLVM.PTX.Execute.Async
import Data.Array.Accelerate.LLVM.PTX.Execute.Environment
import Data.Array.Accelerate.LLVM.PTX.Execute.Marshal
import Data.Array.Accelerate.LLVM.PTX.Execute.Stream            ( Stream )
import Data.Array.Accelerate.LLVM.PTX.Link
import Data.Array.Accelerate.LLVM.PTX.Target
import qualified Data.Array.Accelerate.LLVM.PTX.Debug           as Debug
import qualified Data.Array.Accelerate.LLVM.PTX.Execute.Event   as Event

import qualified Foreign.CUDA.Driver                            as CUDA

import Control.Monad                                            ( when, forM_ )
import Control.Monad.Reader                                     ( asks, local )
import Control.Monad.State                                      ( liftIO )
import Data.ByteString.Short.Char8                              ( ShortByteString, unpack )
import qualified Data.DList                                     as DL
import Data.List                                                ( find )
import Data.Maybe                                               ( fromMaybe )
import Text.Printf                                              ( printf )
import Prelude                                                  hiding ( exp, map, sum, scanl, scanr )


{-# SPECIALISE INLINE executeAcc     :: ExecAcc     PTX      a ->             Par PTX (FutureArraysR PTX a) #-}
{-# SPECIALISE INLINE executeOpenAcc :: ExecOpenAcc PTX aenv a -> Val aenv -> Par PTX (FutureArraysR PTX a) #-}

-- Array expression evaluation
-- ---------------------------

-- Computations are evaluated by traversing the AST bottom up, and for each node
-- distinguishing between three cases:
--
--  1. If it is a Use node, we return a reference to the array data. The data
--     will already have been copied to the device during compilation of the
--     kernels.
--
--  2. If it is a non-skeleton node, such as a let binding or shape conversion,
--     then execute directly by updating the environment or similar.
--
--  3. If it is a skeleton node, then we need to execute the generated LLVM
--     code.
--
instance Execute PTX where
  {-# INLINE map         #-}
  {-# INLINE generate    #-}
  {-# INLINE transform   #-}
  {-# INLINE backpermute #-}
  {-# INLINE fold        #-}
  {-# INLINE foldSeg     #-}
  {-# INLINE scan        #-}
  {-# INLINE scan'       #-}
  {-# INLINE permute     #-}
  {-# INLINE stencil1    #-}
  {-# INLINE stencil2    #-}
  {-# INLINE aforeign    #-}
  map :: Maybe (a :~: b)
-> ArrayR (Array sh a)
-> TypeR b
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Array sh a
-> Par PTX (FutureR PTX (Array sh b))
map           = Maybe (a :~: b)
-> ArrayR (Array sh a)
-> TypeR b
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Array sh a
-> Par PTX (FutureR PTX (Array sh b))
forall a b sh aenv.
HasCallStack =>
Maybe (a :~: b)
-> ArrayR (Array sh a)
-> TypeR b
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Array sh a
-> Par PTX (Future (Array sh b))
mapOp
  generate :: ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> sh
-> Par PTX (FutureR PTX (Array sh e))
generate      = ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> sh
-> Par PTX (FutureR PTX (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
generateOp
  transform :: ArrayR (Array sh a)
-> ArrayR (Array sh' b)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> sh'
-> Array sh a
-> Par PTX (FutureR PTX (Array sh' b))
transform     = ArrayR (Array sh a)
-> ArrayR (Array sh' b)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> sh'
-> Array sh a
-> Par PTX (FutureR PTX (Array sh' b))
forall sh a sh' b aenv.
HasCallStack =>
ArrayR (Array sh a)
-> ArrayR (Array sh' b)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh'
-> Array sh a
-> Par PTX (Future (Array sh' b))
transformOp
  backpermute :: ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> sh'
-> Array sh e
-> Par PTX (FutureR PTX (Array sh' e))
backpermute   = ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> sh'
-> Array sh e
-> Par PTX (FutureR PTX (Array sh' e))
forall sh e sh' aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh'
-> Array sh e
-> Par PTX (Future (Array sh' e))
backpermuteOp
  fold :: HasInitialValue
-> ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array sh e))
fold HasInitialValue
True     = ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
foldOp
  fold HasInitialValue
False    = ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
fold1Op
  foldSeg :: IntegralType i
-> HasInitialValue
-> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Delayed (Segments i)
-> Par PTX (FutureR PTX (Array (sh, Int) e))
foldSeg IntegralType i
i HasInitialValue
_   = IntegralType i
-> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Delayed (Segments i)
-> Par PTX (Future (Array (sh, Int) e))
forall i sh e aenv.
HasCallStack =>
IntegralType i
-> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Delayed (Segments i)
-> Par PTX (Future (Array (sh, Int) e))
foldSegOp IntegralType i
i
  scan :: Direction
-> HasInitialValue
-> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array (sh, Int) e))
scan Direction
_ HasInitialValue
True   = ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array (sh, Int) e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scanOp
  scan Direction
_ HasInitialValue
False  = ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array (sh, Int) e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scan1Op
  scan' :: Direction
-> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array (sh, Int) e, Array sh e))
scan' Direction
_       = ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (FutureR PTX (Array (sh, Int) e, Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'Op
  permute :: HasInitialValue
-> ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Array sh' e
-> Delayed (Array sh e)
-> Par PTX (FutureR PTX (Array sh' e))
permute       = HasInitialValue
-> ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Array sh' e
-> Delayed (Array sh e)
-> Par PTX (FutureR PTX (Array sh' e))
forall sh e sh' aenv.
HasCallStack =>
HasInitialValue
-> ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Array sh' e
-> Delayed (Array sh e)
-> Par PTX (Future (Array sh' e))
permuteOp
  stencil1 :: TypeR a
-> ArrayR (Array sh b)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array sh a)
-> Par PTX (FutureR PTX (Array sh b))
stencil1      = TypeR a
-> ArrayR (Array sh b)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array sh a)
-> Par PTX (FutureR PTX (Array sh b))
forall a sh b aenv.
HasCallStack =>
TypeR a
-> ArrayR (Array sh b)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array sh a)
-> Par PTX (Future (Array sh b))
stencil1Op
  stencil2 :: TypeR a
-> TypeR b
-> ArrayR (Array sh c)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array sh a)
-> Delayed (Array sh b)
-> Par PTX (FutureR PTX (Array sh c))
stencil2      = TypeR a
-> TypeR b
-> ArrayR (Array sh c)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> ValR PTX aenv
-> Delayed (Array sh a)
-> Delayed (Array sh b)
-> Par PTX (FutureR PTX (Array sh c))
forall a b sh c aenv.
HasCallStack =>
TypeR a
-> TypeR b
-> ArrayR (Array sh c)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array sh a)
-> Delayed (Array sh b)
-> Par PTX (Future (Array sh c))
stencil2Op
  aforeign :: String
-> ArraysR as
-> ArraysR bs
-> (as -> Par PTX (FutureR PTX bs))
-> as
-> Par PTX (FutureR PTX bs)
aforeign      = String
-> ArraysR as
-> ArraysR bs
-> (as -> Par PTX (FutureR PTX bs))
-> as
-> Par PTX (FutureR PTX bs)
forall as bs.
HasCallStack =>
String
-> ArraysR as
-> ArraysR bs
-> (as -> Par PTX (Future bs))
-> as
-> Par PTX (Future bs)
aforeignOp


-- Skeleton implementation
-- -----------------------

-- Simple kernels just need to know the shape of the output array
--
{-# INLINE simpleOp #-}
simpleOp
    :: HasCallStack
    => ShortByteString
    -> ArrayR (Array sh e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> sh
    -> Par PTX (Future (Array sh e))
simpleOp :: ShortByteString
-> ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
simpleOp ShortByteString
name ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh
sh =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array sh e)))
-> Par PTX (Future (Array sh e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array sh e)))
 -> Par PTX (Future (Array sh e)))
-> (FunctionTable -> Par PTX (Future (Array sh e)))
-> Par PTX (Future (Array sh e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    Future (Array sh e)
future <- Par PTX (Future (Array sh e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array sh e
result <- ArrayR (Array sh e) -> sh -> Par PTX (Array sh e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array sh e)
repr sh
sh
    --
    let paramR :: TupR (ParamR PTX) (Array sh e)
paramR = ParamR PTX (Array sh e) -> TupR (ParamR PTX) (Array sh e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array sh e) -> TupR (ParamR PTX) (Array sh e))
-> ParamR PTX (Array sh e) -> TupR (ParamR PTX) (Array sh e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh e) -> ParamR PTX (Array sh e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh e)
repr
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> TupR (ParamR PTX) (Array sh e)
-> Array sh e
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp (FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
name) Gamma aenv
gamma Val aenv
aenv (ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr) sh
sh TupR (ParamR PTX) (Array sh e)
paramR Array sh e
result
    FutureR PTX (Array sh e) -> Array sh e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array sh e)
Future (Array sh e)
future Array sh e
result
    Future (Array sh e) -> Par PTX (Future (Array sh e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array sh e)
future

-- Mapping over an array can ignore the dimensionality of the array and
-- treat it as its underlying linear representation.
--
{-# INLINE mapOp #-}
mapOp
    :: HasCallStack
    => Maybe (a :~: b)
    -> ArrayR (Array sh a)
    -> TypeR b
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Array sh a
    -> Par PTX (Future (Array sh b))
mapOp :: Maybe (a :~: b)
-> ArrayR (Array sh a)
-> TypeR b
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Array sh a
-> Par PTX (Future (Array sh b))
mapOp Maybe (a :~: b)
inplace ArrayR (Array sh a)
repr TypeR b
tp ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Array sh a
input@(Array sh a -> sh
forall sh e. Array sh e -> sh
shape -> sh
sh) =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array sh b)))
-> Par PTX (Future (Array sh b))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array sh b)))
 -> Par PTX (Future (Array sh b)))
-> (FunctionTable -> Par PTX (Future (Array sh b)))
-> Par PTX (Future (Array sh b))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    let reprOut :: ArrayR (Array sh b)
reprOut = ShapeR sh -> TypeR b -> ArrayR (Array sh b)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR (ArrayR (Array sh a) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh a)
repr) TypeR b
tp
    Future (Array sh b)
future <- Par PTX (Future (Array sh b))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array sh b
result <- case Maybe (a :~: b)
inplace of
                Just a :~: b
Refl -> Array sh a -> Par PTX (Array sh a)
forall (m :: * -> *) a. Monad m => a -> m a
return Array sh a
input
                Maybe (a :~: b)
Nothing   -> ArrayR (Array sh b) -> sh -> Par PTX (Array sh b)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array sh b)
reprOut sh
sh
    --
    let paramsR :: TupR (ParamR PTX) (Array sh b, Array sh a)
paramsR = ParamR PTX (Array sh b) -> TupR (ParamR PTX) (Array sh b)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh b) -> ParamR PTX (Array sh b)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh b)
reprOut) TupR (ParamR PTX) (Array sh b)
-> TupR (ParamR PTX) (Array sh a)
-> TupR (ParamR PTX) (Array sh b, Array sh a)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Array sh a) -> TupR (ParamR PTX) (Array sh a)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh a) -> ParamR PTX (Array sh a)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh a)
repr)
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> TupR (ParamR PTX) (Array sh b, Array sh a)
-> (Array sh b, Array sh a)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp (FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"map") Gamma aenv
gamma Val aenv
aenv (ArrayR (Array sh a) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh a)
repr) sh
sh TupR (ParamR PTX) (Array sh b, Array sh a)
paramsR (Array sh b
result, Array sh a
input)
    FutureR PTX (Array sh b) -> Array sh b -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array sh b)
Future (Array sh b)
future Array sh b
result
    Future (Array sh b) -> Par PTX (Future (Array sh b))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array sh b)
future

{-# INLINE generateOp #-}
generateOp
    :: HasCallStack
    => ArrayR (Array sh e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> sh
    -> Par PTX (Future (Array sh e))
generateOp :: ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
generateOp = ShortByteString
-> ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
forall sh e aenv.
HasCallStack =>
ShortByteString
-> ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
simpleOp ShortByteString
"generate"

{-# INLINE transformOp #-}
transformOp
    :: HasCallStack
    => ArrayR (Array sh a)
    -> ArrayR (Array sh' b)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> sh'
    -> Array sh a
    -> Par PTX (Future (Array sh' b))
transformOp :: ArrayR (Array sh a)
-> ArrayR (Array sh' b)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh'
-> Array sh a
-> Par PTX (Future (Array sh' b))
transformOp ArrayR (Array sh a)
repr ArrayR (Array sh' b)
repr' ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh'
sh' Array sh a
input =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array sh' b)))
-> Par PTX (Future (Array sh' b))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array sh' b)))
 -> Par PTX (Future (Array sh' b)))
-> (FunctionTable -> Par PTX (Future (Array sh' b)))
-> Par PTX (Future (Array sh' b))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    Future (Array sh' b)
future <- Par PTX (Future (Array sh' b))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array sh' b
result <- ArrayR (Array sh' b) -> sh' -> Par PTX (Array sh' b)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array sh' b)
repr' sh'
sh'
    let paramsR :: TupR (ParamR PTX) (Array sh' b, Array sh a)
paramsR = ParamR PTX (Array sh' b) -> TupR (ParamR PTX) (Array sh' b)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh' b) -> ParamR PTX (Array sh' b)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh' b)
repr') TupR (ParamR PTX) (Array sh' b)
-> TupR (ParamR PTX) (Array sh a)
-> TupR (ParamR PTX) (Array sh' b, Array sh a)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Array sh a) -> TupR (ParamR PTX) (Array sh a)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh a) -> ParamR PTX (Array sh a)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh a)
repr)
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh'
-> sh'
-> TupR (ParamR PTX) (Array sh' b, Array sh a)
-> (Array sh' b, Array sh a)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp (FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"transform") Gamma aenv
gamma Val aenv
aenv (ArrayR (Array sh' b) -> ShapeR sh'
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh' b)
repr') sh'
sh' TupR (ParamR PTX) (Array sh' b, Array sh a)
paramsR (Array sh' b
result, Array sh a
input)
    FutureR PTX (Array sh' b) -> Array sh' b -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array sh' b)
Future (Array sh' b)
future Array sh' b
result
    Future (Array sh' b) -> Par PTX (Future (Array sh' b))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array sh' b)
future

{-# INLINE backpermuteOp #-}
backpermuteOp
    :: HasCallStack
    => ArrayR (Array sh e)
    -> ShapeR sh'
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> sh'
    -> Array sh e
    -> Par PTX (Future (Array sh' e))
backpermuteOp :: ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh'
-> Array sh e
-> Par PTX (Future (Array sh' e))
backpermuteOp (ArrayR ShapeR sh
shr TypeR e
tp) ShapeR sh'
shr' = ArrayR (Array sh e)
-> ArrayR (Array sh' e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh'
-> Array sh e
-> Par PTX (Future (Array sh' e))
forall sh a sh' b aenv.
HasCallStack =>
ArrayR (Array sh a)
-> ArrayR (Array sh' b)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh'
-> Array sh a
-> Par PTX (Future (Array sh' b))
transformOp (ShapeR sh -> TypeR e -> ArrayR (Array sh e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shr TypeR e
tp) (ShapeR sh' -> TypeR e -> ArrayR (Array sh' e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh'
shr' TypeR e
tp)

-- There are two flavours of fold operation:
--
--   1. If we are collapsing to a single value, then multiple thread blocks are
--      working together. Since thread blocks synchronise with each other via
--      kernel launches, each block computes a partial sum and the kernel is
--      launched recursively until the final value is reached.
--
--   2. If this is a multidimensional reduction, then each inner dimension is
--      handled by a single thread block, so no global communication is
--      necessary. Furthermore are two kernel flavours: each innermost dimension
--      can be cooperatively reduced by (a) a thread warp; or (b) a thread
--      block. Currently we always use the first, but require benchmarking to
--      determine when to select each.
--
{-# INLINE fold1Op #-}
fold1Op
    :: HasCallStack
    => ArrayR (Array sh e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array sh e))
fold1Op :: ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
fold1Op ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv arr :: Delayed (Array (sh, Int) e)
arr@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> sh :: (sh, Int)
sh@(sh
sx, Int
sz))
  = String
-> HasInitialValue
-> Par PTX (Future (Array sh e))
-> Par PTX (Future (Array sh e))
forall a. HasCallStack => String -> HasInitialValue -> a -> a
boundsCheck String
"empty array" (Int
sz Int -> Int -> HasInitialValue
forall a. Ord a => a -> a -> HasInitialValue
> Int
0)
  (Par PTX (Future (Array sh e)) -> Par PTX (Future (Array sh e)))
-> Par PTX (Future (Array sh e)) -> Par PTX (Future (Array sh e))
forall a b. (a -> b) -> a -> b
$ case ShapeR (sh, Int) -> (sh, Int) -> Int
forall sh. ShapeR sh -> sh -> Int
size (ShapeR sh -> ShapeR (sh, Int)
forall sh1. ShapeR sh1 -> ShapeR (sh1, Int)
ShapeRsnoc (ShapeR sh -> ShapeR (sh, Int)) -> ShapeR sh -> ShapeR (sh, Int)
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr) (sh, Int)
sh of
      Int
0 -> Array sh e -> Par PTX (Future (Array sh e))
forall arch a.
(Async arch, HasCallStack) =>
a -> Par arch (FutureR arch a)
newFull (Array sh e -> Par PTX (Future (Array sh e)))
-> Par PTX (Array sh e) -> Par PTX (Future (Array sh e))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ArrayR (Array sh e) -> sh -> Par PTX (Array sh e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array sh e)
repr sh
sx  -- empty, but possibly with one or more non-zero dimensions
      Int
_ -> ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
foldCore ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
arr

{-# INLINE foldOp #-}
foldOp
    :: HasCallStack
    => ArrayR (Array sh e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array sh e))
foldOp :: ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
foldOp ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv arr :: Delayed (Array (sh, Int) e)
arr@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> sh :: (sh, Int)
sh@(sh
sx, Int
_))
  = case ShapeR (sh, Int) -> (sh, Int) -> Int
forall sh. ShapeR sh -> sh -> Int
size (ShapeR sh -> ShapeR (sh, Int)
forall sh1. ShapeR sh1 -> ShapeR (sh1, Int)
ShapeRsnoc (ShapeR sh -> ShapeR (sh, Int)) -> ShapeR sh -> ShapeR (sh, Int)
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr) (sh, Int)
sh of
      Int
0 -> ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
generateOp ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh
sx
      Int
_ -> ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
foldCore ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
arr

{-# INLINE foldCore #-}
foldCore
    :: HasCallStack
    => ArrayR (Array sh e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array sh e))
foldCore :: ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
foldCore ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
arr
  | ArrayR ShapeR sh
ShapeRz TypeR e
tp <- ArrayR (Array sh e)
repr
  = TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Vector e)
-> Par PTX (Future (Scalar e))
forall aenv e.
HasCallStack =>
TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Vector e)
-> Par PTX (Future (Scalar e))
foldAllOp TypeR e
tp ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
Delayed (Vector e)
arr
  --
  | HasInitialValue
otherwise
  = ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
foldDimOp ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
arr

{-# INLINE foldAllOp #-}
foldAllOp
    :: forall aenv e. HasCallStack
    => TypeR e
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Vector e)
    -> Par PTX (Future (Scalar e))
foldAllOp :: TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Vector e)
-> Par PTX (Future (Scalar e))
foldAllOp TypeR e
tp ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Vector e)
input =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Scalar e)))
-> Par PTX (Future (Scalar e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Scalar e)))
 -> Par PTX (Future (Scalar e)))
-> (FunctionTable -> Par PTX (Future (Scalar e)))
-> Par PTX (Future (Scalar e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    Future (Scalar e)
future <- Par PTX (Future (Scalar e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    let
        ks :: Kernel
ks        = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"foldAllS"
        km1 :: Kernel
km1       = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"foldAllM1"
        km2 :: Kernel
km2       = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"foldAllM2"
        sh :: ((), Int)
sh@((), Int
n) = Delayed (Vector e) -> ((), Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape Delayed (Vector e)
input
        paramsRinput :: TupR (ParamR PTX) (Maybe (Vector e))
paramsRinput = ParamR PTX (Maybe (Vector e))
-> TupR (ParamR PTX) (Maybe (Vector e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Maybe (Vector e))
 -> TupR (ParamR PTX) (Maybe (Vector e)))
-> ParamR PTX (Maybe (Vector e))
-> TupR (ParamR PTX) (Maybe (Vector e))
forall a b. (a -> b) -> a -> b
$ ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e)))
-> ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Vector e) -> ParamR PTX (Vector e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Vector e) -> ParamR PTX (Vector e))
-> ArrayR (Vector e) -> ParamR PTX (Vector e)
forall a b. (a -> b) -> a -> b
$ ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
tp
        paramsRdim0 :: TupR (ParamR PTX) (Scalar e)
paramsRdim0  = ParamR PTX (Scalar e) -> TupR (ParamR PTX) (Scalar e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Scalar e) -> TupR (ParamR PTX) (Scalar e))
-> ParamR PTX (Scalar e) -> TupR (ParamR PTX) (Scalar e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Scalar e) -> ParamR PTX (Scalar e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Scalar e) -> ParamR PTX (Scalar e))
-> ArrayR (Scalar e) -> ParamR PTX (Scalar e)
forall a b. (a -> b) -> a -> b
$ ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
tp
        paramsRdim1 :: TupR (ParamR PTX) (Vector e)
paramsRdim1  = ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e))
-> ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Vector e) -> ParamR PTX (Vector e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Vector e) -> ParamR PTX (Vector e))
-> ArrayR (Vector e) -> ParamR PTX (Vector e)
forall a b. (a -> b) -> a -> b
$ ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
tp
    --
    if Kernel -> Int -> Int
kernelThreadBlocks Kernel
ks Int
n Int -> Int -> HasInitialValue
forall a. Eq a => a -> a -> HasInitialValue
== Int
1
      then do
        -- The array is small enough that we can compute it in a single step
        Scalar e
result <- ArrayR (Scalar e) -> () -> Par PTX (Scalar e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote (ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
tp) ()
        let paramsR :: TupR (ParamR PTX) (Scalar e, Maybe (Vector e))
paramsR = TupR (ParamR PTX) (Scalar e)
paramsRdim0 TupR (ParamR PTX) (Scalar e)
-> TupR (ParamR PTX) (Maybe (Vector e))
-> TupR (ParamR PTX) (Scalar e, Maybe (Vector e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Maybe (Vector e))
paramsRinput
        Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) (Scalar e, Maybe (Vector e))
-> (Scalar e, Maybe (Vector e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
ks Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int)
sh TupR (ParamR PTX) (Scalar e, Maybe (Vector e))
paramsR (Scalar e
result, Delayed (Vector e) -> Maybe (Vector e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Vector e)
input)
        FutureR PTX (Scalar e) -> Scalar e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Scalar e)
Future (Scalar e)
future Scalar e
result

      else do
        -- Multi-kernel reduction to a single element. The first kernel integrates
        -- any delayed elements, and the second is called recursively until
        -- reaching a single element.
        let
            rec :: Vector e -> Par PTX ()
            rec :: Vector e -> Par PTX ()
rec tmp :: Vector e
tmp@(Array ((),Int
m) ArrayData e
adata)
              | Int
m Int -> Int -> HasInitialValue
forall a. Ord a => a -> a -> HasInitialValue
<= Int
1    = FutureR PTX (Scalar e) -> Scalar e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Scalar e)
Future (Scalar e)
future (() -> ArrayData e -> Scalar e
forall sh e. sh -> ArrayData e -> Array sh e
Array () ArrayData e
adata)
              | HasInitialValue
otherwise = do
                  let sh' :: ((), Int)
sh' = ((), Int
m Int -> Int -> Int
`multipleOf` Kernel -> Int
kernelThreadBlockSize Kernel
km2)
                  Vector e
out <- ArrayR (Vector e) -> ((), Int) -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
tp) ((), Int)
sh'
                  let paramsR2 :: TupR (ParamR PTX) (Vector e, Vector e)
paramsR2 = TupR (ParamR PTX) (Vector e)
paramsRdim1 TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e, Vector e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Vector e)
paramsRdim1
                  Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) (Vector e, Vector e)
-> (Vector e, Vector e)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
km2 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int)
sh' TupR (ParamR PTX) (Vector e, Vector e)
paramsR2 (Vector e
tmp, Vector e
out)
                  Vector e -> Par PTX ()
rec Vector e
out
        --
        let sh' :: ((), Int)
sh' = ((), Int
n Int -> Int -> Int
`multipleOf` Kernel -> Int
kernelThreadBlockSize Kernel
km1)
        Vector e
tmp <- ArrayR (Vector e) -> ((), Int) -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote (ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
tp) ((), Int)
sh'
        let paramsR1 :: TupR (ParamR PTX) (Vector e, Maybe (Vector e))
paramsR1 = TupR (ParamR PTX) (Vector e)
paramsRdim1 TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Maybe (Vector e))
-> TupR (ParamR PTX) (Vector e, Maybe (Vector e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Maybe (Vector e))
paramsRinput
        Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) (Vector e, Maybe (Vector e))
-> (Vector e, Maybe (Vector e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
km1 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int)
sh' TupR (ParamR PTX) (Vector e, Maybe (Vector e))
paramsR1 (Vector e
tmp, Delayed (Vector e) -> Maybe (Vector e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Vector e)
input)
        Vector e -> Par PTX ()
rec Vector e
tmp
    --
    Future (Scalar e) -> Par PTX (Future (Scalar e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Scalar e)
future


{-# INLINE foldDimOp #-}
foldDimOp
    :: HasCallStack
    => ArrayR (Array sh e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array sh e))
foldDimOp :: ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array sh e))
foldDimOp repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Array (sh, Int) e)
input@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> (sh
sh, Int
sz))
  | Int
sz Int -> Int -> HasInitialValue
forall a. Eq a => a -> a -> HasInitialValue
== Int
0   = ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
generateOp ArrayR (Array sh e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh
sh
  | HasInitialValue
otherwise =
    ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array sh e)))
-> Par PTX (Future (Array sh e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array sh e)))
 -> Par PTX (Future (Array sh e)))
-> (FunctionTable -> Par PTX (Future (Array sh e)))
-> Par PTX (Future (Array sh e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
      Future (Array sh e)
future <- Par PTX (Future (Array sh e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
      Array sh e
result <- ArrayR (Array sh e) -> sh -> Par PTX (Array sh e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array sh e)
repr sh
sh
      --
      let paramsR :: TupR (ParamR PTX) (Array sh e, Maybe (Array (sh, Int) e))
paramsR = ParamR PTX (Array sh e) -> TupR (ParamR PTX) (Array sh e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh e) -> ParamR PTX (Array sh e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh e)
repr) TupR (ParamR PTX) (Array sh e)
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Array sh e, Maybe (Array (sh, Int) e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array (sh, Int) e)
 -> ParamR PTX (Maybe (Array (sh, Int) e)))
-> ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e))
-> ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall a b. (a -> b) -> a -> b
$ ShapeR (sh, Int) -> TypeR e -> ArrayR (Array (sh, Int) e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR (ShapeR sh -> ShapeR (sh, Int)
forall sh1. ShapeR sh1 -> ShapeR (sh1, Int)
ShapeRsnoc ShapeR sh
shr) TypeR e
tp)
      Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> TupR (ParamR PTX) (Array sh e, Maybe (Array (sh, Int) e))
-> (Array sh e, Maybe (Array (sh, Int) e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp (FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"fold") Gamma aenv
gamma Val aenv
aenv ShapeR sh
shr sh
sh
sh TupR (ParamR PTX) (Array sh e, Maybe (Array (sh, Int) e))
paramsR (Array sh e
result, Delayed (Array (sh, Int) e) -> Maybe (Array (sh, Int) e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array (sh, Int) e)
input)
      FutureR PTX (Array sh e) -> Array sh e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array sh e)
Future (Array sh e)
future Array sh e
result
      Future (Array sh e) -> Par PTX (Future (Array sh e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array sh e)
future


{-# INLINE foldSegOp #-}
foldSegOp
    :: HasCallStack
    => IntegralType i
    -> ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Delayed (Segments i)
    -> Par PTX (Future (Array (sh, Int) e))
foldSegOp :: IntegralType i
-> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Delayed (Segments i)
-> Par PTX (Future (Array (sh, Int) e))
foldSegOp IntegralType i
intTp ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Array (sh, Int) e)
input@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> (sh
sh, Int
sz)) segments :: Delayed (Segments i)
segments@(Delayed (Segments i) -> ((), Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> ((), Int
ss)) =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array (sh, Int) e)))
-> Par PTX (Future (Array (sh, Int) e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array (sh, Int) e)))
 -> Par PTX (Future (Array (sh, Int) e)))
-> (FunctionTable -> Par PTX (Future (Array (sh, Int) e)))
-> Par PTX (Future (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    let
        ArrayR (ShapeRsnoc ShapeR sh1
shr') TypeR e
_ = ArrayR (Array (sh, Int) e)
repr
        reprSeg :: ArrayR (Segments i)
reprSeg = ShapeR ((), Int) -> TypeR i -> ArrayR (Segments i)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 (TypeR i -> ArrayR (Segments i)) -> TypeR i -> ArrayR (Segments i)
forall a b. (a -> b) -> a -> b
$ ScalarType i -> TypeR i
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ScalarType i -> TypeR i) -> ScalarType i -> TypeR i
forall a b. (a -> b) -> a -> b
$ SingleType i -> ScalarType i
forall a. SingleType a -> ScalarType a
SingleScalarType (SingleType i -> ScalarType i) -> SingleType i -> ScalarType i
forall a b. (a -> b) -> a -> b
$ NumType i -> SingleType i
forall a. NumType a -> SingleType a
NumSingleType (NumType i -> SingleType i) -> NumType i -> SingleType i
forall a b. (a -> b) -> a -> b
$ IntegralType i -> NumType i
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType i
intTp
        n :: Int
n       = Int
ss Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1  -- segments array has been 'scanl (+) 0'`ed
        m :: Int
m       = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shr' sh
sh Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n
        foldseg :: Kernel
foldseg = if (Int
szInt -> Int -> Int
forall a. Integral a => a -> a -> a
`quot`Int
ss) Int -> Int -> HasInitialValue
forall a. Ord a => a -> a -> HasInitialValue
< (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Kernel -> Int
kernelThreadBlockSize Kernel
foldseg_cta)
                    then Kernel
foldseg_warp
                    else Kernel
foldseg_cta
        --
        foldseg_cta :: Kernel
foldseg_cta   = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"foldSeg_block"
        foldseg_warp :: Kernel
foldseg_warp  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"foldSeg_warp"
        -- qinit         = ptxExecutable !# "qinit"
    --
    Future (Array (sh, Int) e)
future  <- Par PTX (Future (Array (sh, Int) e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array (sh, Int) e
result  <- ArrayR (Array (sh, Int) e)
-> (sh, Int) -> Par PTX (Array (sh, Int) e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array (sh, Int) e)
repr (sh
sh, Int
n)
    let paramsR :: TupR
  (ParamR PTX)
  ((Array (sh, Int) e, Maybe (Array (sh, Int) e)),
   Maybe (Segments i))
paramsR = ParamR PTX (Array (sh, Int) e)
-> TupR (ParamR PTX) (Array (sh, Int) e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array (sh, Int) e)
repr) TupR (ParamR PTX) (Array (sh, Int) e)
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Array (sh, Int) e, Maybe (Array (sh, Int) e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array (sh, Int) e)
 -> ParamR PTX (Maybe (Array (sh, Int) e)))
-> ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array (sh, Int) e)
repr) TupR (ParamR PTX) (Array (sh, Int) e, Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Maybe (Segments i))
-> TupR
     (ParamR PTX)
     ((Array (sh, Int) e, Maybe (Array (sh, Int) e)),
      Maybe (Segments i))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Segments i))
-> TupR (ParamR PTX) (Maybe (Segments i))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Segments i) -> ParamR PTX (Maybe (Segments i))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Segments i) -> ParamR PTX (Maybe (Segments i)))
-> ParamR PTX (Segments i) -> ParamR PTX (Maybe (Segments i))
forall a b. (a -> b) -> a -> b
$ ArrayR (Segments i) -> ParamR PTX (Segments i)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Segments i)
reprSeg)
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR
     (ParamR PTX)
     ((Array (sh, Int) e, Maybe (Array (sh, Int) e)),
      Maybe (Segments i))
-> ((Array (sh, Int) e, Maybe (Array (sh, Int) e)),
    Maybe (Segments i))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
foldseg Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
m) TupR
  (ParamR PTX)
  ((Array (sh, Int) e, Maybe (Array (sh, Int) e)),
   Maybe (Segments i))
paramsR ((Array (sh, Int) e
result, Delayed (Array (sh, Int) e) -> Maybe (Array (sh, Int) e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array (sh, Int) e)
input), Delayed (Segments i) -> Maybe (Segments i)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Segments i)
segments)
    FutureR PTX (Array (sh, Int) e) -> Array (sh, Int) e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array (sh, Int) e)
Future (Array (sh, Int) e)
future Array (sh, Int) e
result
    Future (Array (sh, Int) e) -> Par PTX (Future (Array (sh, Int) e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array (sh, Int) e)
future


{-# INLINE scanOp #-}
scanOp
    :: HasCallStack
    => ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array (sh, Int) e))
scanOp :: ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scanOp ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Array (sh, Int) e)
input@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> (sh
sz, Int
n)) =
  case Int
n of
    Int
0 -> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> (sh, Int)
-> Par PTX (Future (Array (sh, Int) e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
generateOp ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv (sh
sz, Int
1)
    Int
_ -> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scanCore ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Delayed (Array (sh, Int) e)
input

{-# INLINE scan1Op #-}
scan1Op
    :: HasCallStack
    => ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array (sh, Int) e))
scan1Op :: ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scan1Op ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Array (sh, Int) e)
input@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> (sh
_, Int
n))
  = String
-> HasInitialValue
-> Par PTX (Future (Array (sh, Int) e))
-> Par PTX (Future (Array (sh, Int) e))
forall a. HasCallStack => String -> HasInitialValue -> a -> a
boundsCheck String
"empty array" (Int
n Int -> Int -> HasInitialValue
forall a. Ord a => a -> a -> HasInitialValue
> Int
0)
  (Par PTX (Future (Array (sh, Int) e))
 -> Par PTX (Future (Array (sh, Int) e)))
-> Par PTX (Future (Array (sh, Int) e))
-> Par PTX (Future (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scanCore ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Int
n Delayed (Array (sh, Int) e)
input

{-# INLINE scanCore #-}
scanCore
    :: HasCallStack
    => ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Int                    -- output size of innermost dimension
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array (sh, Int) e))
scanCore :: ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scanCore ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Int
m Delayed (Array (sh, Int) e)
input
  | ArrayR (ShapeRsnoc ShapeR sh1
ShapeRz) TypeR e
tp <- ArrayR (Array (sh, Int) e)
repr
  = TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Vector e)
-> Par PTX (Future (Vector e))
forall e aenv.
HasCallStack =>
TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Vector e)
-> Par PTX (Future (Vector e))
scanAllOp TypeR e
tp ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Int
m Delayed (Array (sh, Int) e)
Delayed (Vector e)
input
  --
  | HasInitialValue
otherwise
  = ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scanDimOp ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Int
m Delayed (Array (sh, Int) e)
input

{-# INLINE scanAllOp #-}
scanAllOp
    :: HasCallStack
    => TypeR e
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Int                    -- output size
    -> Delayed (Vector e)
    -> Par PTX (Future (Vector e))
scanAllOp :: TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Vector e)
-> Par PTX (Future (Vector e))
scanAllOp TypeR e
tp ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Int
m input :: Delayed (Vector e)
input@(Delayed (Vector e) -> ((), Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> ((), Int
n)) =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Vector e)))
-> Par PTX (Future (Vector e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Vector e)))
 -> Par PTX (Future (Vector e)))
-> (FunctionTable -> Par PTX (Future (Vector e)))
-> Par PTX (Future (Vector e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    let
        k1 :: Kernel
k1  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scanP1"
        k2 :: Kernel
k2  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scanP2"
        k3 :: Kernel
k3  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scanP3"
        --
        c :: Int
c   = Kernel -> Int
kernelThreadBlockSize Kernel
k1
        s :: Int
s   = Int
n Int -> Int -> Int
`multipleOf` Int
c
        --
        repr :: ArrayR (Vector e)
repr = ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
tp
        paramR :: TupR (ParamR PTX) (Vector e)
paramR = ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e))
-> ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Vector e) -> ParamR PTX (Vector e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Vector e)
repr
        paramsR1 :: TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
paramsR1 = TupR (ParamR PTX) (Vector e)
paramR TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e, Vector e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Vector e)
paramR TupR (ParamR PTX) (Vector e, Vector e)
-> TupR (ParamR PTX) (Maybe (Vector e))
-> TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Vector e))
-> TupR (ParamR PTX) (Maybe (Vector e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e)))
-> ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Vector e) -> ParamR PTX (Vector e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Vector e)
repr)
        paramsR3 :: TupR (ParamR PTX) ((Vector e, Vector e), Int)
paramsR3 = TupR (ParamR PTX) (Vector e)
paramR TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e, Vector e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Vector e)
paramR TupR (ParamR PTX) (Vector e, Vector e)
-> TupR (ParamR PTX) Int
-> TupR (ParamR PTX) ((Vector e, Vector e), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX Int -> TupR (ParamR PTX) Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ParamR PTX Int
forall arch. ParamR arch Int
ParamRint
    --
    Future (Vector e)
future  <- Par PTX (Future (Vector e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Vector e
result  <- ArrayR (Vector e) -> ((), Int) -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Vector e)
repr ((), Int
m)

    -- Step 1: Independent thread-block-wide scans of the input. Small arrays
    -- which can be computed by a single thread block will require no
    -- additional work.
    Vector e
tmp     <- ArrayR (Vector e) -> ((), Int) -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Vector e)
repr ((), Int
s)
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
-> ((Vector e, Vector e), Maybe (Vector e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
k1 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
s) TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
paramsR1 ((Vector e
tmp, Vector e
result), Delayed (Vector e) -> Maybe (Vector e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Vector e)
input)

    -- Step 2: Multi-block reductions need to compute the per-block prefix,
    -- then apply those values to the partial results.
    HasInitialValue -> Par PTX () -> Par PTX ()
forall (f :: * -> *).
Applicative f =>
HasInitialValue -> f () -> f ()
when (Int
s Int -> Int -> HasInitialValue
forall a. Ord a => a -> a -> HasInitialValue
> Int
1) (Par PTX () -> Par PTX ()) -> Par PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do
      Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) (Vector e)
-> Vector e
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
k2 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
s)   TupR (ParamR PTX) (Vector e)
paramR Vector e
tmp
      Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) ((Vector e, Vector e), Int)
-> ((Vector e, Vector e), Int)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
k3 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) TupR (ParamR PTX) ((Vector e, Vector e), Int)
paramsR3 ((Vector e
tmp, Vector e
result), Int
c)

    FutureR PTX (Vector e) -> Vector e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Vector e)
Future (Vector e)
future Vector e
result
    Future (Vector e) -> Par PTX (Future (Vector e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Vector e)
future

{-# INLINE scanDimOp #-}
scanDimOp
    :: HasCallStack
    => ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Int
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array (sh, Int) e))
scanDimOp :: ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Int
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e))
scanDimOp ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Int
m input :: Delayed (Array (sh, Int) e)
input@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> (sh
sz, Int
_)) =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array (sh, Int) e)))
-> Par PTX (Future (Array (sh, Int) e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array (sh, Int) e)))
 -> Par PTX (Future (Array (sh, Int) e)))
-> (FunctionTable -> Par PTX (Future (Array (sh, Int) e)))
-> Par PTX (Future (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    let ArrayR (ShapeRsnoc ShapeR sh1
shr') TypeR e
_ = ArrayR (Array (sh, Int) e)
repr
    Future (Array (sh, Int) e)
future  <- Par PTX (Future (Array (sh, Int) e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array (sh, Int) e
result  <- ArrayR (Array (sh, Int) e)
-> (sh, Int) -> Par PTX (Array (sh, Int) e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array (sh, Int) e)
repr (sh
sz, Int
m)
    let paramsR :: TupR (ParamR PTX) (Array (sh, Int) e, Maybe (Array (sh, Int) e))
paramsR = ParamR PTX (Array (sh, Int) e)
-> TupR (ParamR PTX) (Array (sh, Int) e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array (sh, Int) e)
repr) TupR (ParamR PTX) (Array (sh, Int) e)
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Array (sh, Int) e, Maybe (Array (sh, Int) e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array (sh, Int) e)
 -> ParamR PTX (Maybe (Array (sh, Int) e)))
-> ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array (sh, Int) e)
repr)
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) (Array (sh, Int) e, Maybe (Array (sh, Int) e))
-> (Array (sh, Int) e, Maybe (Array (sh, Int) e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp (FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scan") Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shr' sh
sz) TupR (ParamR PTX) (Array (sh, Int) e, Maybe (Array (sh, Int) e))
paramsR (Array (sh, Int) e
result, Delayed (Array (sh, Int) e) -> Maybe (Array (sh, Int) e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array (sh, Int) e)
input)
    FutureR PTX (Array (sh, Int) e) -> Array (sh, Int) e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array (sh, Int) e)
Future (Array (sh, Int) e)
future Array (sh, Int) e
result
    Future (Array (sh, Int) e) -> Par PTX (Future (Array (sh, Int) e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array (sh, Int) e)
future


{-# INLINE scan'Op #-}
scan'Op
    :: HasCallStack
    => ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'Op :: ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'Op ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Array (sh, Int) e)
input@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> (sh
sz, Int
n)) =
  case Int
n of
    Int
0 -> do
      Future (Array (sh, Int) e, Array sh e)
future  <- Par PTX (Future (Array (sh, Int) e, Array sh e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
      Array (sh, Int) e
result  <- ArrayR (Array (sh, Int) e)
-> (sh, Int) -> Par PTX (Array (sh, Int) e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array (sh, Int) e)
repr (sh
sz, Int
0)
      Future (Array sh e)
sums    <- ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> Par PTX (Future (Array sh e))
generateOp (ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
forall sh e. ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank ArrayR (Array (sh, Int) e)
repr) ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh
sz
      Par PTX () -> Par PTX ()
forall arch.
(Async arch, HasCallStack) =>
Par arch () -> Par arch ()
fork (Par PTX () -> Par PTX ()) -> Par PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do Array sh e
sums' <- FutureR PTX (Array sh e) -> Par PTX (Array sh e)
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> Par arch a
get FutureR PTX (Array sh e)
Future (Array sh e)
sums
                FutureR PTX (Array (sh, Int) e, Array sh e)
-> (Array (sh, Int) e, Array sh e) -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array (sh, Int) e, Array sh e)
Future (Array (sh, Int) e, Array sh e)
future (Array (sh, Int) e
result, Array sh e
sums')
      Future (Array (sh, Int) e, Array sh e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array (sh, Int) e, Array sh e)
future
    --
    Int
_ -> ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'Core ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
input

{-# INLINE scan'Core #-}
scan'Core
    :: HasCallStack
    => ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'Core :: ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'Core ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
input
  | ArrayR (ShapeRsnoc ShapeR sh1
ShapeRz) TypeR e
tp <- ArrayR (Array (sh, Int) e)
repr
  = TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Vector e)
-> Par PTX (Future (Vector e, Scalar e))
forall e aenv.
HasCallStack =>
TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Vector e)
-> Par PTX (Future (Vector e, Scalar e))
scan'AllOp TypeR e
tp ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
Delayed (Vector e)
input
  --
  | HasInitialValue
otherwise
  = ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
forall sh e aenv.
HasCallStack =>
ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'DimOp ArrayR (Array (sh, Int) e)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv Delayed (Array (sh, Int) e)
input

{-# INLINE scan'AllOp #-}
scan'AllOp
    :: HasCallStack
    => TypeR e
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Vector e)
    -> Par PTX (Future (Vector e, Scalar e))
scan'AllOp :: TypeR e
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Vector e)
-> Par PTX (Future (Vector e, Scalar e))
scan'AllOp TypeR e
tp ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Vector e)
input@(Delayed (Vector e) -> ((), Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> ((), Int
n)) =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Vector e, Scalar e)))
-> Par PTX (Future (Vector e, Scalar e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Vector e, Scalar e)))
 -> Par PTX (Future (Vector e, Scalar e)))
-> (FunctionTable -> Par PTX (Future (Vector e, Scalar e)))
-> Par PTX (Future (Vector e, Scalar e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    let
        repr :: ArrayR (Vector e)
repr = ShapeR ((), Int) -> TypeR e -> ArrayR (Vector e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 TypeR e
tp
        paramRdim0 :: TupR (ParamR PTX) (Scalar e)
paramRdim0 = ParamR PTX (Scalar e) -> TupR (ParamR PTX) (Scalar e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Scalar e) -> TupR (ParamR PTX) (Scalar e))
-> ParamR PTX (Scalar e) -> TupR (ParamR PTX) (Scalar e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Scalar e) -> ParamR PTX (Scalar e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Scalar e) -> ParamR PTX (Scalar e))
-> ArrayR (Scalar e) -> ParamR PTX (Scalar e)
forall a b. (a -> b) -> a -> b
$ ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
tp
        paramRdim1 :: TupR (ParamR PTX) (Vector e)
paramRdim1 = ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e))
-> ParamR PTX (Vector e) -> TupR (ParamR PTX) (Vector e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Vector e) -> ParamR PTX (Vector e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Vector e)
repr
        k1 :: Kernel
k1  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scanP1"
        k2 :: Kernel
k2  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scanP2"
        k3 :: Kernel
k3  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scanP3"
        --
        c :: Int
c   = Kernel -> Int
kernelThreadBlockSize Kernel
k1
        s :: Int
s   = Int
n Int -> Int -> Int
`multipleOf` Int
c
    --
    Future (Vector e, Scalar e)
future  <- Par PTX (Future (Vector e, Scalar e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Vector e
result  <- ArrayR (Vector e) -> ((), Int) -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Vector e)
repr ((), Int
n)
    Vector e
tmp     <- ArrayR (Vector e) -> ((), Int) -> Par PTX (Vector e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Vector e)
repr ((), Int
s)

    -- Step 1: independent thread-block-wide scans. Each block stores its partial
    -- sum to a temporary array.
    let paramsR1 :: TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
paramsR1 = TupR (ParamR PTX) (Vector e)
paramRdim1 TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e, Vector e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Vector e)
paramRdim1 TupR (ParamR PTX) (Vector e, Vector e)
-> TupR (ParamR PTX) (Maybe (Vector e))
-> TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Vector e))
-> TupR (ParamR PTX) (Maybe (Vector e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e)))
-> ParamR PTX (Vector e) -> ParamR PTX (Maybe (Vector e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Vector e) -> ParamR PTX (Vector e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Vector e)
repr)
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
-> ((Vector e, Vector e), Maybe (Vector e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
k1 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
s) TupR (ParamR PTX) ((Vector e, Vector e), Maybe (Vector e))
paramsR1 ((Vector e
tmp, Vector e
result), Delayed (Vector e) -> Maybe (Vector e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Vector e)
input)

    -- If this was a small array that was processed by a single thread block then
    -- we are done, otherwise compute the per-block prefix and apply those values
    -- to the partial results.
    if Int
s Int -> Int -> HasInitialValue
forall a. Eq a => a -> a -> HasInitialValue
== Int
1
      then
        case Vector e
tmp of
          Array ((), Int)
_ ArrayData e
ad -> FutureR PTX (Vector e, Scalar e)
-> (Vector e, Scalar e) -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Vector e, Scalar e)
Future (Vector e, Scalar e)
future (Vector e
result, () -> ArrayData e -> Scalar e
forall sh e. sh -> ArrayData e -> Array sh e
Array () ArrayData e
ad)

      else do
        Scalar e
sums <- ArrayR (Scalar e) -> () -> Par PTX (Scalar e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote (ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
tp) ()
        let paramsR2 :: TupR (ParamR PTX) (Vector e, Scalar e)
paramsR2 = TupR (ParamR PTX) (Vector e)
paramRdim1 TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Scalar e)
-> TupR (ParamR PTX) (Vector e, Scalar e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Scalar e)
paramRdim0
        let paramsR3 :: TupR (ParamR PTX) ((Vector e, Vector e), Int)
paramsR3 = TupR (ParamR PTX) (Vector e)
paramRdim1 TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e)
-> TupR (ParamR PTX) (Vector e, Vector e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Vector e)
paramRdim1 TupR (ParamR PTX) (Vector e, Vector e)
-> TupR (ParamR PTX) Int
-> TupR (ParamR PTX) ((Vector e, Vector e), Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX Int -> TupR (ParamR PTX) Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ParamR PTX Int
forall arch. ParamR arch Int
ParamRint
        Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) (Vector e, Scalar e)
-> (Vector e, Scalar e)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
k2 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
s)   TupR (ParamR PTX) (Vector e, Scalar e)
paramsR2 (Vector e
tmp, Scalar e
sums)
        Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) ((Vector e, Vector e), Int)
-> ((Vector e, Vector e), Int)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
k3 Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) TupR (ParamR PTX) ((Vector e, Vector e), Int)
paramsR3 ((Vector e
tmp, Vector e
result), Int
c)
        FutureR PTX (Vector e, Scalar e)
-> (Vector e, Scalar e) -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Vector e, Scalar e)
Future (Vector e, Scalar e)
future (Vector e
result, Scalar e
sums)
    --
    Future (Vector e, Scalar e)
-> Par PTX (Future (Vector e, Scalar e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Vector e, Scalar e)
future

{-# INLINE scan'DimOp #-}
scan'DimOp
    :: HasCallStack
    => ArrayR (Array (sh, Int) e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array (sh, Int) e)
    -> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'DimOp :: ArrayR (Array (sh, Int) e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array (sh, Int) e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
scan'DimOp repr :: ArrayR (Array (sh, Int) e)
repr@(ArrayR (ShapeRsnoc ShapeR sh1
shr') TypeR e
_) ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Array (sh, Int) e)
input@(Delayed (Array (sh, Int) e) -> (sh, Int)
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> sh :: (sh, Int)
sh@(sh
sz, Int
_)) =
  ExecutableR PTX
-> (FunctionTable
    -> Par PTX (Future (Array (sh, Int) e, Array sh e)))
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable
  -> Par PTX (Future (Array (sh, Int) e, Array sh e)))
 -> Par PTX (Future (Array (sh, Int) e, Array sh e)))
-> (FunctionTable
    -> Par PTX (Future (Array (sh, Int) e, Array sh e)))
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    Future (Array (sh, Int) e, Array sh e)
future  <- Par PTX (Future (Array (sh, Int) e, Array sh e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array (sh, Int) e
result  <- ArrayR (Array (sh, Int) e)
-> (sh, Int) -> Par PTX (Array (sh, Int) e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array (sh, Int) e)
repr (sh, Int)
sh
    Array sh e
sums    <- ArrayR (Array sh e) -> sh -> Par PTX (Array sh e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote (ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
forall sh e. ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank ArrayR (Array (sh, Int) e)
repr) sh
sz
    let paramsR :: TupR
  (ParamR PTX)
  ((Array (sh, Int) e, Array sh e), Maybe (Array (sh, Int) e))
paramsR = ParamR PTX (Array (sh, Int) e)
-> TupR (ParamR PTX) (Array (sh, Int) e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array (sh, Int) e)
repr) TupR (ParamR PTX) (Array (sh, Int) e)
-> TupR (ParamR PTX) (Array sh e)
-> TupR (ParamR PTX) (Array (sh, Int) e, Array sh e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Array sh e) -> TupR (ParamR PTX) (Array sh e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh e) -> ParamR PTX (Array sh e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Array sh e) -> ParamR PTX (Array sh e))
-> ArrayR (Array sh e) -> ParamR PTX (Array sh e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
forall sh e. ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank ArrayR (Array (sh, Int) e)
repr) TupR (ParamR PTX) (Array (sh, Int) e, Array sh e)
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
-> TupR
     (ParamR PTX)
     ((Array (sh, Int) e, Array sh e), Maybe (Array (sh, Int) e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Array (sh, Int) e))
-> TupR (ParamR PTX) (Maybe (Array (sh, Int) e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array (sh, Int) e)
 -> ParamR PTX (Maybe (Array (sh, Int) e)))
-> ParamR PTX (Array (sh, Int) e)
-> ParamR PTX (Maybe (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array (sh, Int) e) -> ParamR PTX (Array (sh, Int) e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array (sh, Int) e)
repr)
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR
     (ParamR PTX)
     ((Array (sh, Int) e, Array sh e), Maybe (Array (sh, Int) e))
-> ((Array (sh, Int) e, Array sh e), Maybe (Array (sh, Int) e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp (FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"scan") Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), ShapeR sh1 -> sh1 -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh1
shr' sh
sh1
sz) TupR
  (ParamR PTX)
  ((Array (sh, Int) e, Array sh e), Maybe (Array (sh, Int) e))
paramsR ((Array (sh, Int) e
result, Array sh e
sums), Delayed (Array (sh, Int) e) -> Maybe (Array (sh, Int) e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array (sh, Int) e)
input)
    FutureR PTX (Array (sh, Int) e, Array sh e)
-> (Array (sh, Int) e, Array sh e) -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array (sh, Int) e, Array sh e)
Future (Array (sh, Int) e, Array sh e)
future (Array (sh, Int) e
result, Array sh e
sums)
    Future (Array (sh, Int) e, Array sh e)
-> Par PTX (Future (Array (sh, Int) e, Array sh e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array (sh, Int) e, Array sh e)
future


{-# INLINE permuteOp #-}
permuteOp
    :: HasCallStack
    => Bool
    -> ArrayR (Array sh e)
    -> ShapeR sh'
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Array sh' e
    -> Delayed (Array sh e)
    -> Par PTX (Future (Array sh' e))
permuteOp :: HasInitialValue
-> ArrayR (Array sh e)
-> ShapeR sh'
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Array sh' e
-> Delayed (Array sh e)
-> Par PTX (Future (Array sh' e))
permuteOp HasInitialValue
inplace repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
tp) ShapeR sh'
shr' ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv defaults :: Array sh' e
defaults@(Array sh' e -> sh'
forall sh e. Array sh e -> sh
shape -> sh'
shOut) input :: Delayed (Array sh e)
input@(Delayed (Array sh e) -> sh
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> sh
shIn) =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array sh' e)))
-> Par PTX (Future (Array sh' e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array sh' e)))
 -> Par PTX (Future (Array sh' e)))
-> (FunctionTable -> Par PTX (Future (Array sh' e)))
-> Par PTX (Future (Array sh' e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    let
        n :: Int
n        = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shr  sh
sh
shIn
        m :: Int
m        = ShapeR sh' -> sh' -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh'
shr' sh'
shOut
        repr' :: ArrayR (Array sh' e)
repr'    = ShapeR sh' -> TypeR e -> ArrayR (Array sh' e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh'
shr' TypeR e
tp
        reprLock :: ArrayR (Array ((), Int) Word32)
reprLock = ShapeR ((), Int) -> TypeR Word32 -> ArrayR (Array ((), Int) Word32)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ((), Int)
dim1 (TypeR Word32 -> ArrayR (Array ((), Int) Word32))
-> TypeR Word32 -> ArrayR (Array ((), Int) Word32)
forall a b. (a -> b) -> a -> b
$ ScalarType Word32 -> TypeR Word32
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ScalarType Word32 -> TypeR Word32)
-> ScalarType Word32 -> TypeR Word32
forall a b. (a -> b) -> a -> b
$ ScalarType Word32
scalarTypeWord32
        paramR :: TupR (ParamR PTX) (Maybe (Array sh e))
paramR   = ParamR PTX (Maybe (Array sh e))
-> TupR (ParamR PTX) (Maybe (Array sh e))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Maybe (Array sh e))
 -> TupR (ParamR PTX) (Maybe (Array sh e)))
-> ParamR PTX (Maybe (Array sh e))
-> TupR (ParamR PTX) (Maybe (Array sh e))
forall a b. (a -> b) -> a -> b
$ ParamR PTX (Array sh e) -> ParamR PTX (Maybe (Array sh e))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array sh e) -> ParamR PTX (Maybe (Array sh e)))
-> ParamR PTX (Array sh e) -> ParamR PTX (Maybe (Array sh e))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh e) -> ParamR PTX (Array sh e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh e)
repr
        paramR' :: TupR (ParamR PTX) (Array sh' e)
paramR'  = ParamR PTX (Array sh' e) -> TupR (ParamR PTX) (Array sh' e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array sh' e) -> TupR (ParamR PTX) (Array sh' e))
-> ParamR PTX (Array sh' e) -> TupR (ParamR PTX) (Array sh' e)
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh' e) -> ParamR PTX (Array sh' e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh' e)
repr'
        kernel :: Kernel
kernel   = case FunctionTable -> [Kernel]
functionTable FunctionTable
ptxExecutable of
                      Kernel
k:[Kernel]
_ -> Kernel
k
                      [Kernel]
_   -> String -> Kernel
forall a. HasCallStack => String -> a
internalError String
"no kernels found"
    --
    Future (Array sh' e)
future  <- Par PTX (Future (Array sh' e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array sh' e
result  <- if HasInitialValue
inplace
                 then Flag -> String -> Par PTX (Array sh' e) -> Par PTX (Array sh' e)
forall a. Flag -> String -> a -> a
Debug.trace Flag
Debug.dump_exec String
"exec: permute/inplace" (Par PTX (Array sh' e) -> Par PTX (Array sh' e))
-> Par PTX (Array sh' e) -> Par PTX (Array sh' e)
forall a b. (a -> b) -> a -> b
$ Array sh' e -> Par PTX (Array sh' e)
forall (m :: * -> *) a. Monad m => a -> m a
return Array sh' e
defaults
                 else Flag -> String -> Par PTX (Array sh' e) -> Par PTX (Array sh' e)
forall a. Flag -> String -> a -> a
Debug.trace Flag
Debug.dump_exec String
"exec: permute/clone"   (Par PTX (Array sh' e) -> Par PTX (Array sh' e))
-> Par PTX (Array sh' e) -> Par PTX (Array sh' e)
forall a b. (a -> b) -> a -> b
$ Future (Array sh' e) -> Par PTX (Array sh' e)
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> Par arch a
get (Future (Array sh' e) -> Par PTX (Array sh' e))
-> Par PTX (Future (Array sh' e)) -> Par PTX (Array sh' e)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ArrayR (Array sh' e)
-> Array sh' e -> Par PTX (Future (Array sh' e))
forall sh e.
ArrayR (Array sh e) -> Array sh e -> Par PTX (Future (Array sh e))
cloneArrayAsync ArrayR (Array sh' e)
repr' Array sh' e
Array sh' e
defaults
    --
    case Kernel -> ShortByteString
kernelName Kernel
kernel of
      -- execute directly using atomic operations
      ShortByteString
"permute_rmw"   ->
        let paramsR :: TupR (ParamR PTX) (Array sh' e, Maybe (Array sh e))
paramsR = TupR (ParamR PTX) (Array sh' e)
paramR' TupR (ParamR PTX) (Array sh' e)
-> TupR (ParamR PTX) (Maybe (Array sh e))
-> TupR (ParamR PTX) (Array sh' e, Maybe (Array sh e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Maybe (Array sh e))
paramR
        in  Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR (ParamR PTX) (Array sh' e, Maybe (Array sh e))
-> (Array sh' e, Maybe (Array sh e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
kernel Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
n) TupR (ParamR PTX) (Array sh' e, Maybe (Array sh e))
paramsR (Array sh' e
Array sh' e
result, Delayed (Array sh e) -> Maybe (Array sh e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array sh e)
input)

      -- a temporary array is required for spin-locks around the critical section
      ShortByteString
"permute_mutex" -> do
        Future (Array ((), Int) Word32)
barrier     <- Par PTX (Future (Array ((), Int) Word32))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new :: Par PTX (Future (Vector Word32))
        Array ((), Int)
_ ArrayData Word32
ad  <- ArrayR (Array ((), Int) Word32)
-> ((), Int) -> Par PTX (Array ((), Int) Word32)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array ((), Int) Word32)
reprLock ((), Int
m)
        Par PTX () -> Par PTX ()
forall arch.
(Async arch, HasCallStack) =>
Par arch () -> Par arch ()
fork (Par PTX () -> Par PTX ()) -> Par PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do Future (UniqueArray Word32)
fill <- SingleType Word32
-> Int
-> ScalarArrayDataR Word32
-> ArrayData Word32
-> Par PTX (Future (ArrayData Word32))
forall e.
HasCallStack =>
SingleType e
-> Int
-> ScalarArrayDataR e
-> ArrayData e
-> Par PTX (Future (ArrayData e))
memsetArrayAsync (NumType Word32 -> SingleType Word32
forall a. NumType a -> SingleType a
NumSingleType (NumType Word32 -> SingleType Word32)
-> NumType Word32 -> SingleType Word32
forall a b. (a -> b) -> a -> b
$ IntegralType Word32 -> NumType Word32
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType Word32
TypeWord32) Int
m ScalarArrayDataR Word32
0 ArrayData Word32
ad
                  FutureR PTX (Array ((), Int) Word32)
-> Array ((), Int) Word32 -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array ((), Int) Word32)
Future (Array ((), Int) Word32)
barrier (Array ((), Int) Word32 -> Par PTX ())
-> (UniqueArray Word32 -> Array ((), Int) Word32)
-> UniqueArray Word32
-> Par PTX ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Int) -> ArrayData Word32 -> Array ((), Int) Word32
forall sh e. sh -> ArrayData e -> Array sh e
Array ((), Int
m) (UniqueArray Word32 -> Par PTX ())
-> Par PTX (UniqueArray Word32) -> Par PTX ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FutureR PTX (UniqueArray Word32) -> Par PTX (UniqueArray Word32)
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> Par arch a
get FutureR PTX (UniqueArray Word32)
Future (UniqueArray Word32)
fill
        --
        let paramsR :: TupR
  (ParamR PTX)
  ((Array sh' e, Future (Array ((), Int) Word32)),
   Maybe (Array sh e))
paramsR = TupR (ParamR PTX) (Array sh' e)
paramR' TupR (ParamR PTX) (Array sh' e)
-> TupR (ParamR PTX) (Future (Array ((), Int) Word32))
-> TupR (ParamR PTX) (Array sh' e, Future (Array ((), Int) Word32))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Future (Array ((), Int) Word32))
-> TupR (ParamR PTX) (Future (Array ((), Int) Word32))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array ((), Int) Word32)
-> ParamR PTX (FutureR PTX (Array ((), Int) Word32))
forall arch a1. ParamR arch a1 -> ParamR arch (FutureR arch a1)
ParamRfuture (ParamR PTX (Array ((), Int) Word32)
 -> ParamR PTX (FutureR PTX (Array ((), Int) Word32)))
-> ParamR PTX (Array ((), Int) Word32)
-> ParamR PTX (FutureR PTX (Array ((), Int) Word32))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array ((), Int) Word32)
-> ParamR PTX (Array ((), Int) Word32)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array ((), Int) Word32)
reprLock) TupR (ParamR PTX) (Array sh' e, Future (Array ((), Int) Word32))
-> TupR (ParamR PTX) (Maybe (Array sh e))
-> TupR
     (ParamR PTX)
     ((Array sh' e, Future (Array ((), Int) Word32)),
      Maybe (Array sh e))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` TupR (ParamR PTX) (Maybe (Array sh e))
paramR
        Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR ((), Int)
-> ((), Int)
-> TupR
     (ParamR PTX)
     ((Array sh' e, Future (Array ((), Int) Word32)),
      Maybe (Array sh e))
-> ((Array sh' e, Future (Array ((), Int) Word32)),
    Maybe (Array sh e))
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
kernel Gamma aenv
gamma Val aenv
aenv ShapeR ((), Int)
dim1 ((), Int
n) TupR
  (ParamR PTX)
  ((Array sh' e, Future (Array ((), Int) Word32)),
   Maybe (Array sh e))
paramsR ((Array sh' e
Array sh' e
result, Future (Array ((), Int) Word32)
barrier), Delayed (Array sh e) -> Maybe (Array sh e)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array sh e)
input)

      ShortByteString
_               -> String -> Par PTX ()
forall a. HasCallStack => String -> a
internalError String
"unexpected kernel image"
    --
    FutureR PTX (Array sh' e) -> Array sh' e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array sh' e)
Future (Array sh' e)
future Array sh' e
result
    Future (Array sh' e) -> Par PTX (Future (Array sh' e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array sh' e)
future


{-# INLINE stencil1Op #-}
stencil1Op
    :: HasCallStack
    => TypeR a
    -> ArrayR (Array sh b)
    -> sh
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array sh a)
    -> Par PTX (Future (Array sh b))
stencil1Op :: TypeR a
-> ArrayR (Array sh b)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array sh a)
-> Par PTX (Future (Array sh b))
stencil1Op TypeR a
tp repr :: ArrayR (Array sh b)
repr@(ArrayR ShapeR sh
shr TypeR e
_) sh
halo ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input :: Delayed (Array sh a)
input@(Delayed (Array sh a) -> sh
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> sh
sh) =
  ArrayR (Array sh b)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> sh
-> ParamsR PTX (Maybe (Array sh a))
-> Maybe (Array sh a)
-> Par PTX (Future (Array sh b))
forall aenv sh e params.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX (Future (Array sh e))
stencilCore ArrayR (Array sh b)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh
halo sh
sh ParamsR PTX (Maybe (Array sh a))
paramsR (Delayed (Array sh a) -> Maybe (Array sh a)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array sh a)
input)
  where paramsR :: ParamsR PTX (Maybe (Array sh a))
paramsR = ParamR PTX (Maybe (Array sh a)) -> ParamsR PTX (Maybe (Array sh a))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Maybe (Array sh a))
 -> ParamsR PTX (Maybe (Array sh a)))
-> ParamR PTX (Maybe (Array sh a))
-> ParamsR PTX (Maybe (Array sh a))
forall a b. (a -> b) -> a -> b
$ ParamR PTX (Array sh a) -> ParamR PTX (Maybe (Array sh a))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array sh a) -> ParamR PTX (Maybe (Array sh a)))
-> ParamR PTX (Array sh a) -> ParamR PTX (Maybe (Array sh a))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh a) -> ParamR PTX (Array sh a)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Array sh a) -> ParamR PTX (Array sh a))
-> ArrayR (Array sh a) -> ParamR PTX (Array sh a)
forall a b. (a -> b) -> a -> b
$ ShapeR sh -> TypeR a -> ArrayR (Array sh a)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shr TypeR a
tp

-- Using the defaulting instances for stencil operations (for now).
--
{-# INLINE stencil2Op #-}
stencil2Op
    :: HasCallStack
    => TypeR a
    -> TypeR b
    -> ArrayR (Array sh c)
    -> sh
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> Delayed (Array sh a)
    -> Delayed (Array sh b)
    -> Par PTX (Future (Array sh c))
stencil2Op :: TypeR a
-> TypeR b
-> ArrayR (Array sh c)
-> sh
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> Delayed (Array sh a)
-> Delayed (Array sh b)
-> Par PTX (Future (Array sh c))
stencil2Op TypeR a
tpA TypeR b
tpB repr :: ArrayR (Array sh c)
repr@(ArrayR ShapeR sh
shr TypeR e
_) sh
halo ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv input1 :: Delayed (Array sh a)
input1@(Delayed (Array sh a) -> sh
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> sh
sh1) input2 :: Delayed (Array sh b)
input2@(Delayed (Array sh b) -> sh
forall sh e. Delayed (Array sh e) -> sh
delayedShape -> sh
sh2) =
  ArrayR (Array sh c)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> sh
-> ParamsR PTX (Maybe (Array sh a), Maybe (Array sh b))
-> (Maybe (Array sh a), Maybe (Array sh b))
-> Par PTX (Future (Array sh c))
forall aenv sh e params.
HasCallStack =>
ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX (Future (Array sh e))
stencilCore ArrayR (Array sh c)
repr ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh
halo (ShapeR sh -> sh -> sh -> sh
forall sh. ShapeR sh -> sh -> sh -> sh
intersect (ArrayR (Array sh c) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh c)
repr) sh
sh1 sh
sh2) ParamsR PTX (Maybe (Array sh a), Maybe (Array sh b))
paramsR (Delayed (Array sh a) -> Maybe (Array sh a)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array sh a)
input1, Delayed (Array sh b) -> Maybe (Array sh b)
forall sh e. Delayed (Array sh e) -> Maybe (Array sh e)
manifest Delayed (Array sh b)
input2)
  where paramsR :: ParamsR PTX (Maybe (Array sh a), Maybe (Array sh b))
paramsR = ParamR PTX (Maybe (Array sh a))
-> TupR (ParamR PTX) (Maybe (Array sh a))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array sh a) -> ParamR PTX (Maybe (Array sh a))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array sh a) -> ParamR PTX (Maybe (Array sh a)))
-> ParamR PTX (Array sh a) -> ParamR PTX (Maybe (Array sh a))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh a) -> ParamR PTX (Array sh a)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Array sh a) -> ParamR PTX (Array sh a))
-> ArrayR (Array sh a) -> ParamR PTX (Array sh a)
forall a b. (a -> b) -> a -> b
$ ShapeR sh -> TypeR a -> ArrayR (Array sh a)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shr TypeR a
tpA) TupR (ParamR PTX) (Maybe (Array sh a))
-> TupR (ParamR PTX) (Maybe (Array sh b))
-> ParamsR PTX (Maybe (Array sh a), Maybe (Array sh b))
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Maybe (Array sh b))
-> TupR (ParamR PTX) (Maybe (Array sh b))
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ParamR PTX (Array sh b) -> ParamR PTX (Maybe (Array sh b))
forall arch a1. ParamR arch a1 -> ParamR arch (Maybe a1)
ParamRmaybe (ParamR PTX (Array sh b) -> ParamR PTX (Maybe (Array sh b)))
-> ParamR PTX (Array sh b) -> ParamR PTX (Maybe (Array sh b))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array sh b) -> ParamR PTX (Array sh b)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray (ArrayR (Array sh b) -> ParamR PTX (Array sh b))
-> ArrayR (Array sh b) -> ParamR PTX (Array sh b)
forall a b. (a -> b) -> a -> b
$ ShapeR sh -> TypeR b -> ArrayR (Array sh b)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shr TypeR b
tpB)

{-# INLINE stencilCore #-}
stencilCore
    :: forall aenv sh e params. HasCallStack
    => ArrayR (Array sh e)
    -> ExecutableR PTX
    -> Gamma aenv
    -> Val aenv
    -> sh                       -- border dimensions (i.e. index of first interior element)
    -> sh                       -- output array size
    -> ParamsR PTX params
    -> params
    -> Par PTX (Future (Array sh e))
stencilCore :: ArrayR (Array sh e)
-> ExecutableR PTX
-> Gamma aenv
-> Val aenv
-> sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX (Future (Array sh e))
stencilCore repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
_) ExecutableR PTX
exe Gamma aenv
gamma Val aenv
aenv sh
halo sh
shOut ParamsR PTX params
paramsR params
params =
  ExecutableR PTX
-> (FunctionTable -> Par PTX (Future (Array sh e)))
-> Par PTX (Future (Array sh e))
forall b.
HasCallStack =>
ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable ExecutableR PTX
exe ((FunctionTable -> Par PTX (Future (Array sh e)))
 -> Par PTX (Future (Array sh e)))
-> (FunctionTable -> Par PTX (Future (Array sh e)))
-> Par PTX (Future (Array sh e))
forall a b. (a -> b) -> a -> b
$ \FunctionTable
ptxExecutable -> do
    let
        inside :: Kernel
inside  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"stencil_inside"
        border :: Kernel
border  = FunctionTable
ptxExecutable HasCallStack => FunctionTable -> ShortByteString -> Kernel
FunctionTable -> ShortByteString -> Kernel
!# ShortByteString
"stencil_border"

        shIn :: sh
        shIn :: sh
shIn = (Int -> Int -> Int) -> sh -> sh -> sh
trav (\Int
x Int
y -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
y) sh
shOut sh
halo

        trav :: (Int -> Int -> Int) -> sh -> sh -> sh
        trav :: (Int -> Int -> Int) -> sh -> sh -> sh
trav Int -> Int -> Int
f sh
a sh
b = ShapeR sh -> sh -> sh -> sh
forall sh. ShapeR sh -> sh -> sh -> sh
go (ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr) sh
a sh
b
          where
            go :: ShapeR t -> t -> t -> t
            go :: ShapeR t -> t -> t -> t
go ShapeR t
ShapeRz           ()      ()      = ()
            go (ShapeRsnoc ShapeR sh1
shr') (xa,xb) (ya,yb) = (ShapeR sh1 -> sh1 -> sh1 -> sh1
forall sh. ShapeR sh -> sh -> sh -> sh
go ShapeR sh1
shr' sh1
xa sh1
ya, Int -> Int -> Int
f Int
xb Int
yb)
    --
    Future (Array sh e)
future  <- Par PTX (Future (Array sh e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
    Array sh e
result  <- ArrayR (Array sh e) -> sh -> Par PTX (Array sh e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Array sh e)
repr sh
shOut
    Stream
parent  <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream

    -- interior (no bounds checking)
    let paramsRinside :: TupR (ParamR PTX) ((sh, Array sh e), params)
paramsRinside = ParamR PTX sh -> TupR (ParamR PTX) sh
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ShapeR sh -> ParamR PTX sh
forall a arch. ShapeR a -> ParamR arch a
ParamRshape ShapeR sh
shr) TupR (ParamR PTX) sh
-> TupR (ParamR PTX) (Array sh e)
-> TupR (ParamR PTX) (sh, Array sh e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Array sh e) -> TupR (ParamR PTX) (Array sh e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh e) -> ParamR PTX (Array sh e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh e)
repr) TupR (ParamR PTX) (sh, Array sh e)
-> ParamsR PTX params
-> TupR (ParamR PTX) ((sh, Array sh e), params)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamsR PTX params
paramsR
    Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> TupR (ParamR PTX) ((sh, Array sh e), params)
-> ((sh, Array sh e), params)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
inside Gamma aenv
gamma Val aenv
aenv ShapeR sh
shr sh
sh
shIn TupR (ParamR PTX) ((sh, Array sh e), params)
paramsRinside ((sh
sh
shIn, Array sh e
result), params
params)

    -- halo regions (bounds checking)
    -- executed in separate streams so that they might overlap the main stencil
    -- and each other, as individually they will not saturate the device
    [(sh, sh)] -> ((sh, sh) -> Par PTX ()) -> Par PTX ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ShapeR sh -> sh -> sh -> [(sh, sh)]
forall sh. HasCallStack => ShapeR sh -> sh -> sh -> [(sh, sh)]
stencilBorders (ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr) sh
shOut sh
halo) (((sh, sh) -> Par PTX ()) -> Par PTX ())
-> ((sh, sh) -> Par PTX ()) -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ \(sh
u, sh
v) ->
      Par PTX () -> Par PTX ()
forall arch.
(Async arch, HasCallStack) =>
Par arch () -> Par arch ()
fork (Par PTX () -> Par PTX ()) -> Par PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do
        -- launch in a separate stream
        let sh :: sh
sh = (Int -> Int -> Int) -> sh -> sh -> sh
trav (-) sh
v sh
u
        let paramsRborder :: TupR (ParamR PTX) (((sh, sh), Array sh e), params)
paramsRborder = ParamR PTX sh -> TupR (ParamR PTX) sh
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ShapeR sh -> ParamR PTX sh
forall a arch. ShapeR a -> ParamR arch a
ParamRshape ShapeR sh
shr) TupR (ParamR PTX) sh
-> TupR (ParamR PTX) sh -> TupR (ParamR PTX) (sh, sh)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX sh -> TupR (ParamR PTX) sh
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ShapeR sh -> ParamR PTX sh
forall a arch. ShapeR a -> ParamR arch a
ParamRshape ShapeR sh
shr)
                              TupR (ParamR PTX) (sh, sh)
-> TupR (ParamR PTX) (Array sh e)
-> TupR (ParamR PTX) ((sh, sh), Array sh e)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Array sh e) -> TupR (ParamR PTX) (Array sh e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ArrayR (Array sh e) -> ParamR PTX (Array sh e)
forall sh e arch. ArrayR (Array sh e) -> ParamR arch (Array sh e)
ParamRarray ArrayR (Array sh e)
repr)
                              TupR (ParamR PTX) ((sh, sh), Array sh e)
-> ParamsR PTX params
-> TupR (ParamR PTX) (((sh, sh), Array sh e), params)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamsR PTX params
paramsR
        Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> TupR (ParamR PTX) (((sh, sh), Array sh e), params)
-> (((sh, sh), Array sh e), params)
-> Par PTX ()
forall aenv sh params.
HasCallStack =>
Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
border Gamma aenv
gamma Val aenv
aenv ShapeR sh
shr sh
sh
sh TupR (ParamR PTX) (((sh, sh), Array sh e), params)
paramsRborder (((sh
sh
u, sh
sh
sh), Array sh e
result), params
params)

        -- synchronisation with main stream
        Stream
child <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
        Event
event <- LLVM PTX Event -> Par PTX Event
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (Stream -> LLVM PTX Event
Event.waypoint Stream
child)
        HasInitialValue
ready <- IO HasInitialValue -> Par PTX HasInitialValue
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO  (Event -> IO HasInitialValue
Event.query Event
event)
        if HasInitialValue
ready then () -> Par PTX ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                 else IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Event -> Stream -> IO ()
Event.after Event
event Stream
parent)

    FutureR PTX (Array sh e) -> Array sh e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Array sh e)
Future (Array sh e)
future Array sh e
result
    Future (Array sh e) -> Par PTX (Future (Array sh e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Array sh e)
future

-- Compute the stencil border regions, where we may need to evaluate the
-- boundary conditions.
--
{-# INLINE stencilBorders #-}
stencilBorders
    :: forall sh. HasCallStack
    => ShapeR sh
    -> sh
    -> sh
    -> [(sh, sh)]
stencilBorders :: ShapeR sh -> sh -> sh -> [(sh, sh)]
stencilBorders ShapeR sh
shr sh
sh sh
halo = [ Int -> (sh, sh)
face Int
i | Int
i <- [Int
0 .. (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* ShapeR sh -> Int
forall sh. ShapeR sh -> Int
rank ShapeR sh
shr Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] ]
  where
    face :: Int -> (sh, sh)
    face :: Int -> (sh, sh)
face Int
n = Int -> ShapeR sh -> sh -> sh -> (sh, sh)
forall t. Int -> ShapeR t -> t -> t -> (t, t)
go Int
n ShapeR sh
shr sh
sh sh
halo

    go :: Int -> ShapeR t -> t -> t -> (t, t)
    go :: Int -> ShapeR t -> t -> t -> (t, t)
go Int
_ ShapeR t
ShapeRz           ()         ()         = ((), ())
    go Int
n (ShapeRsnoc ShapeR sh1
shr') (sha, sza) (shb, szb)
      = let
            (sh1
sha', sh1
shb')  = Int -> ShapeR sh1 -> sh1 -> sh1 -> (sh1, sh1)
forall t. Int -> ShapeR t -> t -> t -> (t, t)
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2) ShapeR sh1
shr' sh1
sha sh1
shb
            (Int
sza', Int
szb')
              | Int
n Int -> Int -> HasInitialValue
forall a. Ord a => a -> a -> HasInitialValue
<  Int
0    = (Int
0,       Int
sza)
              | Int
n Int -> Int -> HasInitialValue
forall a. Eq a => a -> a -> HasInitialValue
== Int
0    = (Int
0,       Int
szb)
              | Int
n Int -> Int -> HasInitialValue
forall a. Eq a => a -> a -> HasInitialValue
== Int
1    = (Int
szaInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
szb, Int
sza)
              | HasInitialValue
otherwise = (Int
szb,     Int
szaInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
szb)
        in
        ((sh1
sha', Int
sza'), (sh1
shb', Int
szb'))


-- Foreign functions
--
{-# INLINE aforeignOp #-}
aforeignOp
    :: HasCallStack
    => String
    -> ArraysR as
    -> ArraysR bs
    -> (as -> Par PTX (Future bs))
    -> as
    -> Par PTX (Future bs)
aforeignOp :: String
-> ArraysR as
-> ArraysR bs
-> (as -> Par PTX (Future bs))
-> as
-> Par PTX (Future bs)
aforeignOp String
name ArraysR as
_ ArraysR bs
_ as -> Par PTX (Future bs)
asm as
arr = do
  Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
  IO HasInitialValue
-> (Double -> Double -> Double -> IO ())
-> Maybe Stream
-> Par PTX (Future bs)
-> Par PTX (Future bs)
forall (m :: * -> *) a.
MonadIO m =>
IO HasInitialValue
-> (Double -> Double -> Double -> IO ())
-> Maybe Stream
-> m a
-> m a
Debug.monitorProcTime IO HasInitialValue
query Double -> Double -> Double -> IO ()
msg (Stream -> Maybe Stream
forall a. a -> Maybe a
Just (Stream -> Stream
forall a. Lifetime a -> a
unsafeGetValue Stream
stream)) (as -> Par PTX (Future bs)
asm as
arr)
  where
    query :: IO HasInitialValue
query = if HasInitialValue
Debug.monitoringIsEnabled
              then HasInitialValue -> IO HasInitialValue
forall (m :: * -> *) a. Monad m => a -> m a
return HasInitialValue
True
              else IO HasInitialValue -> IO HasInitialValue
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO HasInitialValue -> IO HasInitialValue)
-> IO HasInitialValue -> IO HasInitialValue
forall a b. (a -> b) -> a -> b
$ Flag -> IO HasInitialValue
Debug.getFlag Flag
Debug.dump_exec

    msg :: Double -> Double -> Double -> IO ()
msg Double
wall Double
cpu Double
gpu = do
      Processor -> Double -> IO ()
Debug.addProcessorTime Processor
Debug.PTX Double
gpu
      Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_exec (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"exec: %s %s" String
name (Double -> Double -> Double -> String
Debug.elapsed Double
wall Double
cpu Double
gpu)


-- Skeleton execution
-- ------------------

-- | Retrieve the named kernel
--
(!#) :: HasCallStack => FunctionTable -> ShortByteString -> Kernel
!# :: FunctionTable -> ShortByteString -> Kernel
(!#) FunctionTable
exe ShortByteString
name
  = Kernel -> Maybe Kernel -> Kernel
forall a. a -> Maybe a -> a
fromMaybe (String -> Kernel
forall a. HasCallStack => String -> a
internalError (String
"function not found: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ShortByteString -> String
unpack ShortByteString
name))
  (Maybe Kernel -> Kernel) -> Maybe Kernel -> Kernel
forall a b. (a -> b) -> a -> b
$ ShortByteString -> FunctionTable -> Maybe Kernel
lookupKernel ShortByteString
name FunctionTable
exe

lookupKernel :: ShortByteString -> FunctionTable -> Maybe Kernel
lookupKernel :: ShortByteString -> FunctionTable -> Maybe Kernel
lookupKernel ShortByteString
name FunctionTable
ptxExecutable =
  (Kernel -> HasInitialValue) -> [Kernel] -> Maybe Kernel
forall (t :: * -> *) a.
Foldable t =>
(a -> HasInitialValue) -> t a -> Maybe a
find (\Kernel
k -> Kernel -> ShortByteString
kernelName Kernel
k ShortByteString -> ShortByteString -> HasInitialValue
forall a. Eq a => a -> a -> HasInitialValue
== ShortByteString
name) (FunctionTable -> [Kernel]
functionTable FunctionTable
ptxExecutable)

delayedShape :: Delayed (Array sh e) -> sh
delayedShape :: Delayed (Array sh e) -> sh
delayedShape (Delayed sh
sh) = sh
sh
sh
delayedShape (Manifest Array sh e
a) = Array sh e -> sh
forall sh e. Array sh e -> sh
shape Array sh e
a

manifest :: Delayed (Array sh e) -> Maybe (Array sh e)
manifest :: Delayed (Array sh e) -> Maybe (Array sh e)
manifest (Manifest Array sh e
a) = Array sh e -> Maybe (Array sh e)
forall a. a -> Maybe a
Just Array sh e
a
manifest Delayed{}    = Maybe (Array sh e)
forall a. Maybe a
Nothing

-- | Execute some operation with the supplied executable functions
--
withExecutable :: HasCallStack => ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable :: ExecutableR PTX -> (FunctionTable -> Par PTX b) -> Par PTX b
withExecutable PTXR{..} FunctionTable -> Par PTX b
f =
  (ParState -> ParState) -> Par PTX b -> Par PTX b
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\(Stream
s,Maybe (Lifetime FunctionTable)
_) -> (Stream
s,Lifetime FunctionTable -> Maybe (Lifetime FunctionTable)
forall a. a -> Maybe a
Just Lifetime FunctionTable
ptxExecutable)) (Par PTX b -> Par PTX b) -> Par PTX b -> Par PTX b
forall a b. (a -> b) -> a -> b
$ do
    b
r <- FunctionTable -> Par PTX b
f (Lifetime FunctionTable -> FunctionTable
forall a. Lifetime a -> a
unsafeGetValue Lifetime FunctionTable
ptxExecutable)
    IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Par PTX ()) -> IO () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ Lifetime FunctionTable -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime Lifetime FunctionTable
ptxExecutable
    b -> Par PTX b
forall (m :: * -> *) a. Monad m => a -> m a
return b
r


-- Execute the function implementing this kernel.
--
executeOp
    :: HasCallStack
    => Kernel
    -> Gamma aenv
    -> Val aenv
    -> ShapeR sh
    -> sh
    -> ParamsR PTX params
    -> params
    -> Par PTX ()
executeOp :: Kernel
-> Gamma aenv
-> Val aenv
-> ShapeR sh
-> sh
-> ParamsR PTX params
-> params
-> Par PTX ()
executeOp Kernel
kernel Gamma aenv
gamma Val aenv
aenv ShapeR sh
shr sh
sh ParamsR PTX params
paramsR params
params =
  let n :: Int
n = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shr sh
sh
  in  HasInitialValue -> Par PTX () -> Par PTX ()
forall (f :: * -> *).
Applicative f =>
HasInitialValue -> f () -> f ()
when (Int
n Int -> Int -> HasInitialValue
forall a. Ord a => a -> a -> HasInitialValue
> Int
0) (Par PTX () -> Par PTX ()) -> Par PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ do
        Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
        DList FunParam
argv   <- ParamsR PTX (params, Val aenv)
-> (params, Val aenv) -> Par PTX (DList (ArgR PTX))
forall arch a.
Marshal arch =>
ParamsR arch a -> a -> Par arch (DList (ArgR arch))
marshalParams' @PTX (ParamsR PTX params
paramsR ParamsR PTX params
-> TupR (ParamR PTX) (Val aenv) -> ParamsR PTX (params, Val aenv)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair` ParamR PTX (Val aenv) -> TupR (ParamR PTX) (Val aenv)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (Gamma aenv -> ParamR PTX (Val aenv)
forall aenv arch. Gamma aenv -> ParamR arch (ValR arch aenv)
ParamRenv Gamma aenv
gamma)) (params
params, Val aenv
aenv)
        IO () -> Par PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO  (IO () -> Par PTX ()) -> IO () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$ HasCallStack => Kernel -> Stream -> Int -> [FunParam] -> IO ()
Kernel -> Stream -> Int -> [FunParam] -> IO ()
launch Kernel
kernel Stream
stream Int
n ([FunParam] -> IO ()) -> [FunParam] -> IO ()
forall a b. (a -> b) -> a -> b
$ DList FunParam -> [FunParam]
forall a. DList a -> [a]
DL.toList DList FunParam
argv


-- Execute a device function with the given thread configuration and function
-- parameters.
--
launch :: HasCallStack => Kernel -> Stream -> Int -> [CUDA.FunParam] -> IO ()
launch :: Kernel -> Stream -> Int -> [FunParam] -> IO ()
launch Kernel{Int
ShortByteString
Fun
Int -> Int
kernelSharedMemBytes :: Kernel -> Int
kernelFun :: Kernel -> Fun
kernelThreadBlocks :: Int -> Int
kernelThreadBlockSize :: Int
kernelSharedMemBytes :: Int
kernelFun :: Fun
kernelName :: ShortByteString
kernelName :: Kernel -> ShortByteString
kernelThreadBlockSize :: Kernel -> Int
kernelThreadBlocks :: Kernel -> Int -> Int
..} Stream
stream Int
n [FunParam]
args =
  Stream -> (Stream -> IO ()) -> IO ()
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime Stream
stream ((Stream -> IO ()) -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Stream
st ->
    IO HasInitialValue
-> (Double -> Double -> Double -> IO ())
-> Maybe Stream
-> IO ()
-> IO ()
forall (m :: * -> *) a.
MonadIO m =>
IO HasInitialValue
-> (Double -> Double -> Double -> IO ())
-> Maybe Stream
-> m a
-> m a
Debug.monitorProcTime IO HasInitialValue
query Double -> Double -> Double -> IO ()
msg (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      Fun
-> (Int, Int, Int)
-> (Int, Int, Int)
-> Int
-> Maybe Stream
-> [FunParam]
-> IO ()
CUDA.launchKernel Fun
kernelFun (Int, Int, Int)
grid (Int, Int, Int)
cta Int
smem (Stream -> Maybe Stream
forall a. a -> Maybe a
Just Stream
st) [FunParam]
args
  where
    cta :: (Int, Int, Int)
cta   = (Int
kernelThreadBlockSize, Int
1, Int
1)
    grid :: (Int, Int, Int)
grid  = (Int -> Int
kernelThreadBlocks Int
n, Int
1, Int
1)
    smem :: Int
smem  = Int
kernelSharedMemBytes

    -- Debugging/monitoring support
    query :: IO HasInitialValue
query = if HasInitialValue
Debug.monitoringIsEnabled
              then HasInitialValue -> IO HasInitialValue
forall (m :: * -> *) a. Monad m => a -> m a
return HasInitialValue
True
              else Flag -> IO HasInitialValue
Debug.getFlag Flag
Debug.dump_exec

    fst3 :: (a, b, c) -> a
fst3 (a
x,b
_,c
_)      = a
x
    msg :: Double -> Double -> Double -> IO ()
msg Double
wall Double
cpu Double
gpu  = do
      Processor -> Double -> IO ()
Debug.addProcessorTime Processor
Debug.PTX Double
gpu
      Flag -> String -> IO ()
Debug.traceIO Flag
Debug.dump_exec (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> String -> Int -> Int -> Int -> String -> String
forall r. PrintfType r => String -> r
printf String
"exec: %s <<< %d, %d, %d >>> %s"
               (ShortByteString -> String
unpack ShortByteString
kernelName) ((Int, Int, Int) -> Int
forall a b c. (a, b, c) -> a
fst3 (Int, Int, Int)
grid) ((Int, Int, Int) -> Int
forall a b c. (a, b, c) -> a
fst3 (Int, Int, Int)
cta) Int
smem (Double -> Double -> Double -> String
Debug.elapsed Double
wall Double
cpu Double
gpu)