{-# LANGUAGE DataKinds #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {- | A few ways to solve a system of linear equations in ST monad. The tesult is always computed inplace. -} module Numeric.Subroutine.SolveTriangular ( solveUpperTriangularR , solveUpperTriangularL , solveLowerTriangularR , solveLowerTriangularL ) where import Control.Monad import Control.Monad.ST import Data.Kind import Numeric.DataFrame.Internal.PrimArray import Numeric.DataFrame.ST import Numeric.DataFrame.Type import Numeric.Dimensions {- | Solve a system of linear equations \( Rx = b \) or a linear least squares problem \( \min {|| Rx - b ||}^2 \), where \( R \) is an upper-triangular matrix. DataFrame \( b \) is modified in-place; by the end of the process \( b_m = x \). NB: you can use `subDataFrameView` to truncate @b@ without performing a copy. -} solveUpperTriangularR :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat]) . (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) => DataFrame t '[n,m] -- ^ \(R\) -> STDataFrame s t (m :+ ds) -- ^ Current state of \(b_m\) -- (first @m@ rows of @b@) -> ST s () solveUpperTriangularR r bPtr | Dict <- Dict @(m <= n) = mapM_ go [m-1,m-2..0] where m = fromIntegral $ dimVal' @m :: Int k = fromIntegral kw :: Int CumulDims (_:kw:_) = getDataFrameSteps bPtr go :: Int -> ST s () go i | rii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (ki + j) 0 | otherwise = forM_ [0..k-1] $ \j -> do bij <- (rrii*) <$> readDataFrameOff bPtr (ki + j) writeDataFrameOff bPtr (ki + j) bij forM_ [0..i-1] $ \t -> do let rti = scalar (ixOff (m*t + i) r) ix = k*t + j btj <- readDataFrameOff bPtr ix writeDataFrameOff bPtr ix (btj - rti*bij) where mi = m*i ki = k*i rii = scalar (ixOff (mi + i) r) rrii = recip rii {- | Solve a system of linear equations \( xR = b \), where \( R \) is an upper-triangular matrix. DataFrame \( b \) is modified in-place; by the end of the process \( b = x_m \). The \( (n - m) \) rows of \(R\) are not used. Pad each dimension of \(x\) with \( (n - m) \) zeros if you want to get the full solution. -} solveUpperTriangularL :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat]) . (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n) => STDataFrame s t (ds +: m) -- ^ Current state of \(b\) -- (first @m@ "columns" of x) -> DataFrame t '[n,m] -- ^ \(R\) -> ST s () solveUpperTriangularL bPtr r | Dict <- Dict @(m <= n) = mapM_ go [0..m-1] where m = fromIntegral $ dimVal' @m :: Int k = fromIntegral mkw `quot` m CumulDims (mkw:_) = getDataFrameSteps bPtr go :: Int -> ST s () go i | rii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (m*j + i) 0 | otherwise = forM_ [0..k-1] $ \j -> do let mj = m*j bji <- (rrii*) <$> readDataFrameOff bPtr (mj + i) writeDataFrameOff bPtr (mj + i) bji forM_ [i+1..m-1] $ \t -> do let rit = scalar (ixOff (mi + t) r) bjt <- readDataFrameOff bPtr (mj + t) writeDataFrameOff bPtr (mj + t) (bjt - rit*bji) where mi = m*i rii = scalar (ixOff (mi + i) r) rrii = recip rii {- | Solve a system of linear equations \( Lx = b \), where \( L \) is a lower-triangular matrix. DataFrame \( b \) is modified in-place; by the end of the process \( b = x_n \). The \( (m - n) \) columns of \(L\) are not used. Pad \(x\) with \( (m - n) \) zero elements if you want to get the full solution. -} solveLowerTriangularR :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat]) . (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m) => DataFrame t '[n,m] -- ^ \(L\) -> STDataFrame s t (n :+ ds) -- ^ Current state of \(b\) -- (first @n@ elements of x) -> ST s () solveLowerTriangularR l bPtr | Dict <- Dict @(n <= m) = mapM_ go [0..n-1] where m = fromIntegral $ dimVal' @m :: Int n = fromIntegral $ dimVal' @n :: Int k = fromIntegral kw :: Int CumulDims (_:kw:_) = getDataFrameSteps bPtr go :: Int -> ST s () go i | lii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (ki + j) 0 | otherwise = forM_ [0..k-1] $ \j -> do bij <- (rlii*) <$> readDataFrameOff bPtr (ki + j) writeDataFrameOff bPtr (ki + j) bij forM_ [i+1..n-1] $ \t -> do let rti = scalar (ixOff (m*t + i) l) ix = k*t + j btj <- readDataFrameOff bPtr ix writeDataFrameOff bPtr ix (btj - rti*bij) where mi = m*i ki = k*i lii = scalar (ixOff (mi + i) l) rlii = recip lii {- | Solve a system of linear equations \( xL = b \) or a linear least squares problem \( \min {|| xL - b ||}^2 \), where \( L \) is a lower-triangular matrix. DataFrame \( b \) is modified in-place; by the end of the process \( b_n = x \). The last \( (m - n) \) columns of \(L\) and \(b\) and are not touched. -} solveLowerTriangularL :: forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat]) . (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m) => STDataFrame s t (ds +: m) -- ^ Current state of \(b\) -> DataFrame t '[n,m] -- ^ \(L\) -> ST s () solveLowerTriangularL bPtr l | Dict <- Dict @(n <= m) = mapM_ go [n-1,n-2..0] where m = fromIntegral $ dimVal' @m :: Int n = fromIntegral $ dimVal' @n :: Int k = fromIntegral kmw `quot` m CumulDims (kmw:_) = getDataFrameSteps bPtr go :: Int -> ST s () go i | lii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (m*j + i) 0 | otherwise = forM_ [0..k-1] $ \j -> do let mj = m*j bji <- (rlii*) <$> readDataFrameOff bPtr (mj + i) writeDataFrameOff bPtr (mj + i) bji forM_ [0..i-1] $ \t -> do let lit = scalar (ixOff (mi + t) l) ix = m*j + t bjt <- readDataFrameOff bPtr ix writeDataFrameOff bPtr ix (bjt - lit*bji) where mi = m*i lii = scalar (ixOff (mi + i) l) rlii = recip lii