{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Matrix.Hermitian ( Hermitian, fromList, autoFromList, identity, diagonal, getDiagonal, multiplyVector, square, multiplySquareLeft, multiplyGeneralLeft, multiplySquareRight, multiplyGeneralRight, outer, sumRank1, sumRank2, toSquare, covariance, addTransposed, ) where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import Numeric.LAPACK.Matrix.Triangular.Private (forPointers, pack, unpack, unpackToTemp, diagonalPointers, rowMajorPointers, columnMajorPointers) import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor,ColumnMajor), flipOrder, uploFromOrder) import Numeric.LAPACK.Matrix.Square (Square) import Numeric.LAPACK.Matrix.Private (General, ZeroInt, zeroInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Private (RealOf, fill, zero, one, lacgv, fromReal, realPart, copyToTemp) import qualified Numeric.LAPACK.FFI.Complex as LapackComplex import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.BLAS.FFI.Complex as BlasComplex import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.C.Types (CInt, CChar) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Monad (when) import qualified Data.NonEmpty as NonEmpty import Data.Foldable (forM_) import Data.Complex (Complex) type Hermitian sh = Array (MatrixShape.Hermitian sh) fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Hermitian sh a fromList order sh = Array.fromList (MatrixShape.Hermitian order sh) autoFromList :: (Storable a) => Order -> [a] -> Hermitian ZeroInt a autoFromList order xs = fromList order (zeroInt $ MatrixShape.triangleExtent "Hermitian.autoFromList" $ length xs) xs identity :: (Shape.C sh, Class.Floating a) => Order -> sh -> Hermitian sh a identity order sh = Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $ \triSize aPtr -> do fill zero triSize aPtr mapM_ (flip poke one . snd) $ diagonalPointers order (Shape.size sh) aPtr aPtr diagonal :: (Shape.C sh, Class.Floating a) => Order -> Vector sh (RealOf a) -> Hermitian sh a diagonal order = runDiagonal $ Class.switchFloating (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order) newtype Diagonal sh a = Diagonal {runDiagonal :: Vector sh (RealOf a) -> Hermitian sh a} diagonalAux :: (Shape.C sh, Class.Floating a, RealOf a ~ ar, Storable ar) => Order -> Vector sh ar -> Hermitian sh a diagonalAux order (Array sh x) = Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $ \triSize aPtr -> do fill zero triSize aPtr withForeignPtr x $ \xPtr -> forM_ (diagonalPointers order (Shape.size sh) xPtr aPtr) $ \(srcPtr,dstPtr) -> poke dstPtr . fromReal =<< peek srcPtr getDiagonal :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Vector sh (RealOf a) getDiagonal = runGetDiagonal $ Class.switchFloating (GetDiagonal $ getDiagonalAux) (GetDiagonal $ getDiagonalAux) (GetDiagonal $ getDiagonalAux) (GetDiagonal $ getDiagonalAux) newtype GetDiagonal sh a = GetDiagonal {runGetDiagonal :: Hermitian sh a -> Vector sh (RealOf a)} getDiagonalAux :: (Shape.C sh, Class.Floating a, RealOf a ~ ar, Storable ar) => Hermitian sh a -> Vector sh ar getDiagonalAux (Array (MatrixShape.Hermitian order sh) a) = Array.unsafeCreateWithSize sh $ \n xPtr -> withForeignPtr a $ \aPtr -> forM_ (diagonalPointers order n xPtr aPtr) $ \(dstPtr,srcPtr) -> poke dstPtr . realPart =<< peek srcPtr multiplyVector :: (Shape.C sh, Eq sh, Class.Floating a) => Hermitian sh a -> Vector sh a -> Vector sh a multiplyVector (Array (MatrixShape.Hermitian order shA) a) (Array shX x) = Array.unsafeCreateWithSize shX $ \n yPtr -> do Call.assert "Hermitian.multiplyVector: width shapes mismatch" (shA == shX) evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 betaPtr <- Call.number zero incyPtr <- Call.cint 1 liftIO $ BlasGen.hpmv uploPtr nPtr alphaPtr aPtr xPtr incxPtr betaPtr yPtr incyPtr square :: (Shape.C sh, Eq sh, Class.Floating a) => Hermitian sh a -> Hermitian sh a square (Array shape@(MatrixShape.Hermitian order sh) a) = Array.unsafeCreate shape $ \cpPtr -> do let n = Shape.size sh evalContT $ do sidePtr <- Call.char 'L' uploPtr <- Call.char 'U' nPtr <- Call.cint n let ldPtr = nPtr bPtr <- unpackToTemp (unpackFull order) n a cPtr <- Call.allocaArray (n*n) alphaPtr <- Call.number one betaPtr <- Call.number zero liftIO $ do BlasGen.hemm sidePtr uploPtr nPtr nPtr alphaPtr bPtr ldPtr bPtr ldPtr betaPtr cPtr ldPtr pack order n cPtr cpPtr multiplySquareLeft :: (Shape.C sh, Eq sh, Class.Floating a) => Square sh a -> Hermitian sh a -> Square sh a multiplySquareLeft (Array shapeB@(MatrixShape.Square orderB shB) b) (Array (MatrixShape.Hermitian orderA shA) a) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Hermitian.multiplySquareLeft: shapes mismatch" (shA == shB) let n = Shape.size shB multiplyAux True orderA n a (flipOrder orderB) n b cPtr multiplyGeneralLeft :: (Shape.C height, Shape.C width, Eq width, Class.Floating a) => General height width a -> Hermitian width a -> General height width a multiplyGeneralLeft (Array shapeB@(MatrixShape.General orderB height width) b) (Array (MatrixShape.Hermitian orderA shA) a) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Hermitian.multiplyGeneralLeft: shapes mismatch" (shA == width) multiplyAux True orderA (Shape.size width) a (flipOrder orderB) (Shape.size height) b cPtr multiplySquareRight :: (Shape.C sh, Eq sh, Class.Floating a) => Hermitian sh a -> Square sh a -> Square sh a multiplySquareRight (Array (MatrixShape.Hermitian orderA shA) a) (Array shapeB@(MatrixShape.Square orderB shB) b) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Hermitian.multiplySquareRight: shapes mismatch" (shA == shB) let n = Shape.size shB multiplyAux False orderA n a orderB n b cPtr multiplyGeneralRight :: (Shape.C height, Eq height, Shape.C width, Class.Floating a) => Hermitian height a -> General height width a -> General height width a multiplyGeneralRight (Array (MatrixShape.Hermitian orderA shA) a) (Array shapeB@(MatrixShape.General orderB height width) b) = Array.unsafeCreate shapeB $ \cPtr -> do Call.assert "Hermitian.multiplyGeneralRight: shapes mismatch" (shA == height) multiplyAux False orderA (Shape.size height) a orderB (Shape.size width) b cPtr multiplyAux :: Class.Floating a => Bool -> Order -> Int -> ForeignPtr a -> Order -> Int -> ForeignPtr a -> Ptr a -> IO () multiplyAux extraConjugate orderA m0 a orderB n0 b cPtr = do let size = m0*m0 evalContT $ do let (side,(m,n)) = case orderB of ColumnMajor -> ('L',(m0,n0)) RowMajor -> ('R',(n0,m0)) sidePtr <- Call.char side uploPtr <- Call.char $ uploFromOrder orderA mPtr <- Call.cint m nPtr <- Call.cint n alphaPtr <- Call.number one aPtr <- unpackToTemp (unpack orderA) m0 a ldaPtr <- Call.cint m0 incaPtr <- Call.cint 1 sizePtr <- Call.cint size bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint m betaPtr <- Call.number zero ldcPtr <- Call.cint m liftIO $ do when ((orderA/=orderB) /= extraConjugate) $ lacgv sizePtr aPtr incaPtr BlasGen.hemm sidePtr uploPtr mPtr nPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr outer :: (Shape.C sh, Class.Floating a) => Vector sh a -> Hermitian sh a outer = getMap $ Class.switchFloating (Map outerAux) (Map outerAux) (Map outerAux) (Map outerAux) outerAux :: (Shape.C sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) => Vector sh a -> Hermitian sh a outerAux (Array sh x) = Array.unsafeCreateWithSize (MatrixShape.Hermitian ColumnMajor sh) $ \triSize aPtr -> do let n = Shape.size sh evalContT $ do uploPtr <- Call.char $ uploFromOrder ColumnMajor nPtr <- Call.cint n alphaPtr <- Call.real one xPtr <- ContT $ withForeignPtr x incxPtr <- Call.cint 1 liftIO $ fill zero triSize aPtr liftIO $ hpr uploPtr nPtr alphaPtr xPtr incxPtr aPtr sumRank1 :: (Shape.C sh, Eq sh, Class.Floating a) => NonEmpty.T [] (RealOf a, Vector sh a) -> Hermitian sh a sumRank1 = getSumRank1 $ Class.switchFloating (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux) type SumRank1_ sh a = NonEmpty.T [] (RealOf a, Vector sh a) -> Hermitian sh a newtype SumRank1 sh a = SumRank1 {getSumRank1 :: SumRank1_ sh a} sumRank1Aux :: (Shape.C sh, Eq sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) => SumRank1_ sh a sumRank1Aux xs@(NonEmpty.Cons (_, Array sh _) _) = Array.unsafeCreateWithSize (MatrixShape.Hermitian ColumnMajor sh) $ \triSize aPtr -> do let n = Shape.size sh evalContT $ do uploPtr <- Call.char $ uploFromOrder ColumnMajor nPtr <- Call.cint n alphaPtr <- Call.alloca incxPtr <- Call.cint 1 liftIO $ do fill zero triSize aPtr forM_ xs $ \(alpha, Array shX x) -> withForeignPtr x $ \xPtr -> do Call.assert "Hermitian.sumRank1: non-matching vector size" (sh==shX) poke alphaPtr alpha hpr uploPtr nPtr alphaPtr xPtr incxPtr aPtr type HPR_ a = Ptr CChar -> Ptr CInt -> Ptr (RealOf a) -> Ptr a -> Ptr CInt -> Ptr a -> IO () newtype HPR a = HPR {getHPR :: HPR_ a} hpr :: Class.Floating a => HPR_ a hpr = getHPR $ Class.switchFloating (HPR BlasReal.spr) (HPR BlasReal.spr) (HPR BlasComplex.hpr) (HPR BlasComplex.hpr) sumRank2 :: (Shape.C sh, Eq sh, Class.Floating a) => NonEmpty.T [] (a, (Vector sh a, Vector sh a)) -> Hermitian sh a sumRank2 xys@(NonEmpty.Cons (_, (Array sh _, _)) _) = Array.unsafeCreateWithSize (MatrixShape.Hermitian ColumnMajor sh) $ \triSize aPtr -> do let n = Shape.size sh evalContT $ do uploPtr <- Call.char $ uploFromOrder ColumnMajor nPtr <- Call.cint n alphaPtr <- Call.alloca incPtr <- Call.cint 1 liftIO $ do fill zero triSize aPtr forM_ xys $ \(alpha, (Array shX x, Array shY y)) -> withForeignPtr x $ \xPtr -> withForeignPtr y $ \yPtr -> do Call.assert "Hermitian.sumRank2: non-matching x vector size" (sh==shX) Call.assert "Hermitian.sumRank2: non-matching y vector size" (sh==shY) poke alphaPtr alpha BlasGen.hpr2 uploPtr nPtr alphaPtr xPtr incPtr yPtr incPtr aPtr {- It is not strictly necessary to keep the 'order'. It would be neither more complicated nor less efficient to change the order via the conversion. -} toSquare, _toSquare :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Square sh a _toSquare (Array (MatrixShape.Hermitian order sh) a) = Array.unsafeCreate (MatrixShape.Square order sh) $ \bPtr -> evalContT $ do let n = Shape.size sh aPtr <- ContT $ withForeignPtr a conjPtr <- conjugateToTemp (MatrixShape.triangleSize n) a liftIO $ do unpack (flipOrder order) n conjPtr bPtr -- wrong unpack order n aPtr bPtr toSquare (Array (MatrixShape.Hermitian order sh) a) = Array.unsafeCreate (MatrixShape.Square order sh) $ \bPtr -> withForeignPtr a $ \aPtr -> unpackFull order (Shape.size sh) aPtr bPtr {- | Make a temporary copy only for complex matrices. -} conjugateToTemp :: (Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a) conjugateToTemp n = runCopyToTemp $ Class.switchFloating (CopyToTemp $ ContT . withForeignPtr) (CopyToTemp $ ContT . withForeignPtr) (CopyToTemp $ complexConjugateToTemp n) (CopyToTemp $ complexConjugateToTemp n) newtype CopyToTemp r a = CopyToTemp {runCopyToTemp :: ForeignPtr a -> ContT r IO (Ptr a)} complexConjugateToTemp :: Class.Real a => Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a)) complexConjugateToTemp n x = do nPtr <- Call.cint n xPtr <- copyToTemp n x incxPtr <- Call.cint 1 liftIO $ LapackComplex.lacgv nPtr xPtr incxPtr return xPtr {- | A^H * A -} covariance :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> Hermitian width a covariance = getMap $ Class.switchFloating (Map covarianceAux) (Map covarianceAux) (Map covarianceAux) (Map covarianceAux) newtype Map f g a = Map {getMap :: f a -> g a} covarianceAux :: (Shape.C height, Shape.C width, Class.Floating a, RealOf a ~ ar, Class.Real ar) => General height width a -> Hermitian width a covarianceAux (Array (MatrixShape.General order height width) a) = Array.unsafeCreate (MatrixShape.Hermitian order width) $ \bPtr -> do let n = Shape.size width let k = Shape.size height evalContT $ do nPtr <- Call.cint n kPtr <- Call.cint k alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a betaPtr <- Call.number zero cPtr <- Call.allocaArray (n*n) ldcPtr <- Call.cint n case order of ColumnMajor -> do uploPtr <- Call.char 'U' transPtr <- Call.char 'C' ldaPtr <- Call.cint k liftIO $ do herk uploPtr transPtr nPtr kPtr alphaPtr aPtr ldaPtr betaPtr cPtr ldcPtr pack ColumnMajor n cPtr bPtr RowMajor -> do uploPtr <- Call.char 'L' transPtr <- Call.char 'N' ldaPtr <- Call.cint n liftIO $ do herk uploPtr transPtr nPtr kPtr alphaPtr aPtr ldaPtr betaPtr cPtr ldcPtr pack RowMajor n cPtr bPtr type HERK_ a = Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr (RealOf a) -> Ptr a -> Ptr CInt -> Ptr (RealOf a) -> Ptr a -> Ptr CInt -> IO () newtype HERK a = HERK {getHERK :: HERK_ a} herk :: Class.Floating a => HERK_ a herk = getHERK $ Class.switchFloating (HERK BlasReal.syrk) (HERK BlasReal.syrk) (HERK BlasComplex.herk) (HERK BlasComplex.herk) {- | A^H + A -} addTransposed, _addTransposed :: (Shape.C sh, Class.Floating a) => Square sh a -> Hermitian sh a _addTransposed (Array (MatrixShape.Square order sh) a) = Array.unsafeCreateWithSize (MatrixShape.Hermitian order sh) $ \bSize bPtr -> do let n = Shape.size sh evalContT $ do alphaPtr <- Call.number one incxPtr <- Call.cint 1 aPtr <- ContT $ withForeignPtr a sizePtr <- Call.cint bSize conjPtr <- Call.allocaArray bSize liftIO $ do pack order n aPtr bPtr pack (flipOrder order) n aPtr conjPtr -- wrong lacgv sizePtr conjPtr incxPtr BlasGen.axpy sizePtr alphaPtr conjPtr incxPtr bPtr incxPtr addTransposed (Array (MatrixShape.Square order sh) a) = Array.unsafeCreate (MatrixShape.Hermitian order sh) $ \bPtr -> do let n = Shape.size sh evalContT $ do alphaPtr <- Call.number one incxPtr <- Call.cint 1 incnPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a liftIO $ case order of RowMajor -> forPointers (rowMajorPointers n aPtr bPtr) $ \nPtr (srcPtr,dstPtr) -> do BlasGen.copy nPtr srcPtr incnPtr dstPtr incxPtr lacgv nPtr dstPtr incxPtr BlasGen.axpy nPtr alphaPtr srcPtr incxPtr dstPtr incxPtr ColumnMajor -> forPointers (columnMajorPointers n aPtr bPtr) $ \nPtr ((srcRowPtr,srcColumnPtr),dstPtr) -> do BlasGen.copy nPtr srcRowPtr incnPtr dstPtr incxPtr lacgv nPtr dstPtr incxPtr BlasGen.axpy nPtr alphaPtr srcColumnPtr incxPtr dstPtr incxPtr unpackFull :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () unpackFull order n packedPtr fullPtr = evalContT $ do incxPtr <- Call.cint 1 incyPtr <- Call.cint n liftIO $ case order of RowMajor -> forPointers (rowMajorPointers n fullPtr packedPtr) $ \nPtr (dstPtr,srcPtr) -> do BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr lacgv nPtr dstPtr incyPtr BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr ColumnMajor -> forPointers (columnMajorPointers n fullPtr packedPtr) $ \nPtr ((dstRowPtr,dstColumnPtr),srcPtr) -> do BlasGen.copy nPtr srcPtr incxPtr dstRowPtr incyPtr lacgv nPtr dstRowPtr incyPtr BlasGen.copy nPtr srcPtr incxPtr dstColumnPtr incxPtr _pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () _pack order n fullPtr packedPtr = evalContT $ do incxPtr <- Call.cint 1 liftIO $ case order of ColumnMajor -> forPointers (columnMajorPointers n fullPtr packedPtr) $ \nPtr ((_,srcPtr),dstPtr) -> BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr RowMajor -> forPointers (rowMajorPointers n fullPtr packedPtr) $ \nPtr (srcPtr,dstPtr) -> BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr