{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Numeric.Matrix.LU
( MatrixLU (..), LU (..)
, luSolveR, luSolveL
, detViaLU, inverseViaLU
) where

import Data.Kind
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.SubSpace
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Matrix.Internal
import Numeric.Scalar.Internal
import Numeric.Subroutine.SolveTriangular

-- | Result of LU factorization with Partial Pivoting
--   $$PA = LU$$.
data LU (t :: Type) (n :: Nat)
= LU
{ LU t n -> Matrix t n n
luLower   :: Matrix t n n
-- ^ Unit lower triangular matrix $$L$$.
--   All elements on the diagonal of @L@ equal @1@.
--   The rest of the elements satisfy $$|l_{ij}| \leq 1$$.
, LU t n -> Matrix t n n
luUpper   :: Matrix t n n
-- ^ Upper triangular matrix $$U$$
, LU t n -> Matrix t n n
luPerm    :: Matrix t n n
-- ^ Row permutation matrix $$P$$
, LU t n -> Scalar t
luPermDet :: Scalar t
-- ^ Sign of permutation @luPermDet == det . luPerm@; $$|P| = \pm 1$$.
}

deriving instance (Show t, PrimBytes t, KnownDim n) => Show (LU t n)
deriving instance (Eq (Matrix t n n), Eq t) => Eq (LU t n)

class (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
=> MatrixLU t (n :: Nat) where
-- | Compute LU factorization with Partial Pivoting
lu :: Matrix t n n -> LU t n

instance (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
=> MatrixLU t n where
lu :: Matrix t n n -> LU t n
lu Matrix t n n
a = (forall s. ST s (LU t n)) -> LU t n
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (LU t n)) -> LU t n)
-> (forall s. ST s (LU t n)) -> LU t n
forall a b. (a -> b) -> a -> b
$do STDataFrame s Word '[n] pPtr <- DataFrame Word '[n] -> ST s (STDataFrame s Word '[n]) forall k t (ns :: [k]) s. (Dimensions ns, PrimArray t (DataFrame t ns)) => DataFrame t ns -> ST s (STDataFrame s t ns) unsafeThawDataFrame (DataFrame Word '[n] -> ST s (STDataFrame s Word '[n])) -> DataFrame Word '[n] -> ST s (STDataFrame s Word '[n]) forall a b. (a -> b) -> a -> b$ (Idxs '[n] -> DataFrame Word '[]) -> DataFrame Word '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
(SubSpace t as bs asbs, Dimensions as) =>
(Idxs as -> DataFrame t bs) -> DataFrame t asbs
iwgen  @_ @'[n] @'[] ($$Idx Word i :* TypedList Idx ys U) -> Word -> DataFrame Word '[] forall t. t -> DataFrame t '[] S Word i) STDataFrame s t '[n, n] uPtr <- Matrix t n n -> ST s (STDataFrame s t '[n, n]) forall k t (ns :: [k]) s. (Dimensions ns, PrimArray t (DataFrame t ns)) => DataFrame t ns -> ST s (STDataFrame s t ns) thawDataFrame Matrix t n n a STDataFrame s t '[n, n] lPtr <- ST s (STDataFrame s t '[n, n]) forall k t (ns :: [k]) s. (PrimBytes t, Dimensions ns) => ST s (STDataFrame s t ns) newDataFrame STDataFrame s t '[n] temp <- ST s (STDataFrame s t '[n]) forall k t (ns :: [k]) s. (PrimBytes t, Dimensions ns) => ST s (STDataFrame s t ns) newDataFrame Bool detPositive <- STDataFrame s t '[n] -> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool forall s t (n :: Nat). (PrimBytes t, Fractional t, Ord t, KnownDim n) => STDataFrame s t '[n] -> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool luInplace STDataFrame s t '[n] temp STDataFrame s Word '[n] pPtr STDataFrame s t '[n, n] uPtr DataFrame Word '[n] p <- STDataFrame s Word '[n] -> ST s (DataFrame Word '[n]) forall k t (ns :: [k]) s. PrimArray t (DataFrame t ns) => STDataFrame s t ns -> ST s (DataFrame t ns) unsafeFreezeDataFrame STDataFrame s Word '[n] pPtr -- split U and L [Int] -> (Int -> ST s ()) -> ST s () forall (t :: * -> *) (m :: * -> *) a b. (Foldable t, Monad m) => t a -> (a -> m b) -> m () forM_ [Int 0..Int nInt -> Int -> Int forall a. Num a => a -> a -> a -Int 1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s () forall a b. (a -> b) -> a -> b  \Int i -> do let ni :: Int ni = Int nInt -> Int -> Int forall a. Num a => a -> a -> a *Int i [Int] -> (Int -> ST s ()) -> ST s () forall (t :: * -> *) (m :: * -> *) a b. (Foldable t, Monad m) => t a -> (a -> m b) -> m () forM_ [Int 0..Int nInt -> Int -> Int forall a. Num a => a -> a -> a -Int 1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s () forall a b. (a -> b) -> a -> b  \Int j -> case Int -> Int -> Ordering forall a. Ord a => a -> a -> Ordering compare Int i Int j of Ordering GT -> do DataFrame t '[] lij <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[]) forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> ST s (DataFrame t '[]) readDataFrameOff STDataFrame s t '[n, n] uPtr (Int ni Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s () forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s () writeDataFrameOff STDataFrame s t '[n, n] uPtr (Int ni Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) DataFrame t '[] 0 STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s () forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s () writeDataFrameOff STDataFrame s t '[n, n] lPtr (Int ni Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) DataFrame t '[] lij Ordering EQ -> STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s () forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s () writeDataFrameOff STDataFrame s t '[n, n] lPtr (Int ni Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) DataFrame t '[] 1 Ordering LT -> STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s () forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s () writeDataFrameOff STDataFrame s t '[n, n] lPtr (Int ni Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) DataFrame t '[] 0 Matrix t n n luLower <- STDataFrame s t '[n, n] -> ST s (Matrix t n n) forall k t (ns :: [k]) s. PrimArray t (DataFrame t ns) => STDataFrame s t ns -> ST s (DataFrame t ns) unsafeFreezeDataFrame STDataFrame s t '[n, n] lPtr Matrix t n n luUpper <- STDataFrame s t '[n, n] -> ST s (Matrix t n n) forall k t (ns :: [k]) s. PrimArray t (DataFrame t ns) => STDataFrame s t ns -> ST s (DataFrame t ns) unsafeFreezeDataFrame STDataFrame s t '[n, n] uPtr let luPermDet :: DataFrame t '[] luPermDet = if Bool detPositive then DataFrame t '[] 1 else -DataFrame t '[] 1 luPerm :: Matrix t n n luPerm = (Idxs '[n, n] -> DataFrame t '[]) -> Matrix t n n forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]). (SubSpace t as bs asbs, Dimensions as) => (Idxs as -> DataFrame t bs) -> DataFrame t asbs iwgen @_ @'[n,n] @'[] (\(Idx Word i :* Idx Word j :* TypedList Idx ys U) -> if Word -> DataFrame Word '[] forall t. t -> DataFrame t '[] S Word j DataFrame Word '[] -> DataFrame Word '[] -> Bool forall a. Eq a => a -> a -> Bool == DataFrame Word '[n] p DataFrame Word '[n] -> Word -> DataFrame Word '[] forall k k (t :: k) (d :: k) (ds :: [k]). IndexFrame t d ds => DataFrame t (d : ds) -> Word -> DataFrame t ds ! Word i then DataFrame t '[] 1 else DataFrame t '[] 0) LU t n -> ST s (LU t n) forall (m :: * -> *) a. Monad m => a -> m a return LU :: forall t (n :: Nat). Matrix t n n -> Matrix t n n -> Matrix t n n -> Scalar t -> LU t n LU {Matrix t n n DataFrame t '[] luPerm :: Matrix t n n luPermDet :: DataFrame t '[] luUpper :: Matrix t n n luLower :: Matrix t n n luPermDet :: DataFrame t '[] luPerm :: Matrix t n n luUpper :: Matrix t n n luLower :: Matrix t n n ..} where n :: Int n = Word -> Int forall a b. (Integral a, Num b) => a -> b fromIntegral (KnownDim n => Word forall k (n :: k). KnownDim n => Word dimVal' @n) :: Int -- | Solve @Ax = b@ problem given LU decomposition of A. luSolveR :: forall t (n :: Nat) (ds :: [Nat]) . (MatrixLU t n, Dimensions ds) => LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds) luSolveR :: LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds) luSolveR LU {Matrix t n n Scalar t luPermDet :: Scalar t luPerm :: Matrix t n n luUpper :: Matrix t n n luLower :: Matrix t n n luPermDet :: forall t (n :: Nat). LU t n -> Scalar t luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n luLower :: forall t (n :: Nat). LU t n -> Matrix t n n ..} DataFrame t (n :+ ds) b = (forall s. ST s (DataFrame t (n :+ ds))) -> DataFrame t (n :+ ds) forall a. (forall s. ST s a) -> a runST ((forall s. ST s (DataFrame t (n :+ ds))) -> DataFrame t (n :+ ds)) -> (forall s. ST s (DataFrame t (n :+ ds))) -> DataFrame t (n :+ ds) forall a b. (a -> b) -> a -> b  do STDataFrame s t (n :+ ds) xPtr <- DataFrame t (n :+ ds) -> ST s (STDataFrame s t (n :+ ds)) forall k t (ns :: [k]) s. (Dimensions ns, PrimArray t (DataFrame t ns)) => DataFrame t ns -> ST s (STDataFrame s t ns) thawDataFrame (Matrix t n n DataFrame t (RunList (Snoc' '[n] n)) luPerm DataFrame t (RunList (Snoc' '[n] n)) -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds) forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat). (Contraction t as bs asbs, KnownDim m, PrimArray t (DataFrame t (as +: m)), PrimArray t (DataFrame t (m :+ bs)), PrimArray t (DataFrame t asbs)) => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs %* DataFrame t (n :+ ds) b) -- NB: wasting resources! Matrix t n n -> STDataFrame s t (n :+ ds) -> ST s () forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]). (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m) => DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s () solveLowerTriangularR Matrix t n n luLower STDataFrame s t (n :+ ds) xPtr Matrix t n n -> STDataFrame s t (n :+ ds) -> ST s () forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]). (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) => DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s () solveUpperTriangularR Matrix t n n luUpper STDataFrame s t (n :+ ds) xPtr STDataFrame s t (n :+ ds) -> ST s (DataFrame t (n :+ ds)) forall k t (ns :: [k]) s. PrimArray t (DataFrame t ns) => STDataFrame s t ns -> ST s (DataFrame t ns) unsafeFreezeDataFrame STDataFrame s t (n :+ ds) xPtr -- | Solve @xA = b@ problem given LU decomposition of A. luSolveL :: forall t (n :: Nat) (ds :: [Nat]) . (MatrixLU t n, Dimensions ds) => LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n) luSolveL :: LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n) luSolveL LU {Matrix t n n Scalar t luPermDet :: Scalar t luPerm :: Matrix t n n luUpper :: Matrix t n n luLower :: Matrix t n n luPermDet :: forall t (n :: Nat). LU t n -> Scalar t luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n luLower :: forall t (n :: Nat). LU t n -> Matrix t n n ..} DataFrame t (ds +: n) b | Dim n dn <- KnownDim n => Dim n forall k (n :: k). KnownDim n => Dim n dim @n , Dims ds dds <- Dimensions ds => Dims ds forall k (ds :: [k]). Dimensions ds => Dims ds dims @ds , Dims (ds +: n) Dims <- Dims ds -> Dim n -> Dims (ds +: n) forall k (f :: k -> *) (xs :: [k]) (sy :: [k]) (y :: k). SnocList sy y xs => TypedList f sy -> f y -> TypedList f xs Snoc Dims ds dds Dim n dn , Dict (SnocList ds n ns) Dict <- SnocList ds n ns => Dict (SnocList ds n ns) forall (a :: Constraint). a => Dict a Dict @(SnocList ds n _) = (forall s. ST s (DataFrame t ns)) -> DataFrame t ns forall a. (forall s. ST s a) -> a runST ((forall s. ST s (DataFrame t ns)) -> DataFrame t ns) -> (forall s. ST s (DataFrame t ns)) -> DataFrame t ns forall a b. (a -> b) -> a -> b  do STDataFrame s t ns xPtr <- DataFrame t ns -> ST s (STDataFrame s t ns) forall k t (ns :: [k]) s. (Dimensions ns, PrimArray t (DataFrame t ns)) => DataFrame t ns -> ST s (STDataFrame s t ns) thawDataFrame DataFrame t ns DataFrame t (ds +: n) b STDataFrame s t (ds +: n) -> Matrix t n n -> ST s () forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]). (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) => STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s () solveUpperTriangularL STDataFrame s t ns STDataFrame s t (ds +: n) xPtr Matrix t n n luUpper STDataFrame s t (ds +: n) -> Matrix t n n -> ST s () forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]). (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m) => STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s () solveLowerTriangularL STDataFrame s t ns STDataFrame s t (ds +: n) xPtr Matrix t n n luLower (DataFrame t (ds +: n) -> Matrix t n n -> DataFrame t ns forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat). (Contraction t as bs asbs, KnownDim m, PrimArray t (DataFrame t (as +: m)), PrimArray t (DataFrame t (m :+ bs)), PrimArray t (DataFrame t asbs)) => DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs %* Matrix t n n luPerm) (DataFrame t ns -> DataFrame t ns) -> ST s (DataFrame t ns) -> ST s (DataFrame t ns) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <> STDataFrame s t ns -> ST s (DataFrame t ns) forall k t (ns :: [k]) s. PrimArray t (DataFrame t ns) => STDataFrame s t ns -> ST s (DataFrame t ns) unsafeFreezeDataFrame STDataFrame s t ns xPtr luSolveL LU t n _ DataFrame t (ds +: n) _ = String -> DataFrame t (ds +: n) forall a. HasCallStack => String -> a error String "luSolveL: impossible pattern" -- | Calculate inverse of a matrix via LU decomposition inverseViaLU :: forall (t :: Type) (n :: Nat) . MatrixLU t n => Matrix t n n -> Matrix t n n inverseViaLU :: Matrix t n n -> Matrix t n n inverseViaLU Matrix t n n a = (forall s. ST s (Matrix t n n)) -> Matrix t n n forall a. (forall s. ST s a) -> a runST ((forall s. ST s (Matrix t n n)) -> Matrix t n n) -> (forall s. ST s (Matrix t n n)) -> Matrix t n n forall a b. (a -> b) -> a -> b  do STDataFrame s t '[n, n] xPtr <- Matrix t n n -> ST s (STDataFrame s t '[n, n]) forall k t (ns :: [k]) s. (Dimensions ns, PrimArray t (DataFrame t ns)) => DataFrame t ns -> ST s (STDataFrame s t ns) unsafeThawDataFrame Matrix t n n luPerm -- luPerm is only ever used once Matrix t n n -> STDataFrame s t '[n, n] -> ST s () forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]). (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m) => DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s () solveLowerTriangularR Matrix t n n luLower STDataFrame s t '[n, n] xPtr Matrix t n n -> STDataFrame s t '[n, n] -> ST s () forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]). (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) => DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s () solveUpperTriangularR Matrix t n n luUpper STDataFrame s t '[n, n] xPtr STDataFrame s t '[n, n] -> ST s (Matrix t n n) forall k t (ns :: [k]) s. PrimArray t (DataFrame t ns) => STDataFrame s t ns -> ST s (DataFrame t ns) unsafeFreezeDataFrame STDataFrame s t '[n, n] xPtr where LU {Matrix t n n Scalar t luPermDet :: Scalar t luUpper :: Matrix t n n luLower :: Matrix t n n luPerm :: Matrix t n n luPermDet :: forall t (n :: Nat). LU t n -> Scalar t luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n luLower :: forall t (n :: Nat). LU t n -> Matrix t n n ..} = Matrix t n n -> LU t n forall t (n :: Nat). MatrixLU t n => Matrix t n n -> LU t n lu Matrix t n n a -- perfectly correct, but slightly slower versions: -- inverseViaLU a = luSolveR (lu a) eye -- inverseViaLU a = luSolveL (lu a) eye -- | Calculate determinant of a matrix via LU decomposition detViaLU :: forall (t :: Type) (n :: Nat) . MatrixLU t n => Matrix t n n -> Scalar t detViaLU :: Matrix t n n -> Scalar t detViaLU Matrix t n n m = (Scalar t -> Int -> Scalar t) -> Scalar t -> [Int] -> Scalar t forall (t :: * -> *) b a. Foldable t => (b -> a -> b) -> b -> t a -> b foldl (\Scalar t x Int off -> t -> Scalar t forall t. t -> DataFrame t '[] scalar (Int -> Matrix t n n -> t forall t a. PrimArray t a => Int -> a -> t ixOff Int off Matrix t n n luUpper) Scalar t -> Scalar t -> Scalar t forall a. Num a => a -> a -> a * Scalar t x) Scalar t luPermDet [Int 0,Int nInt -> Int -> Int forall a. Num a => a -> a -> a +Int 1..Int nInt -> Int -> Int forall a. Num a => a -> a -> a *Int n] where n :: Int n = Word -> Int forall a b. (Integral a, Num b) => a -> b fromIntegral (KnownDim n => Word forall k (n :: k). KnownDim n => Word dimVal' @n) :: Int LU {Matrix t n n Scalar t luPerm :: Matrix t n n luLower :: Matrix t n n luPermDet :: Scalar t luUpper :: Matrix t n n luPermDet :: forall t (n :: Nat). LU t n -> Scalar t luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n luLower :: forall t (n :: Nat). LU t n -> Matrix t n n ..} = Matrix t n n -> LU t n forall t (n :: Nat). MatrixLU t n => Matrix t n n -> LU t n lu Matrix t n n m {- | Run LU decomposition with partial pivoting inplace, such that the upper triangular part of matrix \(A$$ becomes $$U$$ and
the lower triangular part (without diagonal) of matrix  $$A$$ becomes $$L$$.

$$U$$ is upper triangular.
$$L$$ is unit lower triangular; all diagonal elements of $$L$$ are implicit and
equal to 1; the rest of the elements a smaller than one $$|l_{ij}| \leq 1$$.

Pivoting is represented as a permutation vector $$p$$;
returned value is the sign of the permutation (positive if @True@, negative otherwise).

NB: Initialize $$p$$ with indices @0..n-1@.

Reference: Algorithm 3.4.1 on p.128
of "Matrix Computations" 4th edition by G. H. Golub and C. F. Van Loan.
-}
luInplace ::
forall (s :: Type) (t :: Type) (n :: Nat)
. (PrimBytes t, Fractional t, Ord t, KnownDim n)
=> STDataFrame s t '[n]    -- ^ Temporary buffer
-> STDataFrame s Word '[n] -- ^ Current state of permutation $$p$$
-> STDataFrame s t '[n,n]  -- ^ Current state of $$A$$
-> ST s Bool
luInplace :: STDataFrame s t '[n]
-> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool
luInplace STDataFrame s t '[n]
temp STDataFrame s Word '[n]
pPtr STDataFrame s t '[n, n]
aPtr = (Bool -> Int -> ST s Bool) -> Bool -> [Int] -> ST s Bool
forall (t :: * -> *) (m :: * -> *) b a.
(b -> a -> m b) -> b -> t a -> m b
foldM (\Bool
b -> (Bool -> Bool) -> ST s Bool -> ST s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool
b Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/=) (ST s Bool -> ST s Bool) -> (Int -> ST s Bool) -> Int -> ST s Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ST s Bool
go) Bool
True [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2]
where
n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n) :: Int

