{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level2
-- Copyright   : [2017] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level2
  where

import Data.Array.Accelerate                                        as A
import Data.Array.Accelerate.Array.Sugar                            ( Array(..) )
import Data.Array.Accelerate.Data.Complex
import Data.Array.Accelerate.LLVM.PTX.Foreign
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Base
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Context
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Level3
import Data.Array.Accelerate.Numeric.LinearAlgebra.LLVM.PTX.Twine
import Data.Array.Accelerate.Numeric.LinearAlgebra.Type

import Foreign.Marshal                                              ( with )
import Foreign.Storable.Complex                                     ( )

import qualified Foreign.CUDA.Ptr                                   as CUDA
import qualified Foreign.CUDA.BLAS                                  as BLAS


-- NOTE: cuBLAS requires matrices to be stored in column-major order
-- (Fortran-style), but Accelerate uses C-style arrays in row-major order.
--
-- If the operation is N or T, we can just swap the operation. For
-- conjugate-transpose (H) operations (on complex valued arguments), since there
-- is no conjugate-no-transpose operation, we implement that via 'gemm', which
-- I assume is more efficient than ?geam followed by ?gemv.
--
gemv :: Numeric e
     => Transpose
     -> ForeignAcc ((Scalar e, Matrix e, Vector e) -> Vector e)
gemv opA = ForeignAcc "ptx.gemv" (gemv' numericR opA)

gemv' :: Numeric e
      => NumericR e
      -> Transpose
      -> Stream
      -> (Scalar e, Matrix e, Vector e)
      -> LLVM PTX (Vector e)
gemv' NumericRcomplex32 H = as_gemm H
gemv' NumericRcomplex64 H = as_gemm H
gemv' _                 t = as_gemv t


as_gemm
    :: Numeric e
    => Transpose
    -> Stream
    -> (Scalar e, Matrix e, Vector e)
    -> LLVM PTX (Vector e)
as_gemm opA stream (alpha, matA, Array sh adata) = do
  let matB = Array (sh,1) adata
  --
  Array (sh',1) vecy <- gemm' opA N stream (alpha, matA, matB)
  return (Array sh' vecy)

as_gemv
    :: forall e. Numeric e
    => Transpose
    -> Stream
    -> (Scalar e, Matrix e, Vector e)
    -> LLVM PTX (Vector e)
as_gemv opA stream (alpha, matA, vecx) = do
  let
      Z :. rowsA :. colsA = arrayShape matA
      Z :. sizeX          = arrayShape vecx

      sizeA   = rowsA * colsA
      sizeY   = case opA of
                  N -> rowsA
                  _ -> colsA

      opA'    = encodeTranspose
              $ case opA of
                  N -> T
                  _ -> N
  --
  vecy    <- allocateRemote (Z :. sizeY) :: LLVM PTX (Vector e)
  alpha'  <- indexRemote alpha 0
  ()      <- do
    withArray matA stream   $ \ptr_A -> do
     withArray vecx stream  $ \ptr_x -> do
      withArray vecy stream $ \ptr_y -> do
       withBLAS             $ \hdl   -> do
         case numericR :: NumericR e of
           NumericRfloat32 -> liftIO $
            with alpha' $ \ptr_alpha ->
             with 0     $ \ptr_beta  ->
               BLAS.sgemv hdl opA' colsA rowsA ptr_alpha ptr_A colsA ptr_x 1 ptr_beta ptr_y 1

           NumericRfloat64 -> liftIO $
            with alpha' $ \ptr_alpha ->
             with 0     $ \ptr_beta  ->
               BLAS.dgemv hdl opA' colsA rowsA ptr_alpha ptr_A colsA ptr_x 1 ptr_beta ptr_y 1

           NumericRcomplex32 -> do
            tmpy <- allocateRemote (Z :. sizeY * 2) :: LLVM PTX (Vector Float)
            withArray tmpy stream           $ \ptr_y' -> do
             interleave ptr_A stream sizeA  $ \ptr_A' -> do
              interleave ptr_x stream sizeX $ \ptr_x' -> do
               liftIO $ do
                with alpha' $ \ptr_alpha ->
                 with 0     $ \ptr_beta  -> do
                  BLAS.cgemv hdl opA' colsA rowsA ptr_alpha ptr_A' colsA ptr_x' 1 ptr_beta (CUDA.castDevPtr ptr_y' :: CUDA.DevicePtr (Complex Float))  1
               deinterleave ptr_y (CUDA.castDevPtr ptr_y' :: CUDA.DevicePtr (Complex Float)) stream sizeY

           NumericRcomplex64 -> do
            tmpy <- allocateRemote (Z :. sizeY * 2) :: LLVM PTX (Vector Double)
            withArray tmpy stream           $ \ptr_y' -> do
             interleave ptr_A stream sizeA  $ \ptr_A' -> do
              interleave ptr_x stream sizeX $ \ptr_x' -> do
               liftIO $ do
                with alpha' $ \ptr_alpha ->
                 with 0     $ \ptr_beta  -> do
                  BLAS.zgemv hdl opA' colsA rowsA ptr_alpha ptr_A' colsA ptr_x' 1 ptr_beta (CUDA.castDevPtr ptr_y' :: CUDA.DevicePtr (Complex Double))  1
               deinterleave ptr_y (CUDA.castDevPtr ptr_y' :: CUDA.DevicePtr (Complex Double)) stream sizeY
  --
  return vecy