{-# LANGUAGE CPP                 #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NoImplicitPrelude   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
{-# LANGUAGE TypeApplications    #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level3
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Level 3 (matrix-matrix) BLAS operations.
--

module Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level3 (

  -- Types
  Numeric, Matrix, Transpose(..),

  -- Matrix-matrix operations
  gemm,

) where

import Data.Array.Accelerate                                        as A
import Data.Array.Accelerate.Data.Complex                           as A
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type

#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level3 as CPU
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3    as PTX
#endif


-- | General matrix-matrix multiply
--
-- \[
-- C = \alpha * \mathrm{op}(A) * \mathrm{op}(B)
-- \]
--
-- where:
--
--   * 'shape' \(\mathrm{op}(A)\) @= Z :. m :. k@
--   * 'shape' \(\mathrm{op}(B)\) @= Z :. k :. n@
--   * 'shape' \(C\) @= Z :. m :. n@
--
-- <https://software.intel.com/en-us/mkl-developer-reference-c-cblas-gemm>
--
gemm :: forall e. Numeric e
     => Exp e                 -- ^ \( \alpha \)
     -> Transpose             -- ^ operation to apply to A
     -> Acc (Matrix e)        -- ^ A
     -> Transpose             -- ^ operation to apply to B
     -> Acc (Matrix e)        -- ^ B
     -> Acc (Matrix e)        -- ^ C
gemm :: Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
gemm Exp e
alpha Transpose
opA Acc (Matrix e)
matA Transpose
opB Acc (Matrix e)
matB = Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e)
go ((Acc (Array DIM0 e), Acc (Matrix e), Acc (Matrix e))
-> Acc (Plain (Acc (Array DIM0 e), Acc (Matrix e), Acc (Matrix e)))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (Exp e -> Acc (Array DIM0 e)
forall e. Elt e => Exp e -> Acc (Scalar e)
unit Exp e
alpha, Acc (Matrix e)
matA, Acc (Matrix e)
matB))
  where
    go :: Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e)
