{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{- |
A few ways to solve a system of linear equations in ST monad.
The tesult is always computed inplace.
 -}
module Numeric.Subroutine.SolveTriangular
  ( solveUpperTriangularR
  , solveUpperTriangularL
  , solveLowerTriangularR
  , solveLowerTriangularL
  ) where


import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.Type
import Numeric.Dimensions


{- |
Solve a system of linear equations \( Rx = b \)
   or a linear least squares problem \( \min {|| Rx - b ||}^2 \),
where \( R \) is an upper-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b_m = x \).

NB: you can use `subDataFrameView` to truncate @b@ without performing a copy.
 -}
solveUpperTriangularR ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
     . (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n)
    => DataFrame t '[n,m]       -- ^ \(R\)
    -> STDataFrame s t  (m :+ ds) -- ^ Current state of \(b_m\)
                                  --   (first @m@ rows of @b@)
    -> ST s ()
solveUpperTriangularR :: DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s ()
solveUpperTriangularR DataFrame t '[n, m]
r STDataFrame s t (m :+ ds)
bPtr | Dict (m <= n)
Dict <- (m <= n) => Dict (m <= n)
forall (a :: Constraint). a => Dict a
Dict @(m <= n) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1,Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2..Int
0]
  where
    m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
    k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
kw :: Int
    CumulDims (Word
_:Word
kw:[Word]
_) = STDataFrame s t (m :+ ds) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (m :+ ds)
bPtr
    go :: Int -> ST s ()
    go :: Int -> ST s ()
go Int
i | DataFrame t '[]
rii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> STDataFrame s t (m :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (m :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
         | Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
            DataFrame t '[]
bij <- (DataFrame t '[]
rriiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t (m :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (m :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
            STDataFrame s t (m :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (m :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
bij
            [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
t -> do
              let rti :: DataFrame t '[]
rti = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
r)
                  ix :: Int
ix = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
              DataFrame t '[]
btj <- STDataFrame s t (m :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (m :+ ds)
bPtr Int
ix
              STDataFrame s t (m :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (m :+ ds)
bPtr Int
ix (DataFrame t '[]
btj DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
rtiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bij)
      where
        mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
        ki :: Int
ki = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
        rii :: DataFrame t '[]
rii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
r)
        rrii :: DataFrame t '[]
rrii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
rii


{- |
Solve a system of linear equations \( xR = b \),
where \( R \) is an upper-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b = x_m \).
The \( (n - m) \) rows of \(R\) are not used.
Pad each dimension of \(x\) with \( (n - m) \) zeros if you want to get the full
solution.
 -}
solveUpperTriangularL ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
     . (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n)
    => STDataFrame s t  (ds +: m) -- ^ Current state of \(b\)
                                  --   (first @m@ "columns" of x)
    -> DataFrame t '[n,m]         -- ^ \(R\)
    -> ST s ()
solveUpperTriangularL :: STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveUpperTriangularL STDataFrame s t (ds +: m)
bPtr DataFrame t '[n, m]
r | Dict (m <= n)
Dict <- (m <= n) => Dict (m <= n)
forall (a :: Constraint). a => Dict a
Dict @(m <= n) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
0..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
  where
    m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
    k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
mkw Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
m
    CumulDims (Word
mkw:[Word]
_) = STDataFrame s t (ds +: m) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (ds +: m)
bPtr
    go :: Int -> ST s ()
    go :: Int -> ST s ()
go Int
i | DataFrame t '[]
rii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
0
         | Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
            let mj :: Int
mj = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j
            DataFrame t '[]
bji <- (DataFrame t '[]
rriiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
            STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
bji
            [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
t -> do
              let rit :: DataFrame t '[]
rit = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t) DataFrame t '[n, m]
r)
              DataFrame t '[]
bjt <- STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t)
              STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t) (DataFrame t '[]
bjt DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
ritDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bji)
      where
        mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
        rii :: DataFrame t '[]
rii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
r)
        rrii :: DataFrame t '[]
rrii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
rii

{- |
Solve a system of linear equations \( Lx = b \),
where \( L \) is a lower-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b = x_n \).
The \( (m - n) \) columns of \(L\) are not used.
Pad \(x\) with \( (m - n) \) zero elements if you want to get the full solution.
 -}
solveLowerTriangularR ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
     . (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m)
    => DataFrame t '[n,m]         -- ^ \(L\)
    -> STDataFrame s t  (n :+ ds) -- ^ Current state of \(b\)
                                  --   (first @n@ elements of x)
    -> ST s ()
solveLowerTriangularR :: DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s ()
solveLowerTriangularR DataFrame t '[n, m]
l STDataFrame s t (n :+ ds)
bPtr | Dict (n <= m)
Dict <- (n <= m) => Dict (n <= m)
forall (a :: Constraint). a => Dict a
Dict @(n <= m) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
  where
    m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
    n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n :: Int
    k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
kw :: Int
    CumulDims (Word
_:Word
kw:[Word]
_) = STDataFrame s t (n :+ ds) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (n :+ ds)
bPtr
    go :: Int -> ST s ()
    go :: Int -> ST s ()
go Int
i | DataFrame t '[]
lii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> STDataFrame s t (n :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (n :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
         | Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
            DataFrame t '[]
bij <- (DataFrame t '[]
rliiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t (n :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (n :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
            STDataFrame s t (n :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (n :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
bij
            [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
t -> do
              let rti :: DataFrame t '[]
rti = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
l)
                  ix :: Int
ix = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
              DataFrame t '[]
btj <- STDataFrame s t (n :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (n :+ ds)
bPtr Int
ix
              STDataFrame s t (n :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (n :+ ds)
bPtr Int
ix (DataFrame t '[]
btj DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
rtiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bij)
      where
        mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
        ki :: Int
ki = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
        lii :: DataFrame t '[]
lii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
l)
        rlii :: DataFrame t '[]
rlii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
lii

{- |
Solve a system of linear equations \( xL = b \)
   or a linear least squares problem \( \min {|| xL - b ||}^2 \),
where \( L \) is a lower-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b_n = x \).
The last \( (m - n) \) columns of \(L\) and \(b\) and are not touched.
 -}
solveLowerTriangularL ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
     . (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m)
    => STDataFrame s t  (ds +: m) -- ^ Current state of \(b\)
    -> DataFrame t '[n,m]         -- ^ \(L\)
    -> ST s ()
solveLowerTriangularL :: STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveLowerTriangularL STDataFrame s t (ds +: m)
bPtr DataFrame t '[n, m]
l | Dict (n <= m)
Dict <- (n <= m) => Dict (n <= m)
forall (a :: Constraint). a => Dict a
Dict @(n <= m) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1,Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2..Int
0]
  where
    m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
    n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n :: Int
    k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
kmw Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
m
    CumulDims (Word
kmw:[Word]
_) = STDataFrame s t (ds +: m) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (ds +: m)
bPtr
    go :: Int -> ST s ()
    go :: Int -> ST s ()
go Int
i | DataFrame t '[]
lii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
0
         | Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
            let mj :: Int
mj = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j
            DataFrame t '[]
bji <- (DataFrame t '[]
rliiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
            STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
bji
            [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
t -> do
              let lit :: DataFrame t '[]
lit = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t) DataFrame t '[n, m]
l)
                  ix :: Int
ix = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t
              DataFrame t '[]
bjt <- STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr Int
ix
              STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr Int
ix (DataFrame t '[]
bjt DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
litDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bji)
      where
        mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
        lii :: DataFrame t '[]
lii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
l)
        rlii :: DataFrame t '[]
rlii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
lii