module Data.Sparse.Internal.TriMatrix where
import qualified Data.IntMap.Strict as IM
import Data.IntMap.Strict ((!))
import Data.Foldable (foldrM)
import Data.List (sort)
import Data.Maybe (fromMaybe)
import Data.Complex
import qualified Data.Graph as G
import qualified Data.Tree as T
import Numeric.Eps
import Data.Sparse.Types
import Data.Sparse.Internal.SList
import Numeric.LinearAlgebra.Class
import Data.Sparse.Internal.SVector
import qualified Data.Sparse.Internal.SVector.Mutable as SMV
import Data.Sparse.SpMatrix (fromListSM, fromListDenseSM, insertSpMatrix, zeroSM, transposeSM, sparsifySM)
import Data.Sparse.Common (prd, prd0, (@@!), nrows, ncols, lookupSM, extractRow, extractCol, SpVector, SpMatrix, foldlWithKeySV, (##), (#~#))
import Control.Iterative (IterationConfig(IterConf), modifyUntilM, modifyUntilM')
import Data.Sparse.PPrint
import Control.Monad.Catch (MonadThrow, throwM)
import Control.Exception.Common
import Control.Monad (when)
import Control.Monad.IO.Class
import Control.Monad.Trans.State
import Control.Monad.Primitive
import qualified Data.Vector as V (Vector, freeze, thaw, fromList, toList, zip)
import qualified Data.Vector.Mutable as VM (MVector, new, set, write, read, clone)
flattenForest :: T.Forest a -> [a]
flattenForest = concatMap T.flatten
reachableFromRHS :: G.Graph -> V.Vector Int -> V.Vector Int
reachableFromRHS g vs = V.fromList . sort . flattenForest $ G.dfs (G.transposeG g) (V.toList vs)
initializeSoln ::
(PrimMonad m, Num a) => V.Vector Int -> SVector a -> m (SMV.SMVector m a)
initializeSoln ixnz (SV n ixb b) = do
let nnzx = length ixnz
xm <- VM.new nnzx
ixnzm <- V.thaw ixnz
VM.set xm 0
SMV.fromListOverwrite ixnzm xm n $ V.toList (V.zip ixb b)
tlUpdateSubdiag :: VectorSpace v => v -> Scalar v -> v -> v
tlUpdateSubdiag lsubdiag xj x = x ^-^ (xj .* lsubdiag)
g0 = G.buildG (0,2) [(0,0),(2,0),(1,1),(2,2)]
t0 :: T.Forest Int
t0 = G.dfs (G.transposeG g0) [0]
g1 = G.buildG (0, 4) [(0,0), (1,0), (1,1), (2,0), (2,2), (3,1), (3,2), (3,3), (4,0), (4,4)]
newtype TriMatrix a = TM { unTM :: IM.IntMap (SList a) } deriving (Show, Functor)
emptyIMSL :: Int -> IM.IntMap (SList a)
emptyIMSL n = IM.fromList [(i, emptySL) | i <- [0 .. n1]]
emptyTM :: Int -> TriMatrix a
emptyTM n = TM (emptyIMSL n)
appendIM :: IM.Key -> (Int, a) -> IM.IntMap (SList a) -> IM.IntMap (SList a)
appendIM i x im = IM.insert i (x `consSL` e) im where
e = fromMaybe emptySL (IM.lookup i im)
lookupWD :: Num a =>
(irow -> mat -> Maybe row)
-> (jcol -> row -> Maybe a)
-> mat
-> irow
-> jcol
-> a
lookupWD rlu clu aa i j = fromMaybe 0 (rlu i aa >>= clu j)
lu :: (Scalar (SpVector t) ~ t, Elt t, AdditiveGroup t, VectorSpace (SpVector t),
MonadThrow m, MonadIO m, PrintDense (SpMatrix t),
Epsilon t) =>
SpMatrix t -> m (SpMatrix t, SpMatrix t)
lu amat = do
let d@(m,n) = (nrows amat, ncols amat)
q (_, _, i) = i == m
luInit = (lmat0, umat0, 1) where
urow0 = extractRow amat 0
lcol0 = extractCol amat 0 ./ (urow0 @@ 0)
umat0 = foldlWithKeySV ins (emptyIMSL n) urow0
lmat0 = IM.insert 0 (SL [(0, 1)]) l0 where
l0 = foldlWithKeySV ins (emptyIMSL m) lcol0
ins acc i x = appendIM i (0, x) acc
luStep (lmat, umat, i) = do
let (umat', uii) = uStep amat lmat umat i
when (nearZero uii) $
throwM (NeedsPivoting "LU" (unwords ["U", show (i,i)]) :: MatrixException Double)
let lmat' = lStep amat lmat umat' uii i
return (lmat', umat', i + 1)
(lfin, ufin, _) <- modifyUntilM' (luConfig d) q luStep luInit
let uu = fillSM d True ufin
ll = fillSM d False lfin
return (ll, uu)
luConfig d = IterConf 0 True vf prf where
vf (l, u, _) = (l, u)
prf (l, u) = do
prd0 $ fillSM d False l
prd0 $ fillSM d True u
uStep :: (Elt a, Epsilon a, AdditiveGroup a) =>
SpMatrix a
-> IM.IntMap (SList a)
-> IM.IntMap (SList a)
-> IM.Key
-> (IM.IntMap (SList a), a)
uStep amat lmat umat i = (foldr ins umat [i .. n1], udiag) where
n = ncols amat
udiag = amat@@!(i,i) (li <.> (umat ! i))
li = lmat ! i
ins j acc
| i == j = appendIM j (i, udiag) acc
| isNz uij = appendIM j (i, uij) acc
| otherwise = acc where
uij = aij li <.> uj
aij = amat @@! (i,j)
uj = umat ! j
lStep :: (Elt a, Epsilon a, AdditiveGroup a) =>
SpMatrix a
-> IM.IntMap (SList a)
-> IM.IntMap (SList a)
-> a
-> IM.Key
-> IM.IntMap (SList a)
lStep amat lmat umat udiag j = foldr ins lmat [j .. m1] where
m = nrows amat
uj = umat ! j
ins i acc
| i == j = appendIM i (j, 1) acc
| isNz lij = appendIM i (j, lij) acc
| otherwise = acc where
lij = (aij li <.> uj)/udiag
aij = amat @@! (i,j)
li = lmat ! i
fillSM :: (Rows, Cols) -> Bool -> IM.IntMap (SList a) -> SpMatrix a
fillSM (m,n) transpq tm = IM.foldlWithKey rowIns (zeroSM m n) tm where
rowIns accRow i row = foldr ins accRow (unSL row) where
ins (j, x) acc | transpq = insertSpMatrix j i x acc
| otherwise = insertSpMatrix i j x acc
test mm = do
(l, u) <- lu mm
prd l
prd u
prd mm
prd $ l #~# u
tm0, tm2, tm4, tm9 :: SpMatrix Double
tm0 = fromListDenseSM 2 [1,3,2,4]
tm2 = fromListDenseSM 3 [12, 6, 4, 51, 167, 24, 4, 68, 41]
tm4 = sparsifySM $ fromListDenseSM 4 [1,0,0,0,2,5,0,10,3,6,8,11,4,7,9,12]
tm9 = fromListSM (4, 4) [(0,0,pi), (1,1, 3), (3, 0, 23), (1,3, 45), (2,2,4), (3,2, 1), (3,1, 5), (3,3, exp 1)]
tmc4 :: SpMatrix (Complex Double)
tmc4 = fromListDenseSM 3 [3:+1, 4:+(1), (5):+3, 2:+2, 3:+(2), 5:+0.2, 7:+(2), 9:+(1), 2:+3]
data T x
= Leaf
| Branch (T x) x (T x)
deriving (Functor, Foldable, Traversable, Show)
next :: Monad m => StateT Int m Int
next = do x <- get ; modify (+ 1) ; return x
next' f = do x <- get; modify f; return x
labelElt :: Monad m => x -> StateT Int m (Int, x)
labelElt x = (,) <$> next <*> pure x