{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Numeric.Matrix.SVD ( MatrixSVD (..), SVD (..) , svd1, svd2, svd3, svd3q ) where import Control.Monad import Control.Monad.ST import Data.Kind import Numeric.Basics import Numeric.DataFrame.Internal.PrimArray import Numeric.DataFrame.ST import Numeric.DataFrame.SubSpace import Numeric.DataFrame.Type import Numeric.Dimensions import Numeric.Matrix.Bidiagonal import Numeric.Matrix.Internal import Numeric.Quaternion.Internal import Numeric.Scalar.Internal import Numeric.Subroutine.Sort import Numeric.Tuple import Numeric.Vector.Internal -- | Result of SVD factorization -- @ M = svdU %* asDiag svdS %* transpose svdV @. -- -- Invariants: -- -- * Singular values `svdS` are in non-increasing order and are non-negative. -- * svdU and svdV are orthogonal matrices -- * det svdU == 1 -- -- NB: data SVD (t :: Type) (n :: Nat) (m :: Nat) = SVD { svdU :: Matrix t n n -- ^ Left-singular basis matrix , svdS :: Vector t (Min n m) -- ^ Vector of singular values , svdV :: Matrix t m m -- ^ Right-singular basis matrix } deriving instance ( Show t, PrimBytes t , KnownDim n, KnownDim m, KnownDim (Min n m)) => Show (SVD t n m) deriving instance ( Eq (Matrix t n n) , Eq (Matrix t m m) , Eq (Vector t (Min n m))) => Eq (SVD t n m) class RealFloatExtras t => MatrixSVD (t :: Type) (n :: Nat) (m :: Nat) where -- | Compute SVD factorization of a matrix svd :: IterativeMethod => Matrix t n m -> SVD t n m -- | Obvious dummy implementation of SVD for 1x1 matrices svd1 :: (PrimBytes t, Num t, Eq t) => Matrix t 1 1 -> SVD t 1 1 svd1 m = SVD { svdU = 1 , svdS = broadcast $ abs x , svdV = broadcast $ if x == 0 then 1 else signum x } where x = ixOff 0 m -- | SVD of a 2x2 matrix can be computed analytically -- -- Related discussion: -- -- https://scicomp.stackexchange.com/questions/8899/robust-algorithm-for-2-times-2-svd/ -- -- https://ieeexplore.ieee.org/document/486688 svd2 :: forall t . RealFloatExtras t => Matrix t 2 2 -> SVD t 2 2 svd2 (DF2 (DF2 m00 m01) (DF2 m10 m11)) = SVD { svdU = DF2 (DF2 uc us) (DF2 (negate us) uc) , svdS = DF2 sigma1 (abs sigma2) , svdV = DF2 (DF2 vc (sg2 (negate vs))) (DF2 vs (sg2 vc )) } where x1 = m00 - m11 -- 2F x2 = m00 + m11 -- 2E y1 = m01 + m10 -- 2G y2 = m01 - m10 -- 2H yy = y1*y2 h1 = hypot x1 y1 -- h1 >= abs x1 h2 = hypot x2 y2 -- h2 >= abs x2 sigma1 = 0.5 * (h2 + h1) -- sigma1 >= abs sigma2 sigma2 = 0.5 * (h2 - h1) -- can be negative, which is accounted by sg2 sg2 = negateUnless (sigma2 >= 0) hx1 = h1 + x1 hx2 = h2 + x2 hxhx = hx1*hx2 hxy = hx1*y2 yhx = y1*hx2 (uc', us', vc', vs') = case (x1 > 0 || y1 /= 0, x2 > 0 || y2 /= 0) of (True , True ) -> (hxhx + yy, hxy - yhx, hxhx - yy, hxy + yhx) (True , False) -> (y1, hx1, -y1, hx1) (False, True ) -> (y2, -hx2, -y2, hx2) (False, False) -> (1, 0, -1, 0) ru = recip $ hypot uc' us' rv = recip $ hypot vc' vs' uc = uc' * ru us = us' * ru vc = vc' * rv vs = vs' * rv -- | Get SVD decomposition of a 3x3 matrix using `svd3q` function. -- -- This function reorders the singular components under the hood to make sure -- @s1 >= s2 >= s3 >= 0@. -- Thus, it has some overhead on top of `svd3q`. svd3 :: forall t . (Quaternion t, RealFloatExtras t) => Matrix t 3 3 -> SVD t 3 3 svd3 m = SVD { svdU = toMatrix33 u , svdS = DF3 s1 s2 s3' , svdV = neg3If (s3 < 0) (toMatrix33 v) } where (u, (DF3 s1 s2 s3), v) = svd3q m s3' = abs s3 neg3If :: Bool -> Matrix t 3 3 -> Matrix t 3 3 neg3If False = id neg3If True = ewmap @t @'[3] neg3 neg3 :: Vector t 3 -> Vector t 3 neg3 (DF3 a b c) = DF3 a b (negate c) -- | Get SVD decomposition of a 3x3 matrix, with orthogonal matrices U and V -- represented as quaternions. -- Important: U and V are bound to be rotations at the expense of the last -- singular value being possibly negative. -- -- This is an adoptation of a specialized 3x3 SVD algorithm described in -- "Computing the Singular Value Decomposition of 3x3 matrices -- with minimal branching and elementary floating point operations", -- by A. McAdams, A. Selle, R. Tamstorf, J. Teran, E. Sifakis. -- -- http://pages.cs.wisc.edu/~sifakis/papers/SVD_TR1690.pdf svd3q :: forall t . (Quaternion t, RealFloatExtras t) => Matrix t 3 3 -> (Quater t, Vector t 3, Quater t) svd3q m = (u, s, v) where v = jacobiEigenQ (transpose m %* m) (s, u) = uncurry fixSigns $ qrDecomposition3 (m %* toMatrix33 v) -- last bit: make sure s1 >= s2 >= 0 fixSigns :: Vector t 3 -> Quater t -> (Vector t 3, Quater t) fixSigns (DF3 s1 s2 s3) q@(Quater a b c d) = case (s1 >= 0, s2 >= 0) of (True , True ) -> (mk3 s1 s2 s3, q) (False, True ) -> (mk3 (negate s1) s2 (negate s3), Quater (-c) d a (-b)) (True , False) -> (mk3 s1 (negate s2) (negate s3), Quater d c (-b) (-a)) (False, False) -> (mk3 (negate s1) (negate s2) s3, Quater b (-a) d (-c)) -- one more thing: -- the singular values are ordered, but may have small errors; -- as a result adjacent values may seem to be out of order by a very small number mk3 :: Scalar t -> Scalar t -> Scalar t -> Vector t 3 mk3 s1 s2 s3' = case (s1 >= s2, s1 >= abs s3, s2 >= abs s3) of (True , True , True ) -> DF3 s1 s2 s3' -- s1 >= s2 >= s3 (True , True , False) -> DF3 s1 s3 (cs s2) -- s1 >= s3 > s2 (True , False, _ ) -> DF3 s3 s1 (cs s2) -- s3 > s1 >= s2 (False, True , True ) -> DF3 s2 s1 s3' -- s2 > s1 >= s3 (False, _ , False) -> DF3 s3 s2 (cs s1) -- s3 > s2 > s1 (False, False, True ) -> DF3 s2 s3 (cs s1) -- s2 >= s3 > s1 where s3 = abs s3' cs = negateUnless (s3' >= 0) -- | Approximate values for cos (a/2) and sin (a/2) of a Givens rotation for -- a 2x2 symmetric matrix. (Algorithm 2) jacobiGivensQ :: forall t . RealFloatExtras t => t -> t -> t -> (t, t) jacobiGivensQ aii aij ajj | g*sh*sh < ch*ch = (w * ch, w * sh) | otherwise = (c', s') where ch = 2 * (aii-ajj) sh = aij w = recip $ hypot ch sh g = 5.82842712474619 :: t -- 3 + sqrt 8 c' = 0.9238795325112867 :: t -- cos (pi/8) s' = 0.3826834323650898 :: t -- sin (pi/8) -- | A quaternion for a QR Givens iteration qrGivensQ :: forall t . RealFloatExtras t => t -> t -> (t, t) qrGivensQ a1 a2 | a1 < 0 = (sh * w, ch * w) | otherwise = (ch * w, sh * w) where rho2 = a1*a1 + a2*a2 sh = if rho2 > M_EPS then a2 else 0 ch = abs a1 + sqrt (max rho2 M_EPS) w = recip $ hypot ch sh -- TODO: consider something like a hypot -- | One iteration of the Jacobi algorithm on a symmetric 3x3 matrix -- -- The three words arguments are indices: -- 0 <= i /= j /= k <= 2 jacobiEigen3Iteration :: (Quaternion t, RealFloatExtras t) => Int -> Int -> Int -> STDataFrame s t '[3,3] -> ST s (Quater t) jacobiEigen3Iteration i j k sPtr = do sii <- readDataFrameOff sPtr ii sij <- readDataFrameOff sPtr ij sjj <- readDataFrameOff sPtr jj sik <- readDataFrameOff sPtr ik sjk <- readDataFrameOff sPtr jk -- Coefficients for a quaternion corresponding to a Givens rotation let (ch, sh) = jacobiGivensQ sii sij sjj a = ch*ch - sh*sh b = 2 * sh*ch aa = a * a ab = a * b bb = b * b -- update the matrix writeDataFrameOff sPtr ii $ aa * sii + 2 * ab * sij + bb * sjj writeDataFrameOff sPtr ij $ ab * (sjj - sii) + (aa - bb) * sij writeDataFrameOff sPtr jj $ bb * sii - 2 * ab * sij + aa * sjj writeDataFrameOff sPtr ik $ a * sik + b * sjk writeDataFrameOff sPtr jk $ a * sjk - b * sik -- write the quaternion qPtr <- unsafeThawDataFrame 0 writeDataFrameOff qPtr k (negate sh) writeDataFrameOff qPtr 3 ch fromVec4 <$> unsafeFreezeDataFrame qPtr where ii = i*3 + i ij = if i < j then i*3 + j else j*3 + i jj = j*3 + j ik = if i < k then i*3 + k else k*3 + i jk = if j < k then j*3 + k else k*3 + j -- | Total number of the Givens rotations during the Jacobi eigendecomposition -- part of the 3x3 SVD equals eigenItersX3*3. -- Value `eigenItersX3 = 6` corresponds to 18 iterations and gives a good precision. eigenItersX3 :: Int eigenItersX3 = 12 -- | Run a few iterations of the Jacobi algorithm on a real-valued 3x3 symmetric matrix. -- The eigenvectors basis of such matrix is orthogonal, and can be represented as -- a quaternion. jacobiEigenQ :: forall t . (Quaternion t, RealFloatExtras t) => Matrix t 3 3 -> Quater t jacobiEigenQ m = runST $ do mPtr <- thawDataFrame m q <- go eigenItersX3 mPtr 1 s1 <- readDataFrameOff mPtr 0 s2 <- readDataFrameOff mPtr 4 s3 <- readDataFrameOff mPtr 8 return $ sortQ s1 s2 s3 * q where go :: Int -> STDataFrame s t '[3,3] -> Quater t -> ST s (Quater t) go 0 _ q = pure q -- -- primitive cyclic iteration; -- -- fast, but the convergence is not perfect -- -- -- -- set eigenItersX3 = 6 for good precision -- go n p q = do -- q1 <- jacobiEigen3Iteration 0 1 2 p -- q2 <- jacobiEigen3Iteration 1 2 0 p -- q3 <- jacobiEigen3Iteration 2 0 1 p -- go (n - 1) p (q3 * q2 * q1 * q) -- Pick the largest element on lower triangle; -- slow because of branching, but has a better convergence -- -- set eigenItersX3 = 12 for good precision -- (slightly faster than the cyclic version with -O0) go n p q = do a10 <- abs <$> readDataFrameOff p 1 a20 <- abs <$> readDataFrameOff p 2 a21 <- abs <$> readDataFrameOff p 5 q' <- jiter n p a10 a20 a21 go (n - 1) p (q' * q) jiter :: Int -> STDataFrame s t '[3,3] -> Scalar t -> Scalar t -> Scalar t -> ST s (Quater t) jiter n p a10 a20 a21 | gt2 a10 a20 a21 = jacobiEigen3Iteration 0 1 2 p | gt2 a20 a10 a21 = jacobiEigen3Iteration 2 0 1 p | gt2 a21 a10 a20 = jacobiEigen3Iteration 1 2 0 p | otherwise = case mod n 3 of 0 -> jacobiEigen3Iteration 0 1 2 p 1 -> jacobiEigen3Iteration 2 0 1 p _ -> jacobiEigen3Iteration 1 2 0 p gt2 :: Scalar t -> Scalar t -> Scalar t -> Bool gt2 a b c = case compare a b of GT -> a >= c EQ -> a > c LT -> False -- Make such a quaternion that rotates the matrix so that: -- abs s1 >= abs s2 >= abs s3 -- Note, the corresponding singular values may be negative, which must be -- taken into account later. sortQ :: Scalar t -> Scalar t -> Scalar t -> Quater t sortQ s1 s2 s3 = sortQ' (s1 >= s2) (s1 >= s3) (s2 >= s3) sortQ' :: Bool -> Bool -> Bool -> Quater t sortQ' True True True = Quater 0 0 0 1 -- s1 >= s2 >= s3 sortQ' True True False = Quater M_SQRT1_2 0 0 (-M_SQRT1_2) -- s1 >= s3 > s2 sortQ' True False _ = Quater 0.5 0.5 0.5 0.5 -- s3 > s1 >= s2 sortQ' False True True = Quater 0 0 M_SQRT1_2 (-M_SQRT1_2) -- s2 > s1 >= s3 sortQ' False _ False = Quater 0 M_SQRT1_2 0 (-M_SQRT1_2) -- s3 > s2 > s1 sortQ' False False True = Quater 0.5 0.5 0.5 (-0.5) -- s2 >= s3 > s1 -- | One Givens rotation for a QR algorithm on a 3x3 matrix -- -- The three words arguments are indices: -- 0 <= i /= j /= k <= 2 -- -- if i < j then the eigen values are already sorted! qrDecomp3Iteration :: (Quaternion t, RealFloatExtras t) => Int -> Int -> Int -> STDataFrame s t '[3,3] -> ST s (Quater t) qrDecomp3Iteration i j k sPtr = do sii <- readDataFrameOff sPtr ii sij <- readDataFrameOff sPtr ij sji <- readDataFrameOff sPtr ji sjj <- readDataFrameOff sPtr jj sik <- readDataFrameOff sPtr ik sjk <- readDataFrameOff sPtr jk -- Coefficients for a quaternion corresponding to a Givens rotation let (ch, sh) = qrGivensQ sii sji a = ch*ch - sh*sh b = 2 * sh*ch -- update the matrix writeDataFrameOff sPtr ii $ a * sii + b * sji writeDataFrameOff sPtr ij $ a * sij + b * sjj writeDataFrameOff sPtr ik $ a * sik + b * sjk writeDataFrameOff sPtr ji 0 -- a * sji - b * sii writeDataFrameOff sPtr jj $ a * sjj - b * sij writeDataFrameOff sPtr jk $ a * sjk - b * sik -- write the quaternion qPtr <- unsafeThawDataFrame 0 writeDataFrameOff qPtr k (negateUnless leftTriple sh) writeDataFrameOff qPtr 3 ch fromVec4 <$> unsafeFreezeDataFrame qPtr where leftTriple = (j - i) /= 1 && (k - j) /= 1 i3 = i*3 j3 = j*3 ii = i3 + i ij = i3 + j ik = i3 + k ji = j3 + i jj = j3 + j jk = j3 + k -- | Run QR decomposition in context of 3x3 svd: AV = US = QR -- The input here is matrix AV -- The R upper-triangular matrix here is in fact a diagonal matrix Sigma; -- The Q orthogonal matrix is matrix U in the svd decomposition, -- represented here as a quaternion. qrDecomposition3 :: (Quaternion t, RealFloatExtras t) => Matrix t 3 3 -> (Vector t 3, Quater t) qrDecomposition3 m = runST $ do mPtr <- thawDataFrame m q1 <- qrDecomp3Iteration 0 1 2 mPtr q2 <- qrDecomp3Iteration 0 2 1 mPtr q3 <- qrDecomp3Iteration 1 2 0 mPtr sig0 <- readDataFrameOff mPtr 0 sig1 <- readDataFrameOff mPtr 4 sig2 <- readDataFrameOff mPtr 8 return (DF3 sig0 sig1 sig2, q3 * q2 * q1) instance RealFloatExtras t => MatrixSVD t 1 1 where svd = svd1 instance RealFloatExtras t => MatrixSVD t 2 2 where svd = svd2 instance (RealFloatExtras t, Quaternion t) => MatrixSVD t 3 3 where svd = svd3 instance {-# INCOHERENT #-} ( RealFloatExtras t, KnownDim n, KnownDim m) => MatrixSVD t n m where svd a = runST $ do D <- pure dnm Dict <- pure $ minIsSmaller dn dm -- GHC is not convinced :( alphas <- unsafeThawDataFrame bdAlpha betas <- unsafeThawDataFrame bdBeta uPtr <- unsafeThawDataFrame bdU vPtr <- unsafeThawDataFrame bdV -- remove last beta if m > n bLast <- readDataFrameOff betas nm1 when (abs bLast > M_EPS) $ svdGolubKahanZeroCol alphas betas vPtr nm1 -- main routine for a bidiagonal matrix let maxIter = 3*nm -- number of tries withinIters <- svdBidiagonalInplace alphas betas uPtr vPtr nm maxIter unless withinIters . tooManyIterations $ "SVD - Givens rotation sweeps for a bidiagonal matrix (" ++ show maxIter ++ " sweeps max)." -- sort singular values sUnsorted <- unsafeFreezeDataFrame alphas let sSorted :: Vector (Tuple '[t, Word]) (Min n m) sSorted = sortBy (\(S (x :! _)) (S (y :! _)) -> compare y x) $ iwmap @_ @_ @'[] (\(Idx i :* U) (S x) -> S (abs x :! i :! U) ) sUnsorted svdS = ewmap @t @_ @'[] (\(S (x :! _)) -> S x) sSorted perm = ewmap @Word @_ @'[] (\(S (_ :! i :! U)) -> S i) sSorted pCount = if nm < 2 then 0 :: Word else foldl (\s (i, j) -> if perm!i > perm!j then succ s else s) 0 [(i, j) | i <- [0..nm2w], j <- [i+1..nm2w+1]] pPositive = even pCount -- alphas and svdS are now out of sync, but that is not a problem -- make sure det U == 1 when ((bdUDet < 0) == pPositive) $ do readDataFrameOff alphas 0 >>= writeDataFrameOff alphas 0 . negate forM_ [0..n - 1] $ \i -> readDataFrameOff uPtr (i*n) >>= writeDataFrameOff uPtr (i*n) . negate -- negate negative singular values forM_ [0..nm1] $ \i -> do s <- readDataFrameOff alphas i when (s < 0) $ do writeDataFrameOff alphas i $ negate s forM_ [0..m - 1] $ \j -> readDataFrameOff vPtr (j*m + i) >>= writeDataFrameOff vPtr (j*m + i) . negate -- apply permutations if necessary if pCount == 0 then do svdU <- unsafeFreezeDataFrame uPtr svdV <- unsafeFreezeDataFrame vPtr return SVD {..} else do svdU' <- unsafeFreezeDataFrame uPtr svdV' <- unsafeFreezeDataFrame vPtr let svdU = iwgen @_ @_ @'[] $ \(i :* Idx j :* U) -> if j >= dimVal dnm then index (i :* Idx j :* U) svdU' else index (i :* Idx (unScalar $ perm!j) :* U) svdU' svdV = iwgen @_ @_ @'[] $ \(i :* Idx j :* U) -> if j >= dimVal dnm then index (i :* Idx j :* U) svdV' else index (i :* Idx (unScalar $ perm!j) :* U) svdV' return SVD {..} where n = fromIntegral $ dimVal dn :: Int m = fromIntegral $ dimVal dm :: Int dn = dim @n dm = dim @m dnm = minDim dn dm nm1 = nm - 1 nm = fromIntegral (dimVal dnm) :: Int nm2w = fromIntegral (max (nm - 2) 0) :: Word -- compute the bidiagonal form b first, solve svd for b. BiDiag {..} = bidiagonalHouseholder a {- Compute svd for a square bidiagonal matrix inplace \( B = U S V^\intercal \) @ B = | a1 b1 0 ... 0 | | 0 a2 b2 0 ... 0 | | 0 0 a3 b3 ... 0 | | ................. | | 0 0 ... an1 bn1 | | 0 0 ... 0 an | bn? (in case if n > m) @ -} svdBidiagonalInplace :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (nm :: Nat) . ( IterativeMethod, RealFloatExtras t , KnownDim n, KnownDim m, KnownDim nm, nm ~ Min n m) => STDataFrame s t '[nm] -- ^ the main diagonal of B and then the singular values. -> STDataFrame s t '[nm] -- ^ first upper diagonal of B -> STDataFrame s t '[n,n] -- ^ U -> STDataFrame s t '[m,m] -- ^ V -> Int -- ^ 0 < q <= nm -- size of a reduced matrix, such that leftover is diagonal -> Int -- iters -> ST s Bool -- whether the algorithm succeeds within the maxIter svdBidiagonalInplace _ _ _ _ 0 _ = pure True svdBidiagonalInplace _ _ _ _ 1 _ = pure True svdBidiagonalInplace _ _ _ _ _ 0 = pure False svdBidiagonalInplace aPtr bPtr uPtr vPtr q' iter = do Dict <- pure $ minIsSmaller (dim @n) (dim @m) (p, q) <- findCounters q' if (q /= 0) then do findZeroDiagonal p q >>= \case Just k | k == q-1 -> svdGolubKahanZeroCol aPtr bPtr vPtr (k-1) | otherwise -> svdGolubKahanZeroRow aPtr bPtr uPtr k Nothing -> svdGolubKahanStep aPtr bPtr uPtr vPtr p q svdBidiagonalInplace aPtr bPtr uPtr vPtr q (iter - 1) else return True where -- nm = fromIntegral $ dimVal' @nm :: Int -- Check if off-diagonal elements are close to zero and nullify them -- if they are small along the way. -- And find such indices p and q that satisfy condition in alg. 8.6.2 -- on p. 492. of "Matrix Computations " (4-th edition). -- Except these are inverted: -- p -- is the starting index of B22 (last submatrix with non-zero superdiagonal) -- q -- is the starting index of B33 (diagonal submatrix) -- -- that is, p and q determine the index and the size of next work piece. findCounters :: Int -> ST s (Int, Int) findCounters = goQ where checkEps :: Int -> ST s Bool checkEps k = do b <- abs <$> readDataFrameOff bPtr (k-1) if b == 0 then return True else do a1 <- abs <$> readDataFrameOff aPtr (k-1) a2 <- abs <$> readDataFrameOff aPtr k if b <= M_EPS * (max (a1 + a2) 1) then True <$ writeDataFrameOff bPtr (k-1) 0 else return False goQ :: Int -> ST s (Int, Int) goQ 0 = pure (0, 0) -- guard against calling with q == 0 goQ 1 = pure (0, 0) -- 1x1 matrix is always diagonal goQ k = checkEps (k-1) >>= \case True -> goQ (k-1) False -> flip (,) k <$> goP (k-2) goP :: Int -> ST s Int goP 0 = pure 0 goP k = checkEps k >>= \case True -> return k False -> goP (k-1) -- For indices p and q (p < q), find the biggest index (< q) such that -- a[p] == 0 findZeroDiagonal :: Int -> Int -> ST s (Maybe Int) findZeroDiagonal p q | k < p = pure Nothing | otherwise = do ak <- readDataFrameOff aPtr k if ak == 0 then pure $ Just k else if abs ak <= M_EPS then Just k <$ writeDataFrameOff aPtr k 0 else findZeroDiagonal p k where k = q - 1 -- | Apply a series of column transformations to make b[k] (and whole column k+1) zero -- (page 491, 1st paragraph) when a[k+1] == 0. -- To make this element zero, I apply a series of Givens transforms on columns -- (multiply on the right). -- -- Prerequisites: -- * a[k+1] == 0 -- * 0 <= k < min n (m-1) -- Invariants: -- * matrix \(B :: n \times m \) is bidiagonal, represented by two diagonals; -- * matrix V is orthogonal -- * \( B = A V^\intercal \), where \(A :: n \times m\) is an implicit original matrix -- Results: -- * Same bidiagonal matrix with b[k] == 0; i.e. (k+1)-th column is zero. -- * matrix V is updated (multiplied on the right) -- -- NB: All changes are made inplace. -- svdGolubKahanZeroCol :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) . (RealFloatExtras t, KnownDim n, KnownDim m, n <= m) => STDataFrame s t '[n] -- ^ the main diagonal of \(B\) -> STDataFrame s t '[n] -- ^ first upper diagonal of \(B\) -> STDataFrame s t '[m,m] -- ^ \(V\) -> Int -- ^ 0 <= k < min (n+1) m -> ST s () svdGolubKahanZeroCol aPtr bPtr vPtr k | k < 0 || k >= lim = error $ unwords [ "svdGolubKahanZeroCol: k =", show k , "is outside of a valid range 0 <= k <", show lim] -- this trick is to convince GHC that constraint (n <= m) is not redundant | Dict <- Dict @(n <= m) = do b <- readDataFrameOff bPtr k writeDataFrameOff bPtr k 0 foldM_ goGivens b [k, k-1 .. 0] where n = fromIntegral $ dimVal' @n :: Int m = fromIntegral $ dimVal' @m :: Int lim = min n (m-1) goGivens :: Scalar t -> Int -> ST s (Scalar t) goGivens 0 _ = return 0 -- non-diagonal element is nullified prematurely goGivens b i = do ai <- readDataFrameOff aPtr i let rab = recip $ hypot b ai c = ai*rab s = b *rab updateGivensMat vPtr i (k+1) c s writeDataFrameOff aPtr i $ ai*c + b*s -- B[i,i] if i == 0 then return 0 else do bi1 <- readDataFrameOff bPtr (i - 1) -- B[i,i-1] writeDataFrameOff bPtr (i - 1) $ bi1 * c return $ negate (bi1 * s) -- | Apply a series of row transformations to make b[k] (and whole column k) zero -- (page 490, last paragraph) when a[k] == 0. -- To make this element zero, I apply a series of Givens transforms on rows -- (multiply on the left). -- -- Prerequisites: -- * a[k] == 0 -- * 0 <= k < n - 1 -- Invariants: -- * matrix \(B :: m \times n \) is bidiagonal, represented by two diagonals; -- * matrix U is orthogonal -- * \( B = U A \), where \(A :: m \times n \) is an implicit original matrix -- Results: -- * Same bidiagonal matrix with b[k] == 0; i.e. k-th column is zero. -- * matrix U is updated (multiplied on the right) -- -- NB: All changes are made inplace. -- svdGolubKahanZeroRow :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) . (RealFloatExtras t, KnownDim n, KnownDim m, n <= m) => STDataFrame s t '[n] -- ^ the main diagonal of B -> STDataFrame s t '[n] -- ^ first upper diagonal of B -> STDataFrame s t '[m,m] -- ^ U -> Int -- ^ 0 <= k < n - 1 -> ST s () svdGolubKahanZeroRow aPtr bPtr uPtr k | k < 0 || k >= n1 = error $ unwords [ "svdGolubKahanZeroRow: k =", show k , "is outside of a valid range 0 <= k <", show n1] -- this trick is to convince GHC that constraint (n <= m) is not redundant | Dict <- Dict @(n <= m) = do b <- readDataFrameOff bPtr k writeDataFrameOff bPtr k 0 foldM_ goGivens b [k+1..n1] where n = fromIntegral $ dimVal' @n :: Int n1 = n - 1 goGivens :: Scalar t -> Int -> ST s (Scalar t) goGivens 0 _ = return 0 -- non-diagonal element is nullified prematurely goGivens b j = do aj <- readDataFrameOff aPtr j bj <- readDataFrameOff bPtr j let rab = recip $ hypot b aj c = aj*rab s = b*rab updateGivensMat uPtr k j c (negate s) writeDataFrameOff aPtr j $ b*s + aj*c writeDataFrameOff bPtr j $ bj*c return $ negate (bj * s) -- | A Golub-Kahan bidiagonal SVD step on an unreduced matrix svdGolubKahanStep :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (nm :: Nat) . ( RealFloatExtras t , KnownDim n, KnownDim m, KnownDim nm, nm ~ Min n m) => STDataFrame s t '[nm] -- ^ the main diagonal of B and then the singular values. -> STDataFrame s t '[nm] -- ^ first upper diagonal of B -> STDataFrame s t '[n,n] -- ^ U -> STDataFrame s t '[m,m] -- ^ V -> Int -- ^ p : 0 <= p < q <= nm; p <= q - 2 -> Int -- ^ q : 0 <= p < q <= nm; p <= q - 2 -> ST s () svdGolubKahanStep aPtr bPtr uPtr vPtr p q | p > q - 2 || p < 0 || q > nm = error $ unwords [ "svdGolubKahanStep: p =", show p, "and q =", show q , "do not satisfy p <= q - 2 or 0 <= p < q <=", show nm] | Dict <- Dict @(nm ~ Min n m) = do (y,z) <- getWilkinsonShiftYZ goGivens2 y z p where nm = fromIntegral $ dimVal' @nm :: Int -- get initial values for one recursion sweep. -- Note, input must satisfy: q >= p+2 getWilkinsonShiftYZ :: ST s (Scalar t, Scalar t) getWilkinsonShiftYZ = do a1 <- readDataFrameOff aPtr p b1 <- readDataFrameOff bPtr p am <- readDataFrameOff aPtr (q-2) an <- readDataFrameOff aPtr (q-1) bm <- if q >= p + 3 then readDataFrameOff bPtr (q-3) else pure 0 bn <- readDataFrameOff bPtr (q-2) let t11 = a1*a1 t12 = a1*b1 tmm = am*am + bm*bm tnn = an*an + bn*bn tnm = am*bn d = 0.5*(tmm - tnn) mu = tnn + d - negateUnless (d >= 0) (hypot d tnm) return (t11 - mu, t12) -- yv = b[k-1]; zv = B[k-1,k+1] -- to be eliminated by 1st Givens r -- yu = a[k]; zu = B[k+1,k-1] -- to be eliminated by 2nd Givens r goGivens2 :: Scalar t -> Scalar t -> Int -> ST s () goGivens2 yv zv k = do a1 <- readDataFrameOff aPtr k -- B[k,k] a2 <- readDataFrameOff aPtr (k+1) -- B[k+1,k+1] b1 <- readDataFrameOff bPtr k -- B[k,k+1] let a1' = a1*cv + b1*sv -- B[k,k] == yu a2' = a2*cv -- B[k+1,k+1] b0' = yv*cv + zv*sv -- B[k-1,k] b1' = b1*cv - a1*sv -- B[k,k+1] yu = a1' -- B[k,k] zu = a2*sv -- B[k+1,k] ryzu = recip $ hypot yu zu cu = yu * ryzu su = zu * ryzu a1'' = yu *cu + zu *su a2'' = a2'*cu - b1'*su b1'' = b1'*cu + a2'*su updateGivensMat vPtr k (k+1) cv sv updateGivensMat uPtr k (k+1) cu su when (k > p) $ writeDataFrameOff bPtr (k-1) b0' writeDataFrameOff bPtr k b1'' writeDataFrameOff aPtr k a1'' writeDataFrameOff aPtr (k+1) a2'' when (k < q - 2) $ do b2 <- readDataFrameOff bPtr (k+1) -- B[k+1,k+2] let b2'' = b2*cu zvn = b2*su writeDataFrameOff bPtr (k+1) b2'' goGivens2 b1'' zvn (k+1) where ryzv = recip $ hypot yv zv cv = yv * ryzv sv = zv * ryzv -- | Update a transformation matrix with a Givens transform (on the right) updateGivensMat :: forall (s :: Type) (t :: Type) (n :: Nat) . (PrimBytes t, Num t, KnownDim n) => STDataFrame s t '[n,n] -> Int -> Int -> Scalar t -> Scalar t -> ST s () updateGivensMat p i j c s = forM_ [0..n-1] $ \k -> do let nk = n*k ioff = nk + i joff = nk + j uki <- readDataFrameOff p ioff ukj <- readDataFrameOff p joff writeDataFrameOff p ioff $ uki*c + ukj*s writeDataFrameOff p joff $ ukj*c - uki*s where n = fromIntegral $ dimVal' @n :: Int minIsSmaller :: forall (n :: Nat) (m :: Nat) . Dim n -> Dim m -> Dict (Min n m <= n, Min n m <= m) minIsSmaller dn dm | Just Dict <- lessOrEqDim dnm dn , Just Dict <- lessOrEqDim dnm dm = Dict | otherwise = error "minIsSmaller: impossible type-level comparison" where dnm = minDim dn dm