{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
  where

import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Sugar.Elt

import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type

import Foreign.Marshal                                              ( with )
import qualified Foreign.CUDA.Ptr                                   as CUDA
import qualified Foreign.CUDA.BLAS                                  as BLAS

import Control.Monad.Reader


-- NOTE: cuBLAS requires that matrices are stored in column-major order
-- (Fortran-style), but Accelerate uses a C-style convention where matrices are
-- stored in row-major order.
--
-- At least for matrix-matrix multiply, we can get around this problem by making
-- use of the equivalence \( B^T \cdot A^T = (A \cdot B)^T \).
--
gemm :: NumericR s e
     -> Transpose
     -> Transpose
     -> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm :: NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm NumericR s e
nR Transpose
opA Transpose
opB = String
-> (((((), Scalar e), Matrix e), Matrix e)
    -> Par PTX (Future (Matrix e)))
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
forall a b.
String -> (a -> Par PTX (Future b)) -> ForeignAcc (a -> b)
ForeignAcc String
"ptx.gemm" (NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB)

gemm' :: NumericR s e
      -> Transpose
      -> Transpose
      -> ((((), Scalar e), Matrix e), Matrix e)
      -> Par PTX (Future (Matrix e))
gemm' :: NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par PTX (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB ((((), Scalar e
alpha), Matrix e
matA), Matrix e
matB) = do
  let
      (((), Int
rowsA), Int
colsA) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matA
      (((), Int
rowsB), Int
colsB) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matB

      (Int
m,Int
k)   = case Transpose
opA of
                  Transpose
N -> (Int
rowsA, Int
colsA)
                  Transpose
_ -> (Int
colsA, Int
rowsA)
      n :: Int
n       = case Transpose
opB of
                  Transpose
N -> Int
colsB
                  Transpose
_ -> Int
rowsB

      lda :: Int
lda     = Int
colsA
      ldb :: Int
ldb     = Int
colsB

      opA' :: Operation
opA'    = Transpose -> Operation
encodeTranspose Transpose
opA
      opB' :: Operation
opB'    = Transpose -> Operation
encodeTranspose Transpose
opB

      aR :: ArrayR (Matrix e)
aR      = ShapeR (((), Int), Int) -> TypeR e -> ArrayR (Matrix e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR (((), Int), Int)
dim2 TypeR e
eR
      eR :: TypeR e
eR      = case NumericR s e
nR of
                  NumericR s e
NumericRfloat32   -> Elt Float => TypeR (EltR Float)
forall a. Elt a => TypeR (EltR a)
eltR @Float
                  NumericR s e
NumericRfloat64   -> Elt Double => TypeR (EltR Double)
forall a. Elt a => TypeR (EltR a)
eltR @Double
                  NumericR s e
NumericRcomplex32 -> Elt (Complex Float) => TypeR (EltR (Complex Float))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Float)
                  NumericR s e
NumericRcomplex64 -> Elt (Complex Double) => TypeR (EltR (Complex Double))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Double)

  --
  Future (Matrix e)
