{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.Matrix.SVD
( MatrixSVD (..), SVD (..)
, svd1, svd2, svd3, svd3q
) where
import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.Basics
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.SubSpace
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Matrix.Bidiagonal
import Numeric.Matrix.Internal
import Numeric.Quaternion.Internal
import Numeric.Scalar.Internal
import Numeric.Subroutine.Sort
import Numeric.Tuple
import Numeric.Vector.Internal
data SVD (t :: Type) (n :: Nat) (m :: Nat)
= SVD
{ svdU :: Matrix t n n
, svdS :: Vector t (Min n m)
, svdV :: Matrix t m m
}
deriving instance ( Show t, PrimBytes t
, KnownDim n, KnownDim m, KnownDim (Min n m))
=> Show (SVD t n m)
deriving instance ( Eq (Matrix t n n)
, Eq (Matrix t m m)
, Eq (Vector t (Min n m)))
=> Eq (SVD t n m)
class RealFloatExtras t
=> MatrixSVD (t :: Type) (n :: Nat) (m :: Nat) where
svd :: IterativeMethod => Matrix t n m -> SVD t n m
svd1 :: (PrimBytes t, Num t, Eq t) => Matrix t 1 1 -> SVD t 1 1
svd1 m = SVD
{ svdU = 1
, svdS = broadcast $ abs x
, svdV = broadcast $ if x == 0 then 1 else signum x
}
where
x = ixOff 0 m
svd2 :: forall t . RealFloatExtras t => Matrix t 2 2 -> SVD t 2 2
svd2 (DF2 (DF2 m00 m01) (DF2 m10 m11)) =
SVD
{ svdU = DF2 (DF2 uc us)
(DF2 (negate us) uc)
, svdS = DF2 sigma1 (abs sigma2)
, svdV = DF2 (DF2 vc (sg2 (negate vs)))
(DF2 vs (sg2 vc ))
}
where
x1 = m00 - m11
x2 = m00 + m11
y1 = m01 + m10
y2 = m01 - m10
yy = y1*y2
h1 = hypot x1 y1
h2 = hypot x2 y2
sigma1 = 0.5 * (h2 + h1)
sigma2 = 0.5 * (h2 - h1)
sg2 = negateUnless (sigma2 >= 0)
hx1 = h1 + x1
hx2 = h2 + x2
hxhx = hx1*hx2
hxy = hx1*y2
yhx = y1*hx2
(uc', us', vc', vs') = case (x1 > 0 || y1 /= 0, x2 > 0 || y2 /= 0) of
(True , True ) -> (hxhx + yy, hxy - yhx, hxhx - yy, hxy + yhx)
(True , False) -> (y1, hx1, -y1, hx1)
(False, True ) -> (y2, -hx2, -y2, hx2)
(False, False) -> (1, 0, -1, 0)
ru = recip $ hypot uc' us'
rv = recip $ hypot vc' vs'
uc = uc' * ru
us = us' * ru
vc = vc' * rv
vs = vs' * rv
svd3 :: forall t . (Quaternion t, RealFloatExtras t) => Matrix t 3 3 -> SVD t 3 3
svd3 m = SVD
{ svdU = toMatrix33 u
, svdS = DF3 s1 s2 s3'
, svdV = neg3If (s3 < 0) (toMatrix33 v)
}
where
(u, (DF3 s1 s2 s3), v) = svd3q m
s3' = abs s3
neg3If :: Bool -> Matrix t 3 3 -> Matrix t 3 3
neg3If False = id
neg3If True = ewmap @t @'[3] neg3
neg3 :: Vector t 3 -> Vector t 3
neg3 (DF3 a b c) = DF3 a b (negate c)
svd3q :: forall t . (Quaternion t, RealFloatExtras t)
=> Matrix t 3 3 -> (Quater t, Vector t 3, Quater t)
svd3q m = (u, s, v)
where
v = jacobiEigenQ (transpose m %* m)
(s, u) = uncurry fixSigns $ qrDecomposition3 (m %* toMatrix33 v)
fixSigns :: Vector t 3 -> Quater t -> (Vector t 3, Quater t)
fixSigns (DF3 s1 s2 s3) q@(Quater a b c d) = case (s1 >= 0, s2 >= 0) of
(True , True ) -> (mk3 s1 s2 s3, q)
(False, True ) -> (mk3 (negate s1) s2 (negate s3), Quater (-c) d a (-b))
(True , False) -> (mk3 s1 (negate s2) (negate s3), Quater d c (-b) (-a))
(False, False) -> (mk3 (negate s1) (negate s2) s3, Quater b (-a) d (-c))
mk3 :: Scalar t -> Scalar t -> Scalar t -> Vector t 3
mk3 s1 s2 s3' = case (s1 >= s2, s1 >= abs s3, s2 >= abs s3) of
(True , True , True ) -> DF3 s1 s2 s3'
(True , True , False) -> DF3 s1 s3 (cs s2)
(True , False, _ ) -> DF3 s3 s1 (cs s2)
(False, True , True ) -> DF3 s2 s1 s3'
(False, _ , False) -> DF3 s3 s2 (cs s1)
(False, False, True ) -> DF3 s2 s3 (cs s1)
where
s3 = abs s3'
cs = negateUnless (s3' >= 0)
jacobiGivensQ :: forall t . RealFloatExtras t => t -> t -> t -> (t, t)
jacobiGivensQ aii aij ajj
| g*sh*sh < ch*ch = (w * ch, w * sh)
| otherwise = (c', s')
where
ch = 2 * (aii-ajj)
sh = aij
w = recip $ hypot ch sh
g = 5.82842712474619 :: t
c' = 0.9238795325112867 :: t
s' = 0.3826834323650898 :: t
qrGivensQ :: forall t . RealFloatExtras t => t -> t -> (t, t)
qrGivensQ a1 a2
| a1 < 0 = (sh * w, ch * w)
| otherwise = (ch * w, sh * w)
where
rho2 = a1*a1 + a2*a2
sh = if rho2 > M_EPS then a2 else 0
ch = abs a1 + sqrt (max rho2 M_EPS)
w = recip $ hypot ch sh
jacobiEigen3Iteration :: (Quaternion t, RealFloatExtras t)
=> Int -> Int -> Int
-> STDataFrame s t '[3,3]
-> ST s (Quater t)
jacobiEigen3Iteration i j k sPtr = do
sii <- readDataFrameOff sPtr ii
sij <- readDataFrameOff sPtr ij
sjj <- readDataFrameOff sPtr jj
sik <- readDataFrameOff sPtr ik
sjk <- readDataFrameOff sPtr jk
let (ch, sh) = jacobiGivensQ sii sij sjj
a = ch*ch - sh*sh
b = 2 * sh*ch
aa = a * a
ab = a * b
bb = b * b
writeDataFrameOff sPtr ii $
aa * sii + 2 * ab * sij + bb * sjj
writeDataFrameOff sPtr ij $
ab * (sjj - sii) + (aa - bb) * sij
writeDataFrameOff sPtr jj $
bb * sii - 2 * ab * sij + aa * sjj
writeDataFrameOff sPtr ik $ a * sik + b * sjk
writeDataFrameOff sPtr jk $ a * sjk - b * sik
qPtr <- unsafeThawDataFrame 0
writeDataFrameOff qPtr k (negate sh)
writeDataFrameOff qPtr 3 ch
fromVec4 <$> unsafeFreezeDataFrame qPtr
where
ii = i*3 + i
ij = if i < j then i*3 + j else j*3 + i
jj = j*3 + j
ik = if i < k then i*3 + k else k*3 + i
jk = if j < k then j*3 + k else k*3 + j
eigenItersX3 :: Int
eigenItersX3 = 12
jacobiEigenQ :: forall t
. (Quaternion t, RealFloatExtras t)
=> Matrix t 3 3 -> Quater t
jacobiEigenQ m = runST $ do
mPtr <- thawDataFrame m
q <- go eigenItersX3 mPtr 1
s1 <- readDataFrameOff mPtr 0
s2 <- readDataFrameOff mPtr 4
s3 <- readDataFrameOff mPtr 8
return $ sortQ s1 s2 s3 * q
where
go :: Int -> STDataFrame s t '[3,3] -> Quater t -> ST s (Quater t)
go 0 _ q = pure q
go n p q = do
a10 <- abs <$> readDataFrameOff p 1
a20 <- abs <$> readDataFrameOff p 2
a21 <- abs <$> readDataFrameOff p 5
q' <- jiter n p a10 a20 a21
go (n - 1) p (q' * q)
jiter :: Int -> STDataFrame s t '[3,3]
-> Scalar t -> Scalar t -> Scalar t -> ST s (Quater t)
jiter n p a10 a20 a21
| gt2 a10 a20 a21
= jacobiEigen3Iteration 0 1 2 p
| gt2 a20 a10 a21
= jacobiEigen3Iteration 2 0 1 p
| gt2 a21 a10 a20
= jacobiEigen3Iteration 1 2 0 p
| otherwise
= case mod n 3 of
0 -> jacobiEigen3Iteration 0 1 2 p
1 -> jacobiEigen3Iteration 2 0 1 p
_ -> jacobiEigen3Iteration 1 2 0 p
gt2 :: Scalar t -> Scalar t -> Scalar t -> Bool
gt2 a b c = case compare a b of
GT -> a >= c
EQ -> a > c
LT -> False
sortQ :: Scalar t -> Scalar t -> Scalar t -> Quater t
sortQ s1 s2 s3 = sortQ' (s1 >= s2) (s1 >= s3) (s2 >= s3)
sortQ' :: Bool -> Bool -> Bool -> Quater t
sortQ' True True True = Quater 0 0 0 1
sortQ' True True False = Quater M_SQRT1_2 0 0 (-M_SQRT1_2)
sortQ' True False _ = Quater 0.5 0.5 0.5 0.5
sortQ' False True True = Quater 0 0 M_SQRT1_2 (-M_SQRT1_2)
sortQ' False _ False = Quater 0 M_SQRT1_2 0 (-M_SQRT1_2)
sortQ' False False True = Quater 0.5 0.5 0.5 (-0.5)
qrDecomp3Iteration :: (Quaternion t, RealFloatExtras t)
=> Int -> Int -> Int
-> STDataFrame s t '[3,3]
-> ST s (Quater t)
qrDecomp3Iteration i j k sPtr = do
sii <- readDataFrameOff sPtr ii
sij <- readDataFrameOff sPtr ij
sji <- readDataFrameOff sPtr ji
sjj <- readDataFrameOff sPtr jj
sik <- readDataFrameOff sPtr ik
sjk <- readDataFrameOff sPtr jk
let (ch, sh) = qrGivensQ sii sji
a = ch*ch - sh*sh
b = 2 * sh*ch
writeDataFrameOff sPtr ii $ a * sii + b * sji
writeDataFrameOff sPtr ij $ a * sij + b * sjj
writeDataFrameOff sPtr ik $ a * sik + b * sjk
writeDataFrameOff sPtr ji 0
writeDataFrameOff sPtr jj $ a * sjj - b * sij
writeDataFrameOff sPtr jk $ a * sjk - b * sik
qPtr <- unsafeThawDataFrame 0
writeDataFrameOff qPtr k (negateUnless leftTriple sh)
writeDataFrameOff qPtr 3 ch
fromVec4 <$> unsafeFreezeDataFrame qPtr
where
leftTriple = (j - i) /= 1 && (k - j) /= 1
i3 = i*3
j3 = j*3
ii = i3 + i
ij = i3 + j
ik = i3 + k
ji = j3 + i
jj = j3 + j
jk = j3 + k
qrDecomposition3 :: (Quaternion t, RealFloatExtras t)
=> Matrix t 3 3 -> (Vector t 3, Quater t)
qrDecomposition3 m = runST $ do
mPtr <- thawDataFrame m
q1 <- qrDecomp3Iteration 0 1 2 mPtr
q2 <- qrDecomp3Iteration 0 2 1 mPtr
q3 <- qrDecomp3Iteration 1 2 0 mPtr
sig0 <- readDataFrameOff mPtr 0
sig1 <- readDataFrameOff mPtr 4
sig2 <- readDataFrameOff mPtr 8
return (DF3 sig0 sig1 sig2, q3 * q2 * q1)
instance RealFloatExtras t => MatrixSVD t 1 1 where
svd = svd1
instance RealFloatExtras t => MatrixSVD t 2 2 where
svd = svd2
instance (RealFloatExtras t, Quaternion t) => MatrixSVD t 3 3 where
svd = svd3
instance {-# INCOHERENT #-}
( RealFloatExtras t, KnownDim n, KnownDim m)
=> MatrixSVD t n m where
svd a = runST $ do
D <- pure dnm
Dict <- pure $ minIsSmaller dn dm
alphas <- unsafeThawDataFrame bdAlpha
betas <- unsafeThawDataFrame bdBeta
uPtr <- unsafeThawDataFrame bdU
vPtr <- unsafeThawDataFrame bdV
bLast <- readDataFrameOff betas nm1
when (abs bLast > M_EPS) $
svdGolubKahanZeroCol alphas betas vPtr nm1
let maxIter = 3*nm
withinIters <- svdBidiagonalInplace alphas betas uPtr vPtr nm maxIter
unless withinIters . tooManyIterations
$ "SVD - Givens rotation sweeps for a bidiagonal matrix ("
++ show maxIter ++ " sweeps max)."
sUnsorted <- unsafeFreezeDataFrame alphas
let sSorted :: Vector (Tuple '[t, Word]) (Min n m)
sSorted = sortBy (\(S (x :! _)) (S (y :! _)) -> compare y x)
$ iwmap @_ @_ @'[] (\(Idx i :* U) (S x) -> S (abs x :! i :! U) ) sUnsorted
svdS = ewmap @t @_ @'[] (\(S (x :! _)) -> S x) sSorted
perm = ewmap @Word @_ @'[] (\(S (_ :! i :! U)) -> S i) sSorted
pCount =
if nm < 2
then 0 :: Word
else foldl (\s (i, j) -> if perm!i > perm!j then succ s else s)
0 [(i, j) | i <- [0..nm2w], j <- [i+1..nm2w+1]]
pPositive = even pCount
when ((bdUDet < 0) == pPositive) $ do
readDataFrameOff alphas 0 >>= writeDataFrameOff alphas 0 . negate
forM_ [0..n - 1] $ \i ->
readDataFrameOff uPtr (i*n) >>= writeDataFrameOff uPtr (i*n) . negate
forM_ [0..nm1] $ \i -> do
s <- readDataFrameOff alphas i
when (s < 0) $ do
writeDataFrameOff alphas i $ negate s
forM_ [0..m - 1] $ \j ->
readDataFrameOff vPtr (j*m + i)
>>= writeDataFrameOff vPtr (j*m + i) . negate
if pCount == 0
then do
svdU <- unsafeFreezeDataFrame uPtr
svdV <- unsafeFreezeDataFrame vPtr
return SVD {..}
else do
svdU' <- unsafeFreezeDataFrame uPtr
svdV' <- unsafeFreezeDataFrame vPtr
let svdU = iwgen @_ @_ @'[] $ \(i :* Idx j :* U) ->
if j >= dimVal dnm
then index (i :* Idx j :* U) svdU'
else index (i :* Idx (unScalar $ perm!j) :* U) svdU'
svdV = iwgen @_ @_ @'[] $ \(i :* Idx j :* U) ->
if j >= dimVal dnm
then index (i :* Idx j :* U) svdV'
else index (i :* Idx (unScalar $ perm!j) :* U) svdV'
return SVD {..}
where
n = fromIntegral $ dimVal dn :: Int
m = fromIntegral $ dimVal dm :: Int
dn = dim @n
dm = dim @m
dnm = minDim dn dm
nm1 = nm - 1
nm = fromIntegral (dimVal dnm) :: Int
nm2w = fromIntegral (max (nm - 2) 0) :: Word
BiDiag {..} = bidiagonalHouseholder a
svdBidiagonalInplace ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (nm :: Nat)
. ( IterativeMethod, RealFloatExtras t
, KnownDim n, KnownDim m, KnownDim nm, nm ~ Min n m)
=> STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n,n]
-> STDataFrame s t '[m,m]
-> Int
-> Int
-> ST s Bool
svdBidiagonalInplace _ _ _ _ 0 _ = pure True
svdBidiagonalInplace _ _ _ _ 1 _ = pure True
svdBidiagonalInplace _ _ _ _ _ 0 = pure False
svdBidiagonalInplace aPtr bPtr uPtr vPtr q' iter = do
Dict <- pure $ minIsSmaller (dim @n) (dim @m)
(p, q) <- findCounters q'
if (q /= 0)
then do
findZeroDiagonal p q >>= \case
Just k
| k == q-1 -> svdGolubKahanZeroCol aPtr bPtr vPtr (k-1)
| otherwise -> svdGolubKahanZeroRow aPtr bPtr uPtr k
Nothing -> svdGolubKahanStep aPtr bPtr uPtr vPtr p q
svdBidiagonalInplace aPtr bPtr uPtr vPtr q (iter - 1)
else return True
where
findCounters :: Int -> ST s (Int, Int)
findCounters = goQ
where
checkEps :: Int -> ST s Bool
checkEps k = do
b <- abs <$> readDataFrameOff bPtr (k-1)
if b == 0
then return True
else do
a1 <- abs <$> readDataFrameOff aPtr (k-1)
a2 <- abs <$> readDataFrameOff aPtr k
if b <= M_EPS * (max (a1 + a2) 1)
then True <$ writeDataFrameOff bPtr (k-1) 0
else return False
goQ :: Int -> ST s (Int, Int)
goQ 0 = pure (0, 0)
goQ 1 = pure (0, 0)
goQ k = checkEps (k-1) >>= \case
True -> goQ (k-1)
False -> flip (,) k <$> goP (k-2)
goP :: Int -> ST s Int
goP 0 = pure 0
goP k = checkEps k >>= \case
True -> return k
False -> goP (k-1)
findZeroDiagonal :: Int -> Int -> ST s (Maybe Int)
findZeroDiagonal p q
| k < p = pure Nothing
| otherwise = do
ak <- readDataFrameOff aPtr k
if ak == 0
then pure $ Just k
else if abs ak <= M_EPS
then Just k <$ writeDataFrameOff aPtr k 0
else findZeroDiagonal p k
where
k = q - 1
svdGolubKahanZeroCol ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
. (RealFloatExtras t, KnownDim n, KnownDim m, n <= m)
=> STDataFrame s t '[n]
-> STDataFrame s t '[n]
-> STDataFrame s t '[m,m]
-> Int
-> ST s ()
svdGolubKahanZeroCol aPtr bPtr vPtr k
| k < 0 || k >= lim = error $ unwords
[ "svdGolubKahanZeroCol: k =", show k
, "is outside of a valid range 0 <= k <", show lim]
| Dict <- Dict @(n <= m) = do
b <- readDataFrameOff bPtr k
writeDataFrameOff bPtr k 0
foldM_ goGivens b [k, k-1 .. 0]
where
n = fromIntegral $ dimVal' @n :: Int
m = fromIntegral $ dimVal' @m :: Int
lim = min n (m-1)
goGivens :: Scalar t -> Int -> ST s (Scalar t)
goGivens 0 _ = return 0
goGivens b i = do
ai <- readDataFrameOff aPtr i
let rab = recip $ hypot b ai
c = ai*rab
s = b *rab
updateGivensMat vPtr i (k+1) c s
writeDataFrameOff aPtr i $ ai*c + b*s
if i == 0
then return 0
else do
bi1 <- readDataFrameOff bPtr (i - 1)
writeDataFrameOff bPtr (i - 1) $ bi1 * c
return $ negate (bi1 * s)
svdGolubKahanZeroRow ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
. (RealFloatExtras t, KnownDim n, KnownDim m, n <= m)
=> STDataFrame s t '[n]
-> STDataFrame s t '[n]
-> STDataFrame s t '[m,m]
-> Int
-> ST s ()
svdGolubKahanZeroRow aPtr bPtr uPtr k
| k < 0 || k >= n1 = error $ unwords
[ "svdGolubKahanZeroRow: k =", show k
, "is outside of a valid range 0 <= k <", show n1]
| Dict <- Dict @(n <= m) = do
b <- readDataFrameOff bPtr k
writeDataFrameOff bPtr k 0
foldM_ goGivens b [k+1..n1]
where
n = fromIntegral $ dimVal' @n :: Int
n1 = n - 1
goGivens :: Scalar t -> Int -> ST s (Scalar t)
goGivens 0 _ = return 0
goGivens b j = do
aj <- readDataFrameOff aPtr j
bj <- readDataFrameOff bPtr j
let rab = recip $ hypot b aj
c = aj*rab
s = b*rab
updateGivensMat uPtr k j c (negate s)
writeDataFrameOff aPtr j $ b*s + aj*c
writeDataFrameOff bPtr j $ bj*c
return $ negate (bj * s)
svdGolubKahanStep ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (nm :: Nat)
. ( RealFloatExtras t
, KnownDim n, KnownDim m, KnownDim nm, nm ~ Min n m)
=> STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n,n]
-> STDataFrame s t '[m,m]
-> Int
-> Int
-> ST s ()
svdGolubKahanStep aPtr bPtr uPtr vPtr p q
| p > q - 2 || p < 0 || q > nm
= error $ unwords
[ "svdGolubKahanStep: p =", show p, "and q =", show q
, "do not satisfy p <= q - 2 or 0 <= p < q <=", show nm]
| Dict <- Dict @(nm ~ Min n m) = do
(y,z) <- getWilkinsonShiftYZ
goGivens2 y z p
where
nm = fromIntegral $ dimVal' @nm :: Int
getWilkinsonShiftYZ :: ST s (Scalar t, Scalar t)
getWilkinsonShiftYZ = do
a1 <- readDataFrameOff aPtr p
b1 <- readDataFrameOff bPtr p
am <- readDataFrameOff aPtr (q-2)
an <- readDataFrameOff aPtr (q-1)
bm <- if q >= p + 3
then readDataFrameOff bPtr (q-3)
else pure 0
bn <- readDataFrameOff bPtr (q-2)
let t11 = a1*a1
t12 = a1*b1
tmm = am*am + bm*bm
tnn = an*an + bn*bn
tnm = am*bn
d = 0.5*(tmm - tnn)
mu = tnn + d - negateUnless (d >= 0) (hypot d tnm)
return (t11 - mu, t12)
goGivens2 :: Scalar t -> Scalar t -> Int -> ST s ()
goGivens2 yv zv k = do
a1 <- readDataFrameOff aPtr k
a2 <- readDataFrameOff aPtr (k+1)
b1 <- readDataFrameOff bPtr k
let a1' = a1*cv + b1*sv
a2' = a2*cv
b0' = yv*cv + zv*sv
b1' = b1*cv - a1*sv
yu = a1'
zu = a2*sv
ryzu = recip $ hypot yu zu
cu = yu * ryzu
su = zu * ryzu
a1'' = yu *cu + zu *su
a2'' = a2'*cu - b1'*su
b1'' = b1'*cu + a2'*su
updateGivensMat vPtr k (k+1) cv sv
updateGivensMat uPtr k (k+1) cu su
when (k > p) $ writeDataFrameOff bPtr (k-1) b0'
writeDataFrameOff bPtr k b1''
writeDataFrameOff aPtr k a1''
writeDataFrameOff aPtr (k+1) a2''
when (k < q - 2) $ do
b2 <- readDataFrameOff bPtr (k+1)
let b2'' = b2*cu
zvn = b2*su
writeDataFrameOff bPtr (k+1) b2''
goGivens2 b1'' zvn (k+1)
where
ryzv = recip $ hypot yv zv
cv = yv * ryzv
sv = zv * ryzv
updateGivensMat ::
forall (s :: Type) (t :: Type) (n :: Nat)
. (PrimBytes t, Num t, KnownDim n)
=> STDataFrame s t '[n,n]
-> Int -> Int
-> Scalar t -> Scalar t -> ST s ()
updateGivensMat p i j c s = forM_ [0..n-1] $ \k -> do
let nk = n*k
ioff = nk + i
joff = nk + j
uki <- readDataFrameOff p ioff
ukj <- readDataFrameOff p joff
writeDataFrameOff p ioff $ uki*c + ukj*s
writeDataFrameOff p joff $ ukj*c - uki*s
where
n = fromIntegral $ dimVal' @n :: Int
minIsSmaller :: forall (n :: Nat) (m :: Nat)
. Dim n -> Dim m -> Dict (Min n m <= n, Min n m <= m)
minIsSmaller dn dm
| Just Dict <- lessOrEqDim dnm dn
, Just Dict <- lessOrEqDim dnm dm
= Dict
| otherwise
= error "minIsSmaller: impossible type-level comparison"
where
dnm = minDim dn dm