-- ToDo: minimum might be available in ipp Intel Performance Primitives
-- or OpenCV
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE CApiFFI #-}
module Numeric.CBLAS.FFI.Private (
   Routine.dotu, Routine.dotc, Routine.sum,
--   omatcopy, ccharFromChar,
   copyMatrix,
   transferMatrix,
   addMatrix,
   ) where

import qualified Numeric.CBLAS.FFI.Routine as Routine
import qualified Numeric.Netlib.Modifier as Modi
import qualified Numeric.Netlib.Class as Class

import Foreign.Ptr (Ptr)
import Foreign.C.Types

import Data.Complex (Complex((:+)))



type OMatCopy a =
      CChar -> CChar ->
      CSize -> CSize ->
      a ->
      Ptr a -> CSize ->
      Ptr a -> CSize ->
      IO ()

-- foreign import ccall "mkl_trans.h mkl_somatcopy"
foreign import capi "mkl_trans.h mkl_somatcopy"
   somatcopy :: OMatCopy Float

-- foreign import ccall "mkl_domatcopy"
foreign import capi "mkl_trans.h mkl_domatcopy"
   domatcopy :: OMatCopy Double

type COMatCopy a =
      CChar -> CChar ->
      CSize -> CSize ->
      a -> a ->
      Ptr (Complex a) -> CSize ->
      Ptr (Complex a) -> CSize ->
      IO ()

foreign import ccall "mkl_comatcopy_hs"
   comatcopy :: COMatCopy Float

foreign import ccall "mkl_zomatcopy_hs"
   zomatcopy :: COMatCopy Double

newtype OMATCOPY a = OMATCOPY {getOMATCOPY :: OMatCopy a}

omatcopy :: (Class.Floating a) => OMatCopy a
omatcopy =
   getOMATCOPY $
   Class.switchFloating
      (OMATCOPY somatcopy)
      (OMATCOPY domatcopy)
      (OMATCOPY $ \order transp rows cols (alphaR:+alphaI) a lda b ldb ->
         comatcopy order transp rows cols alphaR alphaI a lda b ldb)
      (OMATCOPY $ \order transp rows cols (alphaR:+alphaI) a lda b ldb ->
         zomatcopy order transp rows cols alphaR alphaI a lda b ldb)


copyMatrix ::
   (Class.Floating a) =>
   Modi.Transposition ->
   Int -> Int ->
   Ptr a -> Int ->
   Ptr a -> Int ->
   IO ()
copyMatrix transp rows cols a lda b ldb =
   omatcopy
      (ccharFromChar 'C')
      (ccharFromChar $
       case transp of
         Modi.NonTransposed -> 'N'
         Modi.Transposed -> 'T')
      (fromIntegral rows)
      (fromIntegral cols)
      1 a (fromIntegral lda) b (fromIntegral ldb)

transferMatrix ::
   (Class.Floating a) =>
   Modi.Transposition ->
   Modi.Conjugation ->
   Int -> Int ->
   a ->
   Ptr a -> Int ->
   Ptr a -> Int ->
   IO ()
transferMatrix transp conj rows cols alpha a lda b ldb =
   omatcopy
      (ccharFromChar 'C')
      (ccharFromChar $
       case (transp,conj) of
         (Modi.NonTransposed, Modi.NonConjugated) -> 'N'
         (Modi.Transposed, Modi.NonConjugated) -> 'T'
         (Modi.NonTransposed, Modi.Conjugated) -> 'C'
         (Modi.Transposed, Modi.Conjugated) -> 'R')
      (fromIntegral rows)
      (fromIntegral cols)
      alpha a (fromIntegral lda) b (fromIntegral ldb)


type OMatAdd a =
      CChar -> CChar -> CChar ->
      CSize -> CSize ->
      a -> Ptr a -> CSize ->
      a -> Ptr a -> CSize ->
      Ptr a -> CSize ->
      IO ()

-- foreign import ccall "mkl_somatadd"
foreign import capi "mkl_trans.h mkl_somatadd"
   somatadd :: OMatAdd Float

-- foreign import ccall "mkl_domatadd"
foreign import capi "mkl_trans.h mkl_domatadd"
   domatadd :: OMatAdd Double

type COMatAdd a =
      CChar -> CChar -> CChar ->
      CSize -> CSize ->
      a -> a -> Ptr (Complex a) -> CSize ->
      a -> a -> Ptr (Complex a) -> CSize ->
      Ptr (Complex a) -> CSize ->
      IO ()

foreign import ccall "mkl_comatadd_hs"
   comatadd :: COMatAdd Float

foreign import ccall "mkl_zomatadd_hs"
   zomatadd :: COMatAdd Double

newtype OMATADD a = OMATADD {getOMATADD :: OMatAdd a}

omatadd :: (Class.Floating a) => OMatAdd a
omatadd =
   getOMATADD $
   Class.switchFloating
      (OMATADD somatadd)
      (OMATADD domatadd)
      (OMATADD $ \order transa transb rows cols
            (alphaR:+alphaI) a lda (betaR:+betaI) b ldb ->
         comatadd order transa transb rows cols
            alphaR alphaI a lda betaR betaI b ldb)
      (OMATADD $ \order transa transb rows cols
            (alphaR:+alphaI) a lda (betaR:+betaI) b ldb ->
         zomatadd order transa transb rows cols
            alphaR alphaI a lda betaR betaI b ldb)


addMatrix ::
   (Class.Floating a) =>
   Int -> Int ->
   a -> Ptr a -> Int ->
   a -> Ptr a -> Int ->
   IO ()
addMatrix rows cols alpha a lda beta b ldb =
   omatadd
      (ccharFromChar 'C')
      (ccharFromChar 'N')
      (ccharFromChar 'N')
      (fromIntegral rows)
      (fromIntegral cols)
      alpha a (fromIntegral lda)
      beta b (fromIntegral ldb)
      b (fromIntegral ldb)



ccharFromChar :: Char -> CChar
ccharFromChar = toEnum . fromEnum
