{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GADTs #-} module Numeric.LAPACK.Matrix.HermitianPositiveDefinite.Linear ( solve, solveDecomposed, inverse, decompose, determinant, ) where import qualified Numeric.LAPACK.Matrix.Symmetric.Unified as Symmetric import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import qualified Numeric.LAPACK.Vector as Vector import Numeric.LAPACK.Matrix.Hermitian.Private (Determinant(..)) import Numeric.LAPACK.Matrix.Mosaic.Private (withPackingLinear, label, applyFuncPair, triArg, copyTriangleToTemp) import Numeric.LAPACK.Matrix.Mosaic.Basic (takeDiagonal) import Numeric.LAPACK.Matrix.Layout.Private (uploFromOrder) import Numeric.LAPACK.Matrix.Modifier (Conjugation(Conjugated)) import Numeric.LAPACK.Matrix.Private (Full) import Numeric.LAPACK.Linear.Private (solver) import Numeric.LAPACK.Scalar (RealOf, realPart, zero) import Numeric.LAPACK.Private (copySubTrapezoid, copyBlock, fill, rankMsg, definiteMsg) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Foreign.ForeignPtr (withForeignPtr) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) type Hermitian pack sh = Array (Layout.HermitianP pack sh) type Upper pack sh = Array (Layout.UpperTriangularP pack sh) solve :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) => Hermitian pack sh a -> Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a solve (Array shape@(Layout.Mosaic pack _mirror _upper orderA shA) a) = solver "Hermitian.solve" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do uploPtr <- Call.char $ uploFromOrder orderA aPtr <- copyTriangleToTemp Conjugated orderA (Shape.size shape) a withPackingLinear definiteMsg pack $ applyFuncPair (label "ppsv" LapackGen.ppsv) (label "posv" LapackGen.posv) uploPtr nPtr nrhsPtr (triArg aPtr n) xPtr ldxPtr solveDecomposed :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) => Upper pack sh a -> Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a solveDecomposed (Array shape@(Layout.Mosaic pack Layout.NoMirror _upper orderA shA) a) = solver "Hermitian.solveDecomposed" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do uploPtr <- Call.char $ uploFromOrder orderA aPtr <- copyTriangleToTemp Conjugated orderA (Shape.size shape) a withPackingLinear rankMsg pack $ applyFuncPair (label "pptrs" LapackGen.pptrs) (label "potrs" LapackGen.potrs) uploPtr nPtr nrhsPtr (triArg aPtr n) xPtr ldxPtr inverse :: (Shape.C sh, Class.Floating a) => Hermitian pack sh a -> Hermitian pack sh a inverse (Array shape@(Layout.Mosaic pack _mirror _upper order sh) a) = Array.unsafeCreateWithSize shape $ \triSize bPtr -> do let n = Shape.size sh evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a liftIO $ copyBlock triSize aPtr bPtr withPackingLinear definiteMsg pack $ applyFuncPair (label "pptrf" LapackGen.pptrf) (label "potrf" LapackGen.potrf) uploPtr nPtr (triArg bPtr n) withPackingLinear rankMsg pack $ applyFuncPair (label "pptri" LapackGen.pptri) (label "potri" LapackGen.potri) uploPtr nPtr (triArg bPtr n) Symmetric.complement pack Conjugated order n bPtr decompose :: (Shape.C sh, Class.Floating a) => Hermitian pack sh a -> Upper pack sh a decompose (Array (Layout.Mosaic pack _mirror upper order sh) a) = Array.unsafeCreateWithSize (Layout.Mosaic pack Layout.NoMirror upper order sh) $ \triSize bPtr -> do evalContT $ do let uplo = uploFromOrder order uploPtr <- Call.char uplo let n = Shape.size sh nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a let packed = case pack of Layout.Packed -> True; Layout.Unpacked -> False liftIO $ if packed then copyBlock triSize aPtr bPtr else do fill zero (n*n) bPtr copySubTrapezoid uplo n n n aPtr n bPtr withPackingLinear definiteMsg pack $ applyFuncPair (label "pptrf" LapackGen.pptrf) (label "potrf" LapackGen.potrf) uploPtr nPtr (triArg bPtr n) determinant :: (Shape.C sh, Class.Floating a) => Hermitian pack sh a -> RealOf a determinant = getDeterminant $ Class.switchFloating (Determinant determinantAux) (Determinant determinantAux) (Determinant determinantAux) (Determinant determinantAux) determinantAux :: (Shape.C sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) => Hermitian pack sh a -> ar determinantAux = (^(2::Int)) . Vector.product . Array.map realPart . takeDiagonal . decompose