{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts, TypeFamilies, MultiParamTypeClasses, FlexibleInstances #-}
{-# language ApplicativeDo #-}
module Numeric.LinearAlgebra.Sparse
(
(<\>),
pinv,
luSolve,
triLowerSolve,
triUpperSolve,
eigsQR,
eigsArnoldi,
qr,
lu,
chol,
arnoldi,
givens,
conditionNumberSM,
hhRefl,
diagPartitions,
fromListSV, toListSV,
vr, vc,
fromVector, toVectorDense,
constv,
fromListSM, toListSM,
fromRowsL, toRowsL,
fromColsL, toColsL,
fromRowsV, fromColsV,
(-=-), (-||-), fromBlocksDiag,
eye, mkDiagonal, mkSubDiagonal,
permutationSM, permutPairsSM,
isOrthogonalSM, isDiagonalSM,
filterSV, ifilterSV,
nearZero, nearOne, isNz,
(.*), (./),
(<.>),
(#>), (<#),
(##), (#^#), (##^),
(#~#), (#~^#), (#~#^),
(><),
dim, nnz, spy,
cvx,
norm, norm2, norm2', normalize, normalize2, normalize2',
norm1, hilbertDistSq,
transpose, trace, normFrobenius,
prd, prd0,
untilConvergedG0, untilConvergedG, untilConvergedGM,
modifyInspectGuarded, modifyInspectGuardedM, IterationConfig (..),
modifyUntil, modifyUntilM,
linSolve0, LinSolveMethod(..),
PartialFunctionError,InputError, OutOfBoundsIndexError,
OperandSizeMismatch, MatrixException, IterationException
)
where
import Control.Exception.Common
import Control.Iterative
import Data.Sparse.Common
import Control.Monad.Catch
import Data.Typeable
import Control.Monad.State.Strict
import qualified Control.Monad.Trans.State as MTS
import Data.Complex
import qualified Data.Sparse.Internal.IntM as I
import Data.Maybe
import qualified Data.Vector as V
type Num' x = (Epsilon x, Elt x, Show x, Ord x, Typeable x)
conditionNumberSM :: (MonadThrow m, MonadIO m, MatrixRing (SpMatrix a),
PrintDense (SpMatrix a), Num' a) =>
SpMatrix a -> m a
conditionNumberSM m = do
(_, r) <- qr m
let
u = extractDiagDense r
lmax = abs (maximum u)
lmin = abs (minimum u)
kappa = lmax / lmin
if nearZero lmin
then throwM (HugeConditionNumber "conditionNumberSM" kappa)
else return kappa
hhMat :: (Num a, AdditiveGroup a) => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ beta `scale` (x >< x) where
n = dim x
hhRefl :: (Num a, AdditiveGroup a) => SpVector a -> SpMatrix a
hhRefl = hhMat 2
{-# inline givens #-}
givens :: (Elt a, MonadThrow m) => SpMatrix a -> IxRow -> IxCol -> m (SpMatrix a)
givens aa i j
| isValidIxSM aa (i,j) && nrows aa >= ncols aa = do
i' <- candidateRows' (immSM aa) i j
return $ givensMat aa i i' j
| otherwise = throwM (OOBIxsError "Givens" [i, j])
where
givensMat mm i i' j =
fromListSM'
[(i,i, conj c), (i,j, - conj s),
(j,i, s), (j,j, c)]
(eye (nrows mm))
where
(c, s, _) = givensCoef a b
a = mm @@ (i', j)
b = mm @@ (i, j)
candidateRows' mm i j | null u = throwM (OOBNoCompatRows "Givens" (i,j))
| otherwise = return $ head (I.keys u) where
u = I.filterWithKey (\irow row -> irow /= i &&
firstNZColumn row j) mm
firstNZColumn m k = isJust (I.lookup k m) &&
isNothing (I.lookupLT k m)
givensCoef :: Elt t => t -> t -> (t, t, t)
givensCoef u v = (c0/r, s0/r, r) where
c0 = conj u
s0 = conj v
r = hypot u v
hypot :: Elt a => a -> a -> a
hypot x y = sqrt (mag2 x + mag2 y) where
mag2 i = i * conj i
{-# inline qr #-}
qr :: (Elt a, MatrixRing (SpMatrix a), PrintDense (SpMatrix a),
Epsilon a, MonadThrow m, MonadIO m) =>
SpMatrix a
-> m (SpMatrix a, SpMatrix a)
qr mm = do
(qt, r, _) <- modifyUntilM' config haltf qrstepf gminit
return (transpose qt, r)
where
gminit = (eye (nrows mm), mm, subdiagIndicesSM mm)
haltf (_, _, iis) = null iis
config = IterConf 0 False fst2 prd2 where
fst2 (x,y,_) = (x,y)
prd2 (x,y) = do
prd0 x
prd0 y
qrstepf (qmatt, m, iis) = do
let (i, j) = head iis
g <- givens m i j
let
qmatt' = g #~# qmatt
m' = g #~# m
return (qmatt', m', tail iis)
eigsQR :: (MonadThrow m, MonadIO m, Num' a, Normed (SpVector a), MatrixRing (SpMatrix a), Typeable (Magnitude (SpVector a)), PrintDense (SpVector a), PrintDense (SpMatrix a)) =>
Int
-> Bool
-> SpMatrix a
-> m (SpVector a)
eigsQR nitermax debq m = pf <$> untilConvergedGM "eigsQR" c (const True) stepf m
where
pf = extractDiagDense
c = IterConf nitermax debq pf prd
stepf mm = do
(q, _) <- qr mm
return $ q #~^# (m ## q)
eigsArnoldi :: (Scalar (SpVector t) ~ t, MatrixType (SpVector t) ~ SpMatrix t,
Elt t, V (SpVector t), Epsilon t, PrintDense (SpMatrix t),
MatrixRing (SpMatrix t), MonadThrow m, MonadIO m) =>
Int
-> SpMatrix t
-> SpVector t
-> m (SpMatrix t, SpMatrix t, SpVector t)
eigsArnoldi nitermax aa b = do
(q, h) <- arnoldi aa b nitermax
(o, r) <- qr h
return (q, o, extractDiagDense r)
hhV :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t), Epsilon t) =>
SpVector t -> (SpVector t, t)
hhV x = (v, beta) where
tx = tailSV x
sigma = tx <.> tx
vtemp = singletonSV 1 `concatSV` tx
(v, beta) | nearZero sigma = (vtemp, 0)
| otherwise = let mu = sqrt (headSV x**2 + sigma)
xh = headSV x
vh | mag xh <= 1 = xh - mu
| otherwise = - sigma / (xh + mu)
vnew = (1 / vh) `scale` insertSpVector 0 vh vtemp
in (vnew, 2 * xh**2 / (sigma + vh**2))
chol :: (Elt a, Epsilon a, MonadThrow m, MonadIO m, PrintDense (SpMatrix a)) =>
SpMatrix a
-> m (SpMatrix a)
chol aa = do
let n = nrows aa
q (i, _) = i == n
config = IterConf 0 False snd prd0
l0 <- cholUpd aa (0, zeroSM n n)
(_, lfin) <- modifyUntilM' config q (cholUpd aa) l0
return lfin
where
oops i = throwM (NeedsPivoting "chol" (unwords ["L", show (i,i)]) :: MatrixException Double)
cholUpd aa (i, ll) = do
sd <- cholSDRowUpd aa ll i
ll' <- cholDiagUpd aa sd i
return (i + 1, ll')
cholSDRowUpd aa ll i = do
lrs <- fromListSV (i + 1) <$> onRangeSparseM cholSubDiag [0 .. i-1]
return $ insertRow ll lrs i where
cholSubDiag j | isNz ljj = return $ 1/ljj*(aij - inn)
| otherwise = oops j
where
ljj = ll @@! (j, j)
aij = aa @@! (i, j)
inn = contractSub ll ll i j (j - 1)
cholDiagUpd aa ll i = do
cd <- cholDiag
return $ insertSpMatrix i i cd ll where
cholDiag | i == 0 = sqrt <$> aai
| otherwise = do
a <- aai
let l = sum (fmap (**2) lrow)
return $ sqrt (a - l)
where
lrow = ifilterSV (\j _ -> j < i) (extractRow ll i)
aai | isNz aaii = return aaii
| otherwise = oops i
where
aaii = aa @@! (i,i)
lu :: (Scalar (SpVector t) ~ t, Elt t, VectorSpace (SpVector t), Epsilon t,
MonadThrow m) =>
SpMatrix t
-> m (SpMatrix t, SpMatrix t)
lu aa = do
let oops j = throwM (NeedsPivoting "solveForLij" ("U" ++ show (j, j)) :: MatrixException Double)
n = nrows aa
q (i, _, _) = i == n - 1
luInit | isNz u00 = return (1, l0, u0)
| otherwise = oops (0 :: Int)
where
l0 = insertCol (eye n) (extractSubCol aa 0 (1, n-1) ./ u00 ) 0
u0 = insertRow (zeroSM n n) (extractRow aa 0) 0
u00 = u0 @@! (0,0)
luUpd (i, l, u) = do
u' <- uUpd aa n (i, l, u)
l' <- lUpd (i, l, u')
return (i + 1, l', u')
uUpd aa n (ix, lmat, umat) = do
let us = onRangeSparse (solveForUij ix) [ix .. n - 1]
solveForUij i j = a - p where
a = aa @@! (i, j)
p = contractSub lmat umat i j (i - 1)
return $ insertRow umat (fromListSV n us) ix
lUpd (ix, lmat, umat) = do
ls <- lsm
return $ insertCol lmat (fromListSV n ls) ix
where
lsm = onRangeSparseM (`solveForLij` ix) [ix + 1 .. n - 1]
solveForLij i j
| isNz ujj = return $ (a - p)/ujj
| otherwise = oops j
where
a = aa @@! (i, j)
ujj = umat @@! (j , j)
p = contractSub lmat umat i j (i - 1)
s0 <- luInit
(ixf, lf, uf) <- MTS.execStateT (modifyUntilM q luUpd) s0
ufin <- uUpd aa n (ixf, lf, uf)
return (lf, ufin)
lu' aa = do
let oops j = throwM (NeedsPivoting "solveForLij" ("U" ++ show (j, j)) :: MatrixException Double)
n = nrows aa
q (i, _, _) = i == n - 1
luInit | isNz u00 = return (1, l0, u0)
| otherwise = oops (0 :: Int)
where
l0 = insertCol (eye n) ((extractSubCol aa 0 (1, n-1)) ./ u00 ) 0
u0 = insertRow (zeroSM n n) (extractRow aa 0) 0
u00 = u0 @@! (0,0)
luUpd (i, l, u) = do
u' <- uUpd aa n (i, l, u)
l' <- lUpd (i, l, u')
return (i + 1, l', u')
uUpd aa n (ix, lmat, umat) = do
let us = onRangeSparse (solveForUij ix) [ix .. n - 1]
solveForUij i j = a - p where
a = aa @@! (i, j)
p = contractSub lmat umat i j (i - 1)
return $ insertRow umat (fromListSV n us) ix
lUpd (ix, lmat, umat) = do
ls <- lsm
return $ insertCol lmat (fromListSV n ls) ix
where
lsm = onRangeSparseM (`solveForLij` ix) [ix + 1 .. n - 1]
solveForLij i j
| isNz ujj = return $ (a - p)/ujj
| otherwise = oops j
where
a = aa @@! (i, j)
ujj = umat @@! (j , j)
p = contractSub lmat umat i j (i - 1)
s0 <- luInit
let config = IterConf 0 True vf prd2 where
vf (_, l, u) = (l, u)
prd2 (x, y) = do
prd0 x
prd0 y
(ixf, lf, uf) <- modifyUntilM' config q luUpd s0
ufin <- uUpd aa n (ixf, lf, uf)
return (lf, ufin)
tmc4, tmc5, tmc6 :: SpMatrix (Complex Double)
tmc4 = fromListDenseSM 3 [3:+1, 4:+(-1), (-5):+3, 2:+2, 3:+(-2), 5:+0.2, 7:+(-2), 9:+(-1), 2:+3]
tvc4 = vc [1:+3,2:+2,1:+9]
tmc5 = fromListDenseSM 4 $ zipWith (:+) [16..31] [17,14..]
tmc6 = fromListDenseSM 2 $ zipWith (:+) [1,2,3,4] [4,3,2,1]
arnoldi :: (MatrixType (SpVector a) ~ SpMatrix a, V (SpVector a) ,
Scalar (SpVector a) ~ a, Epsilon a, MonadThrow m) =>
SpMatrix a
-> SpVector a
-> Int
-> m (SpMatrix a, SpMatrix a)
arnoldi aa b kn | n == nb = return (fromColsV qvfin, fromListSM (nmax + 1, nmax) hhfin)
| otherwise = throwM (MatVecSizeMismatchException "arnoldi" (m,n) nb)
where
(qvfin, hhfin, nmax, _) = execState (modifyUntil tf arnoldiStep) arnInit
tf (_, _, ii, fbreak) = ii == kn || fbreak
(m, n) = (nrows aa, ncols aa)
nb = dim b
arnInit = (qv1, hh1, 1, False) where
q0 = normalize2 b
aq0 = aa #> q0
h11 = q0 `dot` aq0
q1nn = aq0 ^-^ (h11 .* q0)
hh1 = V.fromList [(0, 0, h11), (1, 0, h21)] where
h21 = norm2' q1nn
q1 = normalize2 q1nn
qv1 = V.fromList [q0, q1]
arnoldiStep (qv, hh, i, _) = (qv', hh', i + 1, breakf) where
qi = V.last qv
aqi = aa #> qi
hhcoli = fmap (`dot` aqi) qv
zv = zeroSV m
qipnn =
aqi ^-^ V.foldl' (^+^) zv (V.zipWith (.*) hhcoli qv)
qipnorm = norm2' qipnn
qip = normalize2 qipnn
hh' = (V.++) hh (indexed2 $ V.snoc hhcoli qipnorm) where
indexed2 v = V.zip3 ii jj v
ii = V.fromList [0 .. n]
jj = V.replicate (n + 1) i
qv' = V.snoc qv qip
breakf | nearZero qipnorm = True
| otherwise = False
diagPartitions :: SpMatrix a
-> (SpMatrix a, SpMatrix a, SpMatrix a)
diagPartitions aa = (e,d,f) where
e = extractSubDiag aa
d = extractDiag aa
f = extractSuperDiag aa
jacobiPre :: Fractional a => SpMatrix a -> SpMatrix a
jacobiPre x = recip <$> extractDiag x
ilu0Pre :: (Scalar (SpVector t) ~ t, Elt t, VectorSpace (SpVector t),
Epsilon t, MonadThrow m) =>
SpMatrix t
-> m (SpMatrix t, SpMatrix t)
ilu0Pre aa = do
(l, u) <- lu aa
let lh = sparsifyLU l aa
uh = sparsifyLU u aa
sparsifyLU m m2 = ifilterSM f m where
f i j _ = isJust (lookupSM m2 i j)
return (lh, uh)
mSsorPre :: (MatrixRing (SpMatrix b), Fractional b) =>
SpMatrix b
-> b
-> (SpMatrix b, SpMatrix b)
mSsorPre aa omega = (l, r) where
(e, d, f) = diagPartitions aa
n = nrows e
l = (eye n ^-^ scale omega e) ## reciprocal d
r = d ^-^ scale omega f
luSolveConfig :: PrintDense (SpVector t) => IterationConfig (SpVector t, IxRow) (SpVector t)
luSolveConfig = IterConf 0 False fst prd0
luSolve :: (Scalar (SpVector t) ~ t, MonadThrow m, Elt t, InnerSpace (SpVector t),
Epsilon t, PrintDense (SpVector t), MonadIO m) =>
SpMatrix t
-> SpMatrix t
-> SpVector t
-> m (SpVector t)
luSolve ll uu b
| isLowerTriSM ll && isUpperTriSM uu = do
w <- triLowerSolve0 luSolveConfig ll b
triUpperSolve0 luSolveConfig uu w
| otherwise = throwM (NonTriangularException "luSolve")
triLowerSolve
:: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
PrintDense (SpVector t), Epsilon t, MonadThrow m, MonadIO m) =>
SpMatrix t -> SpVector t -> m (SpVector t)
triLowerSolve = triLowerSolve0 luSolveConfig
triLowerSolve0 :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
Epsilon t, MonadThrow m, MonadIO m) =>
IterationConfig (SpVector t, IxRow) b
-> SpMatrix t -> SpVector t -> m (SpVector t)
triLowerSolve0 config ll b = do
let q (_, i) = i == nb
nb = svDim b
oops i = throwM (NeedsPivoting "triLowerSolve" (unwords ["L", show (i, i)]) :: MatrixException Double)
lStep (ww, i) = do
let
lii = ll @@ (i, i)
bi = b @@ i
wi = (bi - r)/lii where
r = extractSubRow ll i (0, i-1) `dot` takeSV i ww
if isNz lii
then return (insertSpVector i wi ww, i + 1)
else oops i
lInit = do
let
l00 = ll @@ (0, 0)
b0 = b @@ 0
w0 = b0 / l00
if isNz l00
then return (insertSpVector 0 w0 $ zeroSV (dim b), 1)
else oops (0 :: Int)
l0 <- lInit
(v, _) <- modifyUntilM' config q lStep l0
return $ sparsifySV v
triUpperSolve
:: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
PrintDense (SpVector t), Epsilon t, MonadThrow m, MonadIO m) =>
SpMatrix t -> SpVector t -> m (SpVector t)
triUpperSolve = triUpperSolve0 luSolveConfig
triUpperSolve0 :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
Epsilon t, MonadThrow m, MonadIO m) =>
IterationConfig (SpVector t, IxRow) b
-> SpMatrix t -> SpVector t -> m (SpVector t)
triUpperSolve0 conf uu w = do
let q (_, i) = i == (- 1)
nw = svDim w
oops i = throwM (NeedsPivoting "triUpperSolve" (unwords ["U", show (i, i)]) :: MatrixException Double)
uStep (xx, i) = do
let uii = uu @@ (i, i)
wi = w @@ i
r = extractSubRow_RK uu i (i + 1, nw - 1) `dot` dropSV (i + 1) xx
xi = (wi - r) / uii
if isNz uii
then return (insertSpVector i xi xx, i - 1)
else oops i
uInit = do
let i = nw - 1
u00 = uu @@! (i, i)
w0 = w @@ i
x0 = w0 / u00
if isNz u00
then return (insertSpVector i x0 (zeroSV nw), i - 1)
else oops (0 :: Int)
u0 <- uInit
(x, _) <- modifyUntilM' conf q uStep u0
return $ sparsifySV x
gmres aa b = do
let m = ncols aa
(qa, ha) <- arnoldi aa b m
let b' = norm2' b .* ei mp1 1
where mp1 = nrows ha
(qh, rh) <- qr ha
let rhs' = takeSV (dim b' - 1) (transpose qh #> b')
rh' = takeRows (nrows rh - 1) rh
yhat <- triUpperSolve rh' rhs'
let qa' = takeCols (ncols qa - 1) qa
return $ qa' #> yhat
data CGNE a =
CGNE {_xCgne , _rCgne, _pCgne :: SpVector a} deriving Eq
instance Show a => Show (CGNE a) where
show (CGNE x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
cgneInit :: (MatrixType (SpVector a) ~ SpMatrix a,
LinearVectorSpace (SpVector a)) =>
SpMatrix a -> SpVector a -> SpVector a -> CGNE a
cgneInit aa b x0 = CGNE x0 r0 p0 where
r0 = b ^-^ (aa #> x0)
p0 = transposeSM aa #> r0
cgneStep :: (MatrixType (SpVector a) ~ SpMatrix a,
LinearVectorSpace (SpVector a), InnerSpace (SpVector a),
MatrixRing (SpMatrix a), Fractional (Scalar (SpVector a))) =>
SpMatrix a -> CGNE a -> CGNE a
cgneStep aa (CGNE x r p) = CGNE x1 r1 p1 where
alphai = (r `dot` r) / (p `dot` p)
x1 = x ^+^ (alphai .* p)
r1 = r ^-^ (alphai .* (aa #> p))
beta = (r1 `dot` r1) / (r `dot` r)
p1 = transpose aa #> r ^+^ (beta .* p)
data BCG a =
BCG { _xBcg, _rBcg, _rHatBcg, _pBcg, _pHatBcg :: SpVector a } deriving Eq
bcgInit :: LinearVectorSpace (SpVector a) =>
MatrixType (SpVector a) -> SpVector a -> SpVector a -> BCG a
bcgInit aa b x0 = BCG x0 r0 r0hat p0 p0hat where
r0 = b ^-^ (aa #> x0)
r0hat = r0
p0 = r0
p0hat = r0
bcgStep :: (MatrixType (SpVector a) ~ SpMatrix a,
LinearVectorSpace (SpVector a), InnerSpace (SpVector a),
MatrixRing (SpMatrix a), Fractional (Scalar (SpVector a))) =>
SpMatrix a -> BCG a -> BCG a
bcgStep aa (BCG x r rhat p phat) = BCG x1 r1 rhat1 p1 phat1 where
aap = aa #> p
alpha = (r `dot` rhat) / (aap `dot` phat)
x1 = x ^+^ (alpha .* p)
r1 = r ^-^ (alpha .* aap)
rhat1 = rhat ^-^ (alpha .* (transpose aa #> phat))
beta = (r1 `dot` rhat1) / (r `dot` rhat)
p1 = r1 ^+^ (beta .* p)
phat1 = rhat1 ^+^ (beta .* phat)
instance Show a => Show (BCG a) where
show (BCG x r rhat p phat) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"r_hat = " ++ show rhat ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"p_hat = " ++ show phat ++ "\n"
data CGS a = CGS { _x, _r, _p, _u :: SpVector a} deriving Eq
cgsInit :: LinearVectorSpace (SpVector a) =>
MatrixType (SpVector a) -> SpVector a -> SpVector a -> CGS a
cgsInit aa b x0 = CGS x0 r0 r0 r0 where
r0 = b ^-^ (aa #> x0)
cgsStep :: (V (SpVector a), Fractional (Scalar (SpVector a))) =>
MatrixType (SpVector a) -> SpVector a -> CGS a -> CGS a
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
where
aap = aa #> p
alphaj = (r `dot` rhat) / (aap `dot` rhat)
q = u ^-^ (alphaj .* aap)
xj1 = x ^+^ (alphaj .* (u ^+^ q))
rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q)))
betaj = (rj1 `dot` rhat) / (r `dot` rhat)
uj1 = rj1 ^+^ (betaj .* q)
pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))
instance (Show a) => Show (CGS a) where
show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"u = " ++ show u ++ "\n"
data BICGSTAB a =
BICGSTAB { _xBicgstab, _rBicgstab, _pBicgstab :: SpVector a} deriving Eq
bicgsInit :: LinearVectorSpace (SpVector a) =>
MatrixType (SpVector a) -> SpVector a -> SpVector a -> BICGSTAB a
bicgsInit aa b x0 = BICGSTAB x0 r0 r0 where
r0 = b ^-^ (aa #> x0)
bicgstabStep :: (V (SpVector a), Fractional (Scalar (SpVector a))) =>
MatrixType (SpVector a) -> SpVector a -> BICGSTAB a -> BICGSTAB a
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
aap = aa #> p
alphaj = (r <.> r0hat) / (aap <.> r0hat)
sj = r ^-^ (alphaj .* aap)
aasj = aa #> sj
omegaj = (aasj <.> sj) / (aasj <.> aasj)
xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)
rj1 = sj ^-^ (omegaj .* aasj)
betaj = (rj1 <.> r0hat)/(r <.> r0hat) * alphaj / omegaj
pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))
instance Show a => Show (BICGSTAB a) where
show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
pinv :: (LinearSystem v, MatrixRing (MatrixType v), MonadThrow m, MonadIO m) =>
MatrixType v -> v -> m v
pinv aa b = (aa #^# aa) <\> atb where
atb = transpose aa #> b
data LinSolveMethod = GMRES_
| CGNE_
| BCG_
| CGS_
| BICGSTAB_
deriving (Eq, Show)
linSolve0 method aa b x0
| m /= nb = throwM (MatVecSizeMismatchException "linSolve0" dm nb)
| otherwise = solve aa b where
solve aa' b' | isDiagonalSM aa' = return $ reciprocal aa' #> b'
| otherwise = xHat
xHat = case method of
BICGSTAB_ -> solver "BICGSTAB" nits _xBicgstab (bicgstabStep aa r0hat) (bicgsInit aa b x0)
BCG_ -> solver "BCG" nits _xBcg (bcgStep aa) (bcgInit aa b x0)
CGS_ -> solver "CGS" nits _x (cgsStep aa r0hat) (cgsInit aa b x0)
GMRES_ -> gmres aa b
CGNE_ -> solver "CGNE" nits _xCgne (cgneStep aa) (cgneInit aa b x0)
r0hat = b ^-^ (aa #> x0)
nits = 200
dm@(m,n) = dim aa
nb = dim b
solver fname nitermax fproj stepf initf = do
xf <- untilConvergedG fname config (const True) stepf initf
return $ fproj xf
where
config = IterConf nitermax False fproj prd0
instance LinearSystem (SpVector Double) where
aa <\> b = linSolve0 GMRES_ aa b (mkSpVR n $ replicate n 0.1)
where n = ncols aa