{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
module Numeric.LAPACK.Orthogonal.Plain where
import qualified Numeric.LAPACK.Matrix.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Type as Matrix
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Private as MatrixPriv
import qualified Numeric.LAPACK.Matrix.Layout as LayoutPub
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict
import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Split as Split
import Numeric.LAPACK.Output ((/+/))
import Numeric.LAPACK.Matrix.Plain.Format (formatArray)
import Numeric.LAPACK.Matrix.Type (Matrix, FormatMatrix(formatMatrix))
import Numeric.LAPACK.Matrix.Triangular.Basic (Upper)
import Numeric.LAPACK.Matrix.Layout.Private
(Order(RowMajor, ColumnMajor), sideSwapFromOrder)
import Numeric.LAPACK.Matrix.Extent.Private (Extent)
import Numeric.LAPACK.Matrix.Modifier
(Transposition(NonTransposed, Transposed),
Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Scalar (RealOf, zero, isZero, absolute, conjugate)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked), deconsUnchecked)
import Numeric.LAPACK.Private
(fill, copySubMatrix, copyBlock, conjugateToTemp, caseRealComplexFunc,
withAutoWorkspaceInfo, errorCodeMsg)
import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Storable (Storable)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (liftA2)
import qualified Data.List as List
import Data.Monoid ((<>))
data Hh
data instance Matrix Hh xl xu lower upper meas vert horiz height width a where
Householder ::
Banded.RectangularDiagonal meas vert horiz height width a ->
SplitArray meas vert horiz height width a ->
HouseholderFlex lower upper meas vert horiz height width a
deriving instance
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Storable a,
Show height, Show width, Show a) =>
Show (Matrix Hh xl xu lower upper meas vert horiz height width a)
type SplitArray meas vert horiz height width a =
Split.Split Layout.Reflector meas vert horiz height width a
split_ ::
Matrix Hh xl xu lower upper meas vert horiz height width a ->
SplitArray meas vert horiz height width a
split_ (Householder _tau split) = split
type HouseholderFlex = Matrix Hh () ()
type Householder = HouseholderFlex Layout.Filled Layout.Filled
type General height width =
Householder Extent.Size Extent.Big Extent.Big height width
type Tall height width =
Householder Extent.Size Extent.Big Extent.Small height width
type Wide height width =
Householder Extent.Size Extent.Small Extent.Big height width
type LiberalSquare height width = SquareMeas Extent.Size height width
type Square sh = SquareMeas Extent.Shape sh sh
type SquareMeas meas height width =
Householder meas Extent.Small Extent.Small height width
mapExtent ::
(Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
(Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
ExtentPriv.Map measA vertA horizA measB vertB horizB height width ->
HouseholderFlex lower upper measA vertA horizA height width a ->
HouseholderFlex lower upper measB vertB horizB height width a
mapExtent f (Householder tau split) =
Householder (Banded.mapExtent f tau) $ Split.mapExtent f split
mapHeight ::
(Extent.C vert, Extent.C horiz) =>
(heightA -> heightB) ->
HouseholderFlex lower upper Extent.Size vert horiz heightA width a ->
HouseholderFlex lower upper Extent.Size vert horiz heightB width a
mapHeight f (Householder tau split) =
Householder (Banded.mapHeight f tau) (Split.mapHeight f split)
mapWidth ::
(Extent.C vert, Extent.C horiz) =>
(widthA -> widthB) ->
HouseholderFlex lower upper Extent.Size vert horiz height widthA a ->
HouseholderFlex lower upper Extent.Size vert horiz height widthB a
mapWidth f (Householder tau split) =
Householder (Banded.mapWidth f tau) (Split.mapWidth f split)
uncheck ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
HouseholderFlex lower upper meas vert horiz height width a ->
HouseholderFlex lower upper meas vert horiz
(Unchecked height) (Unchecked width) a
uncheck (Householder tau split) =
Householder
(Banded.mapExtentSizes (ExtentPriv.mapWrap Unchecked Unchecked) tau)
(Split.uncheck split)
caseTallWide ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width) =>
Householder meas vert horiz height width a ->
Either (Tall height width a) (Wide height width a)
caseTallWide (Householder tau (Array shape a)) =
let consHouse taub b newShape =
Householder
(Array.mapShape
(\bandShape ->
bandShape{Layout.bandedExtent = Layout.splitExtent newShape})
taub) $
Array newShape b
in either (Left . consHouse tau a) (Right . consHouse tau a) $
Layout.caseTallWideSplit shape
instance FormatMatrix Hh where
formatMatrix fmt (Householder tau m) =
formatArray fmt (Array.mapShape (Shape.ZeroBased . Shape.size) tau)
/+/
formatArray fmt m
fromMatrix ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Full meas vert horiz height width a ->
Householder meas vert horiz height width a
fromMatrix (Array shape@(Layout.Full order extent) a) =
uncurry Householder $
Array.unsafeCreateWithSizeAndResult
(snd $ Layout.rectangularDiagonal extent) $ \_ tauPtr ->
ArrayIO.unsafeCreate
(Layout.Split Layout.Reflector order extent) $ \qrPtr ->
evalContT $ do
let (m,n) = Layout.dimensions shape
mPtr <- Call.cint m
nPtr <- Call.cint n
aPtr <- ContT $ withForeignPtr a
ldaPtr <- Call.leadingDim m
liftIO $ do
copyBlock (m*n) aPtr qrPtr
case order of
RowMajor ->
withAutoWorkspaceInfo errorCodeMsg "gelqf" $
LapackGen.gelqf mPtr nPtr qrPtr ldaPtr tauPtr
ColumnMajor ->
withAutoWorkspaceInfo errorCodeMsg "geqrf" $
LapackGen.geqrf mPtr nPtr qrPtr ldaPtr tauPtr
determinantR ::
(Extent.Measure meas, Extent.C vert,
Shape.C height, Shape.C width, Class.Floating a) =>
Householder meas vert Extent.Small height width a -> a
determinantR = Split.determinantR . split_
determinant ::
(Shape.C sh, Class.Floating a) =>
HouseholderFlex lower upper Extent.Shape Extent.Small Extent.Small sh sh a ->
a
determinant (Householder tau split) =
List.foldl' (*) (Split.determinantR split) $
(case Layout.splitOrder $ Array.shape split of
RowMajor -> map conjugate
ColumnMajor -> id) $
map (negate.(^(2::Int)).signum) $
filter (not . isZero) $ Array.toList tau
determinantAbsolute ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width,
Class.Floating a) =>
Householder meas vert horiz height width a -> RealOf a
determinantAbsolute =
absolute . either determinantR (const zero) . caseTallWide
leastSquares ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Shape.C nrhs,
Class.Floating a) =>
HouseholderFlex lower upper meas horiz Extent.Small height width a ->
Full meas vert horiz height nrhs a ->
Full meas vert horiz width nrhs a
leastSquares qr =
case Matrix.extent qr of
ExtentPriv.Square _ -> leastSquaresAux qr
ExtentPriv.Separate _ _ ->
Basic.mapHeight deconsUnchecked .
leastSquaresAux (mapWidth Unchecked qr)
leastSquaresAux ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Eq width, Shape.C nrhs,
Class.Floating a) =>
HouseholderFlex lower upper meas horiz Extent.Small height width a ->
Full meas vert horiz height nrhs a ->
Full meas vert horiz width nrhs a
leastSquaresAux qr =
tallSolveR NonTransposed NonConjugated qr . tallMultiplyQAdjoint qr
minimumNorm ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Shape.C nrhs,
Class.Floating a) =>
HouseholderFlex lower upper meas vert Extent.Small width height a ->
Full meas vert horiz height nrhs a ->
Full meas vert horiz width nrhs a
minimumNorm qr =
case Matrix.extent qr of
ExtentPriv.Square _ -> minimumNormAux qr
ExtentPriv.Separate _ _ ->
Basic.mapHeight deconsUnchecked .
minimumNormAux (mapHeight Unchecked qr)
minimumNormAux ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Eq width, Shape.C nrhs,
Class.Floating a) =>
HouseholderFlex lower upper meas vert Extent.Small width height a ->
Full meas vert horiz height nrhs a ->
Full meas vert horiz width nrhs a
minimumNormAux qr = tallMultiplyQ qr . tallSolveR Transposed Conjugated qr
takeRows ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Eq fuse, Shape.C fuse, Shape.C height, Shape.C width, Class.Floating a) =>
Extent meas Extent.Small horiz height fuse ->
Full meas vert horiz fuse width a ->
Full meas vert horiz height width a
takeRows extentA (Array (Layout.Full order extentB) b) =
case Extent.fuse (ExtentPriv.weakenWide extentA) extentB of
Nothing -> error "Householder.takeRows: heights mismatch"
Just extentC ->
Basic.takeSub
(Extent.height extentB) 0 b (Layout.Full order extentC)
addRows ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Eq fuse, Shape.C fuse, Shape.C height, Shape.C width, Class.Floating a) =>
Extent meas vert Extent.Small height fuse ->
Full meas vert horiz fuse width a ->
Full meas vert horiz height width a
addRows extentA (Array shapeB@(Layout.Full order extentB) b) =
case Extent.fuse (ExtentPriv.weakenTall extentA) extentB of
Nothing -> error "Householder.addRows: heights mismatch"
Just extentC ->
Array.unsafeCreateWithSize (Layout.Full order extentC) $
\cSize cPtr ->
withForeignPtr b $ \bPtr ->
case order of
RowMajor -> do
let bSize = Shape.size shapeB
copyBlock bSize bPtr cPtr
fill zero (cSize - bSize) (advancePtr cPtr bSize)
ColumnMajor -> do
let n = Shape.size $ Extent.width extentB
mb = Shape.size $ Extent.height extentB
mc = Shape.size $ Extent.height extentC
copySubMatrix mb n mb bPtr mc cPtr
evalContT $ do
uploPtr <- Call.char 'A'
mPtr <- Call.cint (mc-mb)
nPtr <- Call.cint n
ldcPtr <- Call.leadingDim mc
zPtr <- Call.number zero
liftIO $
LapackGen.laset uploPtr mPtr nPtr zPtr zPtr
(advancePtr cPtr mb) ldcPtr
extractQ ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
Householder meas vert horiz height width a -> MatrixPriv.Square height a
extractQ
(Householder tau (Array (Layout.Split _ order extent) qr)) =
extractQAux tau (Extent.width extent) order
(Extent.square $ Extent.height extent) qr
tallExtractQ ::
(Extent.Measure meas, Extent.C vert,
Shape.C height, Shape.C width, Class.Floating a) =>
Householder meas vert Extent.Small height width a ->
Full meas vert Extent.Small height width a
tallExtractQ
(Householder tau (Array (Layout.Split _ order extent) qr)) =
extractQAux tau (Extent.width extent) order extent qr
extractQAux ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Extent.Measure measA, Extent.C vertA, Extent.C horizA,
Shape.C height, Shape.C width, Shape.C widthQR,
Class.Floating a) =>
Banded.RectangularDiagonal measA vertA horizA height widthQR a ->
widthQR ->
Order -> Extent meas vert horiz height width -> ForeignPtr a ->
Full meas vert horiz height width a
extractQAux (Array widthTau tau) widthQR order extent qr =
Array.unsafeCreate (Layout.Full order extent) $ \qPtr -> do
let (height,width) = Extent.dimensions extent
let m = Shape.size height
let n = Shape.size width
let k = Shape.size widthTau
evalContT $ do
mPtr <- Call.cint m
nPtr <- Call.cint n
kPtr <- Call.cint k
qrPtr <- ContT $ withForeignPtr qr
tauPtr <- ContT $ withForeignPtr tau
case order of
RowMajor -> do
ldaPtr <- Call.leadingDim n
liftIO $ do
copySubMatrix k m (Shape.size widthQR) qrPtr n qPtr
withAutoWorkspaceInfo errorCodeMsg "unglq" $
LapackGen.unglq nPtr mPtr kPtr qPtr ldaPtr tauPtr
ColumnMajor -> do
ldaPtr <- Call.leadingDim m
liftIO $ do
copyBlock (m*k) qrPtr qPtr
withAutoWorkspaceInfo errorCodeMsg "ungqr" $
LapackGen.ungqr mPtr nPtr kPtr qPtr ldaPtr tauPtr
tallMultiplyQ ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse,
Class.Floating a) =>
HouseholderFlex lower upper meas vert Extent.Small height fuse a ->
Full meas vert horiz fuse width a ->
Full meas vert horiz height width a
tallMultiplyQ qr =
multiplyQ NonTransposed NonConjugated qr . addRows (Matrix.extent qr)
tallMultiplyQAdjoint ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) =>
HouseholderFlex lower upper meas horiz Extent.Small fuse height a ->
Full meas vert horiz fuse width a ->
Full meas vert horiz height width a
tallMultiplyQAdjoint qr =
takeRows (Extent.transpose $ Matrix.extent qr) .
multiplyQ Transposed Conjugated qr
multiplyQ ::
(Extent.Measure measA, Extent.C vertA, Extent.C horizA, Shape.C widthA,
Extent.Measure measB, Extent.C vertB, Extent.C horizB, Shape.C widthB,
Shape.C height, Eq height, Class.Floating a) =>
Transposition -> Conjugation ->
HouseholderFlex lower upper measA vertA horizA height widthA a ->
Full measB vertB horizB height widthB a ->
Full measB vertB horizB height widthB a
multiplyQ transposed conjugated
(Householder
(Array widthTau tau)
(Array shapeA@(Layout.Split _ orderA extentA) qr))
(Array shapeB@(Layout.Full orderB extentB) b) =
Array.unsafeCreateWithSize shapeB $ \cSize cPtr -> do
let (heightA,widthA) = Extent.dimensions extentA
let (height,width) = Extent.dimensions extentB
Call.assert "Householder.multiplyQ: height shapes mismatch"
(heightA == height)
let (side,(m,n)) =
sideSwapFromOrder orderB (Shape.size height, Shape.size width)
evalContT $ do
sidePtr <- Call.char side
mPtr <- Call.cint m
nPtr <- Call.cint n
let k = Shape.size widthTau
kPtr <- Call.cint k
transPtr <-
Call.char $ adjointFromTranspose qr $
transposed <> if orderA==orderB then NonTransposed else Transposed
(qrPtr,tauPtr) <-
if (orderA==orderB)
==
(transposed==NonTransposed && conjugated==NonConjugated
||
transposed==Transposed && conjugated==Conjugated)
then
liftA2 (,)
(ContT $ withForeignPtr qr)
(ContT $ withForeignPtr tau)
else
liftA2 (,)
(conjugateToTemp (Shape.size shapeA) qr)
(conjugateToTemp k tau)
bPtr <- ContT $ withForeignPtr b
ldcPtr <- Call.leadingDim m
liftIO $ copyBlock cSize bPtr cPtr
case orderA of
ColumnMajor -> do
ldaPtr <- Call.leadingDim $ Shape.size heightA
liftIO $ withAutoWorkspaceInfo errorCodeMsg "unmqr" $
LapackGen.unmqr sidePtr transPtr
mPtr nPtr kPtr qrPtr ldaPtr tauPtr cPtr ldcPtr
RowMajor -> do
ldaPtr <- Call.leadingDim $ Shape.size widthA
liftIO $ when (k>0) $
withAutoWorkspaceInfo errorCodeMsg "unmlq" $
LapackGen.unmlq sidePtr transPtr
mPtr nPtr kPtr qrPtr ldaPtr tauPtr cPtr ldcPtr
adjointFromTranspose :: (Class.Floating a) => f a -> Transposition -> Char
adjointFromTranspose qr Transposed = invChar qr
adjointFromTranspose _ NonTransposed = 'N'
invChar :: (Class.Floating a) => f a -> Char
invChar f = caseRealComplexFunc f 'T' 'C'
extractR ::
(Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Class.Floating a) =>
HouseholderFlex lower upper meas vert horiz height width a ->
Full meas vert horiz height width a
extractR = Split.extractTriangle (Right Layout.Triangle) . split_
tallExtractR ::
(Extent.Measure meas, Extent.C vert,
Shape.C height, Shape.C width, Class.Floating a) =>
Householder meas vert Extent.Small height width a -> Upper width a
tallExtractR = Split.tallExtractR . split_
tallMultiplyR ::
(Extent.Measure measA, Extent.C vertA,
Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Eq height,
Shape.C heightA, Shape.C widthB, Class.Floating a) =>
Transposition ->
HouseholderFlex lower upper measA vertA Extent.Small heightA height a ->
Full meas vert horiz height widthB a ->
Full meas vert horiz height widthB a
tallMultiplyR transposed = Split.tallMultiplyR transposed . split_
tallSolveR ::
(Extent.Measure measA, Extent.C vertA,
Extent.Measure meas, Extent.C vert, Extent.C horiz,
Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) =>
Transposition -> Conjugation ->
HouseholderFlex lower upper measA vertA Extent.Small height width a ->
Full meas vert horiz width nrhs a -> Full meas vert horiz width nrhs a
tallSolveR transposed conjugated =
Split.tallSolveR transposed conjugated . split_
instance Matrix.Box Hh where
extent = Layout.splitExtent . Array.shape . split_
instance Matrix.ToQuadratic Hh where
heightToQuadratic (Householder tau split) =
Householder
(Array.mapShape (layoutTauSquare . Layout.bandedHeight) tau)
(Split.heightToQuadratic split)
widthToQuadratic (Householder tau split) =
Householder
(Array.mapShape (layoutTauSquare . Layout.bandedWidth) tau)
(Split.widthToQuadratic split)
layoutTauSquare :: sh -> Layout.Diagonal sh
layoutTauSquare = LayoutPub.diagonal Layout.ColumnMajor
instance (xl ~ (), xu ~ ()) => Matrix.MapExtent Hh xl xu lower upper where
mapExtent = mapExtent . ExtentStrict.apply
instance (xl ~ (), xu ~ ()) => Multiply.MultiplyVector Hh xl xu where
matrixVector qr x =
Array.mapShape deconsUnchecked $
Basic.unliftColumn Layout.ColumnMajor
(multiplyQ NonTransposed NonConjugated $ uncheck qr) $
Array.mapShape Unchecked $
Basic.multiplyVector (extractR qr) x
vectorMatrix x qr =
Basic.multiplyVector (Basic.transpose $ extractR qr) $
Basic.unliftColumn Layout.ColumnMajor
(multiplyQ Transposed NonConjugated qr) x
instance (xl ~ (), xu ~ ()) => Multiply.MultiplySquare Hh xl xu where
squareFull qr =
ArrMatrix.lift1 $
multiplyQ NonTransposed NonConjugated qr .
tallMultiplyR NonTransposed qr
fullSquare = flip $ \qr ->
ArrMatrix.lift1 $
Basic.transpose .
tallMultiplyR Transposed qr .
multiplyQ Transposed NonConjugated qr .
Basic.transpose
instance (xl ~ (), xu ~ ()) => Divide.Determinant Hh xl xu where
determinant = determinant
instance (xl ~ (), xu ~ ()) => Divide.Solve Hh xl xu where
solveRight = ArrMatrix.lift1 . leastSquares . mapExtent ExtentPriv.fromSquare
solveLeft =
flip $ \a -> ArrMatrix.lift1 $
Basic.adjoint .
minimumNorm (mapExtent ExtentPriv.fromSquare a) .
Basic.adjoint