{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Private where import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor), transposeFromOrder) import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip)) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.LAPACK.FFI.Complex as LapackComplex import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Numeric.LAPACK.Scalar (zero, one, isZero) import qualified Foreign.Marshal.Utils as Marshal import qualified Foreign.C.String as CStr import Foreign.Marshal.Array (copyArray, advancePtr) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CChar, CInt) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek) import Text.Printf (printf) import Control.Monad.Trans.Cont (ContT(ContT), evalContT, runContT) import Control.Monad.IO.Class (liftIO) import Control.Monad (when, foldM) import Control.Applicative ((<$>)) import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable (Array) import qualified Data.Complex as Complex import Data.Complex (Complex) import Data.Tuple.HT (swap) import Prelude hiding (sum) fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO () fill a n dstPtr = evalContT $ do nPtr <- Call.cint n srcPtr <- Call.number a incxPtr <- Call.cint 0 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO () copyBlock n srcPtr dstPtr = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr copyToTemp :: (Storable a) => Int -> ForeignPtr a -> ContT r IO (Ptr a) copyToTemp n fptr = do ptr <- ContT $ withForeignPtr fptr tmpPtr <- Call.allocaArray n liftIO $ copyArray tmpPtr ptr n return tmpPtr {- | Make a temporary copy only for complex matrices. -} conjugateToTemp :: (Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a) conjugateToTemp n = runCopyToTemp $ Class.switchFloating (CopyToTemp $ ContT . withForeignPtr) (CopyToTemp $ ContT . withForeignPtr) (CopyToTemp $ complexConjugateToTemp n) (CopyToTemp $ complexConjugateToTemp n) newtype CopyToTemp r a = CopyToTemp {runCopyToTemp :: ForeignPtr a -> ContT r IO (Ptr a)} complexConjugateToTemp :: Class.Real a => Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a)) complexConjugateToTemp n x = do nPtr <- Call.cint n xPtr <- copyToTemp n x incxPtr <- Call.cint 1 liftIO $ LapackComplex.lacgv nPtr xPtr incxPtr return xPtr copyConjugate :: (Class.Floating a) => Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO () copyConjugate nPtr xPtr incxPtr yPtr incyPtr = do BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr lacgv nPtr yPtr incyPtr copyCondConjugate :: (Class.Floating a) => Bool -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO () copyCondConjugate conj nPtr xPtr incxPtr yPtr incyPtr = do BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr when conj $ lacgv nPtr yPtr incyPtr condConjugateToTemp :: (Class.Floating a) => Bool -> Int -> ForeignPtr a -> ContT r IO (Ptr a) condConjugateToTemp conj n x = if conj then conjugateToTemp n x else ContT $ withForeignPtr x copyCondConjugateToTemp :: (Class.Floating a) => Bool -> Int -> ForeignPtr a -> ContT r IO (Ptr a) copyCondConjugateToTemp conj n a = do bPtr <- Call.allocaArray n liftIO $ evalContT $ do aPtr <- ContT $ withForeignPtr a sizePtr <- Call.cint n incPtr <- Call.cint 1 liftIO $ copyCondConjugate conj sizePtr aPtr incPtr bPtr incPtr return bPtr {- | In ColumnMajor: Copy a m-by-n-matrix with lda>=m and ldb>=m. -} copySubMatrix :: (Class.Floating a) => Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copySubMatrix = copySubTrapezoid 'A' copySubTrapezoid :: (Class.Floating a) => Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copySubTrapezoid side m n lda aPtr ldb bPtr = evalContT $ do uploPtr <- Call.char side mPtr <- Call.cint m nPtr <- Call.cint n ldaPtr <- Call.leadingDim lda ldbPtr <- Call.leadingDim ldb liftIO $ LapackGen.lacpy uploPtr mPtr nPtr aPtr ldaPtr bPtr ldbPtr copyTransposed :: (Class.Floating a) => Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copyTransposed n m aPtr ldb bPtr = evalContT $ do nPtr <- Call.cint n incaPtr <- Call.cint m incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\akPtr bkPtr -> BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr) (pointerSeq 1 aPtr) (pointerSeq ldb bPtr) {- | Copy a m-by-n-matrix to ColumnMajor order. -} copyToColumnMajor :: (Class.Floating a) => Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyToColumnMajor order m n aPtr bPtr = case order of RowMajor -> copyTransposed m n aPtr m bPtr ColumnMajor -> copyBlock (m*n) aPtr bPtr copyToSubColumnMajor :: (Class.Floating a) => Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copyToSubColumnMajor order m n aPtr ldb bPtr = case order of RowMajor -> copyTransposed m n aPtr ldb bPtr ColumnMajor -> if m==ldb then copyBlock (m*n) aPtr bPtr else copySubMatrix m n m aPtr ldb bPtr pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a] pointerSeq k ptr = iterate (flip advancePtr k) ptr createHigherArray :: (Shape.C sh, Class.Floating a) => sh -> Int -> Int -> Int -> ((Ptr a, Int) -> IO rank) -> IO (rank, Array sh a) createHigherArray shapeX m n nrhs act = fmap swap $ ArrayIO.unsafeCreateWithSizeAndResult shapeX $ \ _ xPtr -> if m>n then runContT (Call.allocaArray (m*nrhs)) $ \tmpPtr -> do r <- act (tmpPtr,m) copySubMatrix n nrhs m tmpPtr n xPtr return r else act (xPtr,n) newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a} sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a sum = runSum $ Class.switchFloating (Sum sumReal) (Sum sumReal) (Sum sumComplex) (Sum sumComplex) sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a sumReal n xPtr incx = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint incx yPtr <- Call.real one incyPtr <- Call.cint 0 liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr sumComplex :: Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a) sumComplex n xPtr incx = evalContT $ do transPtr <- Call.char 'N' mPtr <- Call.cint 1 nPtr <- Call.cint n alphaPtr <- Call.number one onePtr <- Call.number one zeroincPtr <- Call.cint 0 aPtr <- Call.allocaArray n ldaPtr <- Call.leadingDim 1 incxPtr <- Call.cint incx betaPtr <- Call.number zero yPtr <- Call.alloca incyPtr <- Call.cint 1 liftIO $ do BlasGen.copy nPtr onePtr zeroincPtr aPtr incyPtr gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr peek yPtr product :: Class.Floating a => Int -> Ptr a -> Int -> IO a product n xPtr incx = foldM (\x ptr -> do y <- peek ptr; return $! x*y) one $ take n $ pointerSeq incx xPtr newtype LACGV a = LACGV {getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()} lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO () lacgv = getLACGV $ Class.switchFloating (LACGV $ const $ const $ const $ return ()) (LACGV $ const $ const $ const $ return ()) (LACGV LapackComplex.lacgv) (LACGV LapackComplex.lacgv) {- Work around an inconsistency of BLAS. In case of a zero-column matrix BLAS's gemv and gbmv do not initialize the target vector. In contrast, these work-arounds do. -} {-# INLINE gemv #-} gemv :: (Class.Floating a) => Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr BlasGen.gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr {-# INLINE gbmv #-} gbmv :: (Class.Floating a) => Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () gbmv transPtr mPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr BlasGen.gbmv transPtr mPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr initializeMV :: Class.Floating a => Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr = do trans <- peek transPtr let (mtPtr,ntPtr) = if trans == CStr.castCharToCChar 'N' then (mPtr,nPtr) else (nPtr,mPtr) n <- peek ntPtr beta <- peek betaPtr when (n == 0 && isZero beta) $ Marshal.with 0 $ \incbPtr -> BlasGen.copy mtPtr betaPtr incbPtr yPtr incyPtr multiplyMatrix :: (Class.Floating a) => Order -> Order -> Int -> Int -> Int -> ForeignPtr a -> ForeignPtr a -> Ptr a -> IO () multiplyMatrix orderA orderB m k n a b cPtr = do let lda = case orderA of RowMajor -> k; ColumnMajor -> m let ldb = case orderB of RowMajor -> n; ColumnMajor -> k let ldc = m evalContT $ do transaPtr <- Call.char $ transposeFromOrder orderA transbPtr <- Call.char $ transposeFromOrder orderB mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- Call.cint k alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim lda bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.leadingDim ldb betaPtr <- Call.number zero ldcPtr <- Call.leadingDim ldc liftIO $ BlasGen.gemm transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr withAutoWorkspaceInfo :: (Class.Floating a) => String -> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO () withAutoWorkspaceInfo msg name computation = withInfo msg name $ \infoPtr -> withAutoWorkspace $ \workPtr lworkPtr -> computation workPtr lworkPtr infoPtr withAutoWorkspace :: (Class.Floating a) => (Ptr a -> Ptr CInt -> IO ()) -> IO () withAutoWorkspace computation = evalContT $ do lworkPtr <- Call.cint (-1) lwork <- liftIO $ alloca $ \workPtr -> do computation workPtr lworkPtr max 1 . ceilingSize <$> peek workPtr workPtr <- Call.allocaArray lwork liftIO $ pokeCInt lworkPtr lwork liftIO $ computation workPtr lworkPtr withInfo :: String -> String -> (Ptr CInt -> IO ()) -> IO () withInfo msg name computation = alloca $ \infoPtr -> do computation infoPtr info <- peekCInt infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf argMsg name (-info) GT -> error $ name ++ ": " ++ printf msg info argMsg :: String argMsg = "%s: illegal value in %d-th argument" errorCodeMsg :: String errorCodeMsg = "unknown error code %d" rankMsg :: String rankMsg = "deficient rank %d" definiteMsg :: String definiteMsg = "minor of order %d not positive definite" eigenMsg :: String eigenMsg = "%d off-diagonal elements not converging" pokeCInt :: Ptr CInt -> Int -> IO () pokeCInt ptr = poke ptr . fromIntegral peekCInt :: Ptr CInt -> IO Int peekCInt ptr = fromIntegral <$> peek ptr ceilingSize :: (Class.Floating a) => a -> Int ceilingSize = getFlip $ Class.switchFloating (Flip ceiling) (Flip ceiling) (Flip $ ceiling . Complex.realPart) (Flip $ ceiling . Complex.realPart)