{-# LANGUAGE ForeignFunctionInterface #-}
module Numeric.CBLAS.FFI.Routine where

import Numeric.CBLAS.FFI.Type (CBlasInt)
import Numeric.CBLAS.FFI.Common (pointerSeq)

import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal (with, alloca)
import Foreign.Storable (Storable, peek)
import Foreign.Ptr (Ptr)

import Data.Complex (Complex)
import Data.Int

import Prelude hiding (sum)



withResult :: (Storable a) => (Ptr a -> IO ()) -> IO a
withResult act =
   alloca $ \ptr -> act ptr >> peek ptr


type Dot a =
         CBlasInt -> Ptr a -> CBlasInt -> Ptr a -> CBlasInt -> IO a
type DotSub a =
         CBlasInt -> Ptr a -> CBlasInt -> Ptr a -> CBlasInt -> Ptr a -> IO ()


foreign import ccall "cblas_sdot" sdot :: Dot Float
foreign import ccall "cblas_ddot" ddot :: Dot Double

foreign import ccall "cblas_cdotu_sub" cdotu :: DotSub (Complex Float)
foreign import ccall "cblas_cdotc_sub" cdotc :: DotSub (Complex Float)
foreign import ccall "cblas_zdotu_sub" zdotu :: DotSub (Complex Double)
foreign import ccall "cblas_zdotc_sub" zdotc :: DotSub (Complex Double)


newtype DOT a = DOT {getDOT :: Dot a}

dotu :: (Class.Floating a) => Dot a
dotu =
   getDOT $
   Class.switchFloating
      (DOT sdot)
      (DOT ddot)
      (DOT $ \n xPtr incx yPtr incy ->
         withResult $ cdotu n xPtr incx yPtr incy)
      (DOT $ \n xPtr incx yPtr incy ->
         withResult $ zdotu n xPtr incx yPtr incy)

dotc :: (Class.Floating a) => Dot a
dotc =
   getDOT $
   Class.switchFloating
      (DOT sdot)
      (DOT ddot)
      (DOT $ \n xPtr incx yPtr incy ->
         withResult $ cdotc n xPtr incx yPtr incy)
      (DOT $ \n xPtr incx yPtr incy ->
         withResult $ zdotc n xPtr incx yPtr incy)


sum :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
sum n xPtr incx =
   with 1 $ \yPtr -> dotu (fromIntegral n) xPtr (fromIntegral incx) yPtr 0




type Scal c a = CBlasInt -> c -> Ptr a -> CBlasInt -> IO ()
type Axpy c a = CBlasInt -> c -> Ptr a -> CBlasInt -> Ptr a -> CBlasInt -> IO ()

type CScal a = Scal (Ptr a) a
type CAxpy a = Axpy (Ptr a) a

foreign import ccall "cblas_sscal" sscal :: Scal Float Float
foreign import ccall "cblas_dscal" dscal :: Scal Double Double
foreign import ccall "cblas_cscal" cscal :: CScal (Complex Float)
foreign import ccall "cblas_zscal" zscal :: CScal (Complex Double)

foreign import ccall "cblas_saxpy" saxpy :: Axpy Float Float
foreign import ccall "cblas_daxpy" daxpy :: Axpy Double Double
foreign import ccall "cblas_caxpy" caxpy :: CAxpy (Complex Float)
foreign import ccall "cblas_zaxpy" zaxpy :: CAxpy (Complex Double)

newtype SCAL a = SCAL {getSCAL :: Scal a a}

scal :: (Class.Floating a) => Scal a a
scal =
   getSCAL $
   Class.switchFloating
      (SCAL sscal)
      (SCAL dscal)
      (SCAL $ \n alpha x incx ->
         with alpha $ \alphaPtr -> cscal n alphaPtr x incx)
      (SCAL $ \n alpha x incx ->
         with alpha $ \alphaPtr -> zscal n alphaPtr x incx)

newtype AXPY a = AXPY {getAXPY :: Axpy a a}

axpy :: (Class.Floating a) => Axpy a a
axpy =
   getAXPY $
   Class.switchFloating
      (AXPY saxpy)
      (AXPY daxpy)
      (AXPY $ \n alpha x incx y incy ->
         with alpha $ \alphaPtr -> caxpy n alphaPtr x incx y incy)
      (AXPY $ \n alpha x incx y incy ->
         with alpha $ \alphaPtr -> zaxpy n alphaPtr x incx y incy)


addMatrix ::
   (Class.Floating a) =>
   Int -> Int ->
   a -> Ptr a -> Int ->
   a -> Ptr a -> Int ->
   IO ()
addMatrix rows cols alpha aPtr lda beta bPtr ldb = do
   let inc = 1
   if rows == lda && rows == ldb
      then do
         let n = fromIntegral (rows*cols)
         scal n beta bPtr inc
         axpy n alpha aPtr inc bPtr inc
      else do
         let n = fromIntegral rows
         sequence_ $ take cols $
            zipWith
               (\akPtr bkPtr -> do
                  scal n beta bkPtr inc
                  axpy n alpha akPtr inc bkPtr inc)
               (pointerSeq lda aPtr)
               (pointerSeq ldb bPtr)
