{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Numeric.Matrix.LU
  ( MatrixLU (..), LU (..)
  , luSolveR, luSolveL
  , detViaLU, inverseViaLU
  ) where

import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.SubSpace
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Matrix.Internal
import Numeric.Scalar.Internal
import Numeric.Subroutine.SolveTriangular


-- | Result of LU factorization with Partial Pivoting
--   \( PA = LU \).
data LU (t :: Type) (n :: Nat)
  = LU
  { LU t n -> Matrix t n n
luLower   :: Matrix t n n
    -- ^ Unit lower triangular matrix \(L\).
    --   All elements on the diagonal of @L@ equal @1@.
    --   The rest of the elements satisfy \(|l_{ij}| \leq 1\).
  , LU t n -> Matrix t n n
luUpper   :: Matrix t n n
    -- ^ Upper triangular matrix \(U\)
  , LU t n -> Matrix t n n
luPerm    :: Matrix t n n
    -- ^ Row permutation matrix \(P\)
  , LU t n -> Scalar t
luPermDet :: Scalar t
    -- ^ Sign of permutation @luPermDet == det . luPerm@; \(|P| = \pm 1\).
  }

deriving instance (Show t, PrimBytes t, KnownDim n) => Show (LU t n)
deriving instance (Eq (Matrix t n n), Eq t) => Eq (LU t n)

class (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
      => MatrixLU t (n :: Nat) where
    -- | Compute LU factorization with Partial Pivoting
    lu :: Matrix t n n -> LU t n


instance (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
         => MatrixLU t n where
    lu :: Matrix t n n -> LU t n
lu Matrix t n n
a = (forall s. ST s (LU t n)) -> LU t n
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (LU t n)) -> LU t n)
-> (forall s. ST s (LU t n)) -> LU t n
forall a b. (a -> b) -> a -> b
$ do
        STDataFrame s Word '[n]
pPtr <- DataFrame Word '[n] -> ST s (STDataFrame s Word '[n])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame (DataFrame Word '[n] -> ST s (STDataFrame s Word '[n]))
-> DataFrame Word '[n] -> ST s (STDataFrame s Word '[n])
forall a b. (a -> b) -> a -> b
$ (Idxs '[n] -> DataFrame Word '[]) -> DataFrame Word '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
(SubSpace t as bs asbs, Dimensions as) =>
(Idxs as -> DataFrame t bs) -> DataFrame t asbs
iwgen  @_ @'[n] @'[] (\(Idx Word
i :* TypedList Idx ys
U) -> Word -> DataFrame Word '[]
forall t. t -> DataFrame t '[]
S Word
i)
        STDataFrame s t '[n, n]
uPtr <- Matrix t n n -> ST s (STDataFrame s t '[n, n])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame Matrix t n n
a
        STDataFrame s t '[n, n]
lPtr <- ST s (STDataFrame s t '[n, n])
forall k t (ns :: [k]) s.
(PrimBytes t, Dimensions ns) =>
ST s (STDataFrame s t ns)
newDataFrame
        STDataFrame s t '[n]
temp <- ST s (STDataFrame s t '[n])
forall k t (ns :: [k]) s.
(PrimBytes t, Dimensions ns) =>
ST s (STDataFrame s t ns)
newDataFrame
        Bool
detPositive <- STDataFrame s t '[n]
-> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool
forall s t (n :: Nat).
(PrimBytes t, Fractional t, Ord t, KnownDim n) =>
STDataFrame s t '[n]
-> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool
luInplace STDataFrame s t '[n]
temp STDataFrame s Word '[n]
pPtr STDataFrame s t '[n, n]
uPtr
        DataFrame Word '[n]
p <- STDataFrame s Word '[n] -> ST s (DataFrame Word '[n])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s Word '[n]
pPtr
        -- split U and L
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          let ni :: Int
ni = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
          [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
i Int
j of
              Ordering
GT -> do
                DataFrame t '[]
lij <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
uPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
                STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
uPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
                STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
lPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
lij
              Ordering
EQ -> STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
lPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
1
              Ordering
LT -> STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
lPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
        Matrix t n n
luLower <- STDataFrame s t '[n, n] -> ST s (Matrix t n n)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
lPtr
        Matrix t n n
luUpper <- STDataFrame s t '[n, n] -> ST s (Matrix t n n)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
uPtr
        let luPermDet :: DataFrame t '[]
luPermDet = if Bool
detPositive then DataFrame t '[]
1 else -DataFrame t '[]
1
            luPerm :: Matrix t n n
luPerm = (Idxs '[n, n] -> DataFrame t '[]) -> Matrix t n n
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
(SubSpace t as bs asbs, Dimensions as) =>
(Idxs as -> DataFrame t bs) -> DataFrame t asbs
iwgen @_ @'[n,n] @'[]
              (\(Idx Word
i :* Idx Word
j :* TypedList Idx ys
U) -> if Word -> DataFrame Word '[]
forall t. t -> DataFrame t '[]
S Word
j DataFrame Word '[] -> DataFrame Word '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame Word '[n]
p DataFrame Word '[n] -> Word -> DataFrame Word '[]
forall k k (t :: k) (d :: k) (ds :: [k]).
IndexFrame t d ds =>
DataFrame t (d : ds) -> Word -> DataFrame t ds
! Word
i then DataFrame t '[]
1 else DataFrame t '[]
0)
        LU t n -> ST s (LU t n)
forall (m :: * -> *) a. Monad m => a -> m a
return LU :: forall t (n :: Nat).
Matrix t n n -> Matrix t n n -> Matrix t n n -> Scalar t -> LU t n
LU {Matrix t n n
DataFrame t '[]
luPerm :: Matrix t n n
luPermDet :: DataFrame t '[]
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: DataFrame t '[]
luPerm :: Matrix t n n
luUpper :: Matrix t n n
luLower :: Matrix t n n
..}
      where
        n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n) :: Int


-- | Solve @Ax = b@ problem given LU decomposition of A.
luSolveR ::
       forall t (n :: Nat) (ds :: [Nat])
     . (MatrixLU t n, Dimensions ds)
    => LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds)
luSolveR :: LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds)
luSolveR LU {Matrix t n n
Scalar t
luPermDet :: Scalar t
luPerm :: Matrix t n n
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} DataFrame t (n :+ ds)
b = (forall s. ST s (DataFrame t (n :+ ds))) -> DataFrame t (n :+ ds)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (DataFrame t (n :+ ds))) -> DataFrame t (n :+ ds))
-> (forall s. ST s (DataFrame t (n :+ ds)))
-> DataFrame t (n :+ ds)
forall a b. (a -> b) -> a -> b
$ do
    STDataFrame s t (n :+ ds)
xPtr <- DataFrame t (n :+ ds) -> ST s (STDataFrame s t (n :+ ds))
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame (Matrix t n n
DataFrame t (RunList (Snoc' '[n] n))
luPerm DataFrame t (RunList (Snoc' '[n] n))
-> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds)
forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat).
(Contraction t as bs asbs, KnownDim m,
 PrimArray t (DataFrame t (as +: m)),
 PrimArray t (DataFrame t (m :+ bs)),
 PrimArray t (DataFrame t asbs)) =>
DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
%* DataFrame t (n :+ ds)
b) -- NB: wasting resources!
    Matrix t n n -> STDataFrame s t (n :+ ds) -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m,
 n <= m) =>
DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s ()
solveLowerTriangularR Matrix t n n
luLower STDataFrame s t (n :+ ds)
xPtr
    Matrix t n n -> STDataFrame s t (n :+ ds) -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) =>
DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s ()
solveUpperTriangularR Matrix t n n
luUpper STDataFrame s t (n :+ ds)
xPtr
    STDataFrame s t (n :+ ds) -> ST s (DataFrame t (n :+ ds))
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t (n :+ ds)
xPtr

-- | Solve @xA = b@ problem given LU decomposition of A.
luSolveL ::
       forall t (n :: Nat) (ds :: [Nat])
     . (MatrixLU t n, Dimensions ds)
    => LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n)
luSolveL :: LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n)
luSolveL LU {Matrix t n n
Scalar t
luPermDet :: Scalar t
luPerm :: Matrix t n n
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} DataFrame t (ds +: n)
b
  | Dim n
dn  <- KnownDim n => Dim n
forall k (n :: k). KnownDim n => Dim n
dim @n
  , Dims ds
dds <- Dimensions ds => Dims ds
forall k (ds :: [k]). Dimensions ds => Dims ds
dims @ds
  , Dims (ds +: n)
Dims <- Dims ds -> Dim n -> Dims (ds +: n)
forall k (f :: k -> *) (xs :: [k]) (sy :: [k]) (y :: k).
SnocList sy y xs =>
TypedList f sy -> f y -> TypedList f xs
Snoc Dims ds
dds Dim n
dn
  , Dict (SnocList ds n ns)
Dict <- SnocList ds n ns => Dict (SnocList ds n ns)
forall (a :: Constraint). a => Dict a
Dict @(SnocList ds n _)
  = (forall s. ST s (DataFrame t ns)) -> DataFrame t ns
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (DataFrame t ns)) -> DataFrame t ns)
-> (forall s. ST s (DataFrame t ns)) -> DataFrame t ns
forall a b. (a -> b) -> a -> b
$ do
    STDataFrame s t ns
