```{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{- |
A few ways to solve a system of linear equations in ST monad.
The tesult is always computed inplace.
-}
module Numeric.Subroutine.SolveTriangular
( solveUpperTriangularR
, solveUpperTriangularL
, solveLowerTriangularR
, solveLowerTriangularL
) where

import Data.Kind
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.Type
import Numeric.Dimensions

{- |
Solve a system of linear equations \( Rx = b \)
or a linear least squares problem \( \min {|| Rx - b ||}^2 \),
where \( R \) is an upper-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b_m = x \).

NB: you can use `subDataFrameView` to truncate @b@ without performing a copy.
-}
solveUpperTriangularR ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n)
=> DataFrame t '[n,m]       -- ^ \(R\)
-> STDataFrame s t  (m :+ ds) -- ^ Current state of \(b_m\)
--   (first @m@ rows of @b@)
-> ST s ()
solveUpperTriangularR :: DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s ()
solveUpperTriangularR DataFrame t '[n, m]
r STDataFrame s t (m :+ ds)
bPtr | Dict (m <= n)
Dict <- (m <= n) => Dict (m <= n)
forall (a :: Constraint). a => Dict a
Dict @(m <= n) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1,Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2..Int
0]
where
m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
\$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
kw :: Int
CumulDims (Word
_:Word
kw:[Word]
_) = STDataFrame s t (m :+ ds) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (m :+ ds)
bPtr
go :: Int -> ST s ()
go :: Int -> ST s ()
go Int
i | DataFrame t '[]
rii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> STDataFrame s t (m :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (m :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
| Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> do
DataFrame t '[]
bij <- (DataFrame t '[]
rriiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<\$> STDataFrame s t (m :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (m :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
STDataFrame s t (m :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (m :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
bij
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
t -> do
let rti :: DataFrame t '[]
rti = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
r)
ix :: Int
ix = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
DataFrame t '[]
btj <- STDataFrame s t (m :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (m :+ ds)
bPtr Int
ix
STDataFrame s t (m :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (m :+ ds)
bPtr Int
ix (DataFrame t '[]
btj DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
rtiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bij)
where
mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
ki :: Int
ki = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
rii :: DataFrame t '[]
rii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
r)
rrii :: DataFrame t '[]
rrii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
rii

{- |
Solve a system of linear equations \( xR = b \),
where \( R \) is an upper-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b = x_m \).
The \( (n - m) \) rows of \(R\) are not used.
Pad each dimension of \(x\) with \( (n - m) \) zeros if you want to get the full
solution.
-}
solveUpperTriangularL ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n)
=> STDataFrame s t  (ds +: m) -- ^ Current state of \(b\)
--   (first @m@ "columns" of x)
-> DataFrame t '[n,m]         -- ^ \(R\)
-> ST s ()
solveUpperTriangularL :: STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveUpperTriangularL STDataFrame s t (ds +: m)
bPtr DataFrame t '[n, m]
r | Dict (m <= n)
Dict <- (m <= n) => Dict (m <= n)
forall (a :: Constraint). a => Dict a
Dict @(m <= n) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
0..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
where
m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
\$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
mkw Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
m
CumulDims (Word
mkw:[Word]
_) = STDataFrame s t (ds +: m) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (ds +: m)
bPtr
go :: Int -> ST s ()
go :: Int -> ST s ()
go Int
i | DataFrame t '[]
rii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
0
| Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> do
let mj :: Int
mj = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j
DataFrame t '[]
bji <- (DataFrame t '[]
rriiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<\$> STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
bji
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
t -> do
let rit :: DataFrame t '[]
rit = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t) DataFrame t '[n, m]
r)
DataFrame t '[]
bjt <- STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t)
STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t) (DataFrame t '[]
bjt DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
ritDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bji)
where
mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
rii :: DataFrame t '[]
rii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
r)
rrii :: DataFrame t '[]
rrii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
rii

{- |
Solve a system of linear equations \( Lx = b \),
where \( L \) is a lower-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b = x_n \).
The \( (m - n) \) columns of \(L\) are not used.
Pad \(x\) with \( (m - n) \) zero elements if you want to get the full solution.
-}
solveLowerTriangularR ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m)
=> DataFrame t '[n,m]         -- ^ \(L\)
-> STDataFrame s t  (n :+ ds) -- ^ Current state of \(b\)
--   (first @n@ elements of x)
-> ST s ()
solveLowerTriangularR :: DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s ()
solveLowerTriangularR DataFrame t '[n, m]
l STDataFrame s t (n :+ ds)
bPtr | Dict (n <= m)
Dict <- (n <= m) => Dict (n <= m)
forall (a :: Constraint). a => Dict a
Dict @(n <= m) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
where
m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
\$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
\$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n :: Int
k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
kw :: Int
CumulDims (Word
_:Word
kw:[Word]
_) = STDataFrame s t (n :+ ds) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (n :+ ds)
bPtr
go :: Int -> ST s ()
go :: Int -> ST s ()
go Int
i | DataFrame t '[]
lii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> STDataFrame s t (n :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (n :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
| Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> do
DataFrame t '[]
bij <- (DataFrame t '[]
rliiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<\$> STDataFrame s t (n :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (n :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
STDataFrame s t (n :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (n :+ ds)
bPtr (Int
ki Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
bij
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
t -> do
let rti :: DataFrame t '[]
rti = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
l)
ix :: Int
ix = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
DataFrame t '[]
btj <- STDataFrame s t (n :+ ds) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (n :+ ds)
bPtr Int
ix
STDataFrame s t (n :+ ds) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (n :+ ds)
bPtr Int
ix (DataFrame t '[]
btj DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
rtiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bij)
where
mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
ki :: Int
ki = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
lii :: DataFrame t '[]
lii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
l)
rlii :: DataFrame t '[]
rlii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
lii

{- |
Solve a system of linear equations \( xL = b \)
or a linear least squares problem \( \min {|| xL - b ||}^2 \),
where \( L \) is a lower-triangular matrix.

DataFrame \( b \) is modified in-place; by the end of the process \( b_n = x \).
The last \( (m - n) \) columns of \(L\) and \(b\) and are not touched.
-}
solveLowerTriangularL ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m)
=> STDataFrame s t  (ds +: m) -- ^ Current state of \(b\)
-> DataFrame t '[n,m]         -- ^ \(L\)
-> ST s ()
solveLowerTriangularL :: STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveLowerTriangularL STDataFrame s t (ds +: m)
bPtr DataFrame t '[n, m]
l | Dict (n <= m)
Dict <- (n <= m) => Dict (n <= m)
forall (a :: Constraint). a => Dict a
Dict @(n <= m) = (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Int -> ST s ()
go [Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1,Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2..Int
0]
where
m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
\$ KnownDim m => Word
forall k (n :: k). KnownDim n => Word
dimVal' @m :: Int
n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
\$ KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n :: Int
k :: Int
k = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
kmw Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
m
CumulDims (Word
kmw:[Word]
_) = STDataFrame s t (ds +: m) -> CumulDims
forall k t (ns :: [k]) s. STDataFrame s t ns -> CumulDims
getDataFrameSteps STDataFrame s t (ds +: m)
bPtr
go :: Int -> ST s ()
go :: Int -> ST s ()
go Int
i | DataFrame t '[]
lii DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame t '[]
0  = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
0
| Bool
otherwise = [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
j -> do
let mj :: Int
mj = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j
DataFrame t '[]
bji <- (DataFrame t '[]
rliiDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<\$> STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr (Int
mj Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[]
bji
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
\$ \Int
t -> do
let lit :: DataFrame t '[]
lit = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t) DataFrame t '[n, m]
l)
ix :: Int
ix = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t
DataFrame t '[]
bjt <- STDataFrame s t (ds +: m) -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t (ds +: m)
bPtr Int
ix
STDataFrame s t (ds +: m) -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t (ds +: m)
bPtr Int
ix (DataFrame t '[]
bjt DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
litDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
bji)
where
mi :: Int
mi = Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
lii :: DataFrame t '[]
lii = t -> DataFrame t '[]
forall t. t -> DataFrame t '[]
scalar (Int -> DataFrame t '[n, m] -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff (Int
mi Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) DataFrame t '[n, m]
l)
rlii :: DataFrame t '[]
rlii = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
lii
```