{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level3 (
Numeric, Matrix, Transpose(..),
gemm,
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Complex as A
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level3 as CPU
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3 as PTX
#endif
gemm :: forall e. Numeric e
=> Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
gemm :: Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
gemm Exp e
alpha Transpose
opA Acc (Matrix e)
matA Transpose
opB Acc (Matrix e)
matB = Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e)
go ((Acc (Array DIM0 e), Acc (Matrix e), Acc (Matrix e))
-> Acc (Plain (Acc (Array DIM0 e), Acc (Matrix e), Acc (Matrix e)))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (Exp e -> Acc (Array DIM0 e)
forall e. Elt e => Exp e -> Acc (Scalar e)
unit Exp e
alpha, Acc (Matrix e)
matA, Acc (Matrix e)
matB))
where
go :: Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e)
go =
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
ForeignAcc
(ArraysR (Array DIM0 e, Matrix e, Matrix e) -> ArraysR (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall as bs (asm :: * -> *).
(Arrays as, Arrays bs, Foreign asm) =>
asm (ArraysR as -> ArraysR bs)
-> (Acc as -> Acc bs) -> Acc as -> Acc bs
foreignAcc (NumericR e (EltR e)
-> Transpose
-> Transpose
-> ForeignAcc
(((((), Scalar (EltR e)), Matrix (EltR e)), Matrix (EltR e))
-> Matrix (EltR e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
CPU.gemm NumericR e (EltR e)
nR Transpose
opA Transpose
opB) ((Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall a b. (a -> b) -> a -> b
$
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
ForeignAcc
(ArraysR (Array DIM0 e, Matrix e, Matrix e) -> ArraysR (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall as bs (asm :: * -> *).
(Arrays as, Arrays bs, Foreign asm) =>
asm (ArraysR as -> ArraysR bs)
-> (Acc as -> Acc bs) -> Acc as -> Acc bs
foreignAcc (NumericR e (EltR e)
-> Transpose
-> Transpose
-> ForeignAcc
(((((), Scalar (EltR e)), Matrix (EltR e)), Matrix (EltR e))
-> Matrix (EltR e))
forall s e.
NumericR s e
-> Transpose
-> Transpose
-> ForeignAcc (((((), Scalar e), Matrix e), Matrix e) -> Matrix e)
PTX.gemm NumericR e (EltR e)
nR Transpose
opA Transpose
opB) ((Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> (Acc (Array DIM0 e, Matrix e, Matrix e) -> Acc (Matrix e))
-> Acc (Array DIM0 e, Matrix e, Matrix e)
-> Acc (Matrix e)
forall a b. (a -> b) -> a -> b
$
#endif
(\(T3 Acc (Array DIM0 e)
_ Acc (Matrix e)
arr Acc (Matrix e)
brr) -> Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mXm Acc (Matrix e)
arr Acc (Matrix e)
brr)
nR :: NumericR e (EltR e)
nR = Numeric e => NumericR e (EltR e)
forall a. Numeric a => NumericR a (EltR a)
numericR @e
mXm :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mXm :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mXm Acc (Matrix e)
arr Acc (Matrix e)
brr
= (Exp e -> Exp e -> Exp e)
-> Exp e -> Acc (Array (DIM2 :. Int) e) -> Acc (Matrix e)
forall sh a.
(Shape sh, Elt a) =>
(Exp a -> Exp a -> Exp a)
-> Exp a -> Acc (Array (sh :. Int) a) -> Acc (Array sh a)
fold Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
(+) Exp e
0
(Acc (Array (DIM2 :. Int) e) -> Acc (Matrix e))
-> Acc (Array (DIM2 :. Int) e) -> Acc (Matrix e)
forall a b. (a -> b) -> a -> b
$ (Exp e -> Exp e -> Exp e)
-> Acc (Array (DIM2 :. Int) e)
-> Acc (Array (DIM2 :. Int) e)
-> Acc (Array (DIM2 :. Int) e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
zipWith (\Exp e
a Exp e
b -> Exp e
alpha Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e
a Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e
b) Acc (Array (DIM2 :. Int) e)
Acc (Array (FullShape (((DIM0 :. All) :. Int) :. All)) e)
arrRepl Acc (Array (DIM2 :. Int) e)
Acc (Array (FullShape (((DIM0 :. Int) :. All) :. All)) e)
brrRepl
where
DIM0
Z :. Exp Int
rowsA :. Exp Int
_ = Exp (Plain ((DIM0 :. Exp Int) :. Exp Int))
-> (DIM0 :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp DIM2
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
arr') :: Z :. Exp Int :. Exp Int
DIM0
Z :. Exp Int
colsB :. Exp Int
_ = Exp (Plain ((DIM0 :. Exp Int) :. Exp Int))
-> (DIM0 :. Exp Int) :. Exp Int
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift (Acc (Matrix e) -> Exp DIM2
forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape Acc (Matrix e)
brr') :: Z :. Exp Int :. Exp Int
arrRepl :: Acc (Array (FullShape (((DIM0 :. All) :. Int) :. All)) e)
arrRepl = Exp (((DIM0 :. All) :. Int) :. All)
-> Acc (Array (SliceShape (((DIM0 :. All) :. Int) :. All)) e)
-> Acc (Array (FullShape (((DIM0 :. All) :. Int) :. All)) e)
forall slix e.
(Slice slix, Elt e) =>
Exp slix
-> Acc (Array (SliceShape slix) e)
-> Acc (Array (FullShape slix) e)
replicate ((((DIM0 :. All) :. Exp Int) :. All)
-> Exp (Plain (((DIM0 :. All) :. Exp Int) :. All))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift ((((DIM0 :. All) :. Exp Int) :. All)
-> Exp (Plain (((DIM0 :. All) :. Exp Int) :. All)))
-> (((DIM0 :. All) :. Exp Int) :. All)
-> Exp (Plain (((DIM0 :. All) :. Exp Int) :. All))
forall a b. (a -> b) -> a -> b
$ DIM0
Z DIM0 -> All -> DIM0 :. All
forall tail head. tail -> head -> tail :. head
:. All
All (DIM0 :. All) -> Exp Int -> (DIM0 :. All) :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int
colsB ((DIM0 :. All) :. Exp Int)
-> All -> ((DIM0 :. All) :. Exp Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All) Acc (Matrix e)
Acc (Array (SliceShape (((DIM0 :. All) :. Int) :. All)) e)
arr'
brrRepl :: Acc (Array (FullShape (((DIM0 :. Int) :. All) :. All)) e)
brrRepl = Exp (((DIM0 :. Int) :. All) :. All)
-> Acc (Array (SliceShape (((DIM0 :. Int) :. All) :. All)) e)
-> Acc (Array (FullShape (((DIM0 :. Int) :. All) :. All)) e)
forall slix e.
(Slice slix, Elt e) =>
Exp slix
-> Acc (Array (SliceShape slix) e)
-> Acc (Array (FullShape slix) e)
replicate ((((DIM0 :. Exp Int) :. All) :. All)
-> Exp (Plain (((DIM0 :. Exp Int) :. All) :. All))
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift ((((DIM0 :. Exp Int) :. All) :. All)
-> Exp (Plain (((DIM0 :. Exp Int) :. All) :. All)))
-> (((DIM0 :. Exp Int) :. All) :. All)
-> Exp (Plain (((DIM0 :. Exp Int) :. All) :. All))
forall a b. (a -> b) -> a -> b
$ DIM0
Z DIM0 -> Exp Int -> DIM0 :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int
rowsA (DIM0 :. Exp Int) -> All -> (DIM0 :. Exp Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All ((DIM0 :. Exp Int) :. All)
-> All -> ((DIM0 :. Exp Int) :. All) :. All
forall tail head. tail -> head -> tail :. head
:. All
All) Acc (Matrix e)
Acc (Array (SliceShape (((DIM0 :. Int) :. All) :. All)) e)
brr'
arr' :: Acc (Matrix e)
arr' = case Transpose
opA of
Transpose
N -> Acc (Matrix e)
arr
Transpose
T -> Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr
Transpose
H -> case NumericR e (EltR e)
nR of
NumericR e (EltR e)
NumericRcomplex32 -> (Exp (Complex Float) -> Exp (Complex Float))
-> Acc (Array DIM2 (Complex Float))
-> Acc (Array DIM2 (Complex Float))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Float) -> Exp (Complex Float)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate (Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr)
NumericR e (EltR e)
NumericRcomplex64 -> (Exp (Complex Double) -> Exp (Complex Double))
-> Acc (Array DIM2 (Complex Double))
-> Acc (Array DIM2 (Complex Double))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Double) -> Exp (Complex Double)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate (Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr)
NumericR e (EltR e)
_ -> Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
arr
brr' :: Acc (Matrix e)
brr' = case Transpose
opB of
Transpose
N -> Acc (Matrix e) -> Acc (Matrix e)
forall e. Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose Acc (Matrix e)
brr
Transpose
T -> Acc (Matrix e)
brr
Transpose
H -> case NumericR e (EltR e)
nR of
NumericR e (EltR e)
NumericRcomplex32 -> (Exp (Complex Float) -> Exp (Complex Float))
-> Acc (Array DIM2 (Complex Float))
-> Acc (Array DIM2 (Complex Float))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Float) -> Exp (Complex Float)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate Acc (Matrix e)
Acc (Array DIM2 (Complex Float))
brr
NumericR e (EltR e)
NumericRcomplex64 -> (Exp (Complex Double) -> Exp (Complex Double))
-> Acc (Array DIM2 (Complex Double))
-> Acc (Array DIM2 (Complex Double))
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
map Exp (Complex Double) -> Exp (Complex Double)
forall a. Num a => Exp (Complex a) -> Exp (Complex a)
conjugate Acc (Matrix e)
Acc (Array DIM2 (Complex Double))
brr
NumericR e (EltR e)
_ -> Acc (Matrix e)
brr