{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.Matrix.LU
( MatrixLU (..), LU (..)
, luSolveR, luSolveL
, detViaLU, inverseViaLU
) where
import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.SubSpace
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Matrix.Internal
import Numeric.Scalar.Internal
import Numeric.Subroutine.SolveTriangular
data LU (t :: Type) (n :: Nat)
= LU
{ luLower :: Matrix t n n
, luUpper :: Matrix t n n
, luPerm :: Matrix t n n
, luPermDet :: Scalar t
}
deriving instance (Show t, PrimBytes t, KnownDim n) => Show (LU t n)
deriving instance (Eq (Matrix t n n), Eq t) => Eq (LU t n)
class (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
=> MatrixLU t (n :: Nat) where
lu :: Matrix t n n -> LU t n
instance (KnownDim n, Ord t, Fractional t, PrimBytes t, KnownBackend t '[n,n])
=> MatrixLU t n where
lu a = runST $ do
pPtr <- unsafeThawDataFrame $ iwgen @_ @'[n] @'[] (\(Idx i :* U) -> S i)
uPtr <- thawDataFrame a
lPtr <- newDataFrame
temp <- newDataFrame
detPositive <- luInplace temp pPtr uPtr
p <- unsafeFreezeDataFrame pPtr
forM_ [0..n-1] $ \i -> do
let ni = n*i
forM_ [0..n-1] $ \j -> case compare i j of
GT -> do
lij <- readDataFrameOff uPtr (ni + j)
writeDataFrameOff uPtr (ni + j) 0
writeDataFrameOff lPtr (ni + j) lij
EQ -> writeDataFrameOff lPtr (ni + j) 1
LT -> writeDataFrameOff lPtr (ni + j) 0
luLower <- unsafeFreezeDataFrame lPtr
luUpper <- unsafeFreezeDataFrame uPtr
let luPermDet = if detPositive then 1 else -1
luPerm = iwgen @_ @'[n,n] @'[]
(\(Idx i :* Idx j :* U) -> if S j == p ! i then 1 else 0)
return LU {..}
where
n = fromIntegral (dimVal' @n) :: Int
luSolveR ::
forall t (n :: Nat) (ds :: [Nat])
. (MatrixLU t n, Dimensions ds)
=> LU t n -> DataFrame t (n :+ ds) -> DataFrame t (n :+ ds)
luSolveR LU {..} b = runST $ do
xPtr <- thawDataFrame (luPerm %* b)
solveLowerTriangularR luLower xPtr
solveUpperTriangularR luUpper xPtr
unsafeFreezeDataFrame xPtr
luSolveL ::
forall t (n :: Nat) (ds :: [Nat])
. (MatrixLU t n, Dimensions ds)
=> LU t n -> DataFrame t (ds +: n) -> DataFrame t (ds +: n)
luSolveL LU {..} b
| dn <- dim @n
, dds <- dims @ds
, Dims <- Snoc dds dn
, Dict <- Dict @(SnocList ds n _)
= runST $ do
xPtr <- thawDataFrame b
solveUpperTriangularL xPtr luUpper
solveLowerTriangularL xPtr luLower
(%* luPerm) <$> unsafeFreezeDataFrame xPtr
luSolveL _ _ = error "luSolveL: impossible pattern"
inverseViaLU :: forall (t :: Type) (n :: Nat)
. MatrixLU t n => Matrix t n n -> Matrix t n n
inverseViaLU a = runST $ do
xPtr <- unsafeThawDataFrame luPerm
solveLowerTriangularR luLower xPtr
solveUpperTriangularR luUpper xPtr
unsafeFreezeDataFrame xPtr
where
LU {..} = lu a
detViaLU :: forall (t :: Type) (n :: Nat)
. MatrixLU t n => Matrix t n n -> Scalar t
detViaLU m = foldl (\x off -> scalar (ixOff off luUpper) * x) luPermDet [0,n+1..n*n]
where
n = fromIntegral (dimVal' @n) :: Int
LU {..} = lu m
luInplace ::
forall (s :: Type) (t :: Type) (n :: Nat)
. (PrimBytes t, Fractional t, Ord t, KnownDim n)
=> STDataFrame s t '[n]
-> STDataFrame s Word '[n]
-> STDataFrame s t '[n,n]
-> ST s Bool
luInplace temp pPtr aPtr = foldM (\b -> fmap (b /=) . go) True [0..n-2]
where
n = fromIntegral (dimVal' @n) :: Int
go :: Int -> ST s Bool
go k = do
mu <- findPivot k
let swapped = k /= mu
when swapped $ swapRows k mu
akk <- readDataFrameOff aPtr (k*(n+1))
when (akk /= 0) $ do
let rakk = recip akk
forM_ [k+1..n-1] $ \i -> do
let ni = n*i
aik <- (rakk *) <$> readDataFrameOff aPtr (ni + k)
writeDataFrameOff aPtr (ni + k) aik
forM_ [k+1..n-1] $ \j -> do
akj <- readDataFrameOff aPtr (n*k + j)
aij <- readDataFrameOff aPtr (ni + j)
writeDataFrameOff aPtr (ni + j) (aij - aik*akj)
return swapped
findPivot :: Int -> ST s Int
findPivot k = snd <$> foldM findPivotF (0, k) [k..n-1]
where
findPivotF :: (Scalar t, Int) -> Int -> ST s (Scalar t, Int)
findPivotF aj@(a, _) i = do
x <- abs <$> readDataFrameOff aPtr (n*i + k)
return (if x > a then (x, i) else aj)
swapRows :: Int -> Int -> ST s ()
swapRows i j = do
let iPtr = subDataFrameView' (fromIntegral i :* U) aPtr
jPtr = subDataFrameView' (fromIntegral j :* U) aPtr
copyMutableDataFrame' U iPtr temp
copyMutableDataFrame' U jPtr iPtr
copyMutableDataFrame' U temp jPtr
t <- readDataFrameOff pPtr i
readDataFrameOff pPtr j >>= writeDataFrameOff pPtr i
writeDataFrameOff pPtr j t