{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} module Numeric.LAPACK.Matrix.Hermitian.Basic ( Hermitian, HermitianP, Transposition(..), diagonal, takeDiagonal, sumRank1, sumRank2, ) where import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Scalar as Scalar import Numeric.LAPACK.Matrix.Hermitian.Private (Diagonal(..), TakeDiagonal(..)) import Numeric.LAPACK.Matrix.Symmetric.Unified (complement) import Numeric.LAPACK.Matrix.Mosaic.Private (forPointers, diagonalPointerPairs, rowMajorPointers, columnMajorPointers, withPacking, noLabel, applyFuncPair, triArg) import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor,ColumnMajor), uploFromOrder) import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed, Transposed), Conjugation(Conjugated), conjugatedOnRowMajor) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Scalar (RealOf, zero) import Numeric.LAPACK.Private (fill, realPtr, condConjugate) import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.BLAS.FFI.Complex as BlasComplex import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Foreign.C.Types (CInt, CChar) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek) import Control.Monad.Trans.Cont (ContT, evalContT) import Control.Monad.IO.Class (liftIO) import Data.Foldable (forM_) type Hermitian sh = Array (Layout.Hermitian sh) type HermitianP pack sh = Array (Layout.HermitianP pack sh) diagonal :: (Shape.C sh, Class.Floating a) => Order -> Vector sh (RealOf a) -> Hermitian sh a diagonal order = runDiagonal $ Class.switchFloating (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order) (Diagonal $ diagonalAux order) diagonalAux :: (Shape.C sh, Class.Floating a, RealOf a ~ ar, Storable ar) => Order -> Vector sh ar -> Hermitian sh a diagonalAux order (Array sh x) = Array.unsafeCreateWithSize (Layout.hermitian order sh) $ \triSize aPtr -> do fill zero triSize aPtr withForeignPtr x $ \xPtr -> forM_ (diagonalPointerPairs order (Shape.size sh) xPtr aPtr) $ \(srcPtr,dstPtr) -> poke (realPtr dstPtr) =<< peek srcPtr takeDiagonal :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Vector sh (RealOf a) takeDiagonal = runTakeDiagonal $ Class.switchFloating (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux) takeDiagonalAux :: (Shape.C sh, Storable a, RealOf a ~ ar, Storable ar) => Hermitian sh a -> Vector sh ar takeDiagonalAux (Array (Layout.Mosaic _pack _mirror _upper order sh) a) = Array.unsafeCreateWithSize sh $ \n xPtr -> withForeignPtr a $ \aPtr -> forM_ (diagonalPointerPairs order n xPtr aPtr) $ \(dstPtr,srcPtr) -> poke dstPtr =<< peek (realPtr srcPtr) withConjBuffer :: (Shape.C sh, Class.Floating a) => Layout.PackingSingleton pack -> Order -> sh -> Int -> Ptr a -> (Ptr CChar -> Int -> Ptr CInt -> Ptr CInt -> IO ()) -> ContT r IO () withConjBuffer pack order sh triSize aPtr act = do uploPtr <- Call.char $ uploFromOrder order let n = Shape.size sh nPtr <- Call.cint n incxPtr <- Call.cint 1 sizePtr <- Call.cint triSize liftIO $ do fill zero triSize aPtr act uploPtr n nPtr incxPtr condConjugate (conjugatedOnRowMajor order) sizePtr aPtr incxPtr complement pack Conjugated order n aPtr {- Not easy to generalize to Symmetric because LapackComplex.spr and LapackComplex.syr expect complex parameter 'alpha'. -} sumRank1 :: (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) => Order -> sh -> [(RealOf a, Vector sh a)] -> HermitianP pack sh a sumRank1 = getSumRank1 $ Class.switchFloating (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux) (SumRank1 sumRank1Aux) type SumRank1_ pack sh ar a = Order -> sh -> [(ar, Vector sh a)] -> HermitianP pack sh a newtype SumRank1 pack sh a = SumRank1 {getSumRank1 :: SumRank1_ pack sh (RealOf a) a} sumRank1Aux :: (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a, RealOf a ~ ar, Storable ar) => SumRank1_ pack sh ar a sumRank1Aux order sh xs = let pack = Layout.autoPacking in Array.unsafeCreateWithSize (Layout.hermitianP pack order sh) $ \triSize aPtr -> evalContT $ do alphaPtr <- Call.alloca withConjBuffer pack order sh triSize aPtr $ \uploPtr n nPtr incxPtr -> do forM_ xs $ \(alpha, Array shX x) -> withForeignPtr x $ \xPtr -> do Call.assert "Hermitian.sumRank1: non-matching vector size" (sh==shX) poke alphaPtr alpha evalContT $ withPacking pack $ case Scalar.complexSingletonOfFunctor aPtr of Scalar.Real -> applyFuncPair (noLabel BlasReal.spr) (noLabel BlasReal.syr) uploPtr nPtr alphaPtr xPtr incxPtr (triArg aPtr n) Scalar.Complex -> applyFuncPair (noLabel BlasComplex.hpr) (noLabel BlasComplex.her) uploPtr nPtr alphaPtr xPtr incxPtr (triArg aPtr n) {- Not easy to generalize to Symmetric because there are no Complex.spr2 and Complex.syr2. However, there is BlasComplex.syr2k. -} sumRank2 :: (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) => Order -> sh -> [(a, (Vector sh a, Vector sh a))] -> HermitianP pack sh a sumRank2 order sh xys = let pack = Layout.autoPacking in Array.unsafeCreateWithSize (Layout.hermitianP pack order sh) $ \triSize aPtr -> evalContT $ do alphaPtr <- Call.alloca withConjBuffer pack order sh triSize aPtr $ \uploPtr n nPtr incPtr -> do forM_ xys $ \(alpha, (Array shX x, Array shY y)) -> withForeignPtr x $ \xPtr -> withForeignPtr y $ \yPtr -> do Call.assert "Hermitian.sumRank2: non-matching x vector size" (sh==shX) Call.assert "Hermitian.sumRank2: non-matching y vector size" (sh==shY) poke alphaPtr alpha evalContT $ withPacking pack $ applyFuncPair (noLabel BlasGen.hpr2) (noLabel BlasGen.her2) uploPtr nPtr alphaPtr xPtr incPtr yPtr incPtr (triArg aPtr n) _pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO () _pack order n fullPtr packedPtr = evalContT $ do incxPtr <- Call.cint 1 liftIO $ case order of ColumnMajor -> forPointers (columnMajorPointers n fullPtr packedPtr) $ \nPtr ((_,srcPtr),dstPtr) -> BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr RowMajor -> forPointers (rowMajorPointers n fullPtr packedPtr) $ \nPtr (srcPtr,dstPtr) -> BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr