module Numeric.LinearAlgebra.Sparse
(
qr, lu,
ilu0,
conditionNumberSM,
hhMat, hhRefl,
givens,
eigsQR, eigRayleigh,
linSolve, LinSolveMethod, (<\>),
cgne, tfqmr, bicgstab, cgs, bcg,
_xCgne, _xTfq, _xBicgstab, _x, _xBcg,
cgsStep, bicgstabStep,
CGNE, TFQMR, BICGSTAB, CGS, BCG,
diagPartitions,
randArray,
randMat, randVec,
randSpMat, randSpVec,
sparsifySV,
modifyInspectN, runAppendN',
diffSqL
)
where
import Data.Sparse.Common
import Control.Monad.Primitive
import Control.Monad (mapM_, forM_, replicateM)
import Control.Monad.State.Strict
import Control.Monad.Writer
import qualified Data.IntMap.Strict as IM
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC
import Data.Monoid
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.Maybe
sparsifySV :: SpVector Double -> SpVector Double
sparsifySV (SV d im) = SV d $ IM.filter (\x -> abs x >= eps) im
conditionNumberSM :: SpMatrix Double -> Double
conditionNumberSM m | isInfinite kappa = error "Infinite condition number : rank-deficient system"
| otherwise = kappa where
kappa = lmax / lmin
(_, r) = qr m
u = extractDiagDense r
lmax = abs (maximum u)
lmin = abs (minimum u)
hhMat :: Num a => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ scale beta (x >< x) where
n = dim x
hhRefl :: SpVector Double -> SpMatrix Double
hhRefl = hhMat 2.0
hypot :: Floating a => a -> a -> a
hypot x y = abs x * (sqrt (1 + y/x)**2)
sign :: (Ord a, Num a) => a -> a
sign x
| x > 0 = 1
| x == 0 = 0
| otherwise = 1
givensCoef :: (Ord a, Floating a) => a -> a -> (a, a, a)
givensCoef a b
| b==0 = (sign a, 0, abs a)
| a==0 = (0, sign b, abs b)
| abs a > abs b = let t = b/a
u = sign a * abs ( sqrt (1 + t**2))
in (1/u, t/u, a*u)
| otherwise = let t = a/b
u = sign b * abs ( sqrt (1 + t**2))
in (t/u, 1/u, b*u)
givens :: SpMatrix Double -> IxRow -> IxCol -> SpMatrix Double
givens mm i j
| isValidIxSM mm (i,j) && isSquareSM mm =
sparsifySM $ fromListSM' [(i,i,c),(j,j,c),(j,i,s),(i,j,s)] (eye (nrows mm))
| otherwise = error "givens : indices out of bounds"
where
(c, s, _) = givensCoef a b
i' = head $ fromMaybe (error $ "givens: no compatible rows for entry " ++ show (i,j)) (candidateRows (immSM mm) i j)
a = mm @@ (i', j)
b = mm @@ (i, j)
firstNonZeroColumn :: IM.IntMap a -> IxRow -> Bool
firstNonZeroColumn mm k = isJust (IM.lookup k mm) &&
isNothing (IM.lookupLT k mm)
candidateRows :: IM.IntMap (IM.IntMap a) -> IxRow -> IxCol -> Maybe [IM.Key]
candidateRows mm i j | IM.null u = Nothing
| otherwise = Just (IM.keys u) where
u = IM.filterWithKey (\irow row -> irow /= i &&
firstNonZeroColumn row j) mm
qr :: SpMatrix Double -> (SpMatrix Double, SpMatrix Double)
qr mm = (transposeSM qmatt, rmat) where
qmatt = F.foldl' (#~#) ee $ gmats mm
rmat = qmatt #~# mm
ee = eye (nrows mm)
gmats :: SpMatrix Double -> [SpMatrix Double]
gmats mm = gm mm (subdiagIndicesSM mm) where
gm m ((i,j):is) = let g = givens m i j
in g : gm (g #~# m) is
gm _ [] = []
eigsQR :: Int -> SpMatrix Double -> SpVector Double
eigsQR nitermax m = extractDiagDense $ execState (convergtest eigsStep) m where
eigsStep m = r #~# q where (q, r) = qr m
convergtest g = modifyInspectN nitermax f g where
f [m1, m2] = let dm1 = extractDiagDense m1
dm2 = extractDiagDense m2
in norm2 (dm1 ^-^ dm2) <= eps
eigRayleigh :: Int
-> SpMatrix Double
-> (SpVector Double, Double)
-> (SpVector Double, Double)
eigRayleigh nitermax m = execState (convergtest (rayleighStep m)) where
convergtest g = modifyInspectN nitermax f g where
f [(b1, _), (b2, _)] = norm2 (b2 ^-^ b1) <= eps
rayleighStep aa (b, mu) = (b', mu') where
ii = eye (nrows aa)
nom = (aa ^-^ (mu `matScale` ii)) <\> b
b' = normalize 2 nom
mu' = b' `dot` (aa #> b') / (b' `dot` b')
hhV :: SpVector Double -> (SpVector Double, Double)
hhV x = (v, beta) where
n = dim x
tx = tailSV x
sigma = tx `dot` tx
vtemp = singletonSV 1 `concatSV` tx
(v, beta) | sigma <= eps = (vtemp, 0)
| otherwise = let mu = sqrt (headSV x**2 + sigma)
xh = headSV x
vh | xh <= 1 = xh mu
| otherwise = sigma / (xh + mu)
vnew = (1 / vh) .* insertSpVector 0 vh vtemp
in (vnew, 2 * xh**2 / (sigma + vh**2))
lu :: SpMatrix Double -> (SpMatrix Double, SpMatrix Double)
lu aa = (lfin, ufin) where
(ixf,lf,uf) = execState (modifyUntil q (luUpd aa)) (luInit aa)
lfin = lf
ufin = uUpd aa (ixf, lf, uf)
q (i, _, _) = i == (nrows aa 1)
luInit ::
(Num t, Fractional a) => SpMatrix a -> (t, SpMatrix a, SpMatrix a)
luInit aa = (1, l0, u0) where
n = nrows aa
l0 = insertCol (eye n) ((1/u00) .* extractSubCol aa 0 (1,n 1)) 0
u0 = insertRow (zeroSM n n) (extractRow aa 0) 0
u00 = u0 @@ (0,0)
luUpd :: SpMatrix Double
-> (Int, SpMatrix Double, SpMatrix Double)
-> (Int, SpMatrix Double, SpMatrix Double)
luUpd aa (i, l, u) = (i', l', u') where
n = nrows aa
u' = uUpdSparse aa (i, l, u)
l' = lUpdSparse aa (i, l, u')
i' = i + 1
uUpd' ::
Num a =>
([(Int, a)] -> [(Int, a)]) ->
SpMatrix a ->
(Rows, SpMatrix a, SpMatrix a) ->
SpMatrix a
uUpd' ff amat (ix, lmat, umat) = insertRow umat uv ix where
n = nrows amat
colsix = [ix .. n 1]
us = ff $ zip colsix $ map (solveForUij amat lmat umat ix) colsix
uv = fromListSV n us
uUpd :: Num a => SpMatrix a -> (Rows, SpMatrix a, SpMatrix a) -> SpMatrix a
uUpd = uUpd' id
uUpdSparse ::
SpMatrix Double -> (Rows, SpMatrix Double, SpMatrix Double) -> SpMatrix Double
uUpdSparse = uUpd' (filter (isNz . snd))
solveForUij ::
Num a => SpMatrix a -> SpMatrix a -> SpMatrix a -> IxRow -> IxCol -> a
solveForUij amat lmat umat i j = a p where
a = amat @@! (i, j)
p = contractSub lmat umat i j (i 1)
solveForLij ::
SpMatrix Double -> SpMatrix Double -> SpMatrix Double -> IxRow -> IxCol -> Double
solveForLij amat lmat umat i j
| isNz ujj = (a p)/ujj
| otherwise =
error $ unwords ["solveForLij : U",
show (j ,j ),
"is close to 0. Permute rows in order to have a nonzero diagonal of U"]
where
a = amat @@! (i, j)
ujj = umat @@! (j , j)
p = contractSub lmat umat i j (i 1)
lUpd' :: ([(Rows, Double)] -> [(Int, Double)])
-> SpMatrix Double
-> (Rows, SpMatrix Double, SpMatrix Double)
-> SpMatrix Double
lUpd' ff amat (ix, lmat, umat) = insertCol lmat lv ix where
n = nrows amat
rowsix = [ix + 1 .. n 1]
ls = ff $ zip rowsix $ map (\i -> solveForLij amat lmat umat i ix) rowsix
lv = fromListSV n ls
lUpd :: SpMatrix Double -> (Rows, SpMatrix Double, SpMatrix Double) -> SpMatrix Double
lUpd = lUpd' id
lUpdSparse ::
SpMatrix Double -> (Rows, SpMatrix Double, SpMatrix Double) -> SpMatrix Double
lUpdSparse = lUpd' (filter (isNz . snd))
permutAA :: Num b => IxRow -> IxCol -> SpMatrix a -> Maybe (SpMatrix b)
permutAA iref jref (SM (nro,_) mm)
| isJust (lookupIM2 iref jref mm) = Nothing
| otherwise = Just $ permutationSM nro [head u] where
u = IM.keys (ifilterIM2 ff mm)
ff i j _ = i /= iref &&
j == jref
ilu0 aa = (lh, uh) where
(l, u) = lu aa
lh = sparsifyLU l aa
uh = sparsifyLU u aa
sparsifyLU m m2 = SM (dim m) $ ifilterIM2 f (dat m) where
f i j _ = isJust (lookupSM m2 i j)
diagPartitions :: SpMatrix a -> (SpMatrix a, SpMatrix a, SpMatrix a)
diagPartitions aa = (e,d,f) where
e = extractSubDiag aa
d = extractDiag aa
f = extractSuperDiag aa
mSsor :: Fractional a => SpMatrix a -> a -> SpMatrix a
mSsor aa omega = l ## r where
(e, d, f) = diagPartitions aa
n = nrows e
l = d ^-^ scale omega e
r = eye n ^-^ scale omega (reciprocal d ## f)
cgneStep :: SpMatrix Double -> CGNE -> CGNE
cgneStep aa (CGNE x r p) = CGNE x1 r1 p1 where
alphai = r `dot` r / (p `dot` p)
x1 = x ^+^ (alphai .* p)
r1 = r ^-^ (alphai .* (aa #> p))
beta = r1 `dot` r1 / (r `dot` r)
p1 = transposeSM aa #> r ^+^ (beta .* p)
data CGNE =
CGNE {_xCgne , _rCgne, _pCgne :: SpVector Double} deriving Eq
cgne :: SpMatrix Double -> SpVector Double -> SpVector Double -> CGNE
cgne aa b x0 = execState (untilConverged _xCgne (cgneStep aa)) cgneInit where
r0 = b ^-^ (aa #> x0)
p0 = transposeSM aa #> r0
cgneInit = CGNE x0 r0 p0
tfqmrStep :: SpMatrix Double -> SpVector Double -> TFQMR -> TFQMR
tfqmrStep aa r0hat (TFQMR x w u v d m tau theta eta rho alpha) =
TFQMR x1 w1 u1 v1 d1 (m+1) tau1 theta1 eta1 rho1 alpha1
where
w1 = w ^-^ (alpha .* (aa #> u))
d1 = u ^+^ ((theta**2/alpha*eta) .* d)
theta1 = norm2 w1 / tau
c = recip $ sqrt (1 + theta1**2)
tau1 = tau * theta1 * c
eta1 = c**2 * alpha
x1 = x^+^ (eta1 .* d1)
(alpha1, u1, rho1, v1)
| even m = let
alpha' = rho / (v `dot` r0hat)
u' = u ^-^ (alpha' .* v)
in
(alpha', u', rho, v)
| otherwise = let
rho' = w1 `dot` r0hat
beta = rho'/rho
u' = w1 ^+^ (beta .* u)
v' = (aa #> u') ^+^ (beta .* (aa #> u ^+^ (beta .* v)) )
in (alpha, u', rho', v')
tfqmr :: SpMatrix Double -> SpVector Double -> SpVector Double -> TFQMR
tfqmr aa b x0 = execState (untilConverged _xTfq (tfqmrStep aa r0)) tfqmrInit where
n = dim b
r0 = b ^-^ (aa #> x0)
w0 = r0
u0 = r0
v0 = aa #> u0
d0 = zeroSV n
r0hat = r0
rho0 = r0hat `dot` r0
alpha0 = rho0 / (v0 `dot` r0hat)
m = 0
tau0 = norm2 r0
theta0 = 0
eta0 = 0
tfqmrInit = TFQMR x0 w0 u0 v0 d0 m tau0 theta0 eta0 rho0 alpha0
data TFQMR =
TFQMR { _xTfq, _wTfq, _uTfq, _vTfq, _dTfq :: SpVector Double,
_mTfq :: Int,
_tauTfq, _thetaTfq, _etaTfq, _rhoTfq, _alphaTfq :: Double}
deriving Eq
bcgStep :: SpMatrix Double -> BCG -> BCG
bcgStep aa (BCG x r rhat p phat) = BCG x1 r1 rhat1 p1 phat1 where
aap = aa #> p
alpha = (r `dot` rhat) / (aap `dot` phat)
x1 = x ^+^ (alpha .* p)
r1 = r ^-^ (alpha .* aap)
rhat1 = rhat ^-^ (alpha .* (transposeSM aa #> phat))
beta = (r1 `dot` rhat1) / (r `dot` rhat)
p1 = r1 ^+^ (beta .* p)
phat1 = rhat1 ^+^ (beta .* phat)
data BCG =
BCG { _xBcg, _rBcg, _rHatBcg, _pBcg, _pHatBcg :: SpVector Double } deriving Eq
bcg :: SpMatrix Double -> SpVector Double -> SpVector Double -> BCG
bcg aa b x0 = execState (untilConverged _xBcg (bcgStep aa)) bcgInit where
r0 = b ^-^ (aa #> x0)
r0hat = r0
p0 = r0
p0hat = r0
bcgInit = BCG x0 r0 r0hat p0 p0hat
instance Show BCG where
show (BCG x r rhat p phat) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"r_hat = " ++ show rhat ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"p_hat = " ++ show phat ++ "\n"
cgsStep :: SpMatrix Double -> SpVector Double -> CGS -> CGS
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
where
aap = aa #> p
alphaj = (r `dot` rhat) / (aap `dot` rhat)
q = u ^-^ (alphaj .* aap)
xj1 = x ^+^ (alphaj .* (u ^+^ q))
rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q)))
betaj = (rj1 `dot` rhat) / (r `dot` rhat)
uj1 = rj1 ^+^ (betaj .* q)
pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))
data CGS = CGS { _x, _r, _p, _u :: SpVector Double} deriving Eq
cgs ::
SpMatrix Double ->
SpVector Double ->
SpVector Double ->
SpVector Double ->
CGS
cgs aa b x0 rhat =
execState (untilConverged _x (cgsStep aa rhat)) cgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
u0 = r0
cgsInit = CGS x0 r0 p0 u0
instance Show CGS where
show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"u = " ++ show u ++ "\n"
bicgstabStep :: SpMatrix Double -> SpVector Double -> BICGSTAB -> BICGSTAB
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
aap = aa #> p
alphaj = (r `dot` r0hat) / (aap `dot` r0hat)
sj = r ^-^ (alphaj .* aap)
aasj = aa #> sj
omegaj = (aasj `dot` sj) / (aasj `dot` aasj)
xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)
rj1 = sj ^-^ (omegaj .* aasj)
betaj = (rj1 `dot` r0hat)/(r `dot` r0hat) * alphaj / omegaj
pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))
data BICGSTAB =
BICGSTAB { _xBicgstab, _rBicgstab, _pBicgstab :: SpVector Double} deriving Eq
bicgstab
:: SpMatrix Double
-> SpVector Double
-> SpVector Double
-> SpVector Double
-> BICGSTAB
bicgstab aa b x0 r0hat =
execState (untilConverged _xBicgstab (bicgstabStep aa r0hat)) bicgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
bicgsInit = BICGSTAB x0 r0 p0
instance Show BICGSTAB where
show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
pinv :: SpMatrix Double -> SpVector Double -> SpVector Double
pinv aa b = aa #^# aa <\> atb where
atb = transposeSM aa #> b
data LinSolveMethod = CGNE_ | TFQMR_ | BCG_ | CGS_ | BICGSTAB_ deriving (Eq, Show)
linSolve ::
LinSolveMethod -> SpMatrix Double -> SpVector Double -> SpVector Double
linSolve method aa b
| n /= nb = error "linSolve : operand dimensions mismatch"
| otherwise = solve aa b where
solve aa' b' | isDiagonalSM aa' = reciprocal aa' #> b'
| otherwise = solveWith aa' b'
solveWith aa' b' = case method of
CGNE_ -> _xCgne (cgne aa' b' x0)
TFQMR_ -> _xTfq (tfqmr aa' b' x0)
BCG_ -> _xBcg (bcg aa' b' x0)
CGS_ -> _xBicgstab (bicgstab aa' b' x0 x0)
BICGSTAB_ -> _x (cgs aa' b' x0 x0)
x0 = mkSpVectorD n $ replicate n 0.1
(m, n) = dim aa
nb = dim b
(<\>) :: SpMatrix Double -> SpVector Double -> SpVector Double
(<\>) = linSolve BICGSTAB_
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = do
x <- get
let y = f x
put y
if q y then return y
else modifyUntil q f
loopUntilAcc :: Int -> ([t] -> Bool) -> (t -> t) -> t -> t
loopUntilAcc nitermax q f x = go 0 [] x where
go i ll xx | length ll < 2 = go (i + 1) (y : ll) y
| otherwise = if q ll || i == nitermax
then xx
else go (i + 1) (take 2 $ y:ll) y
where y = f xx
modifyInspectN ::
MonadState s m =>
Int ->
([s] -> Bool) ->
(s -> s) ->
m s
modifyInspectN nitermax q f
| nitermax > 0 = go 0 []
| otherwise = error "modifyInspectN : n must be > 0" where
go i ll = do
x <- get
let y = f x
if length ll < 2
then do put y
go (i + 1) (y : ll)
else if q ll || i == nitermax
then do put y
return y
else do put y
go (i + 1) (take 2 $ y : ll)
meanl :: (Foldable t, Fractional a) => t a -> a
meanl xx = 1/fromIntegral (length xx) * sum xx
norm2l :: (Foldable t, Functor t, Floating a) => t a -> a
norm2l xx = sqrt $ sum (fmap (**2) xx)
diffSqL :: Floating a => [a] -> a
diffSqL xx = (x1 x2)**2 where [x1, x2] = [head xx, xx!!1]
untilConverged :: MonadState a m => (a -> SpVector Double) -> (a -> a) -> m a
untilConverged fproj = modifyInspectN 200 (normDiffConverged fproj)
normDiffConverged :: (Foldable t, Functor t) =>
(a -> SpVector Double) -> t a -> Bool
normDiffConverged fp xx = normSq (foldrMap fp (^-^) (zeroSV 0) xx) <= eps
runAppendN :: ([t] -> Bool) -> (t -> t) -> Int -> t -> [t]
runAppendN qq ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go qq ff niter x0 [] where
go q f n z xs =
let x = f z in
if n <= 0 || q xs then xs
else go q f (n1) x (x : xs)
runAppendN' :: (t -> t) -> Int -> t -> [t]
runAppendN' ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go ff niter x0 [] where
go f n z xs =
let x = f z in
if n <= 0 then xs
else go f (n1) x (x : xs)
randArray :: PrimMonad m => Int -> Double -> Double -> m [Double]
randArray n mu sig = do
g <- MWC.create
replicateM n (MWC.normal mu sig g)
randMat :: PrimMonad m => Int -> m (SpMatrix Double)
randMat n = do
g <- MWC.create
aav <- replicateM (n^2) (MWC.normal 0 1 g)
let ii_ = [0 .. n1]
(ix_,iy_) = unzip $ concatMap (zip ii_ . replicate n) ii_
return $ fromListSM (n,n) $ zip3 ix_ iy_ aav
randVec :: PrimMonad m => Int -> m (SpVector Double)
randVec n = do
g <- MWC.create
bv <- replicateM n (MWC.normal 0 1 g)
let ii_ = [0..n1]
return $ fromListSV n $ zip ii_ bv
randSpMat :: Int -> Int -> IO (SpMatrix Double)
randSpMat n nsp | nsp > n^2 = error "randSpMat : nsp must be < n^2 "
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
jj <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSM (n,n) $ zip3 ii jj aav
randSpVec :: Int -> Int -> IO (SpVector Double)
randSpVec n nsp | nsp > n = error "randSpVec : nsp must be < n"
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSV n $ zip ii aav