{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Numeric.Matrix.SVD
  ( MatrixSVD (..), SVD (..)
  , svd1, svd2, svd3, svd3q
  ) where

import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.Basics
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.SubSpace
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Matrix.Bidiagonal
import Numeric.Matrix.Internal
import Numeric.Quaternion.Internal
import Numeric.Scalar.Internal
import Numeric.Subroutine.Sort
import Numeric.Tuple
import Numeric.Vector.Internal

-- | Result of SVD factorization
--   @ M = svdU %* asDiag svdS %* transpose svdV @.
--
--   Invariants:
--
--   * Singular values `svdS` are in non-increasing order and are non-negative.
--   * svdU and svdV are orthogonal matrices
--   * det svdU == 1
--
--   NB: <https://en.wikipedia.org/wiki/Singular_value_decomposition SVD on wiki>
data SVD (t :: Type) (n :: Nat) (m :: Nat)
  = SVD
  { SVD t n m -> Matrix t n n
svdU :: Matrix t n n
    -- ^ Left-singular basis matrix
  , SVD t n m -> Vector t (Min n m)
svdS :: Vector t (Min n m)
    -- ^ Vector of singular values
  , SVD t n m -> Matrix t m m
svdV :: Matrix t m m
    -- ^ Right-singular basis matrix
  }

deriving instance ( Show t, PrimBytes t
                  , KnownDim n, KnownDim m, KnownDim (Min n m))
                  => Show (SVD t n m)
deriving instance ( Eq (Matrix t n n)
                  , Eq (Matrix t m m)
                  , Eq (Vector t (Min n m)))
                  => Eq (SVD t n m)

class RealFloatExtras t
    => MatrixSVD (t :: Type) (n :: Nat) (m :: Nat) where
    -- | Compute SVD factorization of a matrix
    svd :: IterativeMethod => Matrix t n m -> SVD t n m

-- | Obvious dummy implementation of SVD for 1x1 matrices
svd1 :: (PrimBytes t, Num t, Eq t) => Matrix t 1 1 -> SVD t 1 1
svd1 :: Matrix t 1 1 -> SVD t 1 1
svd1 Matrix t 1 1
m = SVD :: forall t (n :: Nat) (m :: Nat).
Matrix t n n -> Vector t (Min n m) -> Matrix t m m -> SVD t n m
SVD
    { svdU :: Matrix t 1 1
svdU = Matrix t 1 1
1
    , svdS :: Vector t (Min 1 1)
svdS = t -> Vector t 1
forall t a. PrimArray t a => t -> a
broadcast (t -> Vector t 1) -> t -> Vector t 1
forall a b. (a -> b) -> a -> b
$ t -> t
forall a. Num a => a -> a
abs t
x
    , svdV :: Matrix t 1 1
svdV = t -> Matrix t 1 1
forall t a. PrimArray t a => t -> a
broadcast (t -> Matrix t 1 1) -> t -> Matrix t 1 1
forall a b. (a -> b) -> a -> b
$ if t
x t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
0 then t
1 else t -> t
forall a. Num a => a -> a
signum t
x
    }
  where
    x :: t
x = Int -> Matrix t 1 1 -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff Int
0 Matrix t 1 1
m

-- | SVD of a 2x2 matrix can be computed analytically
--
--   Related discussion:
--
--   https://scicomp.stackexchange.com/questions/8899/robust-algorithm-for-2-times-2-svd/
--
--   https://ieeexplore.ieee.org/document/486688
svd2 :: forall t . RealFloatExtras t => Matrix t 2 2 -> SVD t 2 2
svd2 :: Matrix t 2 2 -> SVD t 2 2
svd2 (DF2 (DF2 DataFrame t '[]
m00 DataFrame t '[]
m01) (DF2 DataFrame t '[]
m10 DataFrame t '[]
m11)) =
    SVD :: forall t (n :: Nat) (m :: Nat).
Matrix t n n -> Vector t (Min n m) -> Matrix t m m -> SVD t n m
SVD
    { svdU :: Matrix t 2 2
svdU = DataFrame t '[2] -> DataFrame t '[2] -> Matrix t 2 2
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (2 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds -> DataFrame t ds -> DataFrame t (2 : ds)
DF2 (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[2]
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (2 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds -> DataFrame t ds -> DataFrame t (2 : ds)
DF2         DataFrame t '[]
uc  DataFrame t '[]
us)
                 (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[2]
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (2 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds -> DataFrame t ds -> DataFrame t (2 : ds)
DF2 (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
us) DataFrame t '[]
uc)
    , svdS :: Vector t (Min 2 2)
svdS = DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[2]
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (2 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds -> DataFrame t ds -> DataFrame t (2 : ds)
DF2 DataFrame t '[]
sigma1 (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs DataFrame t '[]
sigma2)
    , svdV :: Matrix t 2 2
svdV = DataFrame t '[2] -> DataFrame t '[2] -> Matrix t 2 2
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (2 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds -> DataFrame t ds -> DataFrame t (2 : ds)
DF2 (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[2]
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (2 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds -> DataFrame t ds -> DataFrame t (2 : ds)
DF2 DataFrame t '[]
vc  (DataFrame t '[] -> DataFrame t '[]
sg2 (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
vs)))
                 (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[2]
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (2 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds -> DataFrame t ds -> DataFrame t (2 : ds)
DF2 DataFrame t '[]
vs  (DataFrame t '[] -> DataFrame t '[]
sg2 DataFrame t '[]
vc         ))
    }
  where
    x1 :: DataFrame t '[]
x1 = DataFrame t '[]
m00 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
m11 -- 2F
    x2 :: DataFrame t '[]
x2 = DataFrame t '[]
m00 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
m11 -- 2E
    y1 :: DataFrame t '[]
y1 = DataFrame t '[]
m01 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
m10 -- 2G
    y2 :: DataFrame t '[]
y2 = DataFrame t '[]
m01 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
m10 -- 2H
    yy :: DataFrame t '[]
yy = DataFrame t '[]
y1DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
y2
    h1 :: DataFrame t '[]
h1 = DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. RealFloatExtras a => a -> a -> a
hypot DataFrame t '[]
x1 DataFrame t '[]
y1 -- h1 >= abs x1
    h2 :: DataFrame t '[]
h2 = DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. RealFloatExtras a => a -> a -> a
hypot DataFrame t '[]
x2 DataFrame t '[]
y2 -- h2 >= abs x2
    sigma1 :: DataFrame t '[]
sigma1 = DataFrame t '[]
0.5 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* (DataFrame t '[]
h2 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
h1)  -- sigma1 >= abs sigma2
    sigma2 :: DataFrame t '[]
sigma2 = DataFrame t '[]
0.5 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* (DataFrame t '[]
h2 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
h1)  -- can be negative, which is accounted by sg2
    sg2 :: DataFrame t '[] -> DataFrame t '[]
sg2 = Bool -> DataFrame t '[] -> DataFrame t '[]
forall t. Num t => Bool -> t -> t
negateUnless (DataFrame t '[]
sigma2 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
0)
    hx1 :: DataFrame t '[]
hx1 = DataFrame t '[]
h1 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
x1
    hx2 :: DataFrame t '[]
hx2 = DataFrame t '[]
h2 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
x2
    hxhx :: DataFrame t '[]
hxhx = DataFrame t '[]
hx1DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
hx2
    hxy :: DataFrame t '[]
hxy  = DataFrame t '[]
hx1DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
y2
    yhx :: DataFrame t '[]
yhx  = DataFrame t '[]
y1DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
hx2
    (DataFrame t '[]
uc', DataFrame t '[]
us', DataFrame t '[]
vc', DataFrame t '[]
vs') = case (DataFrame t '[]
x1 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
> DataFrame t '[]
0 Bool -> Bool -> Bool
|| DataFrame t '[]
y1 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
/= DataFrame t '[]
0, DataFrame t '[]
x2 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
> DataFrame t '[]
0 Bool -> Bool -> Bool
|| DataFrame t '[]
y2 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
/= DataFrame t '[]
0) of
      (Bool
True , Bool
True ) -> (DataFrame t '[]
hxhx DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
yy, DataFrame t '[]
hxy DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
yhx, DataFrame t '[]
hxhx DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
yy, DataFrame t '[]
hxy DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
yhx)
      (Bool
True , Bool
False) -> (DataFrame t '[]
y1,  DataFrame t '[]
hx1, -DataFrame t '[]
y1, DataFrame t '[]
hx1)
      (Bool
False, Bool
True ) -> (DataFrame t '[]
y2, -DataFrame t '[]
hx2, -DataFrame t '[]
y2, DataFrame t '[]
hx2)
      (Bool
False, Bool
False) -> (DataFrame t '[]
1, DataFrame t '[]
0, -DataFrame t '[]
1, DataFrame t '[]
0)
    ru :: DataFrame t '[]
ru = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip (DataFrame t '[] -> DataFrame t '[])
-> DataFrame t '[] -> DataFrame t '[]
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. RealFloatExtras a => a -> a -> a
hypot DataFrame t '[]
uc' DataFrame t '[]
us'
    rv :: DataFrame t '[]
rv = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip (DataFrame t '[] -> DataFrame t '[])
-> DataFrame t '[] -> DataFrame t '[]
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. RealFloatExtras a => a -> a -> a
hypot DataFrame t '[]
vc' DataFrame t '[]
vs'
    uc :: DataFrame t '[]
uc = DataFrame t '[]
uc' DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
ru
    us :: DataFrame t '[]
us = DataFrame t '[]
us' DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
ru
    vc :: DataFrame t '[]
vc = DataFrame t '[]
vc' DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
rv
    vs :: DataFrame t '[]
vs = DataFrame t '[]
vs' DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
rv

-- | Get SVD decomposition of a 3x3 matrix using `svd3q` function.
--
--   This function reorders the singular components under the hood to make sure
--   @s1 >= s2 >= s3 >= 0@.
--   Thus, it has some overhead on top of `svd3q`.
svd3 :: forall t . (Quaternion t, RealFloatExtras t) => Matrix t 3 3 -> SVD t 3 3
svd3 :: Matrix t 3 3 -> SVD t 3 3
svd3 Matrix t 3 3
m = SVD :: forall t (n :: Nat) (m :: Nat).
Matrix t n n -> Vector t (Min n m) -> Matrix t m m -> SVD t n m
SVD
        { svdU :: Matrix t 3 3
svdU = Quater t -> Matrix t 3 3
forall t. Quaternion t => Quater t -> Matrix t 3 3
toMatrix33 Quater t
u
        , svdS :: Vector t (Min 3 3)
svdS = DataFrame t '[]
-> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[3]
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
s1 DataFrame t '[]
s2 DataFrame t '[]
s3'
        , svdV :: Matrix t 3 3
svdV = Bool -> Matrix t 3 3 -> Matrix t 3 3
neg3If (DataFrame t '[]
s3 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
< DataFrame t '[]
0) (Quater t -> Matrix t 3 3
forall t. Quaternion t => Quater t -> Matrix t 3 3
toMatrix33 Quater t
v)
        }
  where
    (Quater t
u, (DF3 DataFrame t '[]
s1 DataFrame t '[]
s2 DataFrame t '[]
s3), Quater t
v) = Matrix t 3 3 -> (Quater t, DataFrame t '[3], Quater t)
forall t.
(Quaternion t, RealFloatExtras t) =>
Matrix t 3 3 -> (Quater t, Vector t 3, Quater t)
svd3q Matrix t 3 3
m
    s3' :: DataFrame t '[]
s3' = DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs DataFrame t '[]
s3
    neg3If :: Bool -> Matrix t 3 3 -> Matrix t 3 3
    neg3If :: Bool -> Matrix t 3 3 -> Matrix t 3 3
neg3If Bool
False = Matrix t 3 3 -> Matrix t 3 3
forall a. a -> a
id
    neg3If Bool
True  = (DataFrame t '[3] -> DataFrame t '[3])
-> Matrix t 3 3 -> Matrix t 3 3
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s (bs' :: [k])
       (asbs' :: [k]).
(SubSpace t as bs asbs, SubSpace s as bs' asbs') =>
(DataFrame s bs' -> DataFrame t bs)
-> DataFrame s asbs' -> DataFrame t asbs
ewmap @t @'[3] DataFrame t '[3] -> DataFrame t '[3]
neg3
    neg3 :: Vector t 3 -> Vector t 3
    neg3 :: DataFrame t '[3] -> DataFrame t '[3]
neg3 (DF3 DataFrame t '[]
a DataFrame t '[]
b DataFrame t '[]
c) = DataFrame t '[]
-> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[3]
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
a DataFrame t '[]
b (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
c)


-- | Get SVD decomposition of a 3x3 matrix, with orthogonal matrices U and V
--   represented as quaternions.
--   Important: U and V are bound to be rotations at the expense of the last
--              singular value being possibly negative.
--
--   This is an adoptation of a specialized 3x3 SVD algorithm described in
--     "Computing the Singular Value Decomposition of 3x3 matrices
--      with minimal branching and elementary floating point operations",
--   by  A. McAdams, A. Selle, R. Tamstorf, J. Teran, E. Sifakis.
--
--   http://pages.cs.wisc.edu/~sifakis/papers/SVD_TR1690.pdf
svd3q :: forall t . (Quaternion t, RealFloatExtras t)
      => Matrix t 3 3 -> (Quater t, Vector t 3, Quater t)
svd3q :: Matrix t 3 3 -> (Quater t, Vector t 3, Quater t)
svd3q Matrix t 3 3
m = (Quater t
u, Vector t 3
s, Quater t
v)
  where
    v :: Quater t
v = Matrix t 3 3 -> Quater t
forall t.
(Quaternion t, RealFloatExtras t) =>
Matrix t 3 3 -> Quater t
jacobiEigenQ (Matrix t 3 3 -> Matrix t 3 3
forall k k (t :: k) (n :: k) (m :: k).
MatrixTranspose t n m =>
Matrix t n m -> Matrix t m n
transpose Matrix t 3 3
m DataFrame t (RunList (Snoc' '[3] 3))
-> Matrix t 3 3 -> Matrix t 3 3
forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat).
(Contraction t as bs asbs, KnownDim m,
 PrimArray t (DataFrame t (as +: m)),
 PrimArray t (DataFrame t (m :+ bs)),
 PrimArray t (DataFrame t asbs)) =>
DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
%* Matrix t 3 3
m)
    (Vector t 3
s, Quater t
u) = (Vector t 3 -> Quater t -> (Vector t 3, Quater t))
-> (Vector t 3, Quater t) -> (Vector t 3, Quater t)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Vector t 3 -> Quater t -> (Vector t 3, Quater t)
fixSigns ((Vector t 3, Quater t) -> (Vector t 3, Quater t))
-> (Vector t 3, Quater t) -> (Vector t 3, Quater t)
forall a b. (a -> b) -> a -> b
$ Matrix t 3 3 -> (Vector t 3, Quater t)
forall t.
(Quaternion t, RealFloatExtras t) =>
Matrix t 3 3 -> (Vector t 3, Quater t)
qrDecomposition3 (Matrix t 3 3
DataFrame t (RunList (Snoc' '[3] 3))
m DataFrame t (RunList (Snoc' '[3] 3))
-> Matrix t 3 3 -> Matrix t 3 3
forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat).
(Contraction t as bs asbs, KnownDim m,
 PrimArray t (DataFrame t (as +: m)),
 PrimArray t (DataFrame t (m :+ bs)),
 PrimArray t (DataFrame t asbs)) =>
DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
%* Quater t -> Matrix t 3 3
forall t. Quaternion t => Quater t -> Matrix t 3 3
toMatrix33 Quater t
v)
    -- last bit: make sure    s1 >= s2 >= 0
    fixSigns :: Vector t 3 -> Quater t -> (Vector t 3, Quater t)
    fixSigns :: Vector t 3 -> Quater t -> (Vector t 3, Quater t)
fixSigns (DF3 DataFrame t '[]
s1 DataFrame t '[]
s2 DataFrame t '[]
s3) q :: Quater t
q@(Quater t
a t
b t
c t
d) = case (DataFrame t '[]
s1 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
0, DataFrame t '[]
s2 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
0) of
      (Bool
True , Bool
True ) -> (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
mk3 DataFrame t '[]
s1 DataFrame t '[]
s2 DataFrame t '[]
s3, Quater t
q)
      (Bool
False, Bool
True ) -> (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
mk3 (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s1) DataFrame t '[]
s2 (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s3), t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater (-t
c)  t
d   t
a  (-t
b))
      (Bool
True , Bool
False) -> (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
mk3 DataFrame t '[]
s1 (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s2) (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s3), t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater   t
d   t
c (-t
b) (-t
a))
      (Bool
False, Bool
False) -> (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
mk3 (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s1) (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s2) DataFrame t '[]
s3, t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater   t
b (-t
a)  t
d  (-t
c))
    -- one more thing:
    --   the singular values are ordered, but may have small errors;
    --   as a result adjacent values may seem to be out of order by a very small number
    mk3 :: Scalar t -> Scalar t -> Scalar t -> Vector t 3
    mk3 :: DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
mk3 DataFrame t '[]
s1 DataFrame t '[]
s2 DataFrame t '[]
s3' = case (DataFrame t '[]
s1 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
s2, DataFrame t '[]
s1 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs DataFrame t '[]
s3, DataFrame t '[]
s2 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs DataFrame t '[]
s3) of
        (Bool
True , Bool
True , Bool
True ) -> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
s1 DataFrame t '[]
s2     DataFrame t '[]
s3' -- s1 >= s2 >= s3
        (Bool
True , Bool
True , Bool
False) -> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
s1 DataFrame t '[]
s3 (DataFrame t '[] -> DataFrame t '[]
cs DataFrame t '[]
s2) -- s1 >= s3 >  s2
        (Bool
True , Bool
False, Bool
_    ) -> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
s3 DataFrame t '[]
s1 (DataFrame t '[] -> DataFrame t '[]
cs DataFrame t '[]
s2) -- s3 >  s1 >= s2
        (Bool
False, Bool
True , Bool
True ) -> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
s2 DataFrame t '[]
s1     DataFrame t '[]
s3' -- s2 >  s1 >= s3
        (Bool
False, Bool
_    , Bool
False) -> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
s3 DataFrame t '[]
s2 (DataFrame t '[] -> DataFrame t '[]
cs DataFrame t '[]
s1) -- s3 >  s2 >  s1
        (Bool
False, Bool
False, Bool
True ) -> DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
s2 DataFrame t '[]
s3 (DataFrame t '[] -> DataFrame t '[]
cs DataFrame t '[]
s1) -- s2 >= s3 >  s1
      where
        s3 :: DataFrame t '[]
s3 = DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs DataFrame t '[]
s3'
        cs :: DataFrame t '[] -> DataFrame t '[]
cs = Bool -> DataFrame t '[] -> DataFrame t '[]
forall t. Num t => Bool -> t -> t
negateUnless (DataFrame t '[]
s3' DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
0)


-- | Approximate values for cos (a/2) and sin (a/2) of a Givens rotation for
--    a 2x2 symmetric matrix. (Algorithm 2)
jacobiGivensQ :: forall t . RealFloatExtras t => t -> t -> t -> (t, t)
jacobiGivensQ :: t -> t -> t -> (t, t)
jacobiGivensQ t
aii t
aij t
ajj
    | t
gt -> t -> t
forall a. Num a => a -> a -> a
*t
sht -> t -> t
forall a. Num a => a -> a -> a
*t
sh t -> t -> Bool
forall a. Ord a => a -> a -> Bool
< t
cht -> t -> t
forall a. Num a => a -> a -> a
*t
ch = (t
w t -> t -> t
forall a. Num a => a -> a -> a
* t
ch, t
w t -> t -> t
forall a. Num a => a -> a -> a
* t
sh)
    | Bool
otherwise       = (t
c', t
s')
  where
    ch :: t
ch = t
2 t -> t -> t
forall a. Num a => a -> a -> a
* (t
aiit -> t -> t
forall a. Num a => a -> a -> a
-t
ajj)
    sh :: t
sh = t
aij
    w :: t
w = t -> t
forall a. Fractional a => a -> a
recip (t -> t) -> t -> t
forall a b. (a -> b) -> a -> b
$ t -> t -> t
forall a. RealFloatExtras a => a -> a -> a
hypot t
ch t
sh
    g :: t
g  = t
5.82842712474619 :: t  -- 3 + sqrt 8
    c' :: t
c' = t
0.9238795325112867 :: t -- cos (pi/8)
    s' :: t
s' = t
0.3826834323650898 :: t -- sin (pi/8)


-- | A quaternion for a QR Givens iteration
qrGivensQ :: forall t . RealFloatExtras t => t -> t -> (t, t)
qrGivensQ :: t -> t -> (t, t)
qrGivensQ t
a1 t
a2
    | t
a1 t -> t -> Bool
forall a. Ord a => a -> a -> Bool
< t
0    = (t
sh t -> t -> t
forall a. Num a => a -> a -> a
* t
w, t
ch t -> t -> t
forall a. Num a => a -> a -> a
* t
w)
    | Bool
otherwise = (t
ch t -> t -> t
forall a. Num a => a -> a -> a
* t
w, t
sh t -> t -> t
forall a. Num a => a -> a -> a
* t
w)
  where
    rho2 :: t
rho2 = t
a1t -> t -> t
forall a. Num a => a -> a -> a
*t
a1 t -> t -> t
forall a. Num a => a -> a -> a
+ t
a2t -> t -> t
forall a. Num a => a -> a -> a
*t
a2
    sh :: t
sh = if t
rho2 t -> t -> Bool
forall a. Ord a => a -> a -> Bool
> t
forall a. Epsilon a => a
M_EPS then t
a2 else t
0
    ch :: t
ch = t -> t
forall a. Num a => a -> a
abs t
a1 t -> t -> t
forall a. Num a => a -> a -> a
+ t -> t
forall a. Floating a => a -> a
sqrt (t -> t -> t
forall a. Ord a => a -> a -> a
max t
rho2 t
forall a. Epsilon a => a
M_EPS)
    w :: t
w = t -> t
forall a. Fractional a => a -> a
recip (t -> t) -> t -> t
forall a b. (a -> b) -> a -> b
$ t -> t -> t
forall a. RealFloatExtras a => a -> a -> a
hypot t
ch t
sh -- TODO: consider something like a hypot


-- | One iteration of the Jacobi algorithm on a symmetric 3x3 matrix
--
--   The three words arguments are indices:
--     0 <= i /= j /= k <= 2
jacobiEigen3Iteration :: (Quaternion t, RealFloatExtras t)
                     => Int -> Int -> Int
                     -> STDataFrame s t '[3,3]
                     -> ST s (Quater t)
jacobiEigen3Iteration :: Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
jacobiEigen3Iteration Int
i Int
j Int
k STDataFrame s t '[3, 3]
sPtr = do
    DataFrame t '[]
sii <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ii
    DataFrame t '[]
sij <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ij
    DataFrame t '[]
sjj <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jj
    DataFrame t '[]
sik <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ik
    DataFrame t '[]
sjk <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jk
    -- Coefficients for a quaternion corresponding to a Givens rotation
    let (DataFrame t '[]
ch, DataFrame t '[]
sh) = DataFrame t '[]
-> DataFrame t '[]
-> DataFrame t '[]
-> (DataFrame t '[], DataFrame t '[])
forall t. RealFloatExtras t => t -> t -> t -> (t, t)
jacobiGivensQ DataFrame t '[]
sii DataFrame t '[]
sij DataFrame t '[]
sjj
        a :: DataFrame t '[]
a = DataFrame t '[]
chDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
ch DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
shDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
sh
        b :: DataFrame t '[]
b = DataFrame t '[]
2 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
shDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
ch
        aa :: DataFrame t '[]
aa = DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
a
        ab :: DataFrame t '[]
ab = DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
b
        bb :: DataFrame t '[]
bb = DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
b
    -- update the matrix
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ii (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$
      DataFrame t '[]
aa DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sii DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
2 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
ab DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sij DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
bb DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjj
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ij (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$
      DataFrame t '[]
ab DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* (DataFrame t '[]
sjj DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
sii) DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ (DataFrame t '[]
aa DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
bb) DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sij
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jj (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$
      DataFrame t '[]
bb DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sii DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
2 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
ab DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sij DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
aa DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjj
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ik (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sik DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjk
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jk (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjk DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sik

    -- write the quaternion
    STDataFrame s t '[4]
qPtr <- DataFrame t '[4] -> ST s (STDataFrame s t '[4])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame DataFrame t '[4]
0
    STDataFrame s t '[4] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[4]
qPtr Int
k (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
sh)
    STDataFrame s t '[4] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[4]
qPtr Int
3 DataFrame t '[]
ch
    DataFrame t '[4] -> Quater t
forall t. Quaternion t => Vector t 4 -> Quater t
fromVec4 (DataFrame t '[4] -> Quater t)
-> ST s (DataFrame t '[4]) -> ST s (Quater t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[4] -> ST s (DataFrame t '[4])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[4]
qPtr
  where
    ii :: Int
ii = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
    ij :: Int
ij = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j then Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j else Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
    jj :: Int
jj = Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
    ik :: Int
ik = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k then Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k else Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
    jk :: Int
jk = if Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k then Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k else Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j


-- | Total number of the Givens rotations during the Jacobi eigendecomposition
--   part of the 3x3 SVD equals eigenItersX3*3.
--   Value `eigenItersX3 = 6` corresponds to 18 iterations and gives a good precision.
eigenItersX3 :: Int
eigenItersX3 :: Int
eigenItersX3 = Int
12

-- | Run a few iterations of the Jacobi algorithm on a real-valued 3x3 symmetric matrix.
--   The eigenvectors basis of such matrix is orthogonal, and can be represented as
--   a quaternion.
jacobiEigenQ :: forall t
              . (Quaternion t, RealFloatExtras t)
             => Matrix t 3 3 -> Quater t
jacobiEigenQ :: Matrix t 3 3 -> Quater t
jacobiEigenQ Matrix t 3 3
m = (forall s. ST s (Quater t)) -> Quater t
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Quater t)) -> Quater t)
-> (forall s. ST s (Quater t)) -> Quater t
forall a b. (a -> b) -> a -> b
$ do
    STDataFrame s t '[3, 3]
mPtr <- Matrix t 3 3 -> ST s (STDataFrame s t '[3, 3])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame Matrix t 3 3
m
    Quater t
q  <- Int -> STDataFrame s t '[3, 3] -> Quater t -> ST s (Quater t)
forall s.
Int -> STDataFrame s t '[3, 3] -> Quater t -> ST s (Quater t)
go Int
eigenItersX3 STDataFrame s t '[3, 3]
mPtr Quater t
1
    DataFrame t '[]
s1 <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
mPtr Int
0
    DataFrame t '[]
s2 <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
mPtr Int
4
    DataFrame t '[]
s3 <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
mPtr Int
8
    Quater t -> ST s (Quater t)
forall (m :: * -> *) a. Monad m => a -> m a
return (Quater t -> ST s (Quater t)) -> Quater t -> ST s (Quater t)
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Quater t
sortQ DataFrame t '[]
s1 DataFrame t '[]
s2 DataFrame t '[]
s3 Quater t -> Quater t -> Quater t
forall a. Num a => a -> a -> a
* Quater t
q
  where
    go :: Int -> STDataFrame s t '[3,3] -> Quater t -> ST s (Quater t)
    go :: Int -> STDataFrame s t '[3, 3] -> Quater t -> ST s (Quater t)
go Int
0 STDataFrame s t '[3, 3]
_ Quater t
q = Quater t -> ST s (Quater t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Quater t
q

    -- -- primitive cyclic iteration;
    -- --   fast, but the convergence is not perfect
    -- --
    -- -- set eigenItersX3 = 6 for good precision
    -- go n p q = do
    --   q1 <- jacobiEigen3Iteration 0 1 2 p
    --   q2 <- jacobiEigen3Iteration 1 2 0 p
    --   q3 <- jacobiEigen3Iteration 2 0 1 p
    --   go (n - 1) p (q3 * q2 * q1 * q)

    -- Pick the largest element on lower triangle;
    --   slow because of branching, but has a better convergence
    --
    --  set eigenItersX3 = 12 for good precision
    --    (slightly faster than the cyclic version with -O0)
    go Int
n STDataFrame s t '[3, 3]
p Quater t
q = do
      DataFrame t '[]
a10 <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
p Int
1
      DataFrame t '[]
a20 <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
p Int
2
      DataFrame t '[]
a21 <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
p Int
5
      Quater t
q' <- Int
-> STDataFrame s t '[3, 3]
-> DataFrame t '[]
-> DataFrame t '[]
-> DataFrame t '[]
-> ST s (Quater t)
forall s.
Int
-> STDataFrame s t '[3, 3]
-> DataFrame t '[]
-> DataFrame t '[]
-> DataFrame t '[]
-> ST s (Quater t)
jiter Int
n STDataFrame s t '[3, 3]
p DataFrame t '[]
a10 DataFrame t '[]
a20 DataFrame t '[]
a21
      Int -> STDataFrame s t '[3, 3] -> Quater t -> ST s (Quater t)
forall s.
Int -> STDataFrame s t '[3, 3] -> Quater t -> ST s (Quater t)
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) STDataFrame s t '[3, 3]
p (Quater t
q' Quater t -> Quater t -> Quater t
forall a. Num a => a -> a -> a
* Quater t
q)
    jiter :: Int -> STDataFrame s t '[3,3]
          -> Scalar t -> Scalar t -> Scalar t -> ST s (Quater t)
    jiter :: Int
-> STDataFrame s t '[3, 3]
-> DataFrame t '[]
-> DataFrame t '[]
-> DataFrame t '[]
-> ST s (Quater t)
jiter Int
n STDataFrame s t '[3, 3]
p DataFrame t '[]
a10 DataFrame t '[]
a20 DataFrame t '[]
a21
      | DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Bool
gt2 DataFrame t '[]
a10 DataFrame t '[]
a20 DataFrame t '[]
a21
        = Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
jacobiEigen3Iteration Int
0 Int
1 Int
2 STDataFrame s t '[3, 3]
p
      | DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Bool
gt2 DataFrame t '[]
a20 DataFrame t '[]
a10 DataFrame t '[]
a21
        = Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
jacobiEigen3Iteration Int
2 Int
0 Int
1 STDataFrame s t '[3, 3]
p
      | DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Bool
gt2 DataFrame t '[]
a21 DataFrame t '[]
a10 DataFrame t '[]
a20
        = Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
jacobiEigen3Iteration Int
1 Int
2 Int
0 STDataFrame s t '[3, 3]
p
      | Bool
otherwise
        = case Int -> Int -> Int
forall a. Integral a => a -> a -> a
mod Int
n Int
3 of
            Int
0 -> Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
jacobiEigen3Iteration Int
0 Int
1 Int
2 STDataFrame s t '[3, 3]
p
            Int
1 -> Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
jacobiEigen3Iteration Int
2 Int
0 Int
1 STDataFrame s t '[3, 3]
p
            Int
_ -> Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
jacobiEigen3Iteration Int
1 Int
2 Int
0 STDataFrame s t '[3, 3]
p
    gt2 :: Scalar t -> Scalar t -> Scalar t -> Bool
    gt2 :: DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Bool
gt2 DataFrame t '[]
a DataFrame t '[]
b DataFrame t '[]
c = case DataFrame t '[] -> DataFrame t '[] -> Ordering
forall a. Ord a => a -> a -> Ordering
compare DataFrame t '[]
a DataFrame t '[]
b of
                  Ordering
GT -> DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
c
                  Ordering
EQ -> DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>  DataFrame t '[]
c
                  Ordering
LT -> Bool
False

    -- Make such a quaternion that rotates the matrix so that:
    -- abs s1 >= abs s2 >= abs s3
    -- Note, the corresponding singular values may be negative, which must be
    -- taken into account later.
    sortQ :: Scalar t -> Scalar t -> Scalar t -> Quater t
    sortQ :: DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Quater t
sortQ DataFrame t '[]
s1 DataFrame t '[]
s2 DataFrame t '[]
s3 = Bool -> Bool -> Bool -> Quater t
sortQ' (DataFrame t '[]
s1 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
s2) (DataFrame t '[]
s1 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
s3) (DataFrame t '[]
s2 DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
>= DataFrame t '[]
s3)
    sortQ' :: Bool -> Bool -> Bool -> Quater t
    sortQ' :: Bool -> Bool -> Bool -> Quater t
sortQ' Bool
True  Bool
True  Bool
True  = t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater t
0 t
0 t
0 t
1                    -- s1 >= s2 >= s3
    sortQ' Bool
True  Bool
True  Bool
False = t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater t
forall a. (Eq a, Fractional a, Fractional a) => a
M_SQRT1_2 t
0 t
0 (-t
forall a. (Eq a, Fractional a, Fractional a) => a
M_SQRT1_2) -- s1 >= s3 >  s2
    sortQ' Bool
True  Bool
False Bool
_     = t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater t
0.5 t
0.5 t
0.5 t
0.5            -- s3 >  s1 >= s2
    sortQ' Bool
False Bool
True  Bool
True  = t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater t
0 t
0 t
forall a. (Eq a, Fractional a, Fractional a) => a
M_SQRT1_2 (-t
forall a. (Eq a, Fractional a, Fractional a) => a
M_SQRT1_2) -- s2 >  s1 >= s3
    sortQ' Bool
False Bool
_     Bool
False = t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater t
0 t
forall a. (Eq a, Fractional a, Fractional a) => a
M_SQRT1_2 t
0 (-t
forall a. (Eq a, Fractional a, Fractional a) => a
M_SQRT1_2) -- s3 >  s2 >  s1
    sortQ' Bool
False Bool
False Bool
True  = t -> t -> t -> t -> Quater t
forall t. Quaternion t => t -> t -> t -> t -> Quater t
Quater t
0.5 t
0.5 t
0.5 (-t
0.5)         -- s2 >= s3 >  s1


-- | One Givens rotation for a QR algorithm on a 3x3 matrix
--
--   The three words arguments are indices:
--     0 <= i /= j /= k <= 2
--
--     if i < j then the eigen values are already sorted!
qrDecomp3Iteration :: (Quaternion t, RealFloatExtras t)
                   => Int -> Int -> Int
                   -> STDataFrame s t '[3,3]
                   -> ST s (Quater t)
qrDecomp3Iteration :: Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
qrDecomp3Iteration Int
i Int
j Int
k STDataFrame s t '[3, 3]
sPtr = do
    DataFrame t '[]
sii <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ii
    DataFrame t '[]
sij <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ij
    DataFrame t '[]
sji <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ji
    DataFrame t '[]
sjj <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jj
    DataFrame t '[]
sik <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ik
    DataFrame t '[]
sjk <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jk
    -- Coefficients for a quaternion corresponding to a Givens rotation
    let (DataFrame t '[]
ch, DataFrame t '[]
sh) = DataFrame t '[]
-> DataFrame t '[] -> (DataFrame t '[], DataFrame t '[])
forall t. RealFloatExtras t => t -> t -> (t, t)
qrGivensQ DataFrame t '[]
sii DataFrame t '[]
sji
        a :: DataFrame t '[]
a = DataFrame t '[]
chDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
ch DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
shDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
sh
        b :: DataFrame t '[]
b = DataFrame t '[]
2 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
shDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
ch
    -- update the matrix
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ii (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sii DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sji
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ij (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sij DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjj
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ik (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sik DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjk
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
ji DataFrame t '[]
0 --  a * sji - b * sii
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jj (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjj DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sij
    STDataFrame s t '[3, 3] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[3, 3]
sPtr Int
jk (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
a DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sjk DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
sik

    -- write the quaternion
    STDataFrame s t '[4]
qPtr <- DataFrame t '[4] -> ST s (STDataFrame s t '[4])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame DataFrame t '[4]
0
    STDataFrame s t '[4] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[4]
qPtr Int
k (Bool -> DataFrame t '[] -> DataFrame t '[]
forall t. Num t => Bool -> t -> t
negateUnless Bool
leftTriple DataFrame t '[]
sh)
    STDataFrame s t '[4] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[4]
qPtr Int
3 DataFrame t '[]
ch
    DataFrame t '[4] -> Quater t
forall t. Quaternion t => Vector t 4 -> Quater t
fromVec4 (DataFrame t '[4] -> Quater t)
-> ST s (DataFrame t '[4]) -> ST s (Quater t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[4] -> ST s (DataFrame t '[4])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[4]
qPtr
  where
    leftTriple :: Bool
leftTriple = (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1 Bool -> Bool -> Bool
&& (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
j) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1
    i3 :: Int
i3 = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3
    j3 :: Int
j3 = Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
3
    ii :: Int
ii = Int
i3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
    ij :: Int
ij = Int
i3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
    ik :: Int
ik = Int
i3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k
    ji :: Int
ji = Int
j3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
    jj :: Int
jj = Int
j3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
    jk :: Int
jk = Int
j3 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k

-- | Run QR decomposition in context of 3x3 svd: AV = US = QR
--   The input here is matrix AV
--   The R upper-triangular matrix here is in fact a diagonal matrix Sigma;
--   The Q orthogonal matrix is matrix U in the svd decomposition,
--     represented here as a quaternion.
qrDecomposition3 :: (Quaternion t, RealFloatExtras t)
                 => Matrix t 3 3 -> (Vector t 3, Quater t)
qrDecomposition3 :: Matrix t 3 3 -> (Vector t 3, Quater t)
qrDecomposition3 Matrix t 3 3
m = (forall s. ST s (Vector t 3, Quater t)) -> (Vector t 3, Quater t)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector t 3, Quater t)) -> (Vector t 3, Quater t))
-> (forall s. ST s (Vector t 3, Quater t))
-> (Vector t 3, Quater t)
forall a b. (a -> b) -> a -> b
$ do
    STDataFrame s t '[3, 3]
mPtr <- Matrix t 3 3 -> ST s (STDataFrame s t '[3, 3])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame Matrix t 3 3
m
    Quater t
q1 <- Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
qrDecomp3Iteration Int
0 Int
1 Int
2 STDataFrame s t '[3, 3]
mPtr
    Quater t
q2 <- Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
qrDecomp3Iteration Int
0 Int
2 Int
1 STDataFrame s t '[3, 3]
mPtr
    Quater t
q3 <- Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
forall t s.
(Quaternion t, RealFloatExtras t) =>
Int -> Int -> Int -> STDataFrame s t '[3, 3] -> ST s (Quater t)
qrDecomp3Iteration Int
1 Int
2 Int
0 STDataFrame s t '[3, 3]
mPtr
    DataFrame t '[]
sig0 <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
mPtr Int
0
    DataFrame t '[]
sig1 <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
mPtr Int
4
    DataFrame t '[]
sig2 <- STDataFrame s t '[3, 3] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[3, 3]
mPtr Int
8
    (Vector t 3, Quater t) -> ST s (Vector t 3, Quater t)
forall (m :: * -> *) a. Monad m => a -> m a
return (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] -> Vector t 3
forall t (ds :: [Nat]).
(PrimBytes t, Dimensions (3 : ds), Dimensions ds,
 KnownBackend t ds) =>
DataFrame t ds
-> DataFrame t ds -> DataFrame t ds -> DataFrame t (3 : ds)
DF3 DataFrame t '[]
sig0 DataFrame t '[]
sig1 DataFrame t '[]
sig2, Quater t
q3 Quater t -> Quater t -> Quater t
forall a. Num a => a -> a -> a
* Quater t
q2 Quater t -> Quater t -> Quater t
forall a. Num a => a -> a -> a
* Quater t
q1)


instance RealFloatExtras t => MatrixSVD t 1 1 where
    svd :: Matrix t 1 1 -> SVD t 1 1
svd = Matrix t 1 1 -> SVD t 1 1
forall t. (PrimBytes t, Num t, Eq t) => Matrix t 1 1 -> SVD t 1 1
svd1

instance RealFloatExtras t => MatrixSVD t 2 2 where
    svd :: Matrix t 2 2 -> SVD t 2 2
svd = Matrix t 2 2 -> SVD t 2 2
forall t. RealFloatExtras t => Matrix t 2 2 -> SVD t 2 2
svd2

instance (RealFloatExtras t, Quaternion t) => MatrixSVD t 3 3 where
    svd :: Matrix t 3 3 -> SVD t 3 3
svd = Matrix t 3 3 -> SVD t 3 3
forall t.
(Quaternion t, RealFloatExtras t) =>
Matrix t 3 3 -> SVD t 3 3
svd3

instance {-# INCOHERENT #-}
         ( RealFloatExtras t, KnownDim n, KnownDim m)
         => MatrixSVD t n m where
    svd :: Matrix t n m -> SVD t n m
svd Matrix t n m
a = (forall s. ST s (SVD t n m)) -> SVD t n m
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (SVD t n m)) -> SVD t n m)
-> (forall s. ST s (SVD t n m)) -> SVD t n m
forall a b. (a -> b) -> a -> b
$ do
      Dim (Min' n m (CmpNat n m))
D <- Dim (Min' n m (CmpNat n m)) -> ST s (Dim (Min' n m (CmpNat n m)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Dim (Min' n m (CmpNat n m))
dnm
      Dict
  (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
   LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m))
Dict <- Dict
  (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
   LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m))
-> ST
     s
     (Dict
        (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
         LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Dict
   (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
    LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m))
 -> ST
      s
      (Dict
         (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
          LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m))))
-> Dict
     (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
      LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m))
-> ST
     s
     (Dict
        (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
         LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m)))
forall a b. (a -> b) -> a -> b
$ Dim n
-> Dim m
-> Dict
     (LE (Min' n m (CmpNat n m)) n (CmpNat (Min' n m (CmpNat n m)) n),
      LE (Min' n m (CmpNat n m)) m (CmpNat (Min' n m (CmpNat n m)) m))
forall (n :: Nat) (m :: Nat).
Dim n -> Dim m -> Dict (Min n m <= n, Min n m <= m)
minIsSmaller Dim n
dn Dim m
dm -- GHC is not convinced :(
      STDataFrame s t '[Min' n m (CmpNat n m)]
alphas <- DataFrame t '[Min' n m (CmpNat n m)]
-> ST s (STDataFrame s t '[Min' n m (CmpNat n m)])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame DataFrame t '[Min' n m (CmpNat n m)]
bdAlpha
      STDataFrame s t '[Min' n m (CmpNat n m)]
betas <- DataFrame t '[Min' n m (CmpNat n m)]
-> ST s (STDataFrame s t '[Min' n m (CmpNat n m)])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame DataFrame t '[Min' n m (CmpNat n m)]
bdBeta
      STDataFrame s t '[n, n]
uPtr <- DataFrame t '[n, n] -> ST s (STDataFrame s t '[n, n])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame DataFrame t '[n, n]
bdU
      STDataFrame s t '[m, m]
vPtr <- DataFrame t '[m, m] -> ST s (STDataFrame s t '[m, m])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame DataFrame t '[m, m]
bdV

      -- remove last beta if m > n
      DataFrame t '[]
bLast <- STDataFrame s t '[Min' n m (CmpNat n m)]
-> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[Min' n m (CmpNat n m)]
betas Int
nm1
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs DataFrame t '[]
bLast DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
> DataFrame t '[]
forall a. Epsilon a => a
M_EPS) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
        STDataFrame s t '[Min' n m (CmpNat n m)]