go =
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
      ForeignAcc
  (ArraysR (Array DIM0 e, Matrix e, Matrix e) -> ArraysR (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall as bs (asm :: * -> *).
(Arrays as, Arrays bs, Foreign asm) =>
asm (ArraysR as -> ArraysR bs)
-> (Acc as -> Acc bs) -> Acc as -> Acc bs
foreignAcc (NumericR e (EltR e)
-> Transpose
-> Transpose
-> ForeignAcc
     (((((), Scalar (EltR e)), Matrix (EltR e)), Matrix (EltR e))
      -> Matrix (EltR e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
CPU.gemm NumericR e (EltR e)
nR Transpose
opA Transpose
opB) ((Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
 -> Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall a b. (a -> b) -> a -> b
$
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
      ForeignAcc
  (ArraysR (Array DIM0 e, Matrix e, Matrix e) -> ArraysR (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall as bs (asm :: * -> *).
(Arrays as, Arrays bs, Foreign asm) =>
asm (ArraysR as -> ArraysR bs)
-> (Acc as -> Acc bs) -> Acc as -> Acc bs
foreignAcc (NumericR e (EltR e)
-> Transpose
-> Transpose
-> ForeignAcc
     (((((), Scalar (EltR e)), Matrix (EltR e)), Matrix (EltR e))
      -> Matrix (EltR e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
PTX.gemm NumericR e (EltR e)
nR Transpose
opA Transpose
opB) ((Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
 -> Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall a b. (a -> b) -> a -> b
$
#endif
      (\(T3 Acc (Array DIM0 e)
_ Acc (Matrix e)
arr Acc (Matrix e)
brr) -> Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mXm Acc (Matrix e)
arr Acc (Matrix e)
brr)

    nR :: NumericR e (EltR e)
nR = Numeric e => NumericR e (EltR e)
forall a. Numeric a => NumericR a (EltR a)
numericR @e

    -- General dense matrix-matrix multiply written in pure Accelerate. This is
    -- not efficient due to the memory access patterns. We could probably
    -- improve this a little bit with a divide-and-conquer algorithm, for
    -- example, but using a foreign implementation will be best.
    --
    mXm :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
    mXm :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mXm Acc (Matrix e)
arr Acc (Matrix e)
brr
      = (Exp e -> Exp e -> Exp e)
-> Exp e -> Acc (Array (DIM2 :. Int) e) -> Acc (Matrix e)
forall sh a.
(Shape sh, Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Exp a -> Acc (Array (sh :. Int) a) -> Acc (Array sh a)
fold Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
(+) Exp e
0
      (Acc (Array (DIM2 :. Int) e) -> Acc (Matrix e))
-> Acc (Array (DIM2 :. Int) e) -> Acc (Matrix e)
forall a b. (a -> b) -> a -> b
$ (Exp e -> Exp e -> Exp e)
-> Acc (Array (DIM2 :. Int) e)
-> Acc (Array (DIM2 :. Int) e)
-> Acc (Array (DIM2 :. Int) e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
zipWith (\Exp e
a Exp e
b -> Exp e
alpha Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e
a Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e
b) Acc (Array (DIM2 :. Int) e)
Acc (Array (FullShape (((DIM0 :. All) :. Int) :. All)) e)
arrRepl Acc (Array (DIM2 :. Int) e)
Acc (Array (FullShape (((DIM0 :. Int) :. All) :. All)) e)
brrRepl
      where
        DIM0
Z :. Exp Int
rowsA :. Exp Int
_ = Exp (Plain ((DIM0 :. Exp Int) :. Exp Int))
-> (DIM0 :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp DIM2
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
arr') :: Z :. Exp Int :. Exp Int
        DIM0
Z :. Exp Int
colsB :. Exp Int
_ = Exp (Plain ((DIM0 :. Exp Int) :. Exp Int))
-> (DIM0 :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp DIM2
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
brr') :: Z :. Exp Int :. Exp Int
        --
        arrRepl :: Acc (Array (FullShape (((DIM0 :. All) :. Int) :. All)) e)
arrRepl         = Exp (((DIM0 :. All) :. Int) :. All)
-> Acc (Array (SliceShape (((DIM0 :. All) :. Int) :. All)) e)
-> Acc (Array (FullShape (((DIM0 :. All) :. Int) :. All)) e)
forall slix e.
(Slice slix, Elt e) =>
Exp slix
-> Acc (Array (SliceShape slix) e)
-> Acc (Array (FullShape slix) e)
replicate ((((DIM0 :. All) :. Exp Int) :. All)
-> Exp (Plain (((DIM0 :. All) :. Exp Int) :. All))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift ((((DIM0 :. All) :. Exp Int) :. All)
 -> Exp (Plain (((DIM0 :. All) :. Exp Int) :. All)))
-> (((DIM0 :. All) :. Exp Int) :. All)
-> Exp (Plain (((DIM0 :. All) :. Exp Int) :. All))
forall a b. (a -> b) -> a -> b
$ DIM0
Z DIM0 -> All -> DIM0 :. All
forall tail head. tail -> head -> tail :. head
:. All
All   (DIM0 :. All) -> Exp Int -> (DIM0 :. All) :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int
colsB ((DIM0 :. All) :. Exp Int)
-> All -> ((DIM0 :. All) :. Exp Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All) Acc (Matrix e)
Acc (Array (SliceShape (((DIM0 :. All) :. Int) :. All)) e)
arr'
        brrRepl :: Acc (Array (FullShape (((DIM0 :. Int) :. All) :. All)) e)
brrRepl         = Exp (((DIM0 :. Int) :. All) :. All)
-> Acc (Array (SliceShape (((DIM0 :. Int) :. All) :. All)) e)
-> Acc (Array (FullShape (((DIM0 :. Int) :. All) :. All)) e)
forall slix e.
(Slice slix, Elt e) =>
Exp slix
-> Acc (Array (SliceShape slix) e)
-> Acc (Array (FullShape slix) e)
replicate ((((DIM0 :. Exp Int) :. All) :. All)
-> Exp (Plain (((DIM0 :. Exp Int) :. All) :. All))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift ((((DIM0 :. Exp Int) :. All) :. All)
 -> Exp (Plain (((DIM0 :. Exp Int) :. All) :. All)))
-> (((DIM0 :. Exp Int) :. All) :. All)
-> Exp (Plain (((DIM0 :. Exp Int) :. All) :. All))
forall a b. (a -> b) -> a -> b
$ DIM0
Z DIM0 -> Exp Int -> DIM0 :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int
rowsA (DIM0 :. Exp Int) -> All -> (DIM0 :. Exp Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All   ((DIM0 :. Exp Int) :. All)
-> All -> ((DIM0 :. Exp Int) :. All) :. All
forall tail head. tail -> head -> tail :. head
:. All
All) Acc (Matrix e)
Acc (Array (SliceShape (((DIM0 :. Int) :. All) :. All)) e)
brr'

        -- apply opA
        arr' :: Acc (Matrix e)
arr' = case Transpose
opA of
                 Transpose
N -> Acc (Matrix e)
arr
                 Transpose
T -> Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr
                 Transpose
H -> case NumericR e (EltR e)
nR of
                        NumericR e (EltR e)
NumericRcomplex32 -> (Exp (Complex Float) -> Exp (Complex Float))
-> Acc (Array DIM2 (Complex Float))
-> Acc (Array DIM2 (Complex Float))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Float) -> Exp (Complex Float)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate (Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr)
                        NumericR e (EltR e)
NumericRcomplex64 -> (Exp (Complex Double) -> Exp (Complex Double))
-> Acc (Array DIM2 (Complex Double))
-> Acc (Array DIM2 (Complex Double))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Double) -> Exp (Complex Double)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate (Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr)
                        NumericR e (EltR e)
_                 -> Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr

        -- apply opB and transpose at the same time, which is required for this
        -- algorithm
        brr' :: Acc (Matrix e)
brr' = case Transpose
opB of
                 Transpose
N -> Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
brr
                 Transpose
T -> Acc (Matrix e)
brr
                 Transpose
H -> case NumericR e (EltR e)
nR of
                        NumericR e (EltR e)
NumericRcomplex32 -> (Exp (Complex Float) -> Exp (Complex Float))
-> Acc (Array DIM2 (Complex Float))
-> Acc (Array DIM2 (Complex Float))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Float) -> Exp (Complex Float)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate Acc (Matrix e)
Acc (Array DIM2 (Complex Float))
brr
                        NumericR e (EltR e)
NumericRcomplex64 -> (Exp (Complex Double) -> Exp (Complex Double))
-> Acc (Array DIM2 (Complex Double))
-> Acc (Array DIM2 (Complex Double))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Double) -> Exp (Complex Double)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate Acc (Matrix e)
Acc (Array DIM2 (Complex Double))
brr
                        NumericR e (EltR e)
_                 -> Acc (Matrix e)
brr