-- Runs an iteration of the algorithm;
--- returns whether there was a swap of rows.
go :: Int -> ST s Bool
go :: Int -> ST s Bool
go Int
k = do
Int
mu <- Int -> ST s Int
findPivot  Int
k
let swapped :: Bool
swapped = Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
mu
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
swapped (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$Int -> Int -> ST s () swapRows Int k Int mu DataFrame t '[] akk <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[]) forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> ST s (DataFrame t '[]) readDataFrameOff STDataFrame s t '[n, n] aPtr (Int kInt -> Int -> Int forall a. Num a => a -> a -> a *(Int nInt -> Int -> Int forall a. Num a => a -> a -> a +Int 1)) Bool -> ST s () -> ST s () forall (f :: * -> *). Applicative f => Bool -> f () -> f () when (DataFrame t '[] akk DataFrame t '[] -> DataFrame t '[] -> Bool forall a. Eq a => a -> a -> Bool /= DataFrame t '[] 0) (ST s () -> ST s ()) -> ST s () -> ST s () forall a b. (a -> b) -> a -> b$ do
let rakk :: DataFrame t '[]
rakk = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
akk
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
t a -> (a -> m b) -> m ()
forM_ [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$\Int i -> do let ni :: Int ni = Int nInt -> Int -> Int forall a. Num a => a -> a -> a *Int i DataFrame t '[] aik <- (DataFrame t '[] rakk DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] forall a. Num a => a -> a -> a *) (DataFrame t '[] -> DataFrame t '[]) -> ST s (DataFrame t '[]) -> ST s (DataFrame t '[]) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) DataFrame t '[]
aik
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
t a -> (a -> m b) -> m ()
forM_ [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$\Int j -> do DataFrame t '[] akj <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[]) forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> ST s (DataFrame t '[]) readDataFrameOff STDataFrame s t '[n, n] aPtr (Int nInt -> Int -> Int forall a. Num a => a -> a -> a *Int k Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) DataFrame t '[] aij <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[]) forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> ST s (DataFrame t '[]) readDataFrameOff STDataFrame s t '[n, n] aPtr (Int ni Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s () forall k t (ns :: [k]) s. PrimBytes (DataFrame t '[]) => STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s () writeDataFrameOff STDataFrame s t '[n, n] aPtr (Int ni Int -> Int -> Int forall a. Num a => a -> a -> a + Int j) (DataFrame t '[] aij DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] forall a. Num a => a -> a -> a - DataFrame t '[] aikDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[] forall a. Num a => a -> a -> a *DataFrame t '[] akj) Bool -> ST s Bool forall (m :: * -> *) a. Monad m => a -> m a return Bool swapped findPivot :: Int -> ST s Int findPivot :: Int -> ST s Int findPivot Int k = (DataFrame t '[], Int) -> Int forall a b. (a, b) -> b snd ((DataFrame t '[], Int) -> Int) -> ST s (DataFrame t '[], Int) -> ST s Int forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> ((DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int))
-> (DataFrame t '[], Int) -> [Int] -> ST s (DataFrame t '[], Int)
forall (t :: * -> *) (m :: * -> *) b a.
(b -> a -> m b) -> b -> t a -> m b
foldM (DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int)
findPivotF (DataFrame t '[]
0, Int
k) [Int
k..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
where
findPivotF :: (Scalar t, Int) -> Int -> ST s (Scalar t, Int)
findPivotF :: (DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int)
findPivotF aj :: (DataFrame t '[], Int)
aj@(DataFrame t '[]
a, Int
_) Int
i = do
DataFrame t '[]
x <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<\$> STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
(DataFrame t '[], Int) -> ST s (DataFrame t '[], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (if DataFrame t '[]
x DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
> DataFrame t '[]
a then (DataFrame t '[]
x, Int
i) else (DataFrame t '[], Int)
aj)

swapRows :: Int -> Int -> ST s ()
swapRows :: Int -> Int -> ST s ()
swapRows Int
i Int
j = do
let iPtr :: STDataFrame s t '[n]
iPtr = Idxs '[n] -> STDataFrame s t '[n, n] -> STDataFrame s t '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
ConcatList as bs asbs =>
Idxs as -> STDataFrame s t asbs -> STDataFrame s t bs
subDataFrameView' (Int -> Idx n
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Idx n -> TypedList Idx '[] -> Idxs '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) STDataFrame s t '[n, n]
aPtr
jPtr :: STDataFrame s t '[n]
jPtr = Idxs '[n] -> STDataFrame s t '[n, n] -> STDataFrame s t '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
ConcatList as bs asbs =>
Idxs as -> STDataFrame s t asbs -> STDataFrame s t bs
subDataFrameView' (Int -> Idx n
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
j Idx n -> TypedList Idx '[] -> Idxs '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) STDataFrame s t '[n, n]
aPtr
TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
iPtr STDataFrame s t '[n]
temp
TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
jPtr STDataFrame s t '[n]
iPtr
TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
temp STDataFrame s t '[n]
jPtr
DataFrame Word '[]
t <- STDataFrame s Word '[n] -> Int -> ST s (DataFrame Word '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
pPtr Int
i
STDataFrame s Word '[n] -> Int -> ST s (DataFrame Word '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
pPtr Int
j ST s (DataFrame Word '[])
-> (DataFrame Word '[] -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s Word '[n] -> Int -> DataFrame Word '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s Word '[n]
pPtr Int
i
STDataFrame s Word '[n] -> Int -> DataFrame Word '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s Word '[n]
pPtr Int
j DataFrame Word '[]
t