{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Numeric.Subroutine.Householder
  ( householderReflectionInplaceR
  , householderReflectionInplaceL
  ) where


import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.Basics
import Numeric.DataFrame.ST
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Scalar.Internal


{- | Run a Householder transformation inplace.

     Given some orthongonal matrix \(P\), some matrix \(R\) and index \((k,l)\),
     reflects \(R\) along some hyperplane, such that all elements of \(R\)
     below index \( (k, l) \) become zeros,
     then updates \(P\) with the inverse of the same transform as \(R\).

     Notes and invariants:

       1. The transformation happens inplace for both matrices \(P\) and \(R\);
          if \( R = P^\intercal A \), then \( R' = P^*PR = P'^\intercal A \), where
           \( P' \) and \( R' \) are the updated versions of the input matrices,
           \( P^* \) and \( A \) are implicit matrices.

       2. All elements below and to the left of index \(k,l\) in \(R\)
          are assumed (and not checked) to be zeros;
          these are not touched by the subroutine to save flops.

       3. A logical starting value for \(P\) is an identity matrix.
          The subroutine can be used for a QR decomposition:
            \( Q = P \).

     Returns @True@ if reflection has been performed, and @False@ if it was not needed.
     This can be used to track the sign of @det P@.
 -}
householderReflectionInplaceL ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m)
    => STDataFrame s t '[n] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[n,n]  -- ^ Current state of \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ Current state of \(R\)
    -> Idxs '[n,m] -- ^ Pivot element
    -> ST s Bool
householderReflectionInplaceL :: STDataFrame s t '[n]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[n, m]
-> Idxs '[n, m]
-> ST s Bool
householderReflectionInplaceL STDataFrame s t '[n]
u STDataFrame s t '[n, n]
p STDataFrame s t '[n, m]
r (Idx Word
i :* Idx Word
j :* TypedList Idx ys
U)
    = STDataFrame s t '[n]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[n, m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
forall s t (n :: Nat) (m :: Nat).
(PrimBytes t, Epsilon t, Ord t) =>
STDataFrame s t '[n]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[n, m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
householderReflectionInplaceL' STDataFrame s t '[n]
u STDataFrame s t '[n, n]
p STDataFrame s t '[n, m]
r
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n)
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m)
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
i)
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
j)

householderReflectionInplaceL' ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Epsilon t, Ord t)
    => STDataFrame s t '[n] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[n,n]  -- ^ \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ \(R\)
    -> Int -- ^ \(n\)
    -> Int -- ^ \(m\)
    -> Int -- ^ \( 0 \leq k < n \)
    -> Int -- ^ \( 0 \leq l < m \)
    -> ST s Bool
householderReflectionInplaceL' :: STDataFrame s t '[n]
-> STDataFrame s t '[n, n]
-> STDataFrame s t '[n, m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
householderReflectionInplaceL' STDataFrame s t '[n]
uPtr STDataFrame s t '[n, n]
pPtr STDataFrame s t '[n, m]
rPtr Int
n Int
m Int
k Int
l = do
    -- pivot element (k,l) of new R
    Scalar t
alpha <- ST s (Scalar t)
getAlphaAndUpdateU
    Scalar t
u2 <- ST s (Scalar t)
getU2
    -- u2 == 0 means the column is already zeroed
    if Scalar t
u2 Scalar t -> Scalar t -> Bool
forall a. Ord a => a -> a -> Bool
> Scalar t
forall a. Epsilon a => a
M_EPS
    then do
      let c :: Scalar t
c = Scalar t
2 Scalar t -> Scalar t -> Scalar t
forall a. Fractional a => a -> a -> a
/ Scalar t
u2 -- a mult constant for updating matrices
      -- update R
      Scalar t -> ST s ()
updateRl Scalar t
alpha
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t -> Int -> ST s ()
updateRi Scalar t
c
      -- update P
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t -> Int -> ST s ()
updatePi Scalar t
c
      Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    else Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  where
    n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k -- remaining rows
    rOff0 :: Int
rOff0 = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l -- offset of element (k,l) in matrix R

    -- u = Rk - alpha*ek
    getAlphaAndUpdateU :: ST s (Scalar t)
    getAlphaAndUpdateU :: ST s (Scalar t)
getAlphaAndUpdateU = do
      Scalar t
