{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver -fplugin=GHC.TypeLits.Normalise -fconstraint-solver-iterations=10 #-}
-- | Vectors and Matrices with statically-typed dimensions based on boxed vectors.

module Goal.Core.Vector.Boxed
    ( -- * Vector
      module Data.Vector.Sized
      -- ** Construction
    , doubleton
    , range
    , breakStream
    , breakEvery
    -- ** Deconstruction
    , toPair
    , concat
    -- * Matrix
    , Matrix
    , nRows
    , nColumns
    -- ** Construction
    , fromRows
    , fromColumns
    , matrixIdentity
    , outerProduct
    , diagonalConcat
    -- ** Deconstruction
    , toRows
    , toColumns
    -- ** Manipulation
    , columnVector
    , rowVector
    -- ** BLAS
    , dotProduct
    , matrixVectorMultiply
    , matrixMatrixMultiply
    , inverse
    , transpose
    ) where


--- Imports ---


-- Goal --

import Goal.Core.Util hiding (breakEvery,range)
import qualified Goal.Core.Util (breakEvery)

import qualified Goal.Core.Vector.Generic as G

-- Unqualified --

import Prelude hiding (concat,zipWith,(++),replicate)
import qualified Data.Vector as B
import qualified Data.Vector.Mutable as BM
import qualified Control.Monad.ST as ST
import qualified Data.Vector.Generic.Sized.Internal as I

-- Qualified --

import Data.Vector.Sized
import GHC.TypeNats
import Data.Proxy

-- Qualified Imports --

-- | Flatten a 'Vector' of 'Vector's.
concat :: KnownNat n => Vector m (Vector n x) -> Vector (m*n) x
{-# INLINE concat #-}
concat :: Vector m (Vector n x) -> Vector (m * n) x
concat = Vector m (Vector n x) -> Vector (m * n) x
forall (n :: Nat) (v :: Type -> Type) x (m :: Nat).
(KnownNat n, Vector v x, Vector v (Vector v n x)) =>
Vector v m (Vector v n x) -> Vector v (m * n) x
G.concat

-- | Create a 'Vector' of length 2.
doubleton :: x -> x -> Vector 2 x
{-# INLINE doubleton #-}
doubleton :: x -> x -> Vector 2 x
doubleton = x -> x -> Vector 2 x
forall (v :: Type -> Type) x. Vector v x => x -> x -> Vector v 2 x
G.doubleton

-- | Partition of an interval.
range :: (KnownNat n, Fractional x) => x -> x -> Vector n x
{-# INLINE range #-}
range :: x -> x -> Vector n x
range = x -> x -> Vector n x
forall (v :: Type -> Type) (n :: Nat) x.
(Vector v x, KnownNat n, Fractional x) =>
x -> x -> Vector v n x
G.range

-- | Cycles a list of elements and breaks it up into an infinite list of 'Vector's.
breakStream :: forall n a. KnownNat n => [a] -> [Vector n a]
{-# INLINE breakStream #-}
breakStream :: [a] -> [Vector n a]
breakStream [a]
as =
    Vector a -> Vector n a
forall (v :: Type -> Type) (n :: Nat) a. v a -> Vector v n a
I.Vector (Vector a -> Vector n a) -> ([a] -> Vector a) -> [a] -> Vector n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Vector a
forall a. [a] -> Vector a
B.fromList ([a] -> Vector n a) -> [[a]] -> [Vector n a]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [a] -> [[a]]
forall x. Int -> [x] -> [[x]]
Goal.Core.Util.breakEvery (Proxy n -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)) ([a] -> [a]
forall a. [a] -> [a]
cycle [a]
as)

-- | Converts a length two 'Vector' into a pair of elements.
toPair :: Vector 2 x -> (x,x)
{-# INLINE toPair #-}
toPair :: Vector 2 x -> (x, x)
toPair = Vector 2 x -> (x, x)
forall (v :: Type -> Type) a. Vector v a => Vector v 2 a -> (a, a)
G.toPair

-- | Breaks a 'Vector' into a Vector of Vectors.
breakEvery :: (KnownNat n, KnownNat k) => Vector (n*k) a -> Vector n (Vector k a)
{-# INLINE breakEvery #-}
breakEvery :: Vector (n * k) a -> Vector n (Vector k a)
breakEvery = Vector (n * k) a -> Vector n (Vector k a)
forall (v :: Type -> Type) (n :: Nat) (k :: Nat) a.
(Vector v a, Vector v (Vector v k a), KnownNat n, KnownNat k) =>
Vector v (n * k) a -> Vector v n (Vector v k a)
G.breakEvery


--- Matrices ---


-- | Matrices with static dimensions (boxed).
type Matrix = G.Matrix B.Vector

-- | The number of rows in the 'Matrix'.
nRows :: forall m n a . KnownNat m => Matrix m n a -> Int
{-# INLINE nRows #-}
nRows :: Matrix m n a -> Int
nRows = Matrix m n a -> Int
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
KnownNat m =>
Matrix v m n a -> Int
G.nRows

-- | The number of columns in the 'Matrix'.
nColumns :: forall m n a . KnownNat n => Matrix m n a -> Int
{-# INLINE nColumns #-}
nColumns :: Matrix m n a -> Int
nColumns = Matrix m n a -> Int
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
KnownNat n =>
Matrix v m n a -> Int
G.nColumns

-- | Convert a 'Matrix' into a 'Vector' of 'Vector's of rows.
toRows :: (KnownNat m, KnownNat n) => Matrix m n x -> Vector m (Vector n x)
{-# INLINE toRows #-}
toRows :: Matrix m n x -> Vector m (Vector n x)
toRows = Matrix m n x -> Vector m (Vector n x)
forall (v :: Type -> Type) a (n :: Nat) (m :: Nat).
(Vector v a, Vector v (Vector v n a), KnownNat n, KnownNat m) =>
Matrix v m n a -> Vector v m (Vector v n a)
G.toRows

-- | Convert a 'Matrix' into a 'Vector' of 'Vector's of columns.
toColumns :: (KnownNat m, KnownNat n) => Matrix m n x -> Vector n (Vector m x)
{-# INLINE toColumns #-}
toColumns :: Matrix m n x -> Vector n (Vector m x)
toColumns = Matrix m n x -> Vector n (Vector m x)
forall (v :: Type -> Type) a (m :: Nat) (n :: Nat).
(Vector v a, Vector v (Vector v m a), KnownNat m, KnownNat n,
 Vector v Int) =>
Matrix v m n a -> Vector v n (Vector v m a)
G.toColumns


-- | Turn a 'Vector' into a single column 'Matrix'.
columnVector :: Vector n a -> Matrix n 1 a
{-# INLINE columnVector #-}
columnVector :: Vector n a -> Matrix n 1 a
columnVector = Vector n a -> Matrix n 1 a
forall (v :: Type -> Type) (n :: Nat) a.
Vector v n a -> Matrix v n 1 a
G.columnVector

-- | Turn a 'Vector' into a single row 'Matrix'.
rowVector :: Vector n a -> Matrix 1 n a
{-# INLINE rowVector #-}
rowVector :: Vector n a -> Matrix 1 n a
rowVector = Vector n a -> Matrix 1 n a
forall (v :: Type -> Type) (n :: Nat) a.
Vector v n a -> Matrix v 1 n a
G.rowVector

-- | Create a 'Matrix' from a 'Vector' of row 'Vector's.
fromRows :: KnownNat n => Vector m (Vector n x) -> Matrix m n x
{-# INLINE fromRows #-}
fromRows :: Vector m (Vector n x) -> Matrix m n x
fromRows = Vector m (Vector n x) -> Matrix m n x
forall (v :: Type -> Type) x (n :: Nat) (m :: Nat).
(Vector v x, Vector v (Vector v n x), KnownNat n) =>
Vector v m (Vector v n x) -> Matrix v m n x
G.fromRows

-- | Create a 'Matrix' from a 'Vector' of column 'Vector's.
fromColumns :: (KnownNat n, KnownNat m) => Vector n (Vector m x) -> Matrix m n x
{-# INLINE fromColumns #-}
fromColumns :: Vector n (Vector m x) -> Matrix m n x
fromColumns = Vector n (Vector m x) -> Matrix m n x
forall (v :: Type -> Type) x (n :: Nat) (m :: Nat).
(Vector v x, Vector v Int, Vector v (Vector v n x),
 Vector v (Vector v m x), KnownNat n, KnownNat m) =>
Vector v n (Vector v m x) -> Matrix v m n x
G.fromColumns

-- | Diagonally concatenate two matrices, padding the gaps with zeroes (pure implementation).
diagonalConcat
    :: (KnownNat n, KnownNat m, KnownNat o, KnownNat p, Num a)
    => Matrix n m a -> Matrix o p a -> Matrix (n+o) (m+p) a
{-# INLINE diagonalConcat #-}
diagonalConcat :: Matrix n m a -> Matrix o p a -> Matrix (n + o) (m + p) a
diagonalConcat Matrix n m a
mtx1 Matrix o p a
mtx2 =
    let rws1 :: Vector Vector n (Vector (m + p) a)
rws1 = (Vector m a -> Vector p a -> Vector (m + p) a
forall (n :: Nat) (m :: Nat) a.
Vector n a -> Vector m a -> Vector (n + m) a
++ a -> Vector p a
forall (n :: Nat) a. KnownNat n => a -> Vector n a
replicate a
0) (Vector m a -> Vector (m + p) a)
-> Vector Vector n (Vector m a)
-> Vector Vector n (Vector (m + p) a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Matrix n m a -> Vector Vector n (Vector m a)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n) =>
Matrix m n x -> Vector m (Vector n x)
toRows Matrix n m a
mtx1
        rws2 :: Vector Vector o (Vector (m + p) a)
rws2 = (a -> Vector m a
forall (n :: Nat) a. KnownNat n => a -> Vector n a
replicate a
0 Vector m a -> Vector p a -> Vector (m + p) a
forall (n :: Nat) (m :: Nat) a.
Vector n a -> Vector m a -> Vector (n + m) a
++) (Vector p a -> Vector (m + p) a)
-> Vector Vector o (Vector p a)
-> Vector Vector o (Vector (m + p) a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Matrix o p a -> Vector Vector o (Vector p a)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n) =>
Matrix m n x -> Vector m (Vector n x)
toRows Matrix o p a
mtx2
     in Vector (n + o) (Vector (m + p) a) -> Matrix (n + o) (m + p) a
forall (n :: Nat) (m :: Nat) x.
KnownNat n =>
Vector m (Vector n x) -> Matrix m n x
fromRows (Vector (n + o) (Vector (m + p) a) -> Matrix (n + o) (m + p) a)
-> Vector (n + o) (Vector (m + p) a) -> Matrix (n + o) (m + p) a
forall a b. (a -> b) -> a -> b
$ Vector Vector n (Vector (m + p) a)
rws1 Vector Vector n (Vector (m + p) a)
-> Vector Vector o (Vector (m + p) a)
-> Vector (n + o) (Vector (m + p) a)
forall (n :: Nat) (m :: Nat) a.
Vector n a -> Vector m a -> Vector (n + m) a
++ Vector Vector o (Vector (m + p) a)
rws2

-- | Pure implementation of the dot product.
dotProduct :: Num x => Vector n x -> Vector n x -> x
{-# INLINE dotProduct #-}
dotProduct :: Vector n x -> Vector n x -> x
dotProduct = Vector n x -> Vector n x -> x
forall (v :: Type -> Type) x (n :: Nat).
(Vector v x, Num x) =>
Vector v n x -> Vector v n x -> x
G.dotProduct

-- | Pure implementation of the outer product.
outerProduct
    :: (KnownNat m, KnownNat n, Num x)
    => Vector m x -> Vector n x -> Matrix m n x
{-# INLINE outerProduct #-}
outerProduct :: Vector m x -> Vector n x -> Matrix m n x
outerProduct = Vector m x -> Vector n x -> Matrix m n x
forall (m :: Nat) (n :: Nat) x (v :: Type -> Type).
(KnownNat m, KnownNat n, Num x, Vector v Int, Vector v x,
 Vector v (Vector v n x), Vector v (Vector v m x),
 Vector v (Vector v 1 x)) =>
Vector v n x -> Vector v m x -> Matrix v n m x
G.outerProduct

-- | Pure implementation of 'Matrix' transposition.
transpose
    :: (KnownNat m, KnownNat n, Num x)
    => Matrix m n x -> Matrix n m x
{-# INLINE transpose #-}
transpose :: Matrix m n x -> Matrix n m x
transpose = Matrix m n x -> Matrix n m x
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
(KnownNat m, KnownNat n, Vector v Int, Vector v a,
 Vector v (Vector v m a)) =>
Matrix v m n a -> Matrix v n m a
G.transpose

-- | Pure 'Matrix' x 'Vector' multiplication.
matrixVectorMultiply
    :: (KnownNat m, KnownNat n, Num x)
    => Matrix m n x -> Vector n x -> Vector m x
{-# INLINE matrixVectorMultiply #-}
matrixVectorMultiply :: Matrix m n x -> Vector n x -> Vector m x
matrixVectorMultiply Matrix m n x
mtx = Matrix Vector m 1 x -> Vector m x
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
Matrix v m n a -> Vector v (m * n) a
G.toVector (Matrix Vector m 1 x -> Vector m x)
-> (Vector n x -> Matrix Vector m 1 x) -> Vector n x -> Vector m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix m n x -> Matrix n 1 x -> Matrix Vector m 1 x
forall (m :: Nat) (n :: Nat) (o :: Nat) a.
(KnownNat m, KnownNat n, KnownNat o, Num a) =>
Matrix m n a -> Matrix n o a -> Matrix m o a
matrixMatrixMultiply Matrix m n x
mtx (Matrix n 1 x -> Matrix Vector m 1 x)
-> (Vector n x -> Matrix n 1 x)
-> Vector n x
-> Matrix Vector m 1 x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector n x -> Matrix n 1 x
forall (n :: Nat) a. Vector n a -> Matrix n 1 a
columnVector

-- | The identity 'Matrix'.
matrixIdentity :: (KnownNat n, Num a) => Matrix n n a
{-# INLINE matrixIdentity #-}
matrixIdentity :: Matrix n n a
matrixIdentity =
    Vector n (Vector n a) -> Matrix n n a
forall (n :: Nat) (m :: Nat) x.
KnownNat n =>
Vector m (Vector n x) -> Matrix m n x
fromRows (Vector n (Vector n a) -> Matrix n n a)
-> Vector n (Vector n a) -> Matrix n n a
forall a b. (a -> b) -> a -> b
$ (Finite n -> Vector n a) -> Vector n (Vector n a)
forall (n :: Nat) a. KnownNat n => (Finite n -> a) -> Vector n a
generate (\Finite n
i -> (Finite n -> a) -> Vector n a
forall (n :: Nat) a. KnownNat n => (Finite n -> a) -> Vector n a
generate (\Finite n
j -> if Finite n -> Int
forall (n :: Nat). KnownNat n => Finite n -> Int
finiteInt Finite n
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Finite n -> Int
forall (n :: Nat). KnownNat n => Finite n -> Int
finiteInt Finite n
j then a
1 else a
0))

-- | Pure implementation of matrix inversion.
inverse :: forall a n. (Fractional a, Ord a, KnownNat n) => Matrix n n a -> Maybe (Matrix n n a)
{-# INLINE inverse #-}
inverse :: Matrix n n a -> Maybe (Matrix n n a)
inverse Matrix n n a
mtx =
    let rws :: Vector (Vector a)
rws = Vector n (Vector a) -> Vector (Vector a)
forall (n :: Nat) a. Vector n a -> Vector a
fromSized (Vector n (Vector a) -> Vector (Vector a))
-> Vector n (Vector a) -> Vector (Vector a)
forall a b. (a -> b) -> a -> b
$ Vector (n + n) a -> Vector a
forall (n :: Nat) a. Vector n a -> Vector a
fromSized (Vector (n + n) a -> Vector a)
-> Vector Vector n (Vector (n + n) a) -> Vector n (Vector a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Vector n a -> Vector n a -> Vector (n + n) a)
-> Vector n (Vector n a)
-> Vector n (Vector n a)
-> Vector Vector n (Vector (n + n) a)
forall a b c (n :: Nat).
(a -> b -> c) -> Vector n a -> Vector n b -> Vector n c
zipWith Vector n a -> Vector n a -> Vector (n + n) a
forall (n :: Nat) (m :: Nat) a.
Vector n a -> Vector m a -> Vector (n + m) a
(++) (Matrix n n a -> Vector n (Vector n a)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n) =>
Matrix m n x -> Vector m (Vector n x)
toRows Matrix n n a
mtx) (Matrix n n a -> Vector n (Vector n a)
forall (m :: Nat) (n :: Nat) x.
(KnownNat m, KnownNat n) =>
Matrix m n x -> Vector m (Vector n x)
toRows Matrix n n a
forall (n :: Nat) a. (KnownNat n, Num a) => Matrix n n a
matrixIdentity)
        n :: Int
n = Proxy n -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)
        rws' :: Maybe (Vector (Vector a))
rws' = (Vector (Vector a) -> Int -> Maybe (Vector (Vector a)))
-> Vector (Vector a) -> Vector Int -> Maybe (Vector (Vector a))
forall (m :: Type -> Type) a b.
Monad m =>
(a -> b -> m a) -> a -> Vector b -> m a
B.foldM' Vector (Vector a) -> Int -> Maybe (Vector (Vector a))
forall a.
(Ord a, Fractional a) =>
Vector (Vector a) -> Int -> Maybe (Vector (Vector a))
eliminateRow Vector (Vector a)
rws (Vector Int -> Maybe (Vector (Vector a)))
-> Vector Int -> Maybe (Vector (Vector a))
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Int) -> Vector Int
forall a. Int -> (Int -> a) -> Vector a
B.generate Int
n Int -> Int
forall a. a -> a
id
     in Vector Vector (n * n) a -> Matrix n n a
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
Vector v (m * n) a -> Matrix v m n a
G.Matrix (Vector Vector (n * n) a -> Matrix n n a)
-> (Vector (Vector a) -> Vector Vector (n * n) a)
-> Vector (Vector a)
-> Matrix n n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Vector Vector (n * n) a
forall (v :: Type -> Type) (n :: Nat) a. v a -> Vector v n a
I.Vector (Vector a -> Vector Vector (n * n) a)
-> (Vector (Vector a) -> Vector a)
-> Vector (Vector a)
-> Vector Vector (n * n) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector a -> Vector a) -> Vector (Vector a) -> Vector a
forall a b. (a -> Vector b) -> Vector a -> Vector b
B.concatMap (Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
B.drop Int
n) (Vector (Vector a) -> Matrix n n a)
-> Maybe (Vector (Vector a)) -> Maybe (Matrix n n a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Vector (Vector a))
rws'

-- | Pure 'Matrix' x 'Matrix' multiplication.
matrixMatrixMultiply
    :: forall m n o a. (KnownNat m, KnownNat n, KnownNat o, Num a)
    => Matrix m n a -> Matrix n o a -> Matrix m o a
{-# INLINE matrixMatrixMultiply #-}
matrixMatrixMultiply :: Matrix m n a -> Matrix n o a -> Matrix m o a
matrixMatrixMultiply (G.Matrix (I.Vector Vector a
v)) Matrix n o a
wm =
    let n :: Int
n = Proxy n -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)
        o :: Int
o = Proxy o -> Int
forall (n :: Nat). KnownNat n => Proxy n -> Int
natValInt (Proxy o
forall k (t :: k). Proxy t
Proxy :: Proxy o)
        (G.Matrix (I.Vector Vector a
w')) = Matrix n o a -> Matrix Vector o n a
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
(KnownNat m, KnownNat n, Vector v Int, Vector v a,
 Vector v (Vector v m a)) =>
Matrix v m n a -> Matrix v n m a
G.transpose Matrix n o a
wm
        f :: Finite (m * o) -> a
f Finite (m * o)
k = let (Int
i,Int
j) = Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
divMod (Finite (m * o) -> Int
forall (n :: Nat). KnownNat n => Finite n -> Int
finiteInt Finite (m * o)
k) Int
o
                  slc1 :: Vector a
slc1 = Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
B.unsafeSlice (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Int
n Vector a
v
                  slc2 :: Vector a
slc2 = Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
B.unsafeSlice (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Int
n Vector a
w'
               in Vector a -> Vector a -> a
forall (v :: Type -> Type) x.
(Vector v x, Num x) =>
v x -> v x -> x
G.weakDotProduct Vector a
slc1 Vector a
slc2
     in Vector Vector (m * o) a -> Matrix m o a
forall (v :: Type -> Type) (m :: Nat) (n :: Nat) a.
Vector v (m * n) a -> Matrix v m n a
G.Matrix (Vector Vector (m * o) a -> Matrix m o a)
-> Vector Vector (m * o) a -> Matrix m o a
forall a b. (a -> b) -> a -> b
$ (Finite (m * o) -> a) -> Vector Vector (m * o) a
forall (v :: Type -> Type) (n :: Nat) a.
(KnownNat n, Vector v a) =>
(Finite n -> a) -> Vector v n a
G.generate Finite (m * o) -> a
f


--- Internal ---


eliminateRow :: (Ord a, Fractional a) => B.Vector (B.Vector a) -> Int -> Maybe (B.Vector (B.Vector a))
eliminateRow :: Vector (Vector a) -> Int -> Maybe (Vector (Vector a))
eliminateRow Vector (Vector a)
mtx Int
k = do
    Vector (Vector a)
mtx' <- Int -> Vector (Vector a) -> Maybe (Vector (Vector a))
forall a.
(Fractional a, Ord a) =>
Int -> Vector (Vector a) -> Maybe (Vector (Vector a))
pivotRow Int
k Vector (Vector a)
mtx
    Vector (Vector a) -> Maybe (Vector (Vector a))
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Vector (Vector a) -> Maybe (Vector (Vector a)))
-> (Vector (Vector a) -> Vector (Vector a))
-> Vector (Vector a)
-> Maybe (Vector (Vector a))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector (Vector a) -> Vector (Vector a)
forall a.
Fractional a =>
Int -> Vector (Vector a) -> Vector (Vector a)
nullifyRows Int
k (Vector (Vector a) -> Maybe (Vector (Vector a)))
-> Vector (Vector a) -> Maybe (Vector (Vector a))
forall a b. (a -> b) -> a -> b
$ Int -> Vector (Vector a) -> Vector (Vector a)
forall a.
Fractional a =>
Int -> Vector (Vector a) -> Vector (Vector a)
normalizePivot Int
k Vector (Vector a)
mtx'

pivotRow :: (Fractional a, Ord a) => Int -> B.Vector (B.Vector a) -> Maybe (B.Vector (B.Vector a))
pivotRow :: Int -> Vector (Vector a) -> Maybe (Vector (Vector a))
pivotRow Int
k Vector (Vector a)
rws =
    let l :: Int
l = (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k) (Int -> Int) -> (Vector a -> Int) -> Vector a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Int
forall a. Ord a => Vector a -> Int
B.maxIndex (Vector a -> Int) -> Vector a -> Int
forall a b. (a -> b) -> a -> b
$ a -> a
forall a. Num a => a -> a
abs (a -> a) -> (Vector a -> a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector a -> Int -> a) -> Int -> Vector a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Vector a -> Int -> a
forall a. Vector a -> Int -> a
B.unsafeIndex Int
k (Vector a -> a) -> (Vector a -> Vector a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Vector a -> Vector a
forall a. Int -> Vector a -> Vector a
B.take (Vector (Vector a) -> Int
forall a. Vector a -> Int
B.length Vector (Vector a)
rws) (Vector a -> a) -> Vector (Vector a) -> Vector a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Vector (Vector a) -> Vector (Vector a)
forall a. Int -> Vector a -> Vector a
B.drop Int
k Vector (Vector a)
rws
        ak :: a
ak = Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
B.unsafeIndex Vector (Vector a)
rws Int
k Vector a -> Int -> a
forall a. Vector a -> Int -> a
B.! Int
l
     in if a -> a
forall a. Num a => a -> a
abs a
ak a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1e-10 then Maybe (Vector (Vector a))
forall a. Maybe a
Nothing
                  else (forall s. ST s (Maybe (Vector (Vector a))))
-> Maybe (Vector (Vector a))
forall a. (forall s. ST s a) -> a
ST.runST ((forall s. ST s (Maybe (Vector (Vector a))))
 -> Maybe (Vector (Vector a)))
-> (forall s. ST s (Maybe (Vector (Vector a))))
-> Maybe (Vector (Vector a))
forall a b. (a -> b) -> a -> b
$ do
                           MVector s (Vector a)
mrws <- Vector (Vector a) -> ST s (MVector (PrimState (ST s)) (Vector a))
forall (m :: Type -> Type) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
B.thaw Vector (Vector a)
rws
                           MVector (PrimState (ST s)) (Vector a) -> Int -> Int -> ST s ()
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> Int -> m ()
BM.unsafeSwap MVector s (Vector a)
MVector (PrimState (ST s)) (Vector a)
mrws Int
k Int
l
                           Vector (Vector a) -> Maybe (Vector (Vector a))
forall a. a -> Maybe a
Just (Vector (Vector a) -> Maybe (Vector (Vector a)))
-> ST s (Vector (Vector a)) -> ST s (Maybe (Vector (Vector a)))
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) (Vector a) -> ST s (Vector (Vector a))
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
B.freeze MVector s (Vector a)
MVector (PrimState (ST s)) (Vector a)
mrws

normalizePivot :: Fractional a => Int -> B.Vector (B.Vector a) -> B.Vector (B.Vector a)
normalizePivot :: Int -> Vector (Vector a) -> Vector (Vector a)
normalizePivot Int
k Vector (Vector a)
rws = (forall s. ST s (Vector (Vector a))) -> Vector (Vector a)
forall a. (forall s. ST s a) -> a
ST.runST ((forall s. ST s (Vector (Vector a))) -> Vector (Vector a))
-> (forall s. ST s (Vector (Vector a))) -> Vector (Vector a)
forall a b. (a -> b) -> a -> b
$ do
    let ak :: a
ak = a -> a
forall a. Fractional a => a -> a
recip (a -> a) -> (Vector a -> a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector a -> Int -> a) -> Int -> Vector a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Vector a -> Int -> a
forall a. Vector a -> Int -> a
B.unsafeIndex Int
k (Vector a -> a) -> Vector a -> a
forall a b. (a -> b) -> a -> b
$ Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
B.unsafeIndex Vector (Vector a)
rws Int
k
    MVector s (Vector a)
mrws <- Vector (Vector a) -> ST s (MVector (PrimState (ST s)) (Vector a))
forall (m :: Type -> Type) a.
PrimMonad m =>
Vector a -> m (MVector (PrimState m) a)
B.thaw Vector (Vector a)
rws
    MVector (PrimState (ST s)) (Vector a)
-> (Vector a -> Vector a) -> Int -> ST s ()
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
BM.modify MVector s (Vector a)
MVector (PrimState (ST s)) (Vector a)
mrws ((a -> a -> a
forall a. Num a => a -> a -> a
*a
ak) (a -> a) -> Vector a -> Vector a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$>) Int
k
    MVector (PrimState (ST s)) (Vector a) -> ST s (Vector (Vector a))
forall (m :: Type -> Type) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
B.freeze MVector s (Vector a)
MVector (PrimState (ST s)) (Vector a)
mrws

nullifyRows :: Fractional a => Int -> B.Vector (B.Vector a) -> B.Vector (B.Vector a)
nullifyRows :: Int -> Vector (Vector a) -> Vector (Vector a)
nullifyRows Int
k Vector (Vector a)
rws =
    let rwk :: Vector a
rwk = Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
B.unsafeIndex Vector (Vector a)
rws Int
k
        ak :: a
ak = Vector a -> Int -> a
forall a. Vector a -> Int -> a
B.unsafeIndex Vector a
rwk Int
k
        generator :: Int -> a
generator Int
i = if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
k then a
0 else Vector a -> Int -> a
forall a. Vector a -> Int -> a
B.unsafeIndex (Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
B.unsafeIndex Vector (Vector a)
rws Int
i) Int
k a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
ak
        as :: Vector a
as = Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
B.generate (Vector (Vector a) -> Int
forall a. Vector a -> Int
B.length Vector (Vector a)
rws) Int -> a
generator
     in (Vector a -> Vector a -> Vector a)
-> Vector (Vector a) -> Vector (Vector a) -> Vector (Vector a)
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
B.zipWith ((a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
B.zipWith (-)) Vector (Vector a)
rws (Vector (Vector a) -> Vector (Vector a))
-> Vector (Vector a) -> Vector (Vector a)
forall a b. (a -> b) -> a -> b
$ (\a
a -> (a -> a -> a
forall a. Num a => a -> a -> a
*a
a) (a -> a) -> Vector a -> Vector a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector a
rwk) (a -> Vector a) -> Vector a -> Vector (Vector a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector a
as