xPtr <- DataFrame t ns -> ST s (STDataFrame s t ns)
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame DataFrame t ns
DataFrame t (ds +: n)
b
    STDataFrame s t (ds +: n) -> Matrix t n n -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) =>
STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveUpperTriangularL STDataFrame s t ns
STDataFrame s t (ds +: n)
xPtr Matrix t n n
luUpper
    STDataFrame s t (ds +: n) -> Matrix t n n -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m,
 n <= m) =>
STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveLowerTriangularL STDataFrame s t ns
STDataFrame s t (ds +: n)
xPtr Matrix t n n
luLower
    (DataFrame t (ds +: n) -> Matrix t n n -> DataFrame t ns
forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat).
(Contraction t as bs asbs, KnownDim m,
 PrimArray t (DataFrame t (as +: m)),
 PrimArray t (DataFrame t (m :+ bs)),
 PrimArray t (DataFrame t asbs)) =>
DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
%* Matrix t n n
luPerm) (DataFrame t ns -> DataFrame t ns)
-> ST s (DataFrame t ns) -> ST s (DataFrame t ns)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t ns -> ST s (DataFrame t ns)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t ns
xPtr
luSolveL LU t n
_ DataFrame t (ds +: n)
_ = String -> DataFrame t (ds +: n)
forall a. HasCallStack => String -> a
error String
"luSolveL: impossible pattern"

-- | Calculate inverse of a matrix via LU decomposition
inverseViaLU :: forall (t :: Type) (n :: Nat)
              . MatrixLU t n => Matrix t n n -> Matrix t n n
inverseViaLU :: Matrix t n n -> Matrix t n n
inverseViaLU Matrix t n n
a = (forall s. ST s (Matrix t n n)) -> Matrix t n n
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Matrix t n n)) -> Matrix t n n)
-> (forall s. ST s (Matrix t n n)) -> Matrix t n n
forall a b. (a -> b) -> a -> b
$ do
    STDataFrame s t '[n, n]
xPtr <- Matrix t n n -> ST s (STDataFrame s t '[n, n])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame Matrix t n n
luPerm -- luPerm is only ever used once
    Matrix t n n -> STDataFrame s t '[n, n] -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m,
 n <= m) =>
DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s ()
solveLowerTriangularR Matrix t n n
luLower STDataFrame s t '[n, n]
xPtr
    Matrix t n n -> STDataFrame s t '[n, n] -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) =>
DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s ()
solveUpperTriangularR Matrix t n n
luUpper STDataFrame s t '[n, n]
xPtr
    STDataFrame s t '[n, n] -> ST s (Matrix t n n)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
xPtr
  where
    LU {Matrix t n n
Scalar t
luPermDet :: Scalar t
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPerm :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} = Matrix t n n -> LU t n
forall t (n :: Nat). MatrixLU t n => Matrix t n n -> LU t n
lu Matrix t n n
a
-- perfectly correct, but slightly slower versions:
-- inverseViaLU a = luSolveR (lu a) eye
-- inverseViaLU a = luSolveL (lu a) eye

-- | Calculate determinant of a matrix via LU decomposition
detViaLU :: forall (t :: Type) (n :: Nat)
          . MatrixLU t n => Matrix t n n -> Scalar t
detViaLU :: Matrix t n n -> Scalar t
detViaLU Matrix t n n
m = (Scalar t -> Int -> Scalar t) -> Scalar t -> [Int] -> Scalar t
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Scalar t
x Int
off -> t -> Scalar t
forall t. t -> DataFrame t '[]
scalar (Int -> Matrix t n n -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff Int
off Matrix t n n
luUpper) Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
x) Scalar t
luPermDet [Int
0,Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n]
  where
    n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n) :: Int
    LU {Matrix t n n
Scalar t
luPerm :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: Scalar t
luUpper :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} = Matrix t n n -> LU t n
forall t (n :: Nat). MatrixLU t n => Matrix t n n -> LU t n
lu Matrix t n n
m

{- |
Run LU decomposition with partial pivoting inplace, such that
the upper triangular part of matrix \(A\) becomes \(U\) and
the lower triangular part (without diagonal) of matrix  \(A\) becomes \(L\).

\(U\) is upper triangular.
\(L\) is unit lower triangular; all diagonal elements of \(L\) are implicit and
equal to 1; the rest of the elements a smaller than one \(|l_{ij}| \leq 1\).

Pivoting is represented as a permutation vector \(p\);
returned value is the sign of the permutation (positive if @True@, negative otherwise).

NB: Initialize \(p\) with indices @0..n-1@.

Reference: Algorithm 3.4.1 on p.128
       of "Matrix Computations" 4th edition by G. H. Golub and C. F. Van Loan.
 -}
