{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Linear.Hermitian ( solve, inverse, ) where import Numeric.LAPACK.Matrix.Hermitian (Hermitian) import Numeric.LAPACK.Matrix (General) import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import Numeric.LAPACK.Matrix.Triangular.Private (copyTriangleToTemp) import Numeric.LAPACK.Matrix.Shape.Private (Order(ColumnMajor), uploFromOrder) import Numeric.LAPACK.Private (copyBlock, copyToColumnMajor) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (peek) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative ((<$>)) import Text.Printf (printf) solve :: (Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) => Hermitian sh a -> General sh nrhs a -> General sh nrhs a solve (Array (MatrixShape.Hermitian orderA shA) a) (Array (MatrixShape.General orderB heightB widthB) b) = Array.unsafeCreate (MatrixShape.General ColumnMajor heightB widthB) $ \xPtr -> do Call.assert "Hermitian.solve: height shapes mismatch" (shA == heightB) let n = Shape.size heightB let nrhs = Shape.size widthB let ldb = n evalContT $ do uploPtr <- Call.char $ uploFromOrder orderA nPtr <- Call.cint n nrhsPtr <- Call.cint nrhs apPtr <- copyTriangleToTemp orderA n a ipivPtr <- Call.allocaArray n bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint ldb liftIO $ do copyToColumnMajor orderB n nrhs bPtr xPtr withInfo "hpsv" $ LapackGen.hpsv uploPtr nPtr nrhsPtr apPtr ipivPtr xPtr ldbPtr inverse :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Hermitian sh a inverse (Array shape@(MatrixShape.Hermitian order sh) a) = Array.unsafeCreateWithSize shape $ \triSize bPtr -> do let n = Shape.size sh evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a ipivPtr <- Call.allocaArray n workPtr <- Call.allocaArray n liftIO $ do copyBlock triSize aPtr bPtr withInfo "hptrf" $ LapackGen.hptrf uploPtr nPtr bPtr ipivPtr withInfo "hptri" $ LapackGen.hptri uploPtr nPtr bPtr ipivPtr workPtr withInfo :: String -> (Ptr CInt -> IO ()) -> IO () withInfo name computation = alloca $ \infoPtr -> do computation infoPtr info <- fromIntegral <$> peek infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf "%s: illegal value in %d-th argument" name (-info) GT -> error $ printf "%s: %d-th diagonal value is zero" name info