{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Numeric.Matrix.LU ( MatrixLU (..), LU (..) , luSolveR, luSolveL , detViaLU, inverseViaLU ) where import Control.Monad import Control.Monad.ST import Data.Kind import Numeric.DataFrame.Internal.PrimArray import Numeric.DataFrame.ST import Numeric.DataFrame.SubSpace import Numeric.DataFrame.Type import Numeric.Dimensions import Numeric.Matrix.Internal import Numeric.Scalar.Internal import Numeric.Subroutine.SolveTriangular -- | Result of LU factorization with Partial Pivoting -- \( PA = LU \). data LU (t :: Type) (n :: Nat) = LU { luLower :: Matrix t n n -- ^ Unit lower triangular matrix \(L\). -- All elements on the diagonal of @L@ equal @1@. -- The rest of the elements satisfy \(|l_{ij}| \leq 1\). , luUpper :: Matrix t n n -- ^ Upper triangular matrix \(U\) , luPerm :: Matrix t n n -- ^ Row permutation matrix \(P\) , luPermDet :: Scalar t -- ^ Sign of permutation @luPermDet == det . luPerm@; \(|P| = \pm 1\). } deriving instance (Show t, PrimBytes t, KnownDim n) => Show (LU t n) deriving instance (Eq (Matrix t n n), Eq t) => Eq (LU t n) class (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n]) => MatrixLU t (n :: Nat) where -- | Compute LU factorization with Partial Pivoting lu :: Matrix t n n -> LU t n instance (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n]) => MatrixLU t n where lu a = runST $ do pPtr <- unsafeThawDataFrame $ iwgen @_ @'[n] @'[] (\(Idx i :* U) -> S i) uPtr <- thawDataFrame a lPtr <- newDataFrame temp <- newDataFrame detPositive <- luInplace temp pPtr uPtr p <- unsafeFreezeDataFrame pPtr -- split U and L forM_ [0..n-1] $ \i -> do let ni = n*i forM_ [0..n-1] $ \j -> case compare i j of GT -> do lij <- readDataFrameOff uPtr (ni + j) writeDataFrameOff uPtr (ni + j) 0 writeDataFrameOff lPtr (ni + j) lij EQ -> writeDataFrameOff lPtr (ni + j) 1 LT -> writeDataFrameOff lPtr (ni + j) 0 luLower <- unsafeFreezeDataFrame lPtr luUpper <- unsafeFreezeDataFrame uPtr let luPermDet = if detPositive then 1 else -1 luPerm = iwgen @_ @'[n,n] @'[] (\(Idx i :* Idx j :* U) -> if S j == p ! i then 1 else 0) return LU {..} where n = fromIntegral (dimVal' @n) :: Int -- | Solve @Ax = b@ problem given LU decomposition of A. luSolveR :: forall t (n :: Nat) (ds :: [Nat]) . (MatrixLU t n, Dimensions ds) => LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds) luSolveR LU {..} b = runST $ do xPtr <- thawDataFrame (luPerm %* b) -- NB: wasting resources! solveLowerTriangularR luLower xPtr solveUpperTriangularR luUpper xPtr unsafeFreezeDataFrame xPtr -- | Solve @xA = b@ problem given LU decomposition of A. luSolveL :: forall t (n :: Nat) (ds :: [Nat]) . (MatrixLU t n, Dimensions ds) => LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n) luSolveL LU {..} b | dn <- dim @n , dds <- dims @ds , Dims <- Snoc dds dn , Dict <- Dict @(SnocList ds n _) = runST $ do xPtr <- thawDataFrame b solveUpperTriangularL xPtr luUpper solveLowerTriangularL xPtr luLower (%* luPerm) <$> unsafeFreezeDataFrame xPtr luSolveL _ _ = error "luSolveL: impossible pattern" -- | Calculate inverse of a matrix via LU decomposition inverseViaLU :: forall (t :: Type) (n :: Nat) . MatrixLU t n => Matrix t n n -> Matrix t n n inverseViaLU a = runST $ do xPtr <- unsafeThawDataFrame luPerm -- luPerm is only ever used once solveLowerTriangularR luLower xPtr solveUpperTriangularR luUpper xPtr unsafeFreezeDataFrame xPtr where LU {..} = lu a -- perfectly correct, but slightly slower versions: -- inverseViaLU a = luSolveR (lu a) eye -- inverseViaLU a = luSolveL (lu a) eye -- | Calculate determinant of a matrix via LU decomposition detViaLU :: forall (t :: Type) (n :: Nat) . MatrixLU t n => Matrix t n n -> Scalar t detViaLU m = foldl (\x off -> scalar (ixOff off luUpper) * x) luPermDet [0,n+1..n*n] where n = fromIntegral (dimVal' @n) :: Int LU {..} = lu m {- | Run LU decomposition with partial pivoting inplace, such that the upper triangular part of matrix \(A\) becomes \(U\) and the lower triangular part (without diagonal) of matrix \(A\) becomes \(L\). \(U\) is upper triangular. \(L\) is unit lower triangular; all diagonal elements of \(L\) are implicit and equal to 1; the rest of the elements a smaller than one \(|l_{ij}| \leq 1\). Pivoting is represented as a permutation vector \(p\); returned value is the sign of the permutation (positive if @True@, negative otherwise). NB: Initialize \(p\) with indices @0..n-1@. Reference: Algorithm 3.4.1 on p.128 of "Matrix Computations" 4th edition by G. H. Golub and C. F. Van Loan. -} luInplace :: forall (s :: Type) (t :: Type) (n :: Nat) . (PrimBytes t, Fractional t, Ord t, KnownDim n) => STDataFrame s t '[n] -- ^ Temporary buffer -> STDataFrame s Word '[n] -- ^ Current state of permutation \(p\) -> STDataFrame s t '[n,n] -- ^ Current state of \(A\) -> ST s Bool luInplace temp pPtr aPtr = foldM (\b -> fmap (b /=) . go) True [0..n-2] where n = fromIntegral (dimVal' @n) :: Int -- Runs an iteration of the algorithm; --- returns whether there was a swap of rows. go :: Int -> ST s Bool go k = do mu <- findPivot k let swapped = k /= mu when swapped $ swapRows k mu akk <- readDataFrameOff aPtr (k*(n+1)) when (akk /= 0) $ do let rakk = recip akk forM_ [k+1..n-1] $ \i -> do let ni = n*i aik <- (rakk *) <$> readDataFrameOff aPtr (ni + k) writeDataFrameOff aPtr (ni + k) aik forM_ [k+1..n-1] $ \j -> do akj <- readDataFrameOff aPtr (n*k + j) aij <- readDataFrameOff aPtr (ni + j) writeDataFrameOff aPtr (ni + j) (aij - aik*akj) return swapped findPivot :: Int -> ST s Int findPivot k = snd <$> foldM findPivotF (0, k) [k..n-1] where findPivotF :: (Scalar t, Int) -> Int -> ST s (Scalar t, Int) findPivotF aj@(a, _) i = do x <- abs <$> readDataFrameOff aPtr (n*i + k) return (if x > a then (x, i) else aj) swapRows :: Int -> Int -> ST s () swapRows i j = do let iPtr = subDataFrameView' (fromIntegral i :* U) aPtr jPtr = subDataFrameView' (fromIntegral j :* U) aPtr copyMutableDataFrame' U iPtr temp copyMutableDataFrame' U jPtr iPtr copyMutableDataFrame' U temp jPtr t <- readDataFrameOff pPtr i readDataFrameOff pPtr j >>= writeDataFrameOff pPtr i writeDataFrameOff pPtr j t