{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.Matrix.LU
( MatrixLU (..), LU (..)
, luSolveR, luSolveL
, detViaLU, inverseViaLU
) where
import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.SubSpace
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Matrix.Internal
import Numeric.Scalar.Internal
import Numeric.Subroutine.SolveTriangular
data LU (t :: Type) (n :: Nat)
= LU
{ LU t n -> Matrix t n n
luLower :: Matrix t n n
, LU t n -> Matrix t n n
luUpper :: Matrix t n n
, LU t n -> Matrix t n n
luPerm :: Matrix t n n
, LU t n -> Scalar t
luPermDet :: Scalar t
}
deriving instance (Show t, PrimBytes t, KnownDim n) => Show (LU t n)
deriving instance (Eq (Matrix t n n), Eq t) => Eq (LU t n)
class (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
=> MatrixLU t (n :: Nat) where
lu :: Matrix t n n -> LU t n
instance (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
=> MatrixLU t n where
lu :: Matrix t n n -> LU t n
lu Matrix t n n
a = (forall s. ST s (LU t n)) -> LU t n
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (LU t n)) -> LU t n)
-> (forall s. ST s (LU t n)) -> LU t n
forall a b. (a -> b) -> a -> b
$ do
STDataFrame s Word '[n]
pPtr <- DataFrame Word '[n] -> ST s (STDataFrame s Word '[n])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame (DataFrame Word '[n] -> ST s (STDataFrame s Word '[n]))
-> DataFrame Word '[n] -> ST s (STDataFrame s Word '[n])
forall a b. (a -> b) -> a -> b
$ (Idxs '[n] -> DataFrame Word '[]) -> DataFrame Word '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
(SubSpace t as bs asbs, Dimensions as) =>
(Idxs as -> DataFrame t bs) -> DataFrame t asbs
iwgen @_ @'[n] @'[] (\(Idx Word
i :* TypedList Idx ys
U) -> Word -> DataFrame Word '[]
forall t. t -> DataFrame t '[]
S Word
i)
STDataFrame s t '[n, n]
uPtr <- Matrix t n n -> ST s (STDataFrame s t '[n, n])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame Matrix t n n
a
STDataFrame s t '[n, n]
lPtr <- ST s (STDataFrame s t '[n, n])
forall k t (ns :: [k]) s.
(PrimBytes t, Dimensions ns) =>
ST s (STDataFrame s t ns)
newDataFrame
STDataFrame s t '[n]
temp <- ST s (STDataFrame s t '[n])
forall k t (ns :: [k]) s.
(PrimBytes t, Dimensions ns) =>
ST s (STDataFrame s t ns)
newDataFrame
Bool
detPositive <- STDataFrame s t '[n]
-> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool
forall s t (n :: Nat).
(PrimBytes t, Fractional t, Ord t, KnownDim n) =>
STDataFrame s t '[n]
-> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool
luInplace STDataFrame s t '[n]
temp STDataFrame s Word '[n]
pPtr STDataFrame s t '[n, n]
uPtr
DataFrame Word '[n]
p <- STDataFrame s Word '[n] -> ST s (DataFrame Word '[n])
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s Word '[n]
pPtr
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
let ni :: Int
ni = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
i Int
j of
Ordering
GT -> do
DataFrame t '[]
lij <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
uPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
uPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
lPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
lij
Ordering
EQ -> STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
lPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
1
Ordering
LT -> STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
lPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) DataFrame t '[]
0
Matrix t n n
luLower <- STDataFrame s t '[n, n] -> ST s (Matrix t n n)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
lPtr
Matrix t n n
luUpper <- STDataFrame s t '[n, n] -> ST s (Matrix t n n)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
uPtr
let luPermDet :: DataFrame t '[]
luPermDet = if Bool
detPositive then DataFrame t '[]
1 else -DataFrame t '[]
1
luPerm :: Matrix t n n
luPerm = (Idxs '[n, n] -> DataFrame t '[]) -> Matrix t n n
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]).
(SubSpace t as bs asbs, Dimensions as) =>
(Idxs as -> DataFrame t bs) -> DataFrame t asbs
iwgen @_ @'[n,n] @'[]
(\(Idx Word
i :* Idx Word
j :* TypedList Idx ys
U) -> if Word -> DataFrame Word '[]
forall t. t -> DataFrame t '[]
S Word
j DataFrame Word '[] -> DataFrame Word '[] -> Bool
forall a. Eq a => a -> a -> Bool
== DataFrame Word '[n]
p DataFrame Word '[n] -> Word -> DataFrame Word '[]
forall k k (t :: k) (d :: k) (ds :: [k]).
IndexFrame t d ds =>
DataFrame t (d : ds) -> Word -> DataFrame t ds
! Word
i then DataFrame t '[]
1 else DataFrame t '[]
0)
LU t n -> ST s (LU t n)
forall (m :: * -> *) a. Monad m => a -> m a
return LU :: forall t (n :: Nat).
Matrix t n n -> Matrix t n n -> Matrix t n n -> Scalar t -> LU t n
LU {Matrix t n n
DataFrame t '[]
luPerm :: Matrix t n n
luPermDet :: DataFrame t '[]
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: DataFrame t '[]
luPerm :: Matrix t n n
luUpper :: Matrix t n n
luLower :: Matrix t n n
..}
where
n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n) :: Int
luSolveR ::
forall t (n :: Nat) (ds :: [Nat])
. (MatrixLU t n, Dimensions ds)
=> LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds)
luSolveR :: LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds)
luSolveR LU {Matrix t n n
Scalar t
luPermDet :: Scalar t
luPerm :: Matrix t n n
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} DataFrame t (n :+ ds)
b = (forall s. ST s (DataFrame t (n :+ ds))) -> DataFrame t (n :+ ds)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (DataFrame t (n :+ ds))) -> DataFrame t (n :+ ds))
-> (forall s. ST s (DataFrame t (n :+ ds)))
-> DataFrame t (n :+ ds)
forall a b. (a -> b) -> a -> b
$ do
STDataFrame s t (n :+ ds)
xPtr <- DataFrame t (n :+ ds) -> ST s (STDataFrame s t (n :+ ds))
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame (Matrix t n n
DataFrame t (RunList (Snoc' '[n] n))
luPerm DataFrame t (RunList (Snoc' '[n] n))
-> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds)
forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat).
(Contraction t as bs asbs, KnownDim m,
PrimArray t (DataFrame t (as +: m)),
PrimArray t (DataFrame t (m :+ bs)),
PrimArray t (DataFrame t asbs)) =>
DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
%* DataFrame t (n :+ ds)
b)
Matrix t n n -> STDataFrame s t (n :+ ds) -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m,
n <= m) =>
DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s ()
solveLowerTriangularR Matrix t n n
luLower STDataFrame s t (n :+ ds)
xPtr
Matrix t n n -> STDataFrame s t (n :+ ds) -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) =>
DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s ()
solveUpperTriangularR Matrix t n n
luUpper STDataFrame s t (n :+ ds)
xPtr
STDataFrame s t (n :+ ds) -> ST s (DataFrame t (n :+ ds))
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t (n :+ ds)
xPtr
luSolveL ::
forall t (n :: Nat) (ds :: [Nat])
. (MatrixLU t n, Dimensions ds)
=> LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n)
luSolveL :: LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n)
luSolveL LU {Matrix t n n
Scalar t
luPermDet :: Scalar t
luPerm :: Matrix t n n
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} DataFrame t (ds +: n)
b
| Dim n
dn <- KnownDim n => Dim n
forall k (n :: k). KnownDim n => Dim n
dim @n
, Dims ds
dds <- Dimensions ds => Dims ds
forall k (ds :: [k]). Dimensions ds => Dims ds
dims @ds
, Dims (ds +: n)
Dims <- Dims ds -> Dim n -> Dims (ds +: n)
forall k (f :: k -> *) (xs :: [k]) (sy :: [k]) (y :: k).
SnocList sy y xs =>
TypedList f sy -> f y -> TypedList f xs
Snoc Dims ds
dds Dim n
dn
, Dict (SnocList ds n ns)
Dict <- SnocList ds n ns => Dict (SnocList ds n ns)
forall (a :: Constraint). a => Dict a
Dict @(SnocList ds n _)
= (forall s. ST s (DataFrame t ns)) -> DataFrame t ns
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (DataFrame t ns)) -> DataFrame t ns)
-> (forall s. ST s (DataFrame t ns)) -> DataFrame t ns
forall a b. (a -> b) -> a -> b
$ do
STDataFrame s t ns
xPtr <- DataFrame t ns -> ST s (STDataFrame s t ns)
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
thawDataFrame DataFrame t ns
DataFrame t (ds +: n)
b
STDataFrame s t (ds +: n) -> Matrix t n n -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) =>
STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveUpperTriangularL STDataFrame s t ns
STDataFrame s t (ds +: n)
xPtr Matrix t n n
luUpper
STDataFrame s t (ds +: n) -> Matrix t n n -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m,
n <= m) =>
STDataFrame s t (ds +: m) -> DataFrame t '[n, m] -> ST s ()
solveLowerTriangularL STDataFrame s t ns
STDataFrame s t (ds +: n)
xPtr Matrix t n n
luLower
(DataFrame t (ds +: n) -> Matrix t n n -> DataFrame t ns
forall t (as :: [Nat]) (bs :: [Nat]) (asbs :: [Nat]) (m :: Nat).
(Contraction t as bs asbs, KnownDim m,
PrimArray t (DataFrame t (as +: m)),
PrimArray t (DataFrame t (m :+ bs)),
PrimArray t (DataFrame t asbs)) =>
DataFrame t (as +: m) -> DataFrame t (m :+ bs) -> DataFrame t asbs
%* Matrix t n n
luPerm) (DataFrame t ns -> DataFrame t ns)
-> ST s (DataFrame t ns) -> ST s (DataFrame t ns)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t ns -> ST s (DataFrame t ns)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t ns
xPtr
luSolveL LU t n
_ DataFrame t (ds +: n)
_ = String -> DataFrame t (ds +: n)
forall a. HasCallStack => String -> a
error String
"luSolveL: impossible pattern"
inverseViaLU :: forall (t :: Type) (n :: Nat)
. MatrixLU t n => Matrix t n n -> Matrix t n n
inverseViaLU :: Matrix t n n -> Matrix t n n
inverseViaLU Matrix t n n
a = (forall s. ST s (Matrix t n n)) -> Matrix t n n
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Matrix t n n)) -> Matrix t n n)
-> (forall s. ST s (Matrix t n n)) -> Matrix t n n
forall a b. (a -> b) -> a -> b
$ do
STDataFrame s t '[n, n]
xPtr <- Matrix t n n -> ST s (STDataFrame s t '[n, n])
forall k t (ns :: [k]) s.
(Dimensions ns, PrimArray t (DataFrame t ns)) =>
DataFrame t ns -> ST s (STDataFrame s t ns)
unsafeThawDataFrame Matrix t n n
luPerm
Matrix t n n -> STDataFrame s t '[n, n] -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m,
n <= m) =>
DataFrame t '[n, m] -> STDataFrame s t (n :+ ds) -> ST s ()
solveLowerTriangularR Matrix t n n
luLower STDataFrame s t '[n, n]
xPtr
Matrix t n n -> STDataFrame s t '[n, n] -> ST s ()
forall s t (n :: Nat) (m :: Nat) (ds :: [Nat]).
(PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) =>
DataFrame t '[n, m] -> STDataFrame s t (m :+ ds) -> ST s ()
solveUpperTriangularR Matrix t n n
luUpper STDataFrame s t '[n, n]
xPtr
STDataFrame s t '[n, n] -> ST s (Matrix t n n)
forall k t (ns :: [k]) s.
PrimArray t (DataFrame t ns) =>
STDataFrame s t ns -> ST s (DataFrame t ns)
unsafeFreezeDataFrame STDataFrame s t '[n, n]
xPtr
where
LU {Matrix t n n
Scalar t
luPermDet :: Scalar t
luUpper :: Matrix t n n
luLower :: Matrix t n n
luPerm :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} = Matrix t n n -> LU t n
forall t (n :: Nat). MatrixLU t n => Matrix t n n -> LU t n
lu Matrix t n n
a
detViaLU :: forall (t :: Type) (n :: Nat)
. MatrixLU t n => Matrix t n n -> Scalar t
detViaLU :: Matrix t n n -> Scalar t
detViaLU Matrix t n n
m = (Scalar t -> Int -> Scalar t) -> Scalar t -> [Int] -> Scalar t
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Scalar t
x Int
off -> t -> Scalar t
forall t. t -> DataFrame t '[]
scalar (Int -> Matrix t n n -> t
forall t a. PrimArray t a => Int -> a -> t
ixOff Int
off Matrix t n n
luUpper) Scalar t -> Scalar t -> Scalar t
forall a. Num a => a -> a -> a
* Scalar t
x) Scalar t
luPermDet [Int
0,Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n]
where
n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n) :: Int
LU {Matrix t n n
Scalar t
luPerm :: Matrix t n n
luLower :: Matrix t n n
luPermDet :: Scalar t
luUpper :: Matrix t n n
luPermDet :: forall t (n :: Nat). LU t n -> Scalar t
luPerm :: forall t (n :: Nat). LU t n -> Matrix t n n
luUpper :: forall t (n :: Nat). LU t n -> Matrix t n n
luLower :: forall t (n :: Nat). LU t n -> Matrix t n n
..} = Matrix t n n -> LU t n
forall t (n :: Nat). MatrixLU t n => Matrix t n n -> LU t n
lu Matrix t n n
m
luInplace ::
forall (s :: Type) (t :: Type) (n :: Nat)
. (PrimBytes t, Fractional t, Ord t, KnownDim n)
=> STDataFrame s t '[n]
-> STDataFrame s Word '[n]
-> STDataFrame s t '[n,n]
-> ST s Bool
luInplace :: STDataFrame s t '[n]
-> STDataFrame s Word '[n] -> STDataFrame s t '[n, n] -> ST s Bool
luInplace STDataFrame s t '[n]
temp STDataFrame s Word '[n]
pPtr STDataFrame s t '[n, n]
aPtr = (Bool -> Int -> ST s Bool) -> Bool -> [Int] -> ST s Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Bool
b -> (Bool -> Bool) -> ST s Bool -> ST s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool
b Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
/=) (ST s Bool -> ST s Bool) -> (Int -> ST s Bool) -> Int -> ST s Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ST s Bool
go) Bool
True [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2]
where
n :: Int
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (KnownDim n => Word
forall k (n :: k). KnownDim n => Word
dimVal' @n) :: Int
go :: Int -> ST s Bool
go :: Int -> ST s Bool
go Int
k = do
Int
mu <- Int -> ST s Int
findPivot Int
k
let swapped :: Bool
swapped = Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
mu
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
swapped (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ST s ()
swapRows Int
k Int
mu
DataFrame t '[]
akk <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*(Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (DataFrame t '[]
akk DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Eq a => a -> a -> Bool
/= DataFrame t '[]
0) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
let rakk :: DataFrame t '[]
rakk = DataFrame t '[] -> DataFrame t '[]
forall a. Fractional a => a -> a
recip DataFrame t '[]
akk
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
let ni :: Int
ni = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i
DataFrame t '[]
aik <- (DataFrame t '[]
rakk DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*) (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) DataFrame t '[]
aik
[Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
DataFrame t '[]
akj <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
DataFrame t '[]
aij <- STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
STDataFrame s t '[n, n] -> Int -> DataFrame t '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
ni Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (DataFrame t '[]
aij DataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
- DataFrame t '[]
aikDataFrame t '[] -> DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a -> a
*DataFrame t '[]
akj)
Bool -> ST s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
swapped
findPivot :: Int -> ST s Int
findPivot :: Int -> ST s Int
findPivot Int
k = (DataFrame t '[], Int) -> Int
forall a b. (a, b) -> b
snd ((DataFrame t '[], Int) -> Int)
-> ST s (DataFrame t '[], Int) -> ST s Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int))
-> (DataFrame t '[], Int) -> [Int] -> ST s (DataFrame t '[], Int)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int)
findPivotF (DataFrame t '[]
0, Int
k) [Int
k..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
where
findPivotF :: (Scalar t, Int) -> Int -> ST s (Scalar t, Int)
findPivotF :: (DataFrame t '[], Int) -> Int -> ST s (DataFrame t '[], Int)
findPivotF aj :: (DataFrame t '[], Int)
aj@(DataFrame t '[]
a, Int
_) Int
i = do
DataFrame t '[]
x <- DataFrame t '[] -> DataFrame t '[]
forall a. Num a => a -> a
abs (DataFrame t '[] -> DataFrame t '[])
-> ST s (DataFrame t '[]) -> ST s (DataFrame t '[])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STDataFrame s t '[n, n] -> Int -> ST s (DataFrame t '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s t '[n, n]
aPtr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
(DataFrame t '[], Int) -> ST s (DataFrame t '[], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (if DataFrame t '[]
x DataFrame t '[] -> DataFrame t '[] -> Bool
forall a. Ord a => a -> a -> Bool
> DataFrame t '[]
a then (DataFrame t '[]
x, Int
i) else (DataFrame t '[], Int)
aj)
swapRows :: Int -> Int -> ST s ()
swapRows :: Int -> Int -> ST s ()
swapRows Int
i Int
j = do
let iPtr :: STDataFrame s t '[n]
iPtr = Idxs '[n] -> STDataFrame s t '[n, n] -> STDataFrame s t '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
ConcatList as bs asbs =>
Idxs as -> STDataFrame s t asbs -> STDataFrame s t bs
subDataFrameView' (Int -> Idx n
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Idx n -> TypedList Idx '[] -> Idxs '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) STDataFrame s t '[n, n]
aPtr
jPtr :: STDataFrame s t '[n]
jPtr = Idxs '[n] -> STDataFrame s t '[n, n] -> STDataFrame s t '[n]
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
ConcatList as bs asbs =>
Idxs as -> STDataFrame s t asbs -> STDataFrame s t bs
subDataFrameView' (Int -> Idx n
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
j Idx n -> TypedList Idx '[] -> Idxs '[n]
forall k (f :: k -> *) (xs :: [k]) (y :: k) (ys :: [k]).
(xs ~ (y : ys)) =>
f y -> TypedList f ys -> TypedList f xs
:* TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U) STDataFrame s t '[n, n]
aPtr
TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
iPtr STDataFrame s t '[n]
temp
TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
jPtr STDataFrame s t '[n]
iPtr
TypedList Idx '[]
-> STDataFrame s t '[n] -> STDataFrame s t '[n] -> ST s ()
forall k t (as :: [k]) (bs :: [k]) (asbs :: [k]) s.
(ExactDims bs, PrimBytes t, ConcatList as bs asbs) =>
Idxs as -> STDataFrame s t bs -> STDataFrame s t asbs -> ST s ()
copyMutableDataFrame' TypedList Idx '[]
forall k (f :: k -> *) (xs :: [k]). (xs ~ '[]) => TypedList f xs
U STDataFrame s t '[n]
temp STDataFrame s t '[n]
jPtr
DataFrame Word '[]
t <- STDataFrame s Word '[n] -> Int -> ST s (DataFrame Word '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s Word '[n]
pPtr Int
i
STDataFrame s Word '[n] -> Int -> ST s (DataFrame Word '[])
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> ST s (DataFrame t '[])
readDataFrameOff STDataFrame s Word '[n]
pPtr Int
j ST s (DataFrame Word '[])
-> (DataFrame Word '[] -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STDataFrame s Word '[n] -> Int -> DataFrame Word '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s Word '[n]
pPtr Int
i
STDataFrame s Word '[n] -> Int -> DataFrame Word '[] -> ST s ()
forall k t (ns :: [k]) s.
PrimBytes (DataFrame t '[]) =>
STDataFrame s t ns -> Int -> DataFrame t '[] -> ST s ()
writeDataFrameOff STDataFrame s Word '[n]
pPtr Int
j DataFrame Word '[]
t