alpha' <- Scalar t -> Scalar t
forall a. Floating a => a -> a
sqrt (Scalar t -> Scalar t)
-> ((Scalar t, Int) -> Scalar t) -> (Scalar t, Int) -> Scalar t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scalar t, Int) -> Scalar t
forall a b. (a, b) -> a
fst ((Scalar t, Int) -> Scalar t)
-> ST s (Scalar t, Int) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
n'
        (\(Scalar t
r, Int
off) -> do
          Scalar t
x <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off
          Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Scalar t -> Scalar t
forall a. Num a => a -> a
abs Scalar t
x Scalar t -> Scalar t -> Bool
forall a. Ord a => a -> a -> Bool
<= Scalar t
forall a. Epsilon a => a
M_EPS) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off Scalar t
0
          (Scalar t, Int) -> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
xScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
x, Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
m)
        ) (Scalar t
0, Int
rOff0)
      Scalar t
x0 <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr Int
rOff0
      let alpha :: Scalar t
alpha = if Scalar t
x0 Scalar t -> Scalar t -> Bool
forall a. Ord a => a -> a -> Bool
>= Scalar t
0 then Scalar t -> Scalar t
forall a. Num a => a -> a
negate Scalar t
alpha' else Scalar t
alpha'
      -- update (lower part of) u
      STDataFrame s t '[n] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
