{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Level2
where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.LLVM.Native.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.Native.Base
import Foreign.Marshal.Alloc
import Foreign.Storable
import Foreign.Storable.Complex ( )
import qualified Blas.Primitive.Types as C
import qualified Blas.Primitive.Unsafe as C
gemv :: forall e. Numeric e
=> Transpose
-> ForeignAcc ((Scalar e, Matrix e, Vector e) -> Vector e)
gemv opA = ForeignAcc "native.gemv" gemv'
where
gemv' (alpha, matA, vecx) = do
let
Z :. rowsA :. colsA = arrayShape matA
Z :. sizeX = arrayShape vecx
sizeA = rowsA * colsA
sizeY = case opA of
N -> rowsA
_ -> colsA
opA' = encodeTranspose opA
alpha' = indexArray alpha Z
vecy <- allocateRemote (Z :. sizeY) :: LLVM Native (Vector e)
() <- liftIO $ do
withArray matA $ \ptr_A -> do
withArray vecx $ \ptr_x -> do
withArray vecy $ \ptr_y -> do
case numericR :: NumericR e of
NumericRfloat32 -> C.sgemv C.RowMajor opA' rowsA colsA alpha' ptr_A colsA ptr_x 1 0 ptr_y 1
NumericRfloat64 -> C.dgemv C.RowMajor opA' rowsA colsA alpha' ptr_A colsA ptr_x 1 0 ptr_y 1
NumericRcomplex32 -> do
allocaBytesAligned (sizeY * sizeOf (undefined::Complex e)) 16 $ \ptr_y' -> do
interleave ptr_A sizeA $ \ptr_A' -> do
interleave ptr_x sizeX $ \ptr_x' -> do
C.cgemv C.RowMajor opA' rowsA colsA alpha' ptr_A' colsA ptr_x' 1 0 ptr_y' 1
deinterleave ptr_y ptr_y' sizeY
NumericRcomplex64 -> do
allocaBytesAligned (sizeY * sizeOf (undefined::Complex e)) 16 $ \ptr_y' -> do
interleave ptr_A sizeA $ \ptr_A' -> do
interleave ptr_x sizeX $ \ptr_x' -> do
C.zgemv C.RowMajor opA' rowsA colsA alpha' ptr_A' colsA ptr_x' 1 0 ptr_y' 1
deinterleave ptr_y ptr_y' sizeY
return vecy