{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE StandaloneDeriving #-} module Numeric.LAPACK.Orthogonal.Plain where import qualified Numeric.LAPACK.Matrix.Divide as Divide import qualified Numeric.LAPACK.Matrix.Multiply as Multiply import qualified Numeric.LAPACK.Matrix.Type.Private as Matrix import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded import qualified Numeric.LAPACK.Matrix.Basic as Basic import qualified Numeric.LAPACK.Matrix.Private as MatrixPriv import qualified Numeric.LAPACK.Matrix.Plain.Format as ArrFormat import qualified Numeric.LAPACK.Matrix.Layout as LayoutPub import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv import qualified Numeric.LAPACK.Matrix.Extent as Extent import qualified Numeric.LAPACK.Split as Split import Numeric.LAPACK.Output ((/+/)) import Numeric.LAPACK.Matrix.Plain.Format (formatArray) import Numeric.LAPACK.Matrix.Type.Private (Matrix) import Numeric.LAPACK.Matrix.Triangular.Basic (Upper) import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor, ColumnMajor), sideSwapFromOrder) import Numeric.LAPACK.Matrix.Extent.Private (Extent) import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed, Transposed), Conjugation(NonConjugated, Conjugated)) import Numeric.LAPACK.Matrix.Private (Full) import Numeric.LAPACK.Scalar (RealOf, zero, isZero, absolute, conjugate) import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked), deconsUnchecked) import Numeric.LAPACK.Private (fill, copySubMatrix, copyBlock, conjugateToTemp, caseRealComplexFunc, withAutoWorkspaceInfo, errorCodeMsg) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import Foreign.Marshal.Array (advancePtr) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Storable (Storable) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Monad (when) import Control.Applicative (liftA2) import qualified Data.List as List import Data.Monoid ((<>)) data Hh data instance Matrix Hh xl xu lower upper meas vert horiz height width a where Householder :: Banded.RectangularDiagonal meas vert horiz height width a -> SplitArray meas vert horiz height width a -> HouseholderFlex lower upper meas vert horiz height width a deriving instance (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Storable a, Show height, Show width, Show a) => Show (Matrix Hh xl xu lower upper meas vert horiz height width a) type SplitArray meas vert horiz height width a = Split.Split Layout.Reflector meas vert horiz height width a split_ :: Matrix Hh xl xu lower upper meas vert horiz height width a -> SplitArray meas vert horiz height width a split_ (Householder _tau split) = split type HouseholderFlex = Matrix Hh () () type Householder = HouseholderFlex Layout.Filled Layout.Filled type General height width = Householder Extent.Size Extent.Big Extent.Big height width type Tall height width = Householder Extent.Size Extent.Big Extent.Small height width type Wide height width = Householder Extent.Size Extent.Small Extent.Big height width type LiberalSquare height width = SquareMeas Extent.Size height width type Square sh = SquareMeas Extent.Shape sh sh type SquareMeas meas height width = Householder meas Extent.Small Extent.Small height width mapExtent :: (Extent.Measure measA, Extent.C vertA, Extent.C horizA) => (Extent.Measure measB, Extent.C vertB, Extent.C horizB) => ExtentPriv.Map measA vertA horizA measB vertB horizB height width -> HouseholderFlex lower upper measA vertA horizA height width a -> HouseholderFlex lower upper measB vertB horizB height width a mapExtent f (Householder tau split) = Householder (Banded.mapExtent f tau) $ Split.mapExtent f split mapHeight :: (Extent.C vert, Extent.C horiz) => (heightA -> heightB) -> HouseholderFlex lower upper Extent.Size vert horiz heightA width a -> HouseholderFlex lower upper Extent.Size vert horiz heightB width a mapHeight f (Householder tau split) = Householder (Banded.mapHeight f tau) (Split.mapHeight f split) mapWidth :: (Extent.C vert, Extent.C horiz) => (widthA -> widthB) -> HouseholderFlex lower upper Extent.Size vert horiz height widthA a -> HouseholderFlex lower upper Extent.Size vert horiz height widthB a mapWidth f (Householder tau split) = Householder (Banded.mapWidth f tau) (Split.mapWidth f split) uncheck :: (Extent.Measure meas, Extent.C vert, Extent.C horiz) => HouseholderFlex lower upper meas vert horiz height width a -> HouseholderFlex lower upper meas vert horiz (Unchecked height) (Unchecked width) a uncheck (Householder tau split) = Householder (Banded.mapExtentSizes (ExtentPriv.mapWrap Unchecked Unchecked) tau) (Split.uncheck split) caseTallWide :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) => Householder meas vert horiz height width a -> Either (Tall height width a) (Wide height width a) caseTallWide (Householder tau (Array shape a)) = let consHouse taub b newShape = Householder (Array.mapShape (\bandShape -> bandShape{Layout.bandedExtent = Layout.splitExtent newShape}) taub) $ Array newShape b in either (Left . consHouse tau a) (Right . consHouse tau a) $ Layout.caseTallWideSplit shape instance Matrix.Format Hh where type FormatExtra Hh extra = extra ~ () format fmt (Householder tau m) = formatArray fmt (Array.mapShape (Shape.ZeroBased . Shape.size) tau) /+/ formatArray fmt m instance Matrix.Layout Hh where type LayoutExtra Hh extra = extra ~ () layout (Householder _tau m) = ArrFormat.splitArrayFromList2 (Layout.splitExtent $ Array.shape m) $ ArrFormat.layoutSplit m fromMatrix :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Full meas vert horiz height width a -> Householder meas vert horiz height width a fromMatrix (Array shape@(Layout.Full order extent) a) = uncurry Householder $ Array.unsafeCreateWithSizeAndResult (snd $ Layout.rectangularDiagonal extent) $ \_ tauPtr -> ArrayIO.unsafeCreate (Layout.Split Layout.Reflector order extent) $ \qrPtr -> evalContT $ do let (m,n) = Layout.dimensions shape mPtr <- Call.cint m nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim m liftIO $ do copyBlock (m*n) aPtr qrPtr case order of RowMajor -> withAutoWorkspaceInfo errorCodeMsg "gelqf" $ LapackGen.gelqf mPtr nPtr qrPtr ldaPtr tauPtr ColumnMajor -> withAutoWorkspaceInfo errorCodeMsg "geqrf" $ LapackGen.geqrf mPtr nPtr qrPtr ldaPtr tauPtr determinantR :: (Extent.Measure meas, Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => Householder meas vert Extent.Small height width a -> a determinantR = Split.determinantR . split_ {- For complex numbers LAPACK uses not exactly reflections, i.e. the determinants of the primitive transformations are not necessarily -1. It holds: det(I-tau*v*v^H) = 1-tau*v^H*v because of https://en.wikipedia.org/wiki/Sylvester's_determinant_theorem simple proof from: https://en.wikipedia.org/wiki/Matrix_determinant_lemma I 0 . I+u*vt u . I 0 = I+u*vt u . I 0 = I u vt 1 0 1 -vt 1 vt+vt*u*vt vt*u+1 -vt 1 0 vt*u+1 We already know: v^H*v is real and greater or equal to 1, because v[i] = 1, and determinant has absolute value 1. Let k = v^H*v. For which real k lies 1-tau*k on the unit circle? (1-taur*k)^2 + (taui*k)^2 = 1 1-2*taur*k+(taur^2+taui^2)*k^2 = 1 (taur^2 + taui^2)*k^2 - 2*taur*k = 0 (k/=0) (taur^2 + taui^2)*k - 2*taur = 0 k = 2*taur / (taur^2 + taui^2) 1-tau*k = (taur^2 + taui^2 - tau*2*taur) / (taur^2 + taui^2) = (taur^2 + taui^2 - 2*(taur+i*taui)*taur) / (taur^2 + taui^2) = (-taur^2 + taui^2 - 2*(i*taui)*taur) / (taur^2 + taui^2) = -(taur + i*taui)^2 / (taur^2 + taui^2) -} determinant :: (Shape.C sh, Class.Floating a) => HouseholderFlex lower upper Extent.Shape Extent.Small Extent.Small sh sh a -> a determinant (Householder tau split) = List.foldl' (*) (Split.determinantR split) $ (case Layout.splitOrder $ Array.shape split of RowMajor -> map conjugate ColumnMajor -> id) $ map (negate.(^(2::Int)).signum) $ filter (not . isZero) $ Array.toList tau determinantAbsolute :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Householder meas vert horiz height width a -> RealOf a determinantAbsolute = absolute . either determinantR (const zero) . caseTallWide leastSquares :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) => HouseholderFlex lower upper meas horiz Extent.Small height width a -> Full meas vert horiz height nrhs a -> Full meas vert horiz width nrhs a leastSquares qr = case Matrix.extent qr of ExtentPriv.Square _ -> leastSquaresAux qr ExtentPriv.Separate _ _ -> Basic.mapHeight deconsUnchecked . leastSquaresAux (mapWidth Unchecked qr) leastSquaresAux :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) => HouseholderFlex lower upper meas horiz Extent.Small height width a -> Full meas vert horiz height nrhs a -> Full meas vert horiz width nrhs a leastSquaresAux qr = tallSolveR NonTransposed NonConjugated qr . tallMultiplyQAdjoint qr minimumNorm :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) => HouseholderFlex lower upper meas vert Extent.Small width height a -> Full meas vert horiz height nrhs a -> Full meas vert horiz width nrhs a minimumNorm qr = case Matrix.extent qr of ExtentPriv.Square _ -> minimumNormAux qr ExtentPriv.Separate _ _ -> Basic.mapHeight deconsUnchecked . minimumNormAux (mapHeight Unchecked qr) minimumNormAux :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) => HouseholderFlex lower upper meas vert Extent.Small width height a -> Full meas vert horiz height nrhs a -> Full meas vert horiz width nrhs a minimumNormAux qr = tallMultiplyQ qr . tallSolveR Transposed Conjugated qr takeRows :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Eq fuse, Shape.C fuse, Shape.C height, Shape.C width, Class.Floating a) => Extent meas Extent.Small horiz height fuse -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a takeRows extentA (Array (Layout.Full order extentB) b) = case Extent.fuse (ExtentPriv.weakenWide extentA) extentB of Nothing -> error "Householder.takeRows: heights mismatch" Just extentC -> Basic.takeSub (Extent.height extentB) 0 b (Layout.Full order extentC) addRows :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Eq fuse, Shape.C fuse, Shape.C height, Shape.C width, Class.Floating a) => Extent meas vert Extent.Small height fuse -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a addRows extentA (Array shapeB@(Layout.Full order extentB) b) = case Extent.fuse (ExtentPriv.weakenTall extentA) extentB of Nothing -> error "Householder.addRows: heights mismatch" Just extentC -> Array.unsafeCreateWithSize (Layout.Full order extentC) $ \cSize cPtr -> withForeignPtr b $ \bPtr -> case order of RowMajor -> do let bSize = Shape.size shapeB copyBlock bSize bPtr cPtr fill zero (cSize - bSize) (advancePtr cPtr bSize) ColumnMajor -> do let n = Shape.size $ Extent.width extentB mb = Shape.size $ Extent.height extentB mc = Shape.size $ Extent.height extentC copySubMatrix mb n mb bPtr mc cPtr evalContT $ do uploPtr <- Call.char 'A' mPtr <- Call.cint (mc-mb) nPtr <- Call.cint n ldcPtr <- Call.leadingDim mc zPtr <- Call.number zero liftIO $ LapackGen.laset uploPtr mPtr nPtr zPtr zPtr (advancePtr cPtr mb) ldcPtr extractQ :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Householder meas vert horiz height width a -> MatrixPriv.Square height a extractQ (Householder tau (Array (Layout.Split _ order extent) qr)) = extractQAux tau (Extent.width extent) order (Extent.square $ Extent.height extent) qr tallExtractQ :: (Extent.Measure meas, Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => Householder meas vert Extent.Small height width a -> Full meas vert Extent.Small height width a tallExtractQ (Householder tau (Array (Layout.Split _ order extent) qr)) = extractQAux tau (Extent.width extent) order extent qr extractQAux :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Extent.Measure measA, Extent.C vertA, Extent.C horizA, Shape.C height, Shape.C width, Shape.C widthQR, Class.Floating a) => Banded.RectangularDiagonal measA vertA horizA height widthQR a -> widthQR -> Order -> Extent meas vert horiz height width -> ForeignPtr a -> Full meas vert horiz height width a extractQAux (Array widthTau tau) widthQR order extent qr = Array.unsafeCreate (Layout.Full order extent) $ \qPtr -> do let (height,width) = Extent.dimensions extent let m = Shape.size height let n = Shape.size width let k = Shape.size widthTau evalContT $ do mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- Call.cint k qrPtr <- ContT $ withForeignPtr qr tauPtr <- ContT $ withForeignPtr tau case order of RowMajor -> do ldaPtr <- Call.leadingDim n liftIO $ do copySubMatrix k m (Shape.size widthQR) qrPtr n qPtr withAutoWorkspaceInfo errorCodeMsg "unglq" $ LapackGen.unglq nPtr mPtr kPtr qPtr ldaPtr tauPtr ColumnMajor -> do ldaPtr <- Call.leadingDim m liftIO $ do copyBlock (m*k) qrPtr qPtr withAutoWorkspaceInfo errorCodeMsg "ungqr" $ LapackGen.ungqr mPtr nPtr kPtr qPtr ldaPtr tauPtr tallMultiplyQ :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) => HouseholderFlex lower upper meas vert Extent.Small height fuse a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a tallMultiplyQ qr = multiplyQ NonTransposed NonConjugated qr . addRows (Matrix.extent qr) tallMultiplyQAdjoint :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) => HouseholderFlex lower upper meas horiz Extent.Small fuse height a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a tallMultiplyQAdjoint qr = takeRows (Extent.transpose $ Matrix.extent qr) . multiplyQ Transposed Conjugated qr multiplyQ :: (Extent.Measure measA, Extent.C vertA, Extent.C horizA, Shape.C widthA, Extent.Measure measB, Extent.C vertB, Extent.C horizB, Shape.C widthB, Shape.C height, Eq height, Class.Floating a) => Transposition -> Conjugation -> HouseholderFlex lower upper measA vertA horizA height widthA a -> Full measB vertB horizB height widthB a -> Full measB vertB horizB height widthB a multiplyQ transposed conjugated (Householder (Array widthTau tau) (Array shapeA@(Layout.Split _ orderA extentA) qr)) (Array shapeB@(Layout.Full orderB extentB) b) = Array.unsafeCreateWithSize shapeB $ \cSize cPtr -> do let (heightA,widthA) = Extent.dimensions extentA let (height,width) = Extent.dimensions extentB Call.assert "Householder.multiplyQ: height shapes mismatch" (heightA == height) let (side,(m,n)) = sideSwapFromOrder orderB (Shape.size height, Shape.size width) evalContT $ do sidePtr <- Call.char side mPtr <- Call.cint m nPtr <- Call.cint n let k = Shape.size widthTau kPtr <- Call.cint k transPtr <- Call.char $ adjointFromTranspose qr $ transposed <> if orderA==orderB then NonTransposed else Transposed (qrPtr,tauPtr) <- if (orderA==orderB) == (transposed==NonTransposed && conjugated==NonConjugated || transposed==Transposed && conjugated==Conjugated) then liftA2 (,) (ContT $ withForeignPtr qr) (ContT $ withForeignPtr tau) else liftA2 (,) (conjugateToTemp (Shape.size shapeA) qr) (conjugateToTemp k tau) bPtr <- ContT $ withForeignPtr b ldcPtr <- Call.leadingDim m liftIO $ copyBlock cSize bPtr cPtr case orderA of ColumnMajor -> do ldaPtr <- Call.leadingDim $ Shape.size heightA liftIO $ withAutoWorkspaceInfo errorCodeMsg "unmqr" $ LapackGen.unmqr sidePtr transPtr mPtr nPtr kPtr qrPtr ldaPtr tauPtr cPtr ldcPtr RowMajor -> do ldaPtr <- Call.leadingDim $ Shape.size widthA -- work-around for https://github.com/Reference-LAPACK/lapack/issues/260 liftIO $ when (k>0) $ withAutoWorkspaceInfo errorCodeMsg "unmlq" $ LapackGen.unmlq sidePtr transPtr mPtr nPtr kPtr qrPtr ldaPtr tauPtr cPtr ldcPtr adjointFromTranspose :: (Class.Floating a) => f a -> Transposition -> Char adjointFromTranspose qr Transposed = invChar qr adjointFromTranspose _ NonTransposed = 'N' invChar :: (Class.Floating a) => f a -> Char invChar f = caseRealComplexFunc f 'T' 'C' extractR :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => HouseholderFlex lower upper meas vert horiz height width a -> Full meas vert horiz height width a extractR = Split.extractTriangle (Right Layout.Triangle) . split_ tallExtractR :: (Extent.Measure meas, Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => Householder meas vert Extent.Small height width a -> Upper width a tallExtractR = Split.tallExtractR . split_ tallMultiplyR :: (Extent.Measure measA, Extent.C vertA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C heightA, Shape.C widthB, Class.Floating a) => Transposition -> HouseholderFlex lower upper measA vertA Extent.Small heightA height a -> Full meas vert horiz height widthB a -> Full meas vert horiz height widthB a tallMultiplyR transposed = Split.tallMultiplyR transposed . split_ tallSolveR :: (Extent.Measure measA, Extent.C vertA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) => Transposition -> Conjugation -> HouseholderFlex lower upper measA vertA Extent.Small height width a -> Full meas vert horiz width nrhs a -> Full meas vert horiz width nrhs a tallSolveR transposed conjugated = Split.tallSolveR transposed conjugated . split_ instance Matrix.Box Hh where type BoxExtra Hh extra = extra ~ () extent = Layout.splitExtent . Array.shape . split_ instance Matrix.ToQuadratic Hh where heightToQuadratic (Householder tau split) = Householder (Array.mapShape (layoutTauSquare . Layout.bandedHeight) tau) (Split.heightToQuadratic split) widthToQuadratic (Householder tau split) = Householder (Array.mapShape (layoutTauSquare . Layout.bandedWidth) tau) (Split.widthToQuadratic split) layoutTauSquare :: sh -> Layout.Diagonal sh layoutTauSquare = LayoutPub.diagonal Layout.ColumnMajor instance Matrix.MapExtent Hh where type MapExtentExtra Hh extra = extra ~ () type MapExtentStrip Hh strip = () mapExtent = mapExtent . ExtentStrict.apply instance Multiply.MultiplyVector Hh where type MultiplyVectorExtra Hh extra = extra ~ () matrixVector qr x = Array.mapShape deconsUnchecked $ Basic.unliftColumn Layout.ColumnMajor (multiplyQ NonTransposed NonConjugated $ uncheck qr) $ Array.mapShape Unchecked $ Basic.multiplyVector (extractR qr) x vectorMatrix x qr = Basic.multiplyVector (Basic.transpose $ extractR qr) $ Basic.unliftColumn Layout.ColumnMajor (multiplyQ Transposed NonConjugated qr) x instance Multiply.MultiplySquare Hh where type MultiplySquareExtra Hh extra = extra ~ () squareFull qr = ArrMatrix.lift1 $ multiplyQ NonTransposed NonConjugated qr . tallMultiplyR NonTransposed qr fullSquare = flip $ \qr -> ArrMatrix.lift1 $ Basic.transpose . tallMultiplyR Transposed qr . multiplyQ Transposed NonConjugated qr . Basic.transpose instance Divide.Determinant Hh where type DeterminantExtra Hh extra = extra ~ () determinant = determinant instance Divide.Solve Hh where type SolveExtra Hh extra = extra ~ () solveRight = ArrMatrix.lift1 . leastSquares . mapExtent ExtentPriv.fromSquare solveLeft = flip $ \a -> ArrMatrix.lift1 $ Basic.adjoint . minimumNorm (mapExtent ExtentPriv.fromSquare a) . Basic.adjoint