{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level3
where
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.LLVM.Native.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Base
import Foreign.Ptr
import qualified Blas.Primitive.Types as C
import qualified Blas.Primitive.Unsafe as C
gemm :: NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm :: NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm NumericR s e
nR Transpose
opA Transpose
opB = String
-> (((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e)))
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
forall a b.
String -> (a -> Par Native (Future b)) -> ForeignAcc (a -> b)
ForeignAcc String
"native.gemm" (NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB)
where
gemm' :: NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e))
gemm' :: NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB ((((), Scalar e
alpha), Matrix e
matA), Matrix e
matB) = do
let
(((), Int
rowsA), Int
colsA) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matA
(((), Int
rowsB), Int
colsB) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matB
(Int
m,Int
k) = case Transpose
opA of
Transpose
N -> (Int
rowsA, Int
colsA)
Transpose
_ -> (Int
colsA, Int
rowsA)
n :: Int
n = case Transpose
opB of
Transpose
N -> Int
colsB
Transpose
_ -> Int
rowsB
lda :: Int
lda = Int
colsA
ldb :: Int
ldb = Int
colsB
opA' :: Transpose
opA' = Transpose -> Transpose
encodeTranspose Transpose
opA
opB' :: Transpose
opB' = Transpose -> Transpose
encodeTranspose Transpose
opB
alpha' :: e
alpha' = ArrayR (Scalar e) -> Scalar e -> () -> e
forall sh e. ArrayR (Array sh e) -> Array sh e -> sh -> e
indexArray (ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
eR) Scalar e
alpha ()
aR :: ArrayR (Matrix e)
aR = ShapeR (((), Int), Int) -> TypeR e -> ArrayR (Matrix e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR (((), Int), Int)
dim2 TypeR e
eR
eR :: TypeR e
eR = case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> Elt Float => TypeR (EltR Float)
forall a. Elt a => TypeR (EltR a)
eltR @Float
NumericR s e
NumericRfloat64 -> Elt Double => TypeR (EltR Double)
forall a. Elt a => TypeR (EltR a)
eltR @Double
NumericR s e
NumericRcomplex32 -> Elt (Complex Float) => TypeR (EltR (Complex Float))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Float)
NumericR s e
NumericRcomplex64 -> Elt (Complex Double) => TypeR (EltR (Complex Double))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Double)
Future (Matrix e)
future <- Par Native (Future (Matrix e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
Matrix e
matC <- ArrayR (Matrix e) -> (((), Int), Int) -> Par Native (Matrix e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Matrix e)
aR (((), Int
m), Int
n)
() <- IO () -> Par Native ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Par Native ()) -> IO () -> Par Native ()
forall a b. (a -> b) -> a -> b
$ do
NumericR s e -> Matrix e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Matrix e
matA ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_A -> do
NumericR s e -> Matrix e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Matrix e
matB ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_B -> do
NumericR s e -> Matrix e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Matrix e
matC ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_C -> do
case NumericR s e
nR of
NumericR s e
NumericRfloat32 -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Float
-> Ptr Float
-> Int
-> Ptr Float
-> Int
-> Float
-> Ptr Float
-> Int
-> IO ()
C.sgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k e
Float
alpha' Ptr Float
ArrayPtrs e
ptr_A Int
lda Ptr Float
ArrayPtrs e
ptr_B Int
ldb Float
0 Ptr Float
ArrayPtrs e
ptr_C Int
n
NumericR s e
NumericRfloat64 -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Double
-> Ptr Double
-> Int
-> Ptr Double
-> Int
-> Double
-> Ptr Double
-> Int
-> IO ()
C.dgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k e
Double
alpha' Ptr Double
ArrayPtrs e
ptr_A Int
lda Ptr Double
ArrayPtrs e
ptr_B Int
ldb Double
0 Ptr Double
ArrayPtrs e
ptr_C Int
n
NumericR s e
NumericRcomplex32 -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Complex Float
-> Ptr (Complex Float)
-> Int
-> Ptr (Complex Float)
-> Int
-> Complex Float
-> Ptr (Complex Float)
-> Int
-> IO ()
C.cgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k (EltR (Complex Float) -> Complex Float
forall a. Elt a => EltR a -> a
toElt e
EltR (Complex Float)
alpha') (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_A) Int
lda (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_B) Int
ldb Complex Float
0 (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_C) Int
n
NumericR s e
NumericRcomplex64 -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Complex Double
-> Ptr (Complex Double)
-> Int
-> Ptr (Complex Double)
-> Int
-> Complex Double
-> Ptr (Complex Double)
-> Int
-> IO ()
C.zgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k (EltR (Complex Double) -> Complex Double
forall a. Elt a => EltR a -> a
toElt e
EltR (Complex Double)
alpha') (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_A) Int
lda (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_B) Int
ldb Complex Double
0 (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_C) Int
n
FutureR Native (Matrix e) -> Matrix e -> Par Native ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR Native (Matrix e)
Future (Matrix e)
future Matrix e
matC
Future (Matrix e) -> Par Native (Future (Matrix e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Matrix e)
future