{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
where
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Foreign.Marshal ( with )
import qualified Foreign.CUDA.Ptr as CUDA
import qualified Foreign.CUDA.BLAS as BLAS
import Control.Monad.Reader
gemm :: NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm :: NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm NumericR s e
nR Transpose
opA Transpose
opB = String
-> (((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e)))
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
forall a b.
String -> (a -> Par PTX (Future b)) -> ForeignAcc (a -> b)
ForeignAcc String
"ptx.gemm" (NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB)
gemm' :: NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
gemm' :: NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB ((((), Scalar e
alpha), Matrix e
matA), Matrix e
matB) = do
let
(((), Int
rowsA), Int
colsA) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matA
(((), Int
rowsB), Int
colsB) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matB
(Int
m,Int
k) = case Transpose
opA of
Transpose
N -> (Int
rowsA, Int
colsA)
Transpose
_ -> (Int
colsA, Int
rowsA)
n :: Int
n = case Transpose
opB of
Transpose
N -> Int
colsB
Transpose
_ -> Int
rowsB
lda :: Int
lda = Int
colsA
ldb :: Int
ldb = Int
colsB
opA' :: Operation
opA' = Transpose -> Operation
encodeTranspose Transpose
opA
opB' :: Operation
opB' = Transpose -> Operation
encodeTranspose Transpose
opB
aR :: ArrayR (Matrix e)
aR = ShapeR (((), Int), Int) -> TypeR e -> ArrayR (Matrix e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR (((), Int), Int)
dim2 TypeR e
eR
eR :: TypeR e
eR = case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> Elt Float => TypeR (EltR Float)
forall a. Elt a => TypeR (EltR a)
eltR @Float
NumericR s e
NumericRfloat64 -> Elt Double => TypeR (EltR Double)
forall a. Elt a => TypeR (EltR a)
eltR @Double
NumericR s e
NumericRcomplex32 -> Elt (Complex Float) => TypeR (EltR (Complex Float))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Float)
NumericR s e
NumericRcomplex64 -> Elt (Complex Double) => TypeR (EltR (Complex Double))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Double)
Future (Matrix e)
future <- Par PTX (Future (Matrix e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
Matrix e
matC <- ArrayR (Matrix e) -> (((), Int), Int) -> Par PTX (Matrix e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Matrix e)
aR (((), Int
m), Int
n)
e
alpha' <- TypeR e -> Scalar e -> Int -> Par PTX e
forall arch e sh.
Remote arch =>
TypeR e -> Array sh e -> Int -> Par arch e
indexRemote TypeR e
eR Scalar e
alpha Int
0
() <- LLVM PTX () -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (LLVM PTX () -> Par PTX ()) -> LLVM PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$
NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matA Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_A -> do
NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matB Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_B -> do
NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matC Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_C -> do
(Handle -> LLVM PTX ()) -> LLVM PTX ()
forall b. (Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS ((Handle -> LLVM PTX ()) -> LLVM PTX ())
-> (Handle -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Handle
hdl -> do
case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
Float -> (Ptr Float -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Float
0 ((Ptr Float -> IO ()) -> IO ()) -> (Ptr Float -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Float
ptr_beta ->
Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> DevicePtr Float
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> IO ()
BLAS.sgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr e
Ptr Float
ptr_alpha DevicePtr Float
DevicePtrs e
ptr_B Int
ldb DevicePtr Float
DevicePtrs e
ptr_A Int
lda Ptr Float
ptr_beta DevicePtr Float
DevicePtrs e
ptr_C Int
n
NumericR s e
NumericRfloat64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Double
0 ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Double
ptr_beta ->
Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> DevicePtr Double
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> IO ()
BLAS.dgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr e
Ptr Double
ptr_alpha DevicePtr Double
DevicePtrs e
ptr_B Int
ldb DevicePtr Double
DevicePtrs e
ptr_A Int
lda Ptr Double
ptr_beta DevicePtr Double
DevicePtrs e
ptr_C Int
n
NumericR s e
NumericRcomplex32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
NumericR s (Vec2 Float)
-> Vec2 Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Float)
nR e
Vec2 Float
alpha' ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_alpha ->
Complex Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Float
0 ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_beta ->
Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> DevicePtr (Complex Float)
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> IO ()
BLAS.cgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr (Complex Float)
ptr_alpha (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_B) Int
ldb (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_A) Int
lda Ptr (Complex Float)
ptr_beta (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_C) Int
n
NumericR s e
NumericRcomplex64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
NumericR s (Vec2 Double)
-> Vec2 Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Double)
nR e
Vec2 Double
alpha' ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_alpha ->
Complex Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Double
0 ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_beta ->
Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> DevicePtr (Complex Double)
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> IO ()
BLAS.zgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr (Complex Double)
ptr_alpha (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_B) Int
ldb (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_A) Int
lda Ptr (Complex Double)
ptr_beta (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_C) Int
n
FutureR PTX (Matrix e) -> Matrix e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Matrix e)
Future (Matrix e)
future Matrix e
matC
Future (Matrix e) -> Par PTX (Future (Matrix e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Matrix e)
future