module Numeric.CBLAS.FFI.Private (
   dotReal,
   sum,
   copyMatrix,
--   transferMatrix,
   addMatrix,
   ) where

import Numeric.CBLAS.FFI.Common (pointerSeq)

import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Modifier as Modi
import qualified Numeric.Netlib.Class as Class
import qualified Numeric.Netlib.Utility as Call

import qualified Control.Monad.Trans.Cont as MC
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (liftA2)

import Foreign.Marshal.Array (advancePtr)
import Foreign.Ptr (Ptr, castPtr)

import Data.Complex (Complex((:+)))

import Prelude hiding (sum)



newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a}

sum :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
sum =
   runSum $
   Class.switchFloating
      (Sum sumReal)
      (Sum sumReal)
      (Sum sumComplex)
      (Sum sumComplex)

sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a
sumReal n xPtr incx = MC.evalContT $ do
   nPtr <- Call.cint n
   incxPtr <- Call.cint incx
   yPtr <- Call.real 1
   incyPtr <- Call.cint 0
   liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr

sumComplex ::
   Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex n xPtr incx = MC.evalContT $ do
   nPtr <- Call.cint n
   let sxPtr = castPtr xPtr -- realPtr
   incxPtr <- Call.cint (2*incx)
   yPtr <- Call.real 1
   incyPtr <- Call.cint 0
   liftIO $
      liftA2 (:+)
         (BlasReal.dot nPtr sxPtr incxPtr yPtr incyPtr)
         (BlasReal.dot nPtr (advancePtr sxPtr 1) incxPtr yPtr incyPtr)


dotReal :: Class.Real a => Int -> Ptr a -> Int -> Ptr a -> Int -> IO a
dotReal n xPtr incx yPtr incy = MC.evalContT $ do
   nPtr <- Call.cint n
   incxPtr <- Call.cint incx
   incyPtr <- Call.cint incy
   liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr



copyMatrix ::
   (Class.Floating a) =>
   Modi.Transposition ->
   Int -> Int ->
   Ptr a -> Int ->
   Ptr a -> Int ->
   IO ()
copyMatrix transp rows cols aPtr lda bPtr ldb =
   case transp of
      Modi.NonTransposed -> MC.evalContT $ do
         incPtr <- Call.cint 1
         if rows == lda && rows == ldb
            then do
               nPtr <- Call.cint (rows*cols)
               liftIO $ Blas.copy nPtr aPtr incPtr bPtr incPtr
            else do
               nPtr <- Call.cint rows
               liftIO $ sequence_ $ take cols $
                  zipWith
                     (\akPtr bkPtr -> Blas.copy nPtr akPtr incPtr bkPtr incPtr)
                     (pointerSeq lda aPtr)
                     (pointerSeq ldb bPtr)
      Modi.Transposed -> MC.evalContT $ do
         nPtr <- Call.cint cols
         incaPtr <- Call.cint lda
         incbPtr <- Call.cint 1
         liftIO $ sequence_ $ take rows $
            zipWith
               (\akPtr bkPtr -> Blas.copy nPtr akPtr incaPtr bkPtr incbPtr)
               (pointerSeq 1 aPtr)
               (pointerSeq ldb bPtr)

{-
transferMatrix ::
   (Class.Floating a) =>
   Modi.Transposition ->
   Modi.Conjugation ->
   Int -> Int ->
   a ->
   Ptr a -> Int ->
   Ptr a -> Int ->
   IO ()
transferMatrix transp conj rows cols alpha a lda b ldb =
-}


addMatrix ::
   (Class.Floating a) =>
   Int -> Int ->
   a -> Ptr a -> Int ->
   a -> Ptr a -> Int ->
   IO ()
addMatrix rows cols alpha aPtr lda beta bPtr ldb = MC.evalContT $ do
   incPtr <- Call.cint 1
   alphaPtr <- Call.number alpha
   betaPtr <- Call.number beta
   if rows == lda && rows == ldb
      then do
         nPtr <- Call.cint (rows*cols)
         liftIO $ Blas.scal nPtr betaPtr bPtr incPtr
         liftIO $ Blas.axpy nPtr alphaPtr aPtr incPtr bPtr incPtr
      else do
         nPtr <- Call.cint rows
         liftIO $ sequence_ $ take cols $
            zipWith
               (\akPtr bkPtr -> do
                  Blas.scal nPtr betaPtr bkPtr incPtr
                  Blas.axpy nPtr alphaPtr akPtr incPtr bkPtr incPtr)
               (pointerSeq lda aPtr)
               (pointerSeq ldb bPtr)