-> STDataFrame s t '[Min' n m (CmpNat n m)]
-> STDataFrame s t '[m, m]
-> Int
-> ST s ()
forall s t (n :: Nat) (m :: Nat).
(RealFloatExtras t, KnownDim n, KnownDim m, n <= m) =>
STDataFrame s t '[n]
-> STDataFrame s t '[n]
-> STDataFrame s t '[m, m]
-> Int
-> ST s ()
svdGolubKahanZeroCol STDataFrame s t '[Min' n m (CmpNat n m)]
alphas STDataFrame s t '[Min' n m (CmpNat n m)]
betas STDataFrame s t '[m, m]
vPtr Int
nm1

      -- main routine for a bidiagonal matrix
      let maxIter :: Int
maxIter = Int
3Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
nm -- number of tries
      Bool
withinIters <- STDataFrame s t '[Min' n m (CmpNat n m)]
-> STDataFrame s t '[Min' n m (CmpNat n m)]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s Bool
forall s t (n :: Nat) (m :: Nat) (nm :: Nat).
(IterativeMethod, RealFloatExtras t, KnownDim n, KnownDim m,
 KnownDim nm, nm ~ Min n m) =>
STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s Bool
svdBidiagonalInplace STDataFrame s t '[Min' n m (CmpNat n m)]
alphas STDataFrame s t '[Min' n m (CmpNat n m)]
betas STDataFrame s t '[n, n]
uPtr STDataFrame s t '[m, m]
vPtr Int
nm Int
maxIter
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
withinIters (ST s () -> ST s ()) -> (String -> ST s ()) -> String -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ST s ()
forall a. IterativeMethod => String -> a
tooManyIterations
        (String -> ST s ()) -> String -> ST s ()
