{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications    #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level3
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level3
  where

import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Sugar.Elt

import Data.Array.Accelerate.LLVM.Native.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Base

import Foreign.Ptr
import qualified Blas.Primitive.Types                               as C
import qualified Blas.Primitive.Unsafe                              as C


-- TODO: check whether it is faster to compute this as column-major order:
--
-- https://www.christophlassner.de/using-blas-from-c-with-row-major-data.html
--
gemm :: NumericR s e
     -> Transpose
     -> Transpose
     -> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm :: NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
gemm NumericR s e
nR Transpose
opA Transpose
opB = String
-> (((((), Scalar e), Matrix e), Matrix e)
    -> Par Native (Future (Matrix e)))
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
forall a b.
String -> (a -> Par Native (Future b)) -> ForeignAcc (a -> b)
ForeignAcc String
"native.gemm" (NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB)
  where

gemm' :: NumericR s e
      -> Transpose
      -> Transpose
      -> ((((), Scalar e), Matrix e), Matrix e)
      -> Par Native (Future (Matrix e))
gemm' :: NumericR s e
-> Transpose
-> Transpose
-> ((((), Scalar e), Matrix e), Matrix e)
-> Par Native (Future (Matrix e))
gemm' NumericR s e
nR Transpose
opA Transpose
opB ((((), Scalar e
alpha), Matrix e
matA), Matrix e
matB) = do
  let
      (((), Int
rowsA), Int
colsA) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matA
      (((), Int
rowsB), Int
colsB) = Matrix e -> (((), Int), Int)
forall sh e. Array sh e -> sh
shape Matrix e
matB

      (Int
m,Int
k)   = case Transpose
opA of
                  Transpose
N -> (Int
rowsA, Int
colsA)
                  Transpose
_ -> (Int
colsA, Int
rowsA)
      n :: Int
n       = case Transpose
opB of
                  Transpose
N -> Int
colsB
                  Transpose
_ -> Int
rowsB

      lda :: Int
lda     = Int
colsA
      ldb :: Int
ldb     = Int
colsB

      opA' :: Transpose
opA'    = Transpose -> Transpose
encodeTranspose Transpose
opA
      opB' :: Transpose
opB'    = Transpose -> Transpose
encodeTranspose Transpose
opB
      alpha' :: e
alpha'  = ArrayR (Scalar e) -> Scalar e -> () -> e
forall sh e. ArrayR (Array sh e) -> Array sh e -> sh -> e
indexArray (ShapeR () -> TypeR e -> ArrayR (Scalar e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR ()
dim0 TypeR e
eR) Scalar e
alpha ()

      aR :: ArrayR (Matrix e)
aR      = ShapeR (((), Int), Int) -> TypeR e -> ArrayR (Matrix e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR (((), Int), Int)
dim2 TypeR e
eR
      eR :: TypeR e
eR      = case NumericR s e
nR of
                  NumericR s e
NumericRfloat32   -> Elt Float => TypeR (EltR Float)
forall a. Elt a => TypeR (EltR a)
eltR @Float
                  NumericR s e
NumericRfloat64   -> Elt Double => TypeR (EltR Double)
forall a. Elt a => TypeR (EltR a)
eltR @Double
                  NumericR s e
NumericRcomplex32 -> Elt (Complex Float) => TypeR (EltR (Complex Float))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Float)
                  NumericR s e
NumericRcomplex64 -> Elt (Complex Double) => TypeR (EltR (Complex Double))
forall a. Elt a => TypeR (EltR a)
eltR @(Complex Double)

  --
  Future (Matrix e)
future <- Par Native (Future (Matrix e))
forall arch a.
(Async arch, HasCallStack) =>
Par arch (FutureR arch a)
new
  Matrix e
matC   <- ArrayR (Matrix e) -> (((), Int), Int) -> Par Native (Matrix e)
forall arch sh e.
Remote arch =>
ArrayR (Array sh e) -> sh -> Par arch (Array sh e)
allocateRemote ArrayR (Matrix e)
aR (((), Int
m), Int
n)
  ()     <- IO () -> Par Native ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Par Native ()) -> IO () -> Par Native ()
forall a b. (a -> b) -> a -> b
$ do
    NumericR s e -> Matrix e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Matrix e
matA   ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_A -> do
     NumericR s e -> Matrix e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Matrix e
matB  ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_B -> do
      NumericR s e -> Matrix e -> (ArrayPtrs e -> IO ()) -> IO ()
forall s e sh b.
NumericR s e -> Array sh e -> (ArrayPtrs e -> IO b) -> IO b
withArray NumericR s e
nR Matrix e
matC ((ArrayPtrs e -> IO ()) -> IO ())
-> (ArrayPtrs e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ArrayPtrs e
ptr_C -> do
        case NumericR s e
nR of
          NumericR s e
NumericRfloat32   -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Float
-> Ptr Float
-> Int
-> Ptr Float
-> Int
-> Float
-> Ptr Float
-> Int
-> IO ()
C.sgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k e
Float
alpha' Ptr Float
ArrayPtrs e
ptr_A Int
lda Ptr Float
ArrayPtrs e
ptr_B Int
ldb Float
0 Ptr Float
ArrayPtrs e
ptr_C Int
n
          NumericR s e
NumericRfloat64   -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Double
-> Ptr Double
-> Int
-> Ptr Double
-> Int
-> Double
-> Ptr Double
-> Int
-> IO ()
C.dgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k e
Double
alpha' Ptr Double
ArrayPtrs e
ptr_A Int
lda Ptr Double
ArrayPtrs e
ptr_B Int
ldb Double
0 Ptr Double
ArrayPtrs e
ptr_C Int
n
          NumericR s e
NumericRcomplex32 -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Complex Float
-> Ptr (Complex Float)
-> Int
-> Ptr (Complex Float)
-> Int
-> Complex Float
-> Ptr (Complex Float)
-> Int
-> IO ()
C.cgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k (EltR (Complex Float) -> Complex Float
forall a. Elt a => EltR a -> a
toElt e
EltR (Complex Float)
alpha') (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_A) Int
lda (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_B) Int
ldb Complex Float
0 (Ptr Float -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr Float
ArrayPtrs e
ptr_C) Int
n
          NumericR s e
NumericRcomplex64 -> Order
-> Transpose
-> Transpose
-> Int
-> Int
-> Int
-> Complex Double
-> Ptr (Complex Double)
-> Int
-> Ptr (Complex Double)
-> Int
-> Complex Double
-> Ptr (Complex Double)
-> Int
-> IO ()
C.zgemm Order
C.RowMajor Transpose
opA' Transpose
opB' Int
m Int
n Int
k (EltR (Complex Double) -> Complex Double
forall a. Elt a => EltR a -> a
toElt e
EltR (Complex Double)
alpha') (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_A) Int
lda (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_B) Int
ldb Complex Double
0 (Ptr Double -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr Double
ArrayPtrs e
ptr_C) Int
n
  --
  FutureR Native (Matrix e) -> Matrix e -> Par Native ()
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> a -> Par arch ()
put FutureR Native (Matrix e)
Future (Matrix e)
future Matrix e
matC
  Future (Matrix e) -> Par Native (Future (Matrix e))
forall (m :: * -> *) a. Monad m => a -> m a
return Future (Matrix e)
future