{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Private where import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Foreign.Marshal.Array (advancePtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, peek) import Control.Monad.Trans.Cont (evalContT) import Control.Monad.IO.Class (liftIO) import Data.Functor.Identity (Identity(Identity, runIdentity)) import Data.Complex (Complex) import Prelude hiding (sum) type family RealOf x type instance RealOf Float = Float type instance RealOf Double = Double type instance RealOf (Complex Float) = Float type instance RealOf (Complex Double) = Double zero, one, minusOne :: Class.Floating a => a zero = runIdentity $ Class.switchFloating (Identity 0) (Identity 0) (Identity 0) (Identity 0) one = runIdentity $ Class.switchFloating (Identity 1) (Identity 1) (Identity 1) (Identity 1) minusOne = runIdentity $ Class.switchFloating (Identity (-1)) (Identity (-1)) (Identity (-1)) (Identity (-1)) oneReal :: Class.Real a => a oneReal = runIdentity $ Class.switchReal (Identity 1) (Identity 1) fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO () fill a n dstPtr = evalContT $ do nPtr <- Call.cint n srcPtr <- Call.number a incxPtr <- Call.cint 0 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO () copyBlock n srcPtr dstPtr = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr {- | In ColumnMajor: Copy a m-by-n-matrix with lda>=m and ldb>=m. -} copySubMatrix :: (Storable a, Class.Floating a) => Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copySubMatrix m n lda aPtr ldb bPtr = evalContT $ do uploPtr <- Call.char 'A' mPtr <- Call.cint m nPtr <- Call.cint n ldaPtr <- Call.cint lda ldbPtr <- Call.cint ldb liftIO $ LapackGen.lacpy uploPtr mPtr nPtr aPtr ldaPtr bPtr ldbPtr copyTransposed :: (Storable a, Class.Floating a) => Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copyTransposed n m aPtr ldb bPtr = evalContT $ do nPtr <- Call.cint n incaPtr <- Call.cint m incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\akPtr bkPtr -> BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr) (pointerSeq 1 aPtr) (pointerSeq ldb bPtr) pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a] pointerSeq k ptr = iterate (flip advancePtr k) ptr newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a} sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a sum = runSum $ Class.switchFloating (Sum sumReal) (Sum sumReal) (Sum sumComplex) (Sum sumComplex) sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a sumReal n xPtr incx = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint incx yPtr <- Call.real oneReal incyPtr <- Call.cint 0 liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr sumComplex :: Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a) sumComplex n xPtr incx = evalContT $ do transPtr <- Call.char 'N' mPtr <- Call.cint 1 nPtr <- Call.cint n alphaPtr <- Call.number one onePtr <- Call.number one zeroincPtr <- Call.cint 0 aPtr <- Call.allocaArray n ldaPtr <- Call.cint 1 incxPtr <- Call.cint incx betaPtr <- Call.number zero yPtr <- Call.alloca incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr onePtr zeroincPtr aPtr incyPtr liftIO $ BlasGen.gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr liftIO $ peek yPtr