{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level3 (
Numeric, Matrix, Transpose(..),
gemm,
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Smart as A
import Data.Array.Accelerate.Data.Complex as A
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level3 as CPU
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3 as PTX
#endif
gemm :: forall e. Numeric e
=> Exp e
-> Transpose
-> Acc (Matrix e)
-> Transpose
-> Acc (Matrix e)
-> Acc (Matrix e)
gemm alpha opA matA opB matB = go (lift (unit alpha, matA, matB))
where
go =
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
foreignAcc (CPU.gemm opA opB) $
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
foreignAcc (PTX.gemm opA opB) $
#endif
(\(unatup3 -> (_, arr, brr)) -> mXm arr brr)
mXm :: Acc (Matrix e) -> Acc (Matrix e) -> Acc (Matrix e)
mXm arr brr
= fold (+) 0
$ zipWith (\a b -> alpha * a * b) arrRepl brrRepl
where
Z :. rowsA :. _ = unlift (shape arr') :: Z :. Exp Int :. Exp Int
Z :. colsB :. _ = unlift (shape brr') :: Z :. Exp Int :. Exp Int
arrRepl = replicate (lift $ Z :. All :. colsB :. All) arr'
brrRepl = replicate (lift $ Z :. rowsA :. All :. All) brr'
arr' = case opA of
N -> arr
T -> transpose arr
H -> case numericR :: NumericR e of
NumericRcomplex32 -> map conjugate (transpose arr)
NumericRcomplex64 -> map conjugate (transpose arr)
_ -> transpose arr
brr' = case opB of
N -> transpose brr
T -> brr
H -> case numericR :: NumericR e of
NumericRcomplex32 -> map conjugate brr
NumericRcomplex64 -> map conjugate brr
_ -> brr