luInplace ::
       forall (s :: Type) (t :: Type) (n :: Nat)
     . (PrimBytes t, Fractional t, Ord t, KnownDim n)
    => STDataFrame s t '[n]    -- ^ Temporary buffer
    -> STDataFrame s Word '[n] -- ^ Current state of permutation \(p\)
    -> STDataFrame s t '[n,n]  -- ^ Current state of \(A\)
    -> ST s Bool
luInplace :: STDataFrame s t '[n]
-> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool
luInplace STDataFrame s t '[n]
temp STDataFrame s Word '[n]
pPtr STDataFrame s t '[n, n]
aPtr = (Bool -> Int -> ST s Bool) -> Bool -> [Int] -> ST s Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Bool
b -> (Bool -> Bool) -> ST s Bool -> ST s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool
b Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/=) (ST s Bool -> ST s Bool) -> (Int -> ST s Bool) -> Int -> ST s Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ST s Bool
go) Bool
True [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2]
  where
    n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n) :: Int

    -- Runs an iteration of the algorithm;
    --- returns whether there was a swap of rows.
    go :: Int -> ST s Bool
    go :: Int -> ST s Bool
go Int
k = do
      Int
mu <- Int -> ST s Int
findPivot  Int
k
      let swapped :: Bool
swapped = Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
mu
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
swapped (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ST s ()
swapRows Int
k Int
mu
      DataFrame t '[]
akk <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*(Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
      Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DataFrame t '[]
akk DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
/= DataFrame t '[]
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        let rakk :: DataFrame t '[]
rakk = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
akk
        [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
          let ni :: Int
ni = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
          DataFrame t '[]
aik <- (DataFrame t '[]
rakk DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
          STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) DataFrame t '[]
aik
          [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
            DataFrame t '[]
akj <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
            DataFrame t '[]
aij <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
            STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (DataFrame t '[]
aij DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
aikDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
akj)
      Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
swapped

    findPivot :: Int -> ST s Int
    findPivot :: Int -> ST s Int
findPivot Int
k = (DataFrame t '[], Int) -> Int
forall a b. (a, b) -> b
snd ((DataFrame t '[], Int) -> Int)
-> ST s (DataFrame t '[], Int) -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int))
-> (DataFrame t '[], Int) -> [Int] -> ST s (DataFrame t '[], Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int)
findPivotF (DataFrame t '[]
0, Int
k) [Int
k..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      where
        findPivotF :: (Scalar t, Int) -> Int -> ST s (Scalar t, Int)
        findPivotF :: (DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int)
findPivotF aj :: (DataFrame t '[], Int)
aj@(DataFrame t '[]
a, Int
_) Int
i = do
          DataFrame t '[]
x <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
          (DataFrame t '[], Int) -> ST s (DataFrame t '[], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (if DataFrame t '[]
x DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
> DataFrame t '[]
a then (DataFrame t '[]
x, Int
i) else (DataFrame t '[], Int)
aj)

    swapRows :: Int -> Int -> ST s ()
    swapRows :: Int -> Int -> ST s ()
swapRows Int
i Int
j = do
      let iPtr :: STDataFrame s t '[n]
iPtr = Idxs '[n] -> STDataFrame s t '[n, n] -> STDataFrame s t '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
ConcatList as bs asbs =>
Idxs as -> STDataFrame s t asbs -> STDataFrame s t bs
subDataFrameView' (Int -> Idx n
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Idx n -> TypedList Idx '[] -> Idxs '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) STDataFrame s t '[n, n]
aPtr
          jPtr :: STDataFrame s t '[n]
jPtr = Idxs '[n] -> STDataFrame s t '[n, n] -> STDataFrame s t '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
ConcatList as bs asbs =>
Idxs as -> STDataFrame s t asbs -> STDataFrame s t bs
subDataFrameView' (Int -> Idx n
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
j Idx n -> TypedList Idx '[] -> Idxs '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) STDataFrame s t '[n, n]
aPtr
      TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
iPtr STDataFrame s t '[n]
temp
      TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
jPtr STDataFrame s t '[n]
iPtr
      TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
temp STDataFrame s t '[n]
jPtr
      DataFrame Word '[]
t <- STDataFrame s Word '[n] -> Int -> ST s (DataFrame Word '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s Word '[n]
pPtr Int
i
      STDataFrame s Word '[n] -> Int -> ST s (DataFrame Word '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s Word '[n]
pPtr Int
j ST s (DataFrame Word '[])
-> (DataFrame Word '[] -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s Word '[n] -> Int -> DataFrame Word '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s Word '[n]
pPtr Int
i
      STDataFrame s Word '[n] -> Int -> DataFrame Word '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s Word '[n]
pPtr Int
j DataFrame Word '[]
t