future <- Par PTX (Future (Matrix e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
  Stream
stream <- (ParState -> Stream) -> Par PTX Stream
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ParState -> Stream
ptxStream
  Matrix e
matC   <- ArrayR (Matrix e) -> (((), Int), Int) -> Par PTX (Matrix e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Matrix e)
aR (((), Int
m), Int
n)
  e
alpha' <- TypeR e -> Scalar e -> Int -> Par PTX e
forall arch e sh.
Remote arch =>
TypeR e -> Array sh e -> Int -> Par arch e
indexRemote TypeR e
eR Scalar e
alpha Int
0
  ()     <- LLVM PTX () -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (LLVM PTX () -> Par PTX ()) -> LLVM PTX () -> Par PTX ()
forall a b. (a -> b) -> a -> b
$
    NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matA Stream
stream   ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_A -> do
     NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matB Stream
stream  ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_B -> do
      NumericR s e
-> Matrix e
-> Stream
-> (DevicePtrs e -> LLVM PTX ())
-> LLVM PTX ()
forall s e sh b.
NumericR s e
-> Array sh e
-> Stream
-> (DevicePtrs e -> LLVM PTX b)
-> LLVM PTX b
withArray NumericR s e
nR Matrix e
matC Stream
stream ((DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ())
-> (DevicePtrs e -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \DevicePtrs e
ptr_C -> do
        (Handle -> LLVM PTX ()) -> LLVM PTX ()
forall b. (Handle -> LLVM PTX b) -> LLVM PTX b
withBLAS               ((Handle -> LLVM PTX ()) -> LLVM PTX ())
-> (Handle -> LLVM PTX ()) -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$ \Handle
hdl   -> do
          case NumericR s e
nR of
            NumericR s e
NumericRfloat32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
              e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
               Float -> (Ptr Float -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Float
0     ((Ptr Float -> IO ()) -> IO ()) -> (Ptr Float -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Float
ptr_beta  ->
                Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> DevicePtr Float
-> Int
-> Ptr Float
-> DevicePtr Float
-> Int
-> IO ()
BLAS.sgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr e
Ptr Float
ptr_alpha DevicePtr Float
DevicePtrs e
ptr_B Int
ldb DevicePtr Float
DevicePtrs e
ptr_A Int
lda Ptr Float
ptr_beta DevicePtr Float
DevicePtrs e
ptr_C Int
n

            NumericR s e
NumericRfloat64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
              e -> (Ptr e -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with e
alpha' ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr_alpha ->
               Double -> (Ptr Double -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Double
0     ((Ptr Double -> IO ()) -> IO ()) -> (Ptr Double -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Double
ptr_beta  ->
                Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> DevicePtr Double
-> Int
-> Ptr Double
-> DevicePtr Double
-> Int
-> IO ()
BLAS.dgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr e
Ptr Double
ptr_alpha DevicePtr Double
DevicePtrs e
ptr_B Int
ldb DevicePtr Double
DevicePtrs e
ptr_A Int
lda Ptr Double
ptr_beta DevicePtr Double
DevicePtrs e
ptr_C Int
n

            NumericR s e
NumericRcomplex32 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
              NumericR s (Vec2 Float)
-> Vec2 Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Float)
nR e
Vec2 Float
alpha' ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_alpha ->
               Complex Float -> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Float
0          ((Ptr (Complex Float) -> IO ()) -> IO ())
-> (Ptr (Complex Float) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Float)
ptr_beta  ->
                Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> DevicePtr (Complex Float)
-> Int
-> Ptr (Complex Float)
-> DevicePtr (Complex Float)
-> Int
-> IO ()
BLAS.cgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr (Complex Float)
ptr_alpha (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_B) Int
ldb (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_A) Int
lda Ptr (Complex Float)
ptr_beta (DevicePtr Float -> DevicePtr (Complex Float)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Float
DevicePtrs e
ptr_C) Int
n

            NumericR s e
NumericRcomplex64 -> IO () -> LLVM PTX ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> LLVM PTX ()) -> IO () -> LLVM PTX ()
forall a b. (a -> b) -> a -> b
$
              NumericR s (Vec2 Double)
-> Vec2 Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall s a b.
NumericR s (Vec2 a) -> Vec2 a -> (Ptr (Complex a) -> IO b) -> IO b
withV2 NumericR s e
NumericR s (Vec2 Double)
nR e
Vec2 Double
alpha' ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_alpha ->
               Complex Double -> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with Complex Double
0          ((Ptr (Complex Double) -> IO ()) -> IO ())
-> (Ptr (Complex Double) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Complex Double)
ptr_beta  ->
                Handle
-> Operation
-> Operation
-> Int
-> Int
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> DevicePtr (Complex Double)
-> Int
-> Ptr (Complex Double)
-> DevicePtr (Complex Double)
-> Int
-> IO ()
BLAS.zgemm Handle
hdl Operation
opB' Operation
opA' Int
n Int
m Int
k Ptr (Complex Double)
ptr_alpha (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_B) Int
ldb (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_A) Int
lda Ptr (Complex Double)
ptr_beta (DevicePtr Double -> DevicePtr (Complex Double)
forall a b. DevicePtr a -> DevicePtr b
CUDA.castDevPtr DevicePtr Double
DevicePtrs e
ptr_C) Int
n
  --
  FutureR PTX (Matrix e) -> Matrix e -> Par PTX ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR PTX (Matrix e)
Future (Matrix e)
future Matrix e
matC
  Future (Matrix e) -> Par PTX (Future (Matrix e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Matrix e)
future