forall a b. (a -> b) -> a -> b
$ String
"SVD - Givens rotation sweeps for a bidiagonal matrix ("
           String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
maxIter String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" sweeps max)."

      -- sort singular values
      DataFrame t '[Min' n m (CmpNat n m)]
sUnsorted <- STDataFrame s t '[Min' n m (CmpNat n m)]
-> ST s (DataFrame t '[Min' n m (CmpNat n m)])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[Min' n m (CmpNat n m)]
alphas
      let sSorted :: Vector (Tuple '[t, Word]) (Min n m)
          sSorted :: Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
sSorted = (DataFrame (Tuple '[t, Word]) '[]
 -> DataFrame (Tuple '[t, Word]) '[] -> Ordering)
-> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
-> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
forall k t (n :: k) (ns :: [k]).
(SortableDataFrame t (n : ns), SortBy n) =>
(DataFrame t ns -> DataFrame t ns -> Ordering)
-> DataFrame t (n : ns) -> DataFrame t (n : ns)
sortBy (\(S (y
x :! Tuple ys
_)) (S (y
y :! Tuple ys
_)) -> y -> y -> Ordering
forall a. Ord a => a -> a -> Ordering
compare y
y y
y
x)
                  (Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
 -> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m)))
-> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
-> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
forall a b. (a -> b) -> a -> b
$ (Idxs '[Min' n m (CmpNat n m)]
 -> DataFrame t '[] -> DataFrame (Tuple '[t, Word]) '[])
-> DataFrame t '[Min' n m (CmpNat n m)]
-> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s (bs' :: [k])
       (asbs' :: [k]).
(SubSpace t as bs asbs, SubSpace s as bs' asbs') =>
(Idxs as -> DataFrame s bs' -> DataFrame t bs)
-> DataFrame s asbs' -> DataFrame t asbs
iwmap @_ @_ @'[] (\(Idx Word
i :* TypedList Idx ys
U) (S x) -> Tuple '[t, Word] -> DataFrame (Tuple '[t, Word]) '[]
forall t. t -> DataFrame t '[]
S (t -> t
forall a. Num a => a -> a
abs t
x t -> Tuple '[Word] -> Tuple '[t, Word]
forall (xs :: [*]) y (ys :: [*]).
(xs ~ (y : ys)) =>
y -> Tuple ys -> Tuple xs
:! Word
i Word -> Tuple '[] -> Tuple '[Word]
forall (xs :: [*]) y (ys :: [*]).
(xs ~ (y : ys)) =>
y -> Tuple ys -> Tuple xs
:! Tuple '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) ) DataFrame t '[Min' n m (CmpNat n m)]
sUnsorted
          svdS :: DataFrame t '[Min' n m (CmpNat n m)]
svdS = (DataFrame (Tuple '[t, Word]) '[] -> DataFrame t '[])
-> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
-> DataFrame t '[Min' n m (CmpNat n m)]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s (bs' :: [k])
       (asbs' :: [k]).
(SubSpace t as bs asbs, SubSpace s as bs' asbs') =>
(DataFrame s bs' -> DataFrame t bs)
-> DataFrame s asbs' -> DataFrame t asbs
ewmap @t @_ @'[] (\(S (y
x :! Tuple ys
_)) -> y -> DataFrame y '[]
forall t. t -> DataFrame t '[]
S y
x) Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
sSorted
          perm :: DataFrame Word '[Min' n m (CmpNat n m)]
perm = (DataFrame (Tuple '[t, Word]) '[] -> DataFrame Word '[])
-> Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
-> DataFrame Word '[Min' n m (CmpNat n m)]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s (bs' :: [k])
       (asbs' :: [k]).
(SubSpace t as bs asbs, SubSpace s as bs' asbs') =>
(DataFrame s bs' -> DataFrame t bs)
-> DataFrame s asbs' -> DataFrame t asbs
ewmap @Word @_ @'[] (\(S (y
_ :! y
i :! TypedList Id ys
U)) -> y -> DataFrame y '[]
forall t. t -> DataFrame t '[]
S y
i) Vector (Tuple '[t, Word]) (Min' n m (CmpNat n m))
sSorted
          pCount :: Word
pCount =
             if Int
nm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2
             then Word
0 :: Word
             else (Word -> (Word, Word) -> Word) -> Word -> [(Word, Word)] -> Word
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Word
s (Word
i, Word
j) -> if DataFrame Word '[Min' n m (CmpNat n m)]
permDataFrame Word '[Min' n m (CmpNat n m)]
-> Word -> DataFrame Word '[]
forall k k (t :: k) (d :: k) (ds :: [k]).
IndexFrame t d ds =>
DataFrame t (d : ds) -> Word -> DataFrame t ds
!Word
i DataFrame Word '[] -> DataFrame Word '[] -> Bool
forall a. Ord a => a -> a -> Bool
> DataFrame Word '[Min' n m (CmpNat n m)]
permDataFrame Word '[Min' n m (CmpNat n m)]
-> Word -> DataFrame Word '[]
forall k k (t :: k) (d :: k) (ds :: [k]).
IndexFrame t d ds =>
DataFrame t (d : ds) -> Word -> DataFrame t ds
!Word
j then Word -> Word
forall a. Enum a => a -> a
succ Word
s else Word
s)
                        Word
0 [(Word
i, Word
j) | Word
i <- [Word
0..Word
nm2w], Word
j <- [Word
iWord -> Word -> Word
forall a. Num a => a -> a -> a
+Word
1..Word
nm2wWord -> Word -> Word
forall a. Num a => a -> a -> a
+Word
1]]
          pPositive :: Bool
pPositive = Word -> Bool
forall a. Integral a => a -> Bool
even Word
pCount

      -- alphas and svdS are now out of sync, but that is not a problem

      -- make sure det U == 1
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((DataFrame t '[]
bdUDet DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
< DataFrame t '[]
0) Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
pPositive) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        STDataFrame s t '[Min' n m (CmpNat n m)]
-> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[Min' n m (CmpNat n m)]
alphas Int
0 ST s (DataFrame t '[]) -> (DataFrame t '[] -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s t '[Min' n m (CmpNat n m)]
-> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[Min' n m (CmpNat n m)]
alphas Int
0 (DataFrame t '[] -> ST s ())
-> (DataFrame t '[] -> DataFrame t '[])
-> DataFrame t '[]
-> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
          STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
uPtr (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) ST s (DataFrame t '[]) -> (DataFrame t '[] -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
uPtr (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) (DataFrame t '[] -> ST s ())
-> (DataFrame t '[] -> DataFrame t '[])
-> DataFrame t '[]
-> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate

      -- negate negative singular values
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nm1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
        DataFrame t '[]
s <- STDataFrame s t '[Min' n m (CmpNat n m)]
-> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[Min' n m (CmpNat n m)]
alphas Int
i
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DataFrame t '[]
s DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
< DataFrame t '[]
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
          STDataFrame s t '[Min' n m (CmpNat n m)]
-> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[Min' n m (CmpNat n m)]
alphas Int
i (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s
          [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j ->
            STDataFrame s t '[m, m] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m, m]
vPtr (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
              ST s (DataFrame t '[]) -> (DataFrame t '[] -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s t '[m, m] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[m, m]
vPtr (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) (DataFrame t '[] -> ST s ())
-> (DataFrame t '[] -> DataFrame t '[])
-> DataFrame t '[]
-> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate

      -- apply permutations if necessary
      if Word
pCount Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
0
      then do
        DataFrame t '[n, n]
svdU <- STDataFrame s t '[n, n] -> ST s (DataFrame t '[n, n])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
uPtr
        DataFrame t '[m, m]
svdV <- STDataFrame s t '[m, m] -> ST s (DataFrame t '[m, m])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[m, m]
vPtr
        SVD t n m -> ST s (SVD t n m)
forall (m :: * -> *) a. Monad m => a -> m a
return SVD :: forall t (n :: Nat) (m :: Nat).
Matrix t n n -> Vector t (Min n m) -> Matrix t m m -> SVD t n m
SVD {DataFrame t '[n, n]
DataFrame t '[m, m]
DataFrame t '[Min' n m (CmpNat n m)]
svdV :: DataFrame t '[m, m]
svdU :: DataFrame t '[n, n]
svdS :: DataFrame t '[Min' n m (CmpNat n m)]
svdV :: DataFrame t '[m, m]
svdS :: DataFrame t '[Min' n m (CmpNat n m)]
svdU :: DataFrame t '[n, n]
..}
      else do
        DataFrame t '[n, n]
svdU' <- STDataFrame s t '[n, n] -> ST s (DataFrame t '[n, n])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
uPtr
        DataFrame t '[m, m]
svdV' <- STDataFrame s t '[m, m] -> ST s (DataFrame t '[m, m])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[m, m]
vPtr
        let svdU :: DataFrame t '[n, n]
svdU = forall (asbs :: [Nat]).
(SubSpace t '[n, n] '[] asbs, Dimensions '[n, n]) =>
(Idxs '[n, n] -> DataFrame t '[]) -> DataFrame t asbs
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
(SubSpace t as bs asbs, Dimensions as) =>
(Idxs as -> DataFrame t bs) -> DataFrame t asbs
iwgen @_ @_ @'[] ((Idxs '[n, n] -> DataFrame t '[]) -> DataFrame t '[n, n])
-> (Idxs '[n, n] -> DataFrame t '[]) -> DataFrame t '[n, n]
forall a b. (a -> b) -> a -> b
$ \(Idx y
i :* Idx Word
j :* TypedList Idx ys
U) ->
              if Word
j Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
>= Dim (Min' n m (CmpNat n m)) -> Word
forall k (x :: k). Dim x -> Word
dimVal Dim (Min' n m (CmpNat n m))
dnm
              then Idxs '[n, n] -> DataFrame t '[n, n] -> DataFrame t '[]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
SubSpace t as bs asbs =>
Idxs as -> DataFrame t asbs -> DataFrame t bs
index (Idx y
i Idx y -> TypedList Idx '[n] -> Idxs '[n, n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* Word -> Idx n
forall k (d :: k). BoundedDim d => Word -> Idx d
Idx Word
j Idx n -> TypedList Idx '[] -> TypedList Idx '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) DataFrame t '[n, n]
svdU'
              else Idxs '[n, n] -> DataFrame t '[n, n] -> DataFrame t '[]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
SubSpace t as bs asbs =>
Idxs as -> DataFrame t asbs -> DataFrame t bs
index (Idx y
i Idx y -> TypedList Idx '[n] -> Idxs '[n, n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* Word -> Idx n
forall k (d :: k). BoundedDim d => Word -> Idx d
Idx (DataFrame Word '[] -> Word
forall t. DataFrame t '[] -> t
unScalar (DataFrame Word '[] -> Word) -> DataFrame Word '[] -> Word
forall a b. (a -> b) -> a -> b
$ DataFrame Word '[Min' n m (CmpNat n m)]
permDataFrame Word '[Min' n m (CmpNat n m)]
-> Word -> DataFrame Word '[]
forall k k (t :: k) (d :: k) (ds :: [k]).
IndexFrame t d ds =>
DataFrame t (d : ds) -> Word -> DataFrame t ds
!Word
j) Idx n -> TypedList Idx '[] -> TypedList Idx '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) DataFrame t '[n, n]
svdU'
            svdV :: DataFrame t '[m, m]
svdV = forall (asbs :: [Nat]).
(SubSpace t '[m, m] '[] asbs, Dimensions '[m, m]) =>
(Idxs '[m, m] -> DataFrame t '[]) -> DataFrame t asbs
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
(SubSpace t as bs asbs, Dimensions as) =>
(Idxs as -> DataFrame t bs) -> DataFrame t asbs
iwgen @_ @_ @'[] ((Idxs '[m, m] -> DataFrame t '[]) -> DataFrame t '[m, m])
-> (Idxs '[m, m] -> DataFrame t '[]) -> DataFrame t '[m, m]
forall a b. (a -> b) -> a -> b
$ \(Idx y
i :* Idx Word
j :* TypedList Idx ys
U) ->
              if Word
j Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
>= Dim (Min' n m (CmpNat n m)) -> Word
forall k (x :: k). Dim x -> Word
dimVal Dim (Min' n m (CmpNat n m))
dnm
              then Idxs '[m, m] -> DataFrame t '[m, m] -> DataFrame t '[]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
SubSpace t as bs asbs =>
Idxs as -> DataFrame t asbs -> DataFrame t bs
index (Idx y
i Idx y -> TypedList Idx '[m] -> Idxs '[m, m]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* Word -> Idx m
forall k (d :: k). BoundedDim d => Word -> Idx d
Idx Word
j Idx m -> TypedList Idx '[] -> TypedList Idx '[m]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) DataFrame t '[m, m]
svdV'
              else Idxs '[m, m] -> DataFrame t '[m, m] -> DataFrame t '[]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
SubSpace t as bs asbs =>
Idxs as -> DataFrame t asbs -> DataFrame t bs
index (Idx y
i Idx y -> TypedList Idx '[m] -> Idxs '[m, m]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* Word -> Idx m
forall k (d :: k). BoundedDim d => Word -> Idx d
Idx (DataFrame Word '[] -> Word
forall t. DataFrame t '[] -> t
unScalar (DataFrame Word '[] -> Word) -> DataFrame Word '[] -> Word
forall a b. (a -> b) -> a -> b
$ DataFrame Word '[Min' n m (CmpNat n m)]
permDataFrame Word '[Min' n m (CmpNat n m)]
-> Word -> DataFrame Word '[]
forall k k (t :: k) (d :: k) (ds :: [k]).
IndexFrame t d ds =>
DataFrame t (d : ds) -> Word -> DataFrame t ds
!Word
j) Idx m -> TypedList Idx '[] -> TypedList Idx '[m]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) DataFrame t '[m, m]
svdV'
        SVD t n m -> ST s (SVD t n m)
forall (m :: * -> *) a. Monad m => a -> m a
return SVD :: forall t (n :: Nat) (m :: Nat).
Matrix t n n -> Vector t (Min n m) -> Matrix t m m -> SVD t n m
SVD {DataFrame t '[n, n]
DataFrame t '[m, m]
DataFrame t '[Min' n m (CmpNat n m)]
svdV :: DataFrame t '[m, m]
svdU :: DataFrame t '[n, n]
svdS :: DataFrame t '[Min' n m (CmpNat n m)]
svdV :: DataFrame t '[m, m]
svdS :: DataFrame t '[Min' n m (CmpNat n m)]
svdU :: DataFrame t '[n, n]
..}
      where
        n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ Dim n -> Word
forall k (x :: k). Dim x -> Word
dimVal Dim n
dn :: Int
        m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ Dim m -> Word
forall k (x :: k). Dim x -> Word
dimVal Dim m
dm :: Int
        dn :: Dim n
dn = KnownDim n => Dim n
forall k (n :: k). KnownDim n => Dim n
dim @n
        dm :: Dim m
dm = KnownDim m => Dim m
forall k (n :: k). KnownDim n => Dim n
dim @m
        dnm :: Dim (Min' n m (CmpNat n m))
dnm = Dim n -> Dim m -> Dim (Min' n m (CmpNat n m))
forall (n :: Nat) (m :: Nat). Dim n -> Dim m -> Dim (Min n m)
minDim Dim n
dn Dim m
dm
        nm1 :: Int
nm1 = Int
nm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        nm :: Int
nm = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Dim (Min' n m (CmpNat n m)) -> Word
forall k (x :: k). Dim x -> Word
dimVal Dim (Min' n m (CmpNat n m))
dnm) :: Int
        nm2w :: Word
nm2w = Int -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Int
nm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) Int
0) :: Word
        -- compute the bidiagonal form b first, solve svd for b.
        BiDiag {DataFrame t '[n, n]
DataFrame t '[m, m]
DataFrame t '[Min' n m (CmpNat n m)]
DataFrame t '[]
bdVDet :: forall t (n :: Nat) (m :: Nat). BiDiag t n m -> Scalar t
bdV :: forall t (n :: Nat) (m :: Nat). BiDiag t n m -> Matrix t m m
bdBeta :: forall t (n :: Nat) (m :: Nat). BiDiag t n m -> Vector t (Min n m)
bdAlpha :: forall t (n :: Nat) (m :: Nat). BiDiag t n m -> Vector t (Min n m)
bdUDet :: forall t (n :: Nat) (m :: Nat). BiDiag t n m -> Scalar t
bdU :: forall t (n :: Nat) (m :: Nat). BiDiag t n m -> Matrix t n n
bdVDet :: DataFrame t '[]
bdUDet :: DataFrame t '[]
bdV :: DataFrame t '[m, m]
bdU :: DataFrame t '[n, n]
bdBeta :: DataFrame t '[Min' n m (CmpNat n m)]
bdAlpha :: DataFrame t '[Min' n m (CmpNat n m)]
..} = Matrix t n m -> BiDiag t n m
forall t (n :: Nat) (m :: Nat).
(PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m) =>
Matrix t n m -> BiDiag t n m
bidiagonalHouseholder Matrix t n m
a



{- Compute svd for a square bidiagonal matrix inplace
   \( B = U S V^\intercal \)

@
  B = | a1 b1 0     ... 0 |
      | 0  a2 b2 0  ... 0 |
      | 0  0 a3 b3  ... 0 |
      | ................. |
      | 0  0  ... an1 bn1 |
      | 0  0  ...  0  an  | bn? (in case if n > m)
@
 -}
svdBidiagonalInplace ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (nm :: Nat)
     . ( IterativeMethod, RealFloatExtras t
       , KnownDim n, KnownDim m, KnownDim nm, nm ~ Min n m)
    => STDataFrame s t '[nm] -- ^ the main diagonal of B and then the singular values.
    -> STDataFrame s t '[nm] -- ^ first upper diagonal of B
    -> STDataFrame s t '[n,n] -- ^ U
    -> STDataFrame s t '[m,m] -- ^ V
    -> Int -- ^ 0 < q <= nm -- size of a reduced matrix, such that leftover is diagonal
    -> Int -- iters
    -> ST s Bool -- whether the algorithm succeeds within the maxIter
svdBidiagonalInplace :: STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s Bool
svdBidiagonalInplace STDataFrame s t '[nm]
_ STDataFrame s t '[nm]
_ STDataFrame s t '[n, n]
_ STDataFrame s t '[m, m]
_ Int
0 Int
_ = Bool -> ST s Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
svdBidiagonalInplace STDataFrame s t '[nm]
_ STDataFrame s t '[nm]
_ STDataFrame s t '[n, n]
_ STDataFrame s t '[m, m]
_ Int
1 Int
_ = Bool -> ST s Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
svdBidiagonalInplace STDataFrame s t '[nm]
_ STDataFrame s t '[nm]
_ STDataFrame s t '[n, n]
_ STDataFrame s t '[m, m]
_ Int
_ Int
0 = Bool -> ST s Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
svdBidiagonalInplace STDataFrame s t '[nm]
aPtr STDataFrame s t '[nm]
bPtr STDataFrame s t '[n, n]
uPtr STDataFrame s t '[m, m]
vPtr Int
q' Int
iter = do
    Dict (LE nm n (CmpNat nm n), LE nm m (CmpNat nm m))
Dict <- Dict (LE nm n (CmpNat nm n), LE nm m (CmpNat nm m))
-> ST s (Dict (LE nm n (CmpNat nm n), LE nm m (CmpNat nm m)))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Dict (LE nm n (CmpNat nm n), LE nm m (CmpNat nm m))
 -> ST s (Dict (LE nm n (CmpNat nm n), LE nm m (CmpNat nm m))))
-> Dict (LE nm n (CmpNat nm n), LE nm m (CmpNat nm m))
-> ST s (Dict (LE nm n (CmpNat nm n), LE nm m (CmpNat nm m)))
forall a b. (a -> b) -> a -> b
$ Dim n -> Dim m -> Dict (Min n m <= n, Min n m <= m)
forall (n :: Nat) (m :: Nat).
Dim n -> Dim m -> Dict (Min n m <= n, Min n m <= m)
minIsSmaller (KnownDim n => Dim n
forall k (n :: k). KnownDim n => Dim n
dim @n) (KnownDim m => Dim m
forall k (n :: k). KnownDim n => Dim n
dim @m)
    (Int
p, Int
q) <- Int -> ST s (Int, Int)
findCounters Int
q'
    if (Int
q Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0)
    then do
      Int -> Int -> ST s (Maybe Int)
findZeroDiagonal Int
p Int
q ST s (Maybe Int) -> (Maybe Int -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Just Int
k
          | Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
qInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1  -> STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[m, m]
-> Int
-> ST s ()
forall s t (n :: Nat) (m :: Nat).
(RealFloatExtras t, KnownDim n, KnownDim m, n <= m) =>
STDataFrame s t '[n]
-> STDataFrame s t '[n]
-> STDataFrame s t '[m, m]
-> Int
-> ST s ()
svdGolubKahanZeroCol STDataFrame s t '[nm]
aPtr STDataFrame s t '[nm]
bPtr STDataFrame s t '[m, m]
vPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
          | Bool
otherwise -> STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> Int
-> ST s ()
forall s t (n :: Nat) (m :: Nat).
(RealFloatExtras t, KnownDim n, KnownDim m, n <= m) =>
STDataFrame s t '[n]
-> STDataFrame s t '[n]
-> STDataFrame s t '[m, m]
-> Int
-> ST s ()
svdGolubKahanZeroRow STDataFrame s t '[nm]
aPtr STDataFrame s t '[nm]
bPtr STDataFrame s t '[n, n]
uPtr Int
k
        Maybe Int
Nothing -> STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s ()
forall s t (n :: Nat) (m :: Nat) (nm :: Nat).
(RealFloatExtras t, KnownDim n, KnownDim m, KnownDim nm,
 nm ~ Min n m) =>
STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s ()
svdGolubKahanStep STDataFrame s t '[nm]
aPtr STDataFrame s t '[nm]
bPtr STDataFrame s t '[n, n]
uPtr STDataFrame s t '[m, m]
vPtr Int
p Int
q
      STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s Bool
forall s t (n :: Nat) (m :: Nat) (nm :: Nat).
(IterativeMethod, RealFloatExtras t, KnownDim n, KnownDim m,
 KnownDim nm, nm ~ Min n m) =>
STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s Bool
svdBidiagonalInplace STDataFrame s t '[nm]
aPtr STDataFrame s t '[nm]
bPtr STDataFrame s t '[n, n]
uPtr STDataFrame s t '[m, m]
vPtr Int
q (Int
iter Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    else Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  where
    -- nm = fromIntegral $ dimVal' @nm :: Int

    -- Check if off-diagonal elements are close to zero and nullify them
    -- if they are small along the way.
    -- And find such indices p and q that satisfy condition in alg. 8.6.2
    -- on p. 492. of "Matrix Computations " (4-th edition).
    -- Except these are inverted:
    --   p -- is the starting index of B22 (last submatrix with non-zero superdiagonal)
    --   q -- is the starting index of B33 (diagonal submatrix)
    --
    -- that is, p and q determine the index and the size of next work piece.
    findCounters :: Int -> ST s (Int, Int)
    findCounters :: Int -> ST s (Int, Int)
findCounters = Int -> ST s (Int, Int)
goQ
      where
        checkEps :: Int -> ST s Bool
        checkEps :: Int -> ST s Bool
checkEps Int
k = do
          DataFrame t '[]
b <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[nm] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
bPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
          if DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0
          then Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
          else do
            DataFrame t '[]
a1 <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[nm] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
            DataFrame t '[]
a2 <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[nm] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr  Int
k
            if DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
<= DataFrame t '[]
forall a. Epsilon a => a
M_EPS DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* (DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Ord a => a -> a -> a
max (DataFrame t '[]
a1 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
a2) DataFrame t '[]
1)
            then Bool
True Bool -> ST s () -> ST s Bool
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STDataFrame s t '[nm] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[nm]
bPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) DataFrame t '[]
0
            else Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        goQ :: Int -> ST s (Int, Int)
        goQ :: Int -> ST s (Int, Int)
goQ Int
0 = (Int, Int) -> ST s (Int, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
0, Int
0) -- guard against calling with q == 0
        goQ Int
1 = (Int, Int) -> ST s (Int, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
0, Int
0) -- 1x1 matrix is always diagonal
        goQ Int
k = Int -> ST s Bool
checkEps (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ST s Bool -> (Bool -> ST s (Int, Int)) -> ST s (Int, Int)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
True  -> Int -> ST s (Int, Int)
goQ (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
          Bool
False -> (Int -> Int -> (Int, Int)) -> Int -> Int -> (Int, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) Int
k (Int -> (Int, Int)) -> ST s Int -> ST s (Int, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> ST s Int
goP (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2)
        goP :: Int -> ST s Int
        goP :: Int -> ST s Int
goP Int
0 = Int -> ST s Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
0
        goP Int
k = Int -> ST s Bool
checkEps Int
k ST s Bool -> (Bool -> ST s Int) -> ST s Int
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Bool
True  -> Int -> ST s Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
k
          Bool
False -> Int -> ST s Int
goP (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)

    -- For indices p and q (p < q), find the biggest index (< q) such that
    --  a[p] == 0
    findZeroDiagonal :: Int -> Int -> ST s (Maybe Int)
    findZeroDiagonal :: Int -> Int -> ST s (Maybe Int)
findZeroDiagonal Int
p Int
q
      | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
p     = Maybe Int -> ST s (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Int
forall a. Maybe a
Nothing
      | Bool
otherwise = do
        DataFrame t '[]
ak <- STDataFrame s t '[nm] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr Int
k
        if DataFrame t '[]
ak DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0
        then Maybe Int -> ST s (Maybe Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Int -> ST s (Maybe Int)) -> Maybe Int -> ST s (Maybe Int)
forall a b. (a -> b) -> a -> b
$ Int -> Maybe Int
forall a. a -> Maybe a
Just Int
k
        else if DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs DataFrame t '[]
ak DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
<= DataFrame t '[]
forall a. Epsilon a => a
M_EPS
             then Int -> Maybe Int
forall a. a -> Maybe a
Just Int
k Maybe Int -> ST s () -> ST s (Maybe Int)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STDataFrame s t '[nm] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[nm]
aPtr Int
k DataFrame t '[]
0
             else Int -> Int -> ST s (Maybe Int)
findZeroDiagonal Int
p Int
k
      where
        k :: Int
k = Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1


-- | Apply a series of column transformations to make b[k] (and whole column k+1) zero
--    (page 491, 1st paragraph) when a[k+1] == 0.
--   To make this element zero, I apply a series of Givens transforms on columns
--   (multiply on the right).
--
--   Prerequisites:
--     * a[k+1] == 0
--     * 0 <= k < min n (m-1)
--   Invariants:
--     * matrix \(B :: n \times m \) is bidiagonal, represented by two diagonals;
--     * matrix V is orthogonal
--     * \( B = A V^\intercal \), where \(A :: n \times m\) is an implicit original matrix
--   Results:
--     * Same bidiagonal matrix with b[k] == 0; i.e. (k+1)-th column is zero.
--     * matrix V is updated (multiplied on the right)
--
--   NB: All changes are made inplace.
--
svdGolubKahanZeroCol ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (RealFloatExtras t, KnownDim n, KnownDim m, n <= m)
    => STDataFrame s t '[n] -- ^ the main diagonal of \(B\)
    -> STDataFrame s t '[n] -- ^ first upper diagonal of \(B\)
    -> STDataFrame s t '[m,m] -- ^ \(V\)
    -> Int -- ^ 0 <= k < min (n+1) m
    -> ST s ()
svdGolubKahanZeroCol :: STDataFrame s t '[n]
-> STDataFrame s t '[n]
-> STDataFrame s t '[m, m]
-> Int
-> ST s ()
svdGolubKahanZeroCol STDataFrame s t '[n]
aPtr STDataFrame s t '[n]
bPtr STDataFrame s t '[m, m]
vPtr Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lim = String -> ST s ()
forall a. IterativeMethod => String -> a
error (String -> ST s ()) -> String -> ST s ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords
      [ String
"svdGolubKahanZeroCol: k =", Int -> String
forall a. Show a => a -> String
show Int
k
      , String
"is outside of a valid range 0 <= k <", Int -> String
forall a. Show a => a -> String
show Int
lim]
    -- this trick is to convince GHC that constraint (n <= m) is not redundant
  | Dict (n <= m)
Dict <- (n <= m) => Dict (n <= m)
forall (a :: Constraint). a => Dict a
Dict @(n <= m) = do
    DataFrame t '[]
b <- STDataFrame s t '[n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
bPtr Int
k
    STDataFrame s t '[n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
bPtr Int
k DataFrame t '[]
0
    (DataFrame t '[] -> Int -> ST s (DataFrame t '[]))
-> DataFrame t '[] -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ DataFrame t '[] -> Int -> ST s (DataFrame t '[])
goGivens DataFrame t '[]
b [Int
k, Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 .. Int
0]
  where
    n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n :: Int
    m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
    lim :: Int
lim = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
    goGivens :: Scalar t -> Int -> ST s (Scalar t)
    goGivens :: DataFrame t '[] -> Int -> ST s (DataFrame t '[])
goGivens DataFrame t '[]
0 Int
_ = DataFrame t '[] -> ST s (DataFrame t '[])
forall (m :: * -> *) a. Monad m => a -> m a
return DataFrame t '[]
0 -- non-diagonal element is nullified prematurely
    goGivens DataFrame t '[]
b Int
i = do
      DataFrame t '[]
ai <- STDataFrame s t '[n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
aPtr Int
i
      let rab :: DataFrame t '[]
rab = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip (DataFrame t '[] -> DataFrame t '[])
-> DataFrame t '[] -> DataFrame t '[]
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. RealFloatExtras a => a -> a -> a
hypot DataFrame t '[]
b DataFrame t '[]
ai
          c :: DataFrame t '[]
c = DataFrame t '[]
aiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
rab
          s :: DataFrame t '[]
s = DataFrame t '[]
b DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
rab
      STDataFrame s t '[m, m]
-> Int -> Int -> DataFrame t '[] -> DataFrame t '[] -> ST s ()
forall s t (n :: Nat).
(PrimBytes t, Num t, KnownDim n) =>
STDataFrame s t '[n, n]
-> Int -> Int -> Scalar t -> Scalar t -> ST s ()
updateGivensMat STDataFrame s t '[m, m]
vPtr Int
i (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) DataFrame t '[]
c DataFrame t '[]
s
      STDataFrame s t '[n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
aPtr Int
i (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
aiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
c DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
bDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
s -- B[i,i]
      if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
      then DataFrame t '[] -> ST s (DataFrame t '[])
forall (m :: * -> *) a. Monad m => a -> m a
return DataFrame t '[]
0
      else do
        DataFrame t '[]
bi1 <- STDataFrame s t '[n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
bPtr (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)  -- B[i,i-1]
        STDataFrame s t '[n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
bPtr (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
bi1 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
c
        DataFrame t '[] -> ST s (DataFrame t '[])
forall (m :: * -> *) a. Monad m => a -> m a
return (DataFrame t '[] -> ST s (DataFrame t '[]))
-> DataFrame t '[] -> ST s (DataFrame t '[])
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate (DataFrame t '[]
bi1 DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
s)

-- | Apply a series of row transformations to make b[k] (and whole column k) zero
--    (page 490, last paragraph) when a[k] == 0.
--   To make this element zero, I apply a series of Givens transforms on rows
--   (multiply on the left).
--
--   Prerequisites:
--     * a[k] == 0
--     * 0 <= k < n - 1
--   Invariants:
--     * matrix \(B :: m \times n \) is bidiagonal, represented by two diagonals;
--     * matrix U is orthogonal
--     * \( B = U A \), where \(A :: m \times n \) is an implicit original matrix
--   Results:
--     * Same bidiagonal matrix with b[k] == 0; i.e. k-th column is zero.
--     * matrix U is updated (multiplied on the right)
--
--   NB: All changes are made inplace.
--
svdGolubKahanZeroRow ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (RealFloatExtras t, KnownDim n, KnownDim m, n <= m)
    => STDataFrame s t '[n] -- ^ the main diagonal of B
    -> STDataFrame s t '[n] -- ^ first upper diagonal of B
    -> STDataFrame s t '[m,m] -- ^ U
    -> Int -- ^ 0 <= k < n - 1
    -> ST s ()
svdGolubKahanZeroRow :: STDataFrame s t '[n]
-> STDataFrame s t '[n]
-> STDataFrame s t '[m, m]
-> Int
-> ST s ()
svdGolubKahanZeroRow STDataFrame s t '[n]
aPtr STDataFrame s t '[n]
bPtr STDataFrame s t '[m, m]
uPtr Int
k
  | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n1 = String -> ST s ()
forall a. IterativeMethod => String -> a
error (String -> ST s ()) -> String -> ST s ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords
      [ String
"svdGolubKahanZeroRow: k =", Int -> String
forall a. Show a => a -> String
show Int
k
      , String
"is outside of a valid range 0 <= k <", Int -> String
forall a. Show a => a -> String
show Int
n1]
    -- this trick is to convince GHC that constraint (n <= m) is not redundant
  | Dict (n <= m)
Dict <- (n <= m) => Dict (n <= m)
forall (a :: Constraint). a => Dict a
Dict @(n <= m) = do
    DataFrame t '[]
b <- STDataFrame s t '[n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
bPtr Int
k
    STDataFrame s t '[n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
bPtr Int
k DataFrame t '[]
0
    (DataFrame t '[] -> Int -> ST s (DataFrame t '[]))
-> DataFrame t '[] -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ DataFrame t '[] -> Int -> ST s (DataFrame t '[])
goGivens DataFrame t '[]
b [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
n1]
  where
    n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n :: Int
    n1 :: Int
n1 = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    goGivens :: Scalar t -> Int -> ST s (Scalar t)
    goGivens :: DataFrame t '[] -> Int -> ST s (DataFrame t '[])
goGivens DataFrame t '[]
0 Int
_ = DataFrame t '[] -> ST s (DataFrame t '[])
forall (m :: * -> *) a. Monad m => a -> m a
return DataFrame t '[]
0 -- non-diagonal element is nullified prematurely
    goGivens DataFrame t '[]
b Int
j = do
      DataFrame t '[]
aj <- STDataFrame s t '[n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
aPtr Int
j
      DataFrame t '[]
bj <- STDataFrame s t '[n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
bPtr Int
j
      let rab :: DataFrame t '[]
rab = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip (DataFrame t '[] -> DataFrame t '[])
-> DataFrame t '[] -> DataFrame t '[]
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. RealFloatExtras a => a -> a -> a
hypot DataFrame t '[]
b DataFrame t '[]
aj
          c :: DataFrame t '[]
c = DataFrame t '[]
ajDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
rab
          s :: DataFrame t '[]
s =  DataFrame t '[]
bDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
rab
      STDataFrame s t '[m, m]
-> Int -> Int -> DataFrame t '[] -> DataFrame t '[] -> ST s ()
forall s t (n :: Nat).
(PrimBytes t, Num t, KnownDim n) =>
STDataFrame s t '[n, n]
-> Int -> Int -> Scalar t -> Scalar t -> ST s ()
updateGivensMat STDataFrame s t '[m, m]
uPtr Int
k Int
j DataFrame t '[]
c (DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate DataFrame t '[]
s)
      STDataFrame s t '[n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
aPtr Int
j (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
bDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
s DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
+ DataFrame t '[]
ajDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
c
      STDataFrame s t '[n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
bPtr Int
j (DataFrame t '[] -> ST s ()) -> DataFrame t '[] -> ST s ()
forall a b. (a -> b) -> a -> b
$ DataFrame t '[]
bjDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
c
      DataFrame t '[] -> ST s (DataFrame t '[])
forall (m :: * -> *) a. Monad m => a -> m a
return (DataFrame t '[] -> ST s (DataFrame t '[]))
-> DataFrame t '[] -> ST s (DataFrame t '[])
forall a b. (a -> b) -> a -> b
$ DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
negate (DataFrame t '[]
bj DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
* DataFrame t '[]
s)

-- | A Golub-Kahan bidiagonal SVD step on an unreduced matrix
svdGolubKahanStep ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (nm :: Nat)
     . ( RealFloatExtras t
       , KnownDim n, KnownDim m, KnownDim nm, nm ~ Min n m)
    => STDataFrame s t '[nm] -- ^ the main diagonal of B and then the singular values.
    -> STDataFrame s t '[nm] -- ^ first upper diagonal of B
    -> STDataFrame s t '[n,n] -- ^ U
    -> STDataFrame s t '[m,m] -- ^ V
    -> Int -- ^ p : 0 <= p < q <= nm; p <= q - 2
    -> Int -- ^ q : 0 <= p < q <= nm; p <= q - 2
    -> ST s ()
svdGolubKahanStep :: STDataFrame s t '[nm]
-> STDataFrame s t '[nm]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[m, m]
-> Int
-> Int
-> ST s ()
svdGolubKahanStep STDataFrame s t '[nm]
aPtr STDataFrame s t '[nm]
bPtr STDataFrame s t '[n, n]
uPtr STDataFrame s t '[m, m]
vPtr Int
p Int
q
  | Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 Bool -> Bool -> Bool
|| Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
q Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
nm
    = String -> ST s ()
forall a. IterativeMethod => String -> a
error (String -> ST s ()) -> String -> ST s ()
forall a b. (a -> b) -> a -> b
$ [String] -> String
unwords
        [ String
"svdGolubKahanStep: p =", Int -> String
forall a. Show a => a -> String
show Int
p, String
"and q =", Int -> String
forall a. Show a => a -> String
show Int
q
        , String
"do not satisfy p <= q - 2 or 0 <= p < q <=", Int -> String
forall a. Show a => a -> String
show Int
nm]
  | Dict (nm ~ Min n m)
Dict <- (nm ~ Min n m) => Dict (nm ~ Min n m)
forall (a :: Constraint). a => Dict a
Dict @(nm ~ Min n m) = do
    (Scalar t
y,Scalar t
z) <- ST s (Scalar t, Scalar t)
getWilkinsonShiftYZ
    Scalar t -> Scalar t -> Int -> ST s ()
goGivens2 Scalar t
y Scalar t
z Int
p
  where
    nm :: Int
nm = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim nm => Word
forall k (n :: k). KnownDim n => Word
dimVal' @nm :: Int

    -- get initial values for one recursion sweep.
    -- Note, input must satisfy: q >= p+2
    getWilkinsonShiftYZ :: ST s (Scalar t, Scalar t)
    getWilkinsonShiftYZ :: ST s (Scalar t, Scalar t)
getWilkinsonShiftYZ  = do
      Scalar t
a1 <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr Int
p
      Scalar t
b1 <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
bPtr Int
p
      Scalar t
am <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr (Int
qInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2)
      Scalar t
an <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr (Int
qInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
      Scalar t
bm <- if Int
q Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
3
            then STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
bPtr (Int
qInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
3)
            else Scalar t -> ST s (Scalar t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Scalar t
0
      Scalar t
bn <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
bPtr (Int
qInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2)
      let t11 :: Scalar t
t11 = Scalar t
a1Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
a1
          t12 :: Scalar t
t12 = Scalar t
a1Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
b1
          tmm :: Scalar t
tmm = Scalar t
amScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
am Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
bmScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
bm
          tnn :: Scalar t
tnn = Scalar t
anScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
an Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
bnScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
bn
          tnm :: Scalar t
tnm = Scalar t
amScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
bn
          d :: Scalar t
d   = Scalar t
0.5Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*(Scalar t
tmm Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
tnn)
          mu :: Scalar t
mu  = Scalar t
tnn Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
d Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Bool -> Scalar t -> Scalar t
forall t. Num t => Bool -> t -> t
negateUnless (Scalar t
d Scalar t -> Scalar t -> Bool
forall a. Ord a => a -> a -> Bool
>= Scalar t
0) (Scalar t -> Scalar t -> Scalar t
forall a. RealFloatExtras a => a -> a -> a
hypot Scalar t
d Scalar t
tnm)
      (Scalar t, Scalar t) -> ST s (Scalar t, Scalar t)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scalar t
t11 Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
mu, Scalar t
t12)

    -- yv = b[k-1]; zv = B[k-1,k+1] -- to be eliminated by 1st Givens r
    -- yu = a[k];   zu = B[k+1,k-1] -- to be eliminated by 2nd Givens r
    goGivens2 :: Scalar t -> Scalar t -> Int -> ST s ()
    goGivens2 :: Scalar t -> Scalar t -> Int -> ST s ()
goGivens2 Scalar t
yv Scalar t
zv Int
k = do
          Scalar t
a1 <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr Int
k     -- B[k,k]
          Scalar t
a2 <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
aPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) -- B[k+1,k+1]
          Scalar t
b1 <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
bPtr Int
k     -- B[k,k+1]
          let a1' :: Scalar t
a1' = Scalar t
a1Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cv Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
b1Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
sv  -- B[k,k] == yu
              a2' :: Scalar t
a2' = Scalar t
a2Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cv          -- B[k+1,k+1]
              b0' :: Scalar t
b0' = Scalar t
yvScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cv Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
zvScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
sv  -- B[k-1,k]
              b1' :: Scalar t
b1' = Scalar t
b1Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cv Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
a1Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
sv  -- B[k,k+1]
              yu :: Scalar t
yu  = Scalar t
a1'            -- B[k,k]
              zu :: Scalar t
zu  = Scalar t
a2Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
sv          -- B[k+1,k]
              ryzu :: Scalar t
ryzu = Scalar t -> Scalar t
forall a. Fractional a => a -> a
recip (Scalar t -> Scalar t) -> Scalar t -> Scalar t
forall a b. (a -> b) -> a -> b
$ Scalar t -> Scalar t -> Scalar t
forall a. RealFloatExtras a => a -> a -> a
hypot Scalar t
yu Scalar t
zu
              cu :: Scalar t
cu = Scalar t
yu Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
ryzu
              su :: Scalar t
su = Scalar t
zu Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
ryzu
              a1'' :: Scalar t
a1'' = Scalar t
yu Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cu Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
zu Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
su
              a2'' :: Scalar t
a2'' = Scalar t
a2'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cu Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
b1'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
su
              b1'' :: Scalar t
b1'' = Scalar t
b1'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cu Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
a2'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
su
          STDataFrame s t '[m, m]
-> Int -> Int -> Scalar t -> Scalar t -> ST s ()
forall s t (n :: Nat).
(PrimBytes t, Num t, KnownDim n) =>
STDataFrame s t '[n, n]
-> Int -> Int -> Scalar t -> Scalar t -> ST s ()
updateGivensMat STDataFrame s t '[m, m]
vPtr Int
k (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Scalar t
cv Scalar t
sv
          STDataFrame s t '[n, n]
-> Int -> Int -> Scalar t -> Scalar t -> ST s ()
forall s t (n :: Nat).
(PrimBytes t, Num t, KnownDim n) =>
STDataFrame s t '[n, n]
-> Int -> Int -> Scalar t -> Scalar t -> ST s ()
updateGivensMat STDataFrame s t '[n, n]
uPtr Int
k (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Scalar t
cu Scalar t
su

          Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
p) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ STDataFrame s t '[nm] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[nm]
bPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Scalar t
b0'
          STDataFrame s t '[nm] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[nm]
bPtr Int
k Scalar t
b1''
          STDataFrame s t '[nm] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[nm]
aPtr Int
k Scalar t
a1''
          STDataFrame s t '[nm] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[nm]
aPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Scalar t
a2''
          Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
q Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            Scalar t
b2 <- STDataFrame s t '[nm] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[nm]
bPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) -- B[k+1,k+2]
            let b2'' :: Scalar t
b2'' = Scalar t
b2Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
cu
                zvn :: Scalar t
zvn  = Scalar t
b2Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
su
            STDataFrame s t '[nm] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[nm]
bPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Scalar t
b2''
            Scalar t -> Scalar t -> Int -> ST s ()
goGivens2 Scalar t
b1'' Scalar t
zvn (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
        where
          ryzv :: Scalar t
ryzv = Scalar t -> Scalar t
forall a. Fractional a => a -> a
recip (Scalar t -> Scalar t) -> Scalar t -> Scalar t
forall a b. (a -> b) -> a -> b
$ Scalar t -> Scalar t -> Scalar t
forall a. RealFloatExtras a => a -> a -> a
hypot Scalar t
yv Scalar t
zv
          cv :: Scalar t
cv = Scalar t
yv Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
ryzv
          sv :: Scalar t
sv = Scalar t
zv Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
ryzv

-- | Update a transformation matrix with a Givens transform (on the right)
updateGivensMat ::
       forall (s :: Type) (t :: Type) (n :: Nat)
     . (PrimBytes t, Num t, KnownDim n)
    => STDataFrame s t '[n,n]
    -> Int -> Int
    -> Scalar t -> Scalar t -> ST s ()
updateGivensMat :: STDataFrame s t '[n, n]
-> Int -> Int -> Scalar t -> Scalar t -> ST s ()
updateGivensMat STDataFrame s t '[n, n]
p Int
i Int
j Scalar t
c Scalar t
s = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
    let nk :: Int
nk = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k
        ioff :: Int
ioff = Int
nk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
        joff :: Int
joff = Int
nk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
    Scalar t
uki <- STDataFrame s t '[n, n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
p Int
ioff
    Scalar t
ukj <- STDataFrame s t '[n, n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
p Int
joff
    STDataFrame s t '[n, n] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
p Int
ioff (Scalar t -> ST s ()) -> Scalar t -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t
ukiScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
c Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
ukjScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
s
    STDataFrame s t '[n, n] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
p Int
joff (Scalar t -> ST s ()) -> Scalar t -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t
ukjScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
c Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
ukiScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
s
  where
    n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n :: Int


minIsSmaller :: forall (n :: Nat) (m :: Nat)
              . Dim n -> Dim m -> Dict (Min n m <= n, Min n m <= m)
minIsSmaller :: Dim n -> Dim m -> Dict (Min n m <= n, Min n m <= m)
minIsSmaller Dim n
dn Dim m
dm
  | Just Dict (Min n m <= n)
Dict <- Dim (Min n m) -> Dim n -> Maybe (Dict (Min n m <= n))
forall (x :: Nat) (y :: Nat).
Dim x -> Dim y -> Maybe (Dict (x <= y))
lessOrEqDim Dim (Min n m)
dnm Dim n
dn
  , Just Dict (Min n m <= m)
Dict <- Dim (Min n m) -> Dim m -> Maybe (Dict (Min n m <= m))
forall (x :: Nat) (y :: Nat).
Dim x -> Dim y -> Maybe (Dict (x <= y))
lessOrEqDim Dim (Min n m)
dnm Dim m
dm
    = Dict (Min n m <= n, Min n m <= m)
forall (a :: Constraint). a => Dict a
Dict
  | Bool
otherwise
    = String -> Dict (Min n m <= n, Min n m <= m)
forall a. IterativeMethod => String -> a
error String
"minIsSmaller: impossible type-level comparison"
  where
    dnm :: Dim (Min n m)
dnm = Dim n -> Dim m -> Dim (Min n m)
forall (n :: Nat) (m :: Nat). Dim n -> Dim m -> Dim (Min n m)
minDim Dim n
dn Dim m
dm