{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} module Numeric.LAPACK.Matrix.BandedHermitian.Basic ( BandedHermitian, StaticVector, Transposition(..), fromList, identity, identityFatOrder, diagonal, takeDiagonal, toHermitian, toBanded, forceOrder, takeTopLeft, takeBottomRight, multiplyVector, multiplyFull, gramian, sumRank1, takeUpper, ) where import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv import qualified Numeric.LAPACK.Matrix.Extent as Extent import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded import qualified Numeric.LAPACK.Matrix.Mosaic.Private as Mos import qualified Numeric.LAPACK.Matrix.RowMajor as RowMajor import qualified Numeric.LAPACK.Matrix.Private as Matrix import qualified Numeric.LAPACK.Vector.Private as VectorPriv import qualified Numeric.LAPACK.Vector as Vector import qualified Data.Array.Comfort.Shape.Static as ShapeStatic import Numeric.LAPACK.Matrix.Hermitian.Private (TakeDiagonal(..)) import Numeric.LAPACK.Matrix.Hermitian.Basic (Hermitian) import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor,ColumnMajor), uploFromOrder, UnaryProxy, natFromProxy) import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed, Transposed), transposeOrder, Conjugation(NonConjugated), conjugatedOnRowMajor) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Scalar (RealOf, zero, one) import Numeric.LAPACK.Private (fill, lacgv, caseRealComplexFunc, realPtr, copyBlock, copyConjugate, condConjugate, condConjugateToTemp, pointerSeq, pokeCInt, copySubMatrix) import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.BLAS.FFI.Complex as BlasComplex import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Type.Data.Num.Unary.Literal as TypeNum import qualified Type.Data.Num.Unary.Proof as Proof import qualified Type.Data.Num.Unary as Unary import Type.Data.Num.Unary ((:+:)) import Type.Data.Num (integralFromProxy) import Type.Base.Proxy (Proxy(Proxy)) import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Storable as CheckedArray import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Data.Array.Comfort.Shape ((::+)((::+))) import Foreign.Marshal.Array (advancePtr) import Foreign.C.Types (CInt, CChar) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek, peekElemOff) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Data.Foldable (for_) import Data.Tuple.HT (mapPair) import Data.Complex (Complex, conjugate) type BandedHermitian offDiag size = Array (Layout.BandedHermitian offDiag size) type Diagonal size = BandedHermitian TypeNum.U0 size fromList :: (Unary.Natural offDiag, Shape.C size, Storable a) => UnaryProxy offDiag -> Order -> size -> [a] -> BandedHermitian offDiag size a fromList numOff order size = CheckedArray.fromList (Layout.BandedHermitian numOff order size) identityFatOrder :: (Unary.Natural offDiag, Shape.C sh, Class.Floating a) => Order -> sh -> BandedHermitian offDiag sh a identityFatOrder order sh = case order of RowMajor -> fromRowMajor $ RowMajor.tensorProduct (Left NonConjugated) (Vector.one sh) (Vector.unit Shape.static $ Left ()) ColumnMajor -> fromColumnMajor $ RowMajor.tensorProduct (Left NonConjugated) (Vector.one sh) (Vector.unit Shape.static $ Right ()) fromRowMajor :: (Unary.Natural offDiag, Shape.C size) => RowMajor.Matrix size (() ::+ ShapeStatic.ZeroBased offDiag) a -> BandedHermitian offDiag size a fromRowMajor = Array.mapShape (\(size, () ::+ ShapeStatic.ZeroBased k) -> Layout.BandedHermitian k RowMajor size) fromColumnMajor :: (Unary.Natural offDiag, Shape.C size) => RowMajor.Matrix size (ShapeStatic.ZeroBased offDiag ::+ ()) a -> BandedHermitian offDiag size a fromColumnMajor = Array.mapShape (\(size, ShapeStatic.ZeroBased k ::+ ()) -> Layout.BandedHermitian k ColumnMajor size) identity :: (Shape.C sh, Class.Floating a) => sh -> Diagonal sh a identity = Array.mapShape (Layout.BandedHermitian Proxy ColumnMajor) . Vector.one diagonal :: (Shape.C sh, Class.Floating a) => Vector sh (RealOf a) -> Diagonal sh a diagonal = Array.mapShape (Layout.BandedHermitian Proxy ColumnMajor) . Vector.fromReal takeDiagonal :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => BandedHermitian offDiag size a -> Vector size (RealOf a) takeDiagonal = runTakeDiagonal $ Class.switchFloating (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux) (TakeDiagonal takeDiagonalAux) takeDiagonalAux :: (Unary.Natural offDiag, Shape.C size, Class.Floating a, RealOf a ~ ar, Class.Real ar) => BandedHermitian offDiag size a -> Vector size ar takeDiagonalAux (Array (Layout.BandedHermitian numOff order size) a) = let k = integralFromProxy numOff in Array.unsafeCreateWithSize size $ \n yPtr -> evalContT $ do nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a let xPtr = realPtr $ advancePtr aPtr $ case order of RowMajor -> 0 ColumnMajor -> k incxPtr <- Call.cint (caseRealComplexFunc aPtr 1 2 * (k+1)) incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr toHermitian :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => BandedHermitian offDiag size a -> Hermitian size a toHermitian (Array (Layout.BandedHermitian numOff order size) a) = Array.unsafeCreateWithSize (Layout.hermitian order size) $ Mos.fromBanded (integralFromProxy numOff) order (Shape.size size) a toBanded :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => BandedHermitian offDiag size a -> Banded.Square offDiag offDiag size a toBanded (Array (Layout.BandedHermitian numOff order sh) a) = Array.unsafeCreate (Layout.Banded (numOff,numOff) order (Extent.square sh)) $ \bPtr -> withForeignPtr a $ \aPtr -> let n = Shape.size sh k = integralFromProxy numOff lda = k+1 ldb = 2*k+1 in case order of ColumnMajor -> do copySubMatrix lda n lda aPtr ldb bPtr columnToRowMajor copyConjugate (n-1) k lda (advancePtr aPtr (k+1)) ldb (advancePtr bPtr (k+1)) fill zero k (advancePtr bPtr (ldb*n-k)) RowMajor -> do copySubMatrix lda n lda aPtr ldb (advancePtr bPtr k) fill zero k bPtr rowToColumnMajor copyConjugate (n-1) k lda (advancePtr aPtr 1) ldb (advancePtr bPtr ldb) forceOrder :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => Order -> BandedHermitian offDiag size a -> BandedHermitian offDiag size a forceOrder newOrder a = if newOrder == Layout.bandedHermitianOrder (Array.shape a) then a else flipOrder a flipOrder :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => BandedHermitian offDiag size a -> BandedHermitian offDiag size a flipOrder (Array (Layout.BandedHermitian numOff order sh) a) = Array.unsafeCreate (Layout.BandedHermitian numOff (Layout.flipOrder order) sh) $ \bPtr -> withForeignPtr a $ \aPtr -> let n = Shape.size sh k = integralFromProxy numOff + 1 in case order of ColumnMajor -> columnToRowMajor BlasGen.copy n k k aPtr k bPtr RowMajor -> rowToColumnMajor BlasGen.copy n k k aPtr k bPtr columnToRowMajor :: (Class.Floating a) => (Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () columnToRowMajor copy n k lda aPtr ldb bPtr = evalContT $ do incxPtr <- Call.cint (lda-1) incyPtr <- Call.cint 1 inczPtr <- Call.cint 0 zPtr <- Call.number zero nPtr <- Call.alloca liftIO $ for_ (take n [0..]) $ \i -> do let split = min k $ n-i let xPtr = advancePtr aPtr (i*lda + k-1) let yPtr = advancePtr bPtr (i*ldb) pokeCInt nPtr split copy nPtr xPtr incxPtr yPtr incyPtr pokeCInt nPtr (k - split) BlasGen.copy nPtr zPtr inczPtr (advancePtr yPtr split) incyPtr rowToColumnMajor :: (Class.Floating a) => (Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () rowToColumnMajor copy n k lda aPtr ldb bPtr = evalContT $ do incxPtr <- Call.cint (lda-1) incyPtr <- Call.cint 1 inczPtr <- Call.cint 0 zPtr <- Call.number zero nPtr <- Call.alloca liftIO $ for_ (take n [0..]) $ \i -> do let split = max 0 (k-i-1) let xPtr = advancePtr aPtr (i*lda + (split+1-k)*(lda-1)) let yPtr = advancePtr bPtr (i*ldb) pokeCInt nPtr split BlasGen.copy nPtr zPtr inczPtr yPtr incyPtr pokeCInt nPtr (k - split) copy nPtr xPtr incxPtr (advancePtr yPtr split) incyPtr takeTopLeft :: (Unary.Natural offDiag, Shape.C sh0, Shape.C sh1, Class.Floating a) => BandedHermitian offDiag (sh0 ::+ sh1) a -> BandedHermitian offDiag sh0 a takeTopLeft (Array (Layout.BandedHermitian numOff order (sh0 ::+ _sh1)) a) = Array.unsafeCreateWithSize (Layout.BandedHermitian numOff order sh0) $ \n bPtr -> withForeignPtr a $ \aPtr -> copyBlock n aPtr bPtr takeBottomRight :: (Unary.Natural offDiag, Shape.C sh0, Shape.C sh1, Class.Floating a) => BandedHermitian offDiag (sh0 ::+ sh1) a -> BandedHermitian offDiag sh1 a takeBottomRight (Array (Layout.BandedHermitian numOff order (sh0 ::+ sh1)) a) = Array.unsafeCreateWithSize (Layout.BandedHermitian numOff order sh1) $ \n bPtr -> withForeignPtr a $ \aPtr -> copyBlock n (advancePtr aPtr $ (integralFromProxy numOff + 1) * Shape.size sh0) bPtr multiplyVector :: (Unary.Natural offDiag, Shape.C size, Eq size, Class.Floating a) => Transposition -> BandedHermitian offDiag size a -> Vector size a -> Vector size a multiplyVector transposed (Array (Layout.BandedHermitian numOff order size) a) (Array sizeX x) = Array.unsafeCreateWithSize size $ \n yPtr -> do Call.assert "BandedHermitian.multiplyVector: shapes mismatch" (size == sizeX) let k = integralFromProxy numOff evalContT $ do let conj = conjugatedOnRowMajor $ transposeOrder transposed order uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n kPtr <- Call.cint k alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim $ k+1 xPtr <- condConjugateToTemp conj n x incxPtr <- Call.cint 1 betaPtr <- Call.number zero incyPtr <- Call.cint 1 liftIO $ do BlasGen.hbmv uploPtr nPtr kPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr condConjugate conj nPtr yPtr incyPtr gramian :: (Shape.C size, Eq size, Class.Floating a, Unary.Natural sub, Unary.Natural super) => Banded.Square sub super size a -> BandedHermitian (sub :+: super) size a gramian a = case mapPair (natFromProxy,natFromProxy) $ Layout.bandedOffDiagonals $ Array.shape a of (sub,super) -> case (Proof.addNat sub super, Proof.addComm sub super) of (Proof.Nat, Proof.AddComm) -> fromUpperPart $ Banded.multiply (Banded.adjoint a) a fromUpperPart :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => Banded.Square offDiag offDiag size a -> BandedHermitian offDiag size a fromUpperPart (Array (Layout.Banded (sub,super) order extent) a) = let sh = Extent.squareSize extent n = Shape.size sh kl = integralFromProxy sub ku = integralFromProxy super lda = kl+1+ku ldb = ku+1 in Array.unsafeCreate (Layout.BandedHermitian super order sh) $ \bPtr -> withForeignPtr a $ \aPtr -> case order of ColumnMajor -> copySubMatrix ldb n lda aPtr ldb bPtr RowMajor -> copySubMatrix ldb n lda (advancePtr aPtr kl) ldb bPtr multiplyFull :: (Unary.Natural offDiag, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Class.Floating a) => Transposition -> BandedHermitian offDiag height a -> Matrix.Full meas vert horiz height width a -> Matrix.Full meas vert horiz height width a multiplyFull transposed a b = case Layout.fullOrder $ Array.shape b of ColumnMajor -> multiplyFullSpecial transposed a b RowMajor -> multiplyFullGeneric transposed a b multiplyFullSpecial :: (Unary.Natural offDiag, Extent.Measure meas, Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width, Class.Floating a) => Transposition -> BandedHermitian offDiag height a -> Matrix.Full meas vert horiz height width a -> Matrix.Full meas vert horiz height width a multiplyFullSpecial transposed (Array (Layout.BandedHermitian numOff orderA sizeA) a) (Array (Layout.Full orderB extentB) b) = Array.unsafeCreate (Layout.Full orderB extentB) $ \cPtr -> do Call.assert "BandedHermitian.multiplyFull: shapes mismatch" (sizeA == Extent.height extentB) let (height,width) = Extent.dimensions extentB case orderB of ColumnMajor -> multiplyFullColumnMajor transposed numOff (height,width) orderA a b cPtr RowMajor -> multiplyFullRowMajor transposed numOff (height,width) orderA a b cPtr multiplyFullColumnMajor :: (Unary.Natural offDiag, Shape.C height, Shape.C width, Class.Floating a) => Transposition -> UnaryProxy offDiag -> (height, width) -> Order -> ForeignPtr a -> ForeignPtr a -> Ptr a -> IO () multiplyFullColumnMajor transposed numOff (height,width) order a b cPtr = do let n = Shape.size height let nrhs = Shape.size width let k = integralFromProxy numOff evalContT $ do uploPtr <- Call.char $ uploFromOrder order nPtr <- Call.cint n kPtr <- Call.cint k alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim $ k+1 bPtr <- ContT $ withForeignPtr b incxPtr <- Call.cint 1 betaPtr <- Call.number zero incyPtr <- Call.cint 1 let pointers = take nrhs $ zip (pointerSeq n bPtr) (pointerSeq n cPtr) case transposeOrder transposed order of RowMajor -> do xPtr <- Call.allocaArray n liftIO $ for_ pointers $ \(biPtr,yPtr) -> do copyConjugate nPtr biPtr incxPtr xPtr incxPtr BlasGen.hbmv uploPtr nPtr kPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr lacgv nPtr yPtr incyPtr ColumnMajor -> liftIO $ for_ pointers $ \(xPtr,yPtr) -> BlasGen.hbmv uploPtr nPtr kPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr multiplyFullRowMajor :: (Unary.Natural offDiag, Shape.C height, Shape.C width, Class.Floating a) => Transposition -> UnaryProxy offDiag -> (height, width) -> Order -> ForeignPtr a -> ForeignPtr a -> Ptr a -> IO () multiplyFullRowMajor = error "BandedHermitian.multiplyFullRowMajor: not implemented" multiplyFullGeneric :: (Unary.Natural offDiag, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Class.Floating a) => Transposition -> BandedHermitian offDiag height a -> Matrix.Full meas vert horiz height width a -> Matrix.Full meas vert horiz height width a multiplyFullGeneric transposed a b = let (lower,upper) = (takeStrictLower a, takeUpper a) (lowerT,upperT) = case transposed of Transposed -> (Banded.transpose upper, Banded.transpose lower) NonTransposed -> (lower,upper) in VectorPriv.mac one (Banded.multiplyFull (Banded.mapExtent ExtentPriv.fromSquare lowerT) b) (Banded.multiplyFull (Banded.mapExtent ExtentPriv.fromSquare upperT) b) takeUpper :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => BandedHermitian offDiag size a -> Banded.Square TypeNum.U0 offDiag size a takeUpper = Array.mapShape (\(Layout.BandedHermitian numOff order sh) -> Layout.bandedSquare (Proxy,numOff) order sh) takeStrictLower :: (Unary.Natural offDiag, Shape.C size, Class.Floating a) => BandedHermitian offDiag size a -> Banded.Square offDiag TypeNum.U0 size a takeStrictLower (Array (Layout.BandedHermitian numOff order sh) x) = Array.unsafeCreateWithSize (Layout.bandedSquare (numOff,Proxy) (Layout.flipOrder order) sh) $ \size yPtr -> evalContT $ do let k = integralFromProxy numOff nPtr <- Call.cint $ Shape.size sh xPtr <- ContT $ withForeignPtr x sizePtr <- Call.cint size incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 inczPtr <- Call.cint 0 ldbPtr <- Call.leadingDim $ k+1 zPtr <- Call.number zero liftIO $ do copyConjugate sizePtr xPtr incxPtr yPtr incyPtr let offset = case order of ColumnMajor -> k; RowMajor -> 0 BlasGen.copy nPtr zPtr inczPtr (advancePtr yPtr offset) ldbPtr type StaticVector n = Vector (ShapeStatic.ZeroBased n) sumRank1 :: (Unary.Natural k, Shape.Indexed sh, Class.Floating a) => Order -> sh -> [(RealOf a, (Shape.Index sh, StaticVector (Unary.Succ k) a))] -> BandedHermitian k sh a sumRank1 = getSumRank1 $ Class.switchFloating (SumRank1 $ sumRank1Aux Proxy) (SumRank1 $ sumRank1Aux Proxy) (SumRank1 $ sumRank1Aux Proxy) (SumRank1 $ sumRank1Aux Proxy) newtype SumRank1 k sh a = SumRank1 {getSumRank1 :: SumRank1_ k sh (RealOf a) a} type SumRank1_ k sh ar a = Order -> sh -> [(ar, (Shape.Index sh, StaticVector (Unary.Succ k) a))] -> BandedHermitian k sh a sumRank1Aux :: (Unary.Natural k, Shape.Indexed sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) => UnaryProxy k -> SumRank1_ k sh ar a sumRank1Aux numOff order size xs = Array.unsafeCreateWithSize (Layout.BandedHermitian numOff order size) $ \bSize aPtr -> evalContT $ do let k = integralFromProxy numOff let n = Shape.size size let lda = k+1 uploPtr <- Call.char $ uploFromOrder order mPtr <- Call.cint lda alphaPtr <- Call.alloca incxPtr <- Call.cint 1 kPtr <- Call.cint k ldbPtr <- Call.leadingDim k bSizePtr <- Call.cint bSize liftIO $ do fill zero bSize aPtr for_ xs $ \(alpha, (offset, Array _shX x)) -> withForeignPtr x $ \xPtr -> do let i = Shape.offset size offset Call.assert "BandedHermitian.sumRank1: index too large" (i+k < n) let bPtr = advancePtr aPtr (lda*i) hbr order k alpha uploPtr mPtr kPtr alphaPtr xPtr incxPtr bPtr incxPtr ldbPtr condConjugate (conjugatedOnRowMajor order) bSizePtr aPtr incxPtr type HBR_ ar a = Order -> Int -> ar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO () newtype HBR a = HBR {getHBR :: HBR_ (RealOf a) a} hbr :: Class.Floating a => HBR_ (RealOf a) a hbr = getHBR $ Class.switchFloating (HBR syr) (HBR syr) (HBR her) (HBR her) syr :: (Class.Real a) => HBR_ a a syr order k alpha uploPtr nPtr kPtr alphaPtr xPtr incxPtr a0Ptr incaPtr ldaPtr = case order of ColumnMajor -> do let aPtr = advancePtr a0Ptr k poke alphaPtr alpha BlasReal.syr uploPtr kPtr alphaPtr xPtr incxPtr aPtr ldaPtr poke alphaPtr . (alpha*) =<< peekElemOff xPtr k BlasGen.axpy nPtr alphaPtr xPtr incxPtr (advancePtr aPtr (k*k)) incaPtr RowMajor -> do let aPtr = a0Ptr poke alphaPtr . (alpha*) =<< peek xPtr BlasGen.axpy nPtr alphaPtr xPtr incxPtr aPtr incaPtr poke alphaPtr alpha BlasReal.syr uploPtr kPtr alphaPtr (advancePtr xPtr 1) incxPtr (advancePtr aPtr (k+1)) ldaPtr her :: (Class.Real a) => HBR_ a (Complex a) her order k alpha uploPtr nPtr kPtr alphaPtr xPtr incxPtr a0Ptr incaPtr ldaPtr = case order of ColumnMajor -> do let aPtr = advancePtr a0Ptr k let alphaRealPtr = realPtr alphaPtr poke alphaRealPtr alpha BlasComplex.her uploPtr kPtr alphaRealPtr xPtr incxPtr aPtr ldaPtr poke alphaPtr . fmap (alpha*) . conjugate =<< peekElemOff xPtr k BlasGen.axpy nPtr alphaPtr xPtr incxPtr (advancePtr aPtr (k*k)) incaPtr RowMajor -> do let aPtr = a0Ptr let alphaRealPtr = realPtr alphaPtr poke alphaPtr . fmap (alpha*) . conjugate =<< peek xPtr BlasGen.axpy nPtr alphaPtr xPtr incxPtr aPtr incaPtr poke alphaRealPtr alpha BlasComplex.her uploPtr kPtr alphaRealPtr (advancePtr xPtr 1) incxPtr (advancePtr aPtr (k+1)) ldaPtr