{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Linear.HermitianPositiveDefinite ( solve, inverse, decompose, ) where import Numeric.LAPACK.Matrix.Hermitian (Hermitian) import Numeric.LAPACK.Matrix.Triangular (Upper) import Numeric.LAPACK.Matrix (General) import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import Numeric.LAPACK.Matrix.Triangular.Private (copyTriangleToTemp) import Numeric.LAPACK.Matrix.Shape.Private (Order(ColumnMajor), uploFromOrder) import Numeric.LAPACK.Private (copyBlock, copyToColumnMajor) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (peek) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative ((<$>)) import Text.Printf (printf) solve :: (Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) => Hermitian sh a -> General sh nrhs a -> General sh nrhs a solve (Array (MatrixShape.Hermitian orderA shA) a) (Array (MatrixShape.General orderB heightB widthB) b) = Array.unsafeCreate (MatrixShape.General ColumnMajor heightB widthB) $ \xPtr -> do Call.assert "Hermitian.solve: height shapes mismatch" (shA == heightB) let n = Shape.size heightB let nrhs = Shape.size widthB let ldb = n evalContT $ do uploPtr <- Call.char $ uploFromOrder orderA nPtr <- Call.cint n nrhsPtr <- Call.cint nrhs apPtr <- copyTriangleToTemp orderA n a bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint ldb liftIO $ do copyToColumnMajor orderB n nrhs bPtr xPtr withInfo "ppsv" $ LapackGen.ppsv uploPtr nPtr nrhsPtr apPtr xPtr ldbPtr inverse :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Hermitian sh a inverse (Array shape@(MatrixShape.Hermitian order sh) a) = Array.unsafeCreateWithSize shape $ \triSize bPtr -> do evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint $ Shape.size sh aPtr <- ContT $ withForeignPtr a liftIO $ do copyBlock triSize aPtr bPtr withInfo "pptrf" $ LapackGen.pptrf uploPtr nPtr bPtr withInfo "pptri" $ LapackGen.pptri uploPtr nPtr bPtr {- | Cholesky decomposition -} decompose :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Upper sh a decompose (Array (MatrixShape.Hermitian order sh) a) = Array.unsafeCreateWithSize (MatrixShape.Triangular MatrixShape.Upper order sh) $ \triSize bPtr -> do evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint $ Shape.size sh aPtr <- ContT $ withForeignPtr a liftIO $ do copyBlock triSize aPtr bPtr withInfo "pptrf" $ LapackGen.pptrf uploPtr nPtr bPtr withInfo :: String -> (Ptr CInt -> IO ()) -> IO () withInfo name computation = alloca $ \infoPtr -> do computation infoPtr info <- fromIntegral <$> peek infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf "%s: illegal value in %d-th argument" name (-info) GT -> error $ printf "%s: minor of order %d not positive definite" name info