module Numeric.CBLAS.FFI.Extra (sum, copyMatrix, addMatrix) where

import qualified Numeric.CBLAS.FFI.Private as Priv
import Numeric.Netlib.Modifier (Transposition(..))

import qualified Numeric.Netlib.Class as Class

import Foreign.Ptr (Ptr)

import Control.Monad (when)

{-
import Foreign.Ptr (minusPtr, nullPtr)
import Text.Printf (printf)
-}

import Prelude hiding (sum)


sum :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
sum = Priv.sum


{-
Do not simply re-export Priv.copyMatrix.
Instead ensure consistent interface across all implementations.
-}
copyMatrix ::
   (Class.Floating a) =>
   Transposition ->
   Int -> Int ->
   Ptr a -> Int ->
   Ptr a -> Int ->
   IO ()
copyMatrix transp rows cols a lda b ldb = do
{-
   printf "trans %s, rows %d, cols %d, a %x, lda %d, b %x, ldb %d\n"
      (show transp) rows cols (minusPtr a nullPtr) lda (minusPtr b nullPtr) ldb
-}
   when (rows>0 && cols>0) $
      Priv.copyMatrix transp rows cols a lda b ldb
{-
   putStrLn "trans end"
-}


addMatrix ::
   (Class.Floating a) =>
   Int -> Int ->
   a -> Ptr a -> Int ->
   a -> Ptr a -> Int ->
   IO ()
addMatrix = Priv.addMatrix


{-
ToDo:
batched matrix multipliation

mkl has it as batched, strided multiplication

https://www.intel.com/content/www/us/en/developer/articles/technical/introducing-batch-gemm-operations.html

flame-blis:
void BLIS_EXPORT_BLAS cblas_sgemm_batch(enum CBLAS_ORDER Order,
                 enum CBLAS_TRANSPOSE *TransA_array,
                 enum CBLAS_TRANSPOSE *TransB_array,
                 f77_int *M_array, f77_int *N_array,
                 f77_int *K_array, const float *alpha_array, const float **A,
                 f77_int *lda_array, const float **B, f77_int *ldb_array,
                 const float *beta_array, float **C, f77_int *ldc_array,
                 f77_int group_count, f77_int *group_size);

Missing in OpenBLAS:
https://github.com/OpenMathLib/OpenBLAS/discussions/4707
-}
