{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Foreign.Marshal ( with )
import qualified Foreign.CUDA.Ptr as CUDA
import qualified Foreign.CUDA.BLAS as BLAS
gemm :: Numeric e
=> Transpose
-> Transpose
-> ForeignAcc ((Scalar e, Matrix e, Matrix e) -> Matrix e)
gemm opA opB = ForeignAcc "ptx.gemm" (gemm' opA opB)
gemm'
:: forall e. Numeric e
=> Transpose
-> Transpose
-> Stream
-> (Scalar e, Matrix e, Matrix e)
-> LLVM PTX (Matrix e)
gemm' opA opB stream (alpha, matA, matB) = do
let
Z :. rowsA :. colsA = arrayShape matA
Z :. rowsB :. colsB = arrayShape matB
(m,k) = case opA of
N -> (rowsA, colsA)
_ -> (colsA, rowsA)
n = case opB of
N -> colsB
_ -> rowsB
lda = colsA
ldb = colsB
opA' = encodeTranspose opA
opB' = encodeTranspose opB
matC <- allocateRemote (Z :. m :. n) :: LLVM PTX (Matrix e)
alpha' <- indexRemote alpha 0
() <- withArray matA stream $ \ptr_A -> do
withArray matB stream $ \ptr_B -> do
withArray matC stream $ \ptr_C -> do
withBLAS $ \hdl -> do
case numericR :: NumericR e of
NumericRfloat32 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.sgemm hdl opB' opA' n m k ptr_alpha ptr_B ldb ptr_A lda ptr_beta ptr_C n
NumericRfloat64 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.dgemm hdl opB' opA' n m k ptr_alpha ptr_B ldb ptr_A lda ptr_beta ptr_C n
NumericRcomplex32 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.cgemm hdl opB' opA' n m k ptr_alpha (CUDA.castDevPtr ptr_B) ldb (CUDA.castDevPtr ptr_A) lda ptr_beta (CUDA.castDevPtr ptr_C) n
NumericRcomplex64 -> liftIO $
with alpha' $ \ptr_alpha ->
with 0 $ \ptr_beta ->
BLAS.zgemm hdl opB' opA' n m k ptr_alpha (CUDA.castDevPtr ptr_B) ldb (CUDA.castDevPtr ptr_A) lda ptr_beta (CUDA.castDevPtr ptr_C) n
return matC