{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.BLAS.Level2 (
Numeric, Vector, Matrix, Transpose(..),
gemv,
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Smart as A
import Data.Array.Accelerate.Data.Complex as A
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level2 as CPU
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
import qualified Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level2 as PTX
#endif
gemv :: forall e. Numeric e
=> Exp e
-> Transpose
-> Acc (Matrix e)
-> Acc (Vector e)
-> Acc (Vector e)
gemv alpha opA matA x = go (lift (unit alpha, matA, x))
where
go =
#ifdef ACCELERATE_LLVM_NATIVE_BACKEND
foreignAcc (CPU.gemv opA) $
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
foreignAcc (PTX.gemv opA) $
#endif
(\(unatup3 -> (_, arr, brr)) -> mXv arr brr)
mXv :: Acc (Matrix e) -> Acc (Vector e) -> Acc (Vector e)
mXv arr brr
= fold (+) 0
$ zipWith (\a b -> alpha * a * b) arr' brr'
where
Z :. m :. _ = unlift (shape arr') :: Z :. Exp Int :. Exp Int
brr' = replicate (lift (Z :. m :. All)) brr
arr' = case opA of
N -> arr
T -> transpose arr
H -> case numericR :: NumericR e of
NumericRcomplex32 -> map conjugate (transpose arr)
NumericRcomplex64 -> map conjugate (transpose arr)
_ -> transpose arr