uPtr Int
k (Scalar t
x0 Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
alpha)
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
1) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ ST s (Int, Int) -> ST s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST s (Int, Int) -> ST s ()) -> ST s (Int, Int) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int
-> ((Int, Int) -> ST s (Int, Int)) -> (Int, Int) -> ST s (Int, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM (Int
n' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        (\(Int
i, Int
off) -> (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Int
offInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m) (Int, Int) -> ST s () -> ST s (Int, Int)
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$
          (STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off ST s (Scalar t) -> (Scalar t -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s t '[n] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n]
uPtr Int
i)
        ) (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Int
rOff0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m)
      Scalar t -> ST s (Scalar t)
forall (m :: * -> *) a. Monad m => a -> m a
return Scalar t
alpha

    -- l-th column of R zeroes below pivot
    updateRl :: Scalar t -> ST s ()
    updateRl :: Scalar t -> ST s ()
updateRl Scalar t
alpha = do
      STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr Int
rOff0 Scalar t
alpha
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
n' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
1) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ ST s Int -> ST s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST s Int -> ST s ()) -> ST s Int -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> ST s Int) -> Int -> ST s Int
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM (Int
n' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
        (\Int
off -> (Int
offInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m) Int -> ST s () -> ST s Int
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off Scalar t
0) (Int
rOff0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m)

    -- update i-th column of R
    updateRi :: Scalar t -> Int -> ST s ()
    updateRi :: Scalar t -> Int -> ST s ()
updateRi Scalar t
c Int
i = do
      -- dot product of u and Ri
      Scalar t
uRi <- ((Scalar t, (Int, Int)) -> Scalar t)
-> ST s (Scalar t, (Int, Int)) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar t, (Int, Int)) -> Scalar t
forall a b. (a, b) -> a
fst (ST s (Scalar t, (Int, Int)) -> ST s (Scalar t))
-> (((Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
    -> ST s (Scalar t, (Int, Int)))
-> ((Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
-> ST s (Scalar t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
 -> (Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
-> (Scalar t, (Int, Int))
-> ((Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
-> ST s (Scalar t, (Int, Int))
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int
-> ((Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
-> (Scalar t, (Int, Int))
-> ST s (Scalar t, (Int, Int))
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
n') (Scalar t
0, (Int
k, Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
i)) (((Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
 -> ST s (Scalar t))
-> ((Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int)))
-> ST s (Scalar t)
forall a b. (a -> b) -> a -> b
$ \(Scalar t
r, (Int
j, Int
off)) -> do
        Scalar t
ju  <- STDataFrame s t '[n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
uPtr Int
j
        Scalar t
jiR <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off
        (Scalar t, (Int, Int)) -> ST s (Scalar t, (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
ju Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
jiR, (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Int
offInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m))
      let c' :: Scalar t
c' = Scalar t
c Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
uRi
      -- update each element
      ST s (Int, Int) -> ST s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST s (Int, Int) -> ST s ()) -> ST s (Int, Int) -> ST s ()
forall a b. (a -> b) -> a -> b
$ (((Int, Int) -> ST s (Int, Int)) -> (Int, Int) -> ST s (Int, Int))
-> (Int, Int) -> ((Int, Int) -> ST s (Int, Int)) -> ST s (Int, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int
-> ((Int, Int) -> ST s (Int, Int)) -> (Int, Int) -> ST s (Int, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
n') (Int
k, Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
i) (((Int, Int) -> ST s (Int, Int)) -> ST s (Int, Int))
-> ((Int, Int) -> ST s (Int, Int)) -> ST s (Int, Int)
forall a b. (a -> b) -> a -> b
$ \(Int
j, Int
off) -> do
        Scalar t
ju  <- STDataFrame s t '[n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
uPtr Int
j
        Scalar t
jiR <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off
        STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off (Scalar t -> ST s ()) -> Scalar t -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t
jiR Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
c'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
ju
        (Int, Int) -> ST s (Int, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Int
offInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m)

    -- update i-th row of P
    updatePi :: Scalar t -> Int -> ST s ()
    updatePi :: Scalar t -> Int -> ST s ()
updatePi Scalar t
c Int
i = do
      let off0 :: Int
off0 = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n
      -- dot product of u and Pi
      Scalar t
uPi <- ((Scalar t, Int) -> Scalar t)
-> ST s (Scalar t, Int) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar t, Int) -> Scalar t
forall a b. (a, b) -> a
fst (ST s (Scalar t, Int) -> ST s (Scalar t))
-> (((Scalar t, Int) -> ST s (Scalar t, Int))
    -> ST s (Scalar t, Int))
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> ST s (Scalar t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Scalar t, Int) -> ST s (Scalar t, Int))
 -> (Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> ST s (Scalar t, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
n') (Scalar t
0, Int
k) (((Scalar t, Int) -> ST s (Scalar t, Int)) -> ST s (Scalar t))
-> ((Scalar t, Int) -> ST s (Scalar t, Int)) -> ST s (Scalar t)
forall a b. (a -> b) -> a -> b
$ \(Scalar t
r, Int
j) -> do
        Scalar t
ju  <- STDataFrame s t '[n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
uPtr  Int
j
        Scalar t
ijP <- STDataFrame s t '[n, n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
pPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
        (Scalar t, Int) -> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
ju Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
ijP, Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      let c' :: Scalar t
c' = Scalar t
c Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
uPi
      -- update each element
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
k..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
        Scalar t
ju  <- STDataFrame s t '[n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
uPtr Int
j
        Scalar t
ijP <- STDataFrame s t '[n, n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
pPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
        STDataFrame s t '[n, n] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
pPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Scalar t -> ST s ()) -> Scalar t -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t
ijP Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
c'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
ju

    -- get module squared of u (for Q = I - 2 u*uT / |u|^2 )
    getU2 :: ST s (Scalar t)
    getU2 :: ST s (Scalar t)
getU2 = (Scalar t, Int) -> Scalar t
forall a b. (a, b) -> a
fst ((Scalar t, Int) -> Scalar t)
-> ST s (Scalar t, Int) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
n'
      (\(Scalar t
r, Int
off) -> (\Scalar t
x -> (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
xScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
x, Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) (Scalar t -> (Scalar t, Int))
-> ST s (Scalar t) -> ST s (Scalar t, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[n] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n]
uPtr Int
off) (Scalar t
0, Int
k)

{- | Run a Householder transformation inplace.

  Similar to `householderReflectionInplaceR`, but works from right to left
   - use to zero elements to the right from the pivot.

     Returns @True@ if reflection has been performed, and @False@ if it was not needed.
     This can be used to track the sign of @det P@.
 -}
householderReflectionInplaceR ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m)
    => STDataFrame s t '[m] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[m,m]  -- ^ Current state of \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ Current state of \(R\)
    -> Idxs '[n,m] -- ^ Pivot element
    -> ST s Bool
householderReflectionInplaceR :: STDataFrame s t '[m]
-> STDataFrame s t '[m, m]
-> STDataFrame s t '[n, m]
-> Idxs '[n, m]
-> ST s Bool
householderReflectionInplaceR STDataFrame s t '[m]
u STDataFrame s t '[m, m]
p STDataFrame s t '[n, m]
r (Idx Word
i :* Idx Word
j :* TypedList Idx ys
U)
    = STDataFrame s t '[m]
-> STDataFrame s t '[m, m]
-> STDataFrame s t '[n, m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
forall s t (n :: Nat) (m :: Nat).
(PrimBytes t, Epsilon t, Ord t) =>
STDataFrame s t '[m]
-> STDataFrame s t '[m, m]
-> STDataFrame s t '[n, m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
householderReflectionInplaceR' STDataFrame s t '[m]
u STDataFrame s t '[m, m]
p STDataFrame s t '[n, m]
r
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n)
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m)
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
i)
      (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
j)

householderReflectionInplaceR' ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Epsilon t, Ord t)
    => STDataFrame s t '[m] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[m,m]  -- ^ \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ \(R\)
    -> Int -- ^ \(n\)
    -> Int -- ^ \(m\)
    -> Int -- ^ \( 0 \leq k < n \)
    -> Int -- ^ \( 0 \leq l < m \)
    -> ST s Bool
householderReflectionInplaceR' :: STDataFrame s t '[m]
-> STDataFrame s t '[m, m]
-> STDataFrame s t '[n, m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
householderReflectionInplaceR' STDataFrame s t '[m]
uPtr STDataFrame s t '[m, m]
pPtr STDataFrame s t '[n, m]
rPtr Int
n Int
m Int
k Int
l = do
    -- pivot element (k,l) of new R
    Scalar t
alpha <- ST s (Scalar t)
getAlphaAndUpdateU
    Scalar t
u2 <- ST s (Scalar t)
getU2
    -- u2 == 0 means the column is already zeroed
    if Scalar t
u2 Scalar t -> Scalar t -> Bool
forall a. Ord a => a -> a -> Bool
> Scalar t
forall a. Epsilon a => a
M_EPS
    then do
      let c :: Scalar t
c = Scalar t
2 Scalar t -> Scalar t -> Scalar t
forall a. Fractional a => a -> a -> a
/ Scalar t
u2 -- a mult constant for updating matrices
      -- update R
      Scalar t -> ST s ()
updateRk Scalar t
alpha
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t -> Int -> ST s ()
updateRi Scalar t
c
      -- update P
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t -> Int -> ST s ()
updatePi Scalar t
c
      Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    else Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  where
    m' :: Int
m' = Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l -- remaining cols
    rOff0 :: Int
rOff0 = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
l -- offset of element (k,l) in matrix R

    -- u = Rl - alpha*el
    getAlphaAndUpdateU :: ST s (Scalar t)
    getAlphaAndUpdateU :: ST s (Scalar t)
getAlphaAndUpdateU = do
      Scalar t
alpha' <- Scalar t -> Scalar t
forall a. Floating a => a -> a
sqrt (Scalar t -> Scalar t)
-> ((Scalar t, Int) -> Scalar t) -> (Scalar t, Int) -> Scalar t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scalar t, Int) -> Scalar t
forall a b. (a, b) -> a
fst ((Scalar t, Int) -> Scalar t)
-> ST s (Scalar t, Int) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
m'
        (\(Scalar t
r, Int
off) -> do
          Scalar t
x <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off
          Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Scalar t -> Scalar t
forall a. Num a => a -> a
abs Scalar t
x Scalar t -> Scalar t -> Bool
forall a. Ord a => a -> a -> Bool
<= Scalar t
forall a. Epsilon a => a
M_EPS) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr Int
off Scalar t
0
          (Scalar t, Int) -> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
xScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
x, Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        ) (Scalar t
0, Int
rOff0)
      Scalar t
x0 <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr Int
rOff0
      let alpha :: Scalar t
alpha = if Scalar t
x0 Scalar t -> Scalar t -> Bool
forall a. Ord a => a -> a -> Bool
>= Scalar t
0 then Scalar t -> Scalar t
forall a. Num a => a -> a
negate Scalar t
alpha' else Scalar t
alpha'
      -- update (lower part of) u
      STDataFrame s t '[m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[m]
uPtr Int
l (Scalar t
x0 Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
alpha)
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
1..Int
m'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
        STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr (Int
rOff0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) ST s (Scalar t) -> (Scalar t -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s t '[m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[m]
uPtr (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
      Scalar t -> ST s (Scalar t)
forall (m :: * -> *) a. Monad m => a -> m a
return Scalar t
alpha

    -- k-th row of R zeroes below pivot
    updateRk :: Scalar t -> ST s ()
    updateRk :: Scalar t -> ST s ()
updateRk Scalar t
alpha = do
      STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr Int
rOff0 Scalar t
alpha
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
rOff0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
rOff0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
m'Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ (Int -> Scalar t -> ST s ()) -> Scalar t -> Int -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr) Scalar t
0

    -- update i-th row of R
    updateRi :: Scalar t -> Int -> ST s ()
    updateRi :: Scalar t -> Int -> ST s ()
updateRi Scalar t
c Int
i = do
      let off0 :: Int
off0 = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m
      -- dot product of u and Ri
      Scalar t
uRi <- ((Scalar t, Int) -> Scalar t)
-> ST s (Scalar t, Int) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar t, Int) -> Scalar t
forall a b. (a, b) -> a
fst (ST s (Scalar t, Int) -> ST s (Scalar t))
-> (((Scalar t, Int) -> ST s (Scalar t, Int))
    -> ST s (Scalar t, Int))
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> ST s (Scalar t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Scalar t, Int) -> ST s (Scalar t, Int))
 -> (Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> ST s (Scalar t, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
m') (Scalar t
0, Int
l) (((Scalar t, Int) -> ST s (Scalar t, Int)) -> ST s (Scalar t))
-> ((Scalar t, Int) -> ST s (Scalar t, Int)) -> ST s (Scalar t)
forall a b. (a -> b) -> a -> b
$ \(Scalar t
r, Int
j) -> do
        Scalar t
ju  <- STDataFrame s t '[m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m]
uPtr  Int
j
        Scalar t
jiR <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
        (Scalar t, Int) -> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
ju Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
jiR, Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      let c' :: Scalar t
c' = Scalar t
c Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
uRi
      -- update each element
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
l..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
        Scalar t
ju  <- STDataFrame s t '[m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m]
uPtr Int
j
        Scalar t
jiR <- STDataFrame s t '[n, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, m]
rPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
        STDataFrame s t '[n, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, m]
rPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Scalar t -> ST s ()) -> Scalar t -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t
jiR Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
c'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
ju

    -- update i-th row of P
    updatePi :: Scalar t -> Int -> ST s ()
    updatePi :: Scalar t -> Int -> ST s ()
updatePi Scalar t
c Int
i = do
      let off0 :: Int
off0 = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m
      -- dot product of u and Pi
      Scalar t
uPi <- ((Scalar t, Int) -> Scalar t)
-> ST s (Scalar t, Int) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar t, Int) -> Scalar t
forall a b. (a, b) -> a
fst (ST s (Scalar t, Int) -> ST s (Scalar t))
-> (((Scalar t, Int) -> ST s (Scalar t, Int))
    -> ST s (Scalar t, Int))
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> ST s (Scalar t)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Scalar t, Int) -> ST s (Scalar t, Int))
 -> (Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> ST s (Scalar t, Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
m') (Scalar t
0, Int
l) (((Scalar t, Int) -> ST s (Scalar t, Int)) -> ST s (Scalar t))
-> ((Scalar t, Int) -> ST s (Scalar t, Int)) -> ST s (Scalar t)
forall a b. (a -> b) -> a -> b
$ \(Scalar t
r, Int
j) -> do
        Scalar t
ju  <- STDataFrame s t '[m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m]
uPtr  Int
j
        Scalar t
ijP <- STDataFrame s t '[m, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m, m]
pPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
        (Scalar t, Int) -> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
ju Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
ijP, Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
      let c' :: Scalar t
c' = Scalar t
c Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
uPi
      -- update each element
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
l..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
        Scalar t
ju  <- STDataFrame s t '[m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m]
uPtr Int
j
        Scalar t
ijP <- STDataFrame s t '[m, m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m, m]
pPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
        STDataFrame s t '[m, m] -> Int -> Scalar t -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[m, m]
pPtr (Int
off0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Scalar t -> ST s ()) -> Scalar t -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scalar t
ijP Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
- Scalar t
c'Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
ju

    -- get module squared of u (for Q = I - 2 u*uT / |u|^2 )
    getU2 :: ST s (Scalar t)
    getU2 :: ST s (Scalar t)
getU2 = (Scalar t, Int) -> Scalar t
forall a b. (a, b) -> a
fst ((Scalar t, Int) -> Scalar t)
-> ST s (Scalar t, Int) -> ST s (Scalar t)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> ((Scalar t, Int) -> ST s (Scalar t, Int))
-> (Scalar t, Int)
-> ST s (Scalar t, Int)
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM Int
m'
      (\(Scalar t
r, Int
off) -> (\Scalar t
x -> (Scalar t
r Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
+ Scalar t
xScalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
*Scalar t
x, Int
off Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) (Scalar t -> (Scalar t, Int))
-> ST s (Scalar t) -> ST s (Scalar t, Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[m] -> Int -> ST s (Scalar t)
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[m]
uPtr Int
off) (Scalar t
0, Int
l)


nTimesM :: Monad m => Int -> (a -> m a) -> a -> m a
nTimesM :: Int -> (a -> m a) -> a -> m a
nTimesM Int
0 a -> m a
_ a
x = a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
nTimesM Int
n a -> m a
m a
x = a -> m a
m a
x m a -> (a -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> (a -> m a) -> a -> m a
forall (m :: * -> *) a. Monad m => Int -> (a -> m a) -> a -> m a
nTimesM (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) a -> m a
m