{-# LANGUAGE ForeignFunctionInterface #-}
module Numeric.CBLAS.FFI.Private (
   Routine.dotu, Routine.dotc, Routine.sum,
   Order, rowMajor, columnMajor,
   Transpose, noTrans, trans, conjTrans, conjNoTrans,
   omatcopy,
   copyMatrix,
   transferMatrix,
   addMatrix,
   ) where

import qualified Numeric.CBLAS.FFI.Routine as Routine
import Numeric.CBLAS.FFI.Type

import qualified Numeric.Netlib.Modifier as Modi
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal (with)
-- import Foreign.Storable (peek)
import Foreign.Ptr (Ptr)
-- import Foreign.C.Types

import Data.Complex (Complex)



type OMatCopy c a =
      Order -> Transpose ->
      CBlasInt -> CBlasInt ->
      c ->
      Ptr a -> CBlasInt ->
      Ptr a -> CBlasInt ->
      IO ()

type COMatCopy a = OMatCopy (Ptr a) a

foreign import ccall "cblas_somatcopy" somatcopy :: OMatCopy Float Float
foreign import ccall "cblas_domatcopy" domatcopy :: OMatCopy Double Double
foreign import ccall "cblas_comatcopy" comatcopy :: COMatCopy (Complex Float)
foreign import ccall "cblas_zomatcopy" zomatcopy :: COMatCopy (Complex Double)

newtype OMATCOPY a = OMATCOPY {getOMATCOPY :: OMatCopy a a}

omatcopy :: (Class.Floating a) => OMatCopy a a
omatcopy =
   getOMATCOPY $
   Class.switchFloating
      (OMATCOPY somatcopy)
      (OMATCOPY domatcopy)
      (OMATCOPY $ \order transp rows cols alpha a lda b ldb ->
         with alpha $ \alphaPtr ->
         comatcopy order transp rows cols alphaPtr a lda b ldb)
      (OMATCOPY $ \order transp rows cols alpha a lda b ldb ->
         with alpha $ \alphaPtr ->
         zomatcopy order transp rows cols alphaPtr a lda b ldb)


copyMatrix ::
   (Class.Floating a) =>
   Modi.Transposition ->
   Int -> Int ->
   Ptr a -> Int ->
   Ptr a -> Int ->
   IO ()
copyMatrix transp rows cols a lda b ldb =
   omatcopy
      columnMajor
      (case transp of
         Modi.Transposed -> trans
         Modi.NonTransposed -> noTrans)
      (fromIntegral rows)
      (fromIntegral cols)
      1 a (fromIntegral lda) b (fromIntegral ldb)

transferMatrix ::
   (Class.Floating a) =>
   Modi.Transposition ->
   Modi.Conjugation ->
   Int -> Int ->
   a ->
   Ptr a -> Int ->
   Ptr a -> Int ->
   IO ()
transferMatrix transp conj rows cols alpha a lda b ldb =
   omatcopy
      columnMajor
      (case (transp,conj) of
         (Modi.NonTransposed, Modi.NonConjugated) -> noTrans
         (Modi.Transposed, Modi.NonConjugated) -> trans
         (Modi.NonTransposed, Modi.Conjugated) -> conjNoTrans
         (Modi.Transposed, Modi.Conjugated) -> conjTrans)
      (fromIntegral rows)
      (fromIntegral cols)
      alpha a (fromIntegral lda) b (fromIntegral ldb)



type GeAdd c a =
      Order ->
      CBlasInt -> CBlasInt ->
      c -> Ptr a -> CBlasInt ->
      c -> Ptr a -> CBlasInt ->
      IO ()

foreign import ccall "cblas_sgeadd" sgeadd :: GeAdd Float Float
foreign import ccall "cblas_dgeadd" dgeadd :: GeAdd Double Double

foreign import ccall "cblas_cgeadd"
   cgeadd :: GeAdd (Ptr (Complex Float)) (Complex Float)

foreign import ccall "cblas_zgeadd"
   zgeadd :: GeAdd (Ptr (Complex Double)) (Complex Double)


newtype GEADD a = GEADD {getGEADD :: GeAdd a a}

geadd :: (Class.Floating a) => GeAdd a a
geadd =
   getGEADD $
   Class.switchFloating
      (GEADD sgeadd)
      (GEADD dgeadd)
      (GEADD $ \order rows cols alpha a lda beta b ldb ->
         with alpha $ \alphaPtr ->
         with beta $ \betaPtr ->
         cgeadd order rows cols alphaPtr a lda betaPtr b ldb)
      (GEADD $ \order rows cols alpha a lda beta b ldb ->
         with alpha $ \alphaPtr ->
         with beta $ \betaPtr ->
         zgeadd order rows cols alphaPtr a lda betaPtr b ldb)

addMatrix ::
   (Class.Floating a) =>
   Int -> Int ->
   a -> Ptr a -> Int ->
   a -> Ptr a -> Int ->
   IO ()
addMatrix rows cols alpha a lda beta b ldb =
   geadd
      columnMajor
      (fromIntegral rows)
      (fromIntegral cols)
      alpha a (fromIntegral lda)
      beta b (fromIntegral ldb)
