module Numeric.LinearAlgebra.Sparse
(
qr,
lu,
chol,
conditionNumberSM,
hhMat, hhRefl,
givens,
arnoldi,
eigsQR, eigRayleigh,
linSolve, LinSolveMethod(..), (<\>),
pinv,
luSolve,
ilu0, mSsor,
diagPartitions,
randArray,
randMat, randVec,
randSpMat, randSpVec,
sparsifySV,
modifyInspectN, runAppendN', untilConverged,
diffSqL
)
where
import Data.Sparse.Common
import Control.Monad.Primitive
import Control.Monad (mapM_, forM_, replicateM)
import Control.Monad.State.Strict
import qualified Data.IntMap.Strict as IM
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC
import Data.Monoid
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.Maybe
import qualified Data.Vector as V
sparsifySV :: Epsilon a => SpVector a -> SpVector a
sparsifySV = filterSV isNz
conditionNumberSM :: (Epsilon a, RealFloat a) => SpMatrix a -> a
conditionNumberSM m | nearZero lmin = error "Infinite condition number : rank-deficient system"
| otherwise = kappa where
kappa = lmax / lmin
(_, r) = qr m
u = extractDiagDense r
lmax = abs (maximum u)
lmin = abs (minimum u)
hhMat :: Num a => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ scale beta (x >< x) where
n = dim x
hhRefl :: Num a => SpVector a -> SpMatrix a
hhRefl = hhMat (fromInteger 2)
hypot :: Floating a => a -> a -> a
hypot x y = abs x * (sqrt (1 + y/x)**2)
sign :: (Ord a, Num a) => a -> a
sign x
| x > 0 = 1
| x == 0 = 0
| otherwise = 1
givensCoef :: (Ord a, Floating a) => a -> a -> (a, a, a)
givensCoef a b
| b==0 = (sign a, 0, abs a)
| a==0 = (0, sign b, abs b)
| abs a > abs b = let t = b/a
u = sign a * abs ( sqrt (1 + t**2))
in (1/u, t/u, a*u)
| otherwise = let t = a/b
u = sign b * abs ( sqrt (1 + t**2))
in (t/u, 1/u, b*u)
givens :: (Floating a, Epsilon a, Ord a) => SpMatrix a -> IxRow -> IxCol -> SpMatrix a
givens mm i j
| isValidIxSM mm (i,j) && nrows mm >= ncols mm =
fromListSM' [(i,i,c),(j,j,c),(j,i,s),(i,j,s)] (eye (nrows mm))
| otherwise = error "givens : indices out of bounds"
where
(c, s, _) = givensCoef a b
i' = head $ fromMaybe (error $ "givens: no compatible rows for entry " ++ show (i,j)) (candidateRows (immSM mm) i j)
a = mm @@ (i', j)
b = mm @@ (i, j)
candidateRows :: IM.IntMap (IM.IntMap a) -> IxRow -> IxCol -> Maybe [IM.Key]
candidateRows mm i j | IM.null u = Nothing
| otherwise = Just (IM.keys u) where
u = IM.filterWithKey (\irow row -> irow /= i &&
firstNonZeroColumn row j) mm
firstNonZeroColumn :: IM.IntMap a -> IxRow -> Bool
firstNonZeroColumn mm k = isJust (IM.lookup k mm) &&
isNothing (IM.lookupLT k mm)
qr :: (Epsilon a, Ord a, Floating a) => SpMatrix a -> (SpMatrix a, SpMatrix a)
qr mm = (transposeSM qt, r) where
(qt, r, _) = execState (modifyUntil qf stepf) gminit
qf (_, _, iis) = null iis
stepf (qmatt, m, iis) = (qmatt', m', tail iis) where
(i, j) = head iis
g = givens m i j
qmatt' = g #~# qmatt
m' = g #~# m
gminit = (eye (nrows mm), mm, subdiagIndicesSM mm)
eigsQR :: (Epsilon a, Real a, Floating a) => Int -> SpMatrix a -> SpVector a
eigsQR nitermax m = extractDiagDense $ execState (convergtest eigsStep) m where
eigsStep m = r #~# q where (q, r) = qr m
convergtest g = modifyInspectN nitermax f g where
f [m1, m2] = let dm1 = extractDiagDense m1
dm2 = extractDiagDense m2
in nearZero $ norm2 (dm1 ^-^ dm2)
eigRayleigh nitermax m = execState (convergtest (rayleighStep m)) where
convergtest g = modifyInspectN nitermax f g where
f [(b1, _), (b2, _)] = nearZero $ norm2 (b2 ^-^ b1)
rayleighStep aa (b, mu) = (b', mu') where
ii = eye (nrows aa)
nom = (aa ^-^ (mu `matScale` ii)) <\> b
b' = normalize 2 nom
mu' = b' `dot` (aa #> b') / (b' `dot` b')
hhV :: (Epsilon a, Real a, Floating a) => SpVector a -> (SpVector a, a)
hhV x = (v, beta) where
n = dim x
tx = tailSV x
sigma = tx `dot` tx
vtemp = singletonSV 1 `concatSV` tx
(v, beta) | nearZero sigma = (vtemp, 0)
| otherwise = let mu = sqrt (headSV x**2 + sigma)
xh = headSV x
vh | xh <= 1 = xh mu
| otherwise = sigma / (xh + mu)
vnew = (1 / vh) .* insertSpVector 0 vh vtemp
in (vnew, 2 * xh**2 / (sigma + vh**2))
chol :: (Epsilon a, Real a, Floating a) => SpMatrix a -> SpMatrix a
chol aa = lfin where
(_, lfin) = execState (modifyUntil q cholUpd) cholInit
q (i, _) = i == nrows aa
cholInit = cholUpd (0, zeroSM n n)
n = nrows aa
cholUpd (i, ll) = (i + 1, ll') where
ll' = cholDiagUpd (cholSDRowUpd ll)
cholSDRowUpd ll_ = insertRow ll_ lrs i where
lrs = fromListSV (i + 1) $ onRangeSparse (cholSubDiag ll i) [0 .. i1]
cholDiagUpd ll_ = insertSpMatrix i i (cholDiag ll_ i) ll_
cholSubDiag ll i j = 1/ljj*(aij inn) where
ljj = ll@@(j, j)
aij = aa@@(i, j)
inn = contractSub ll ll i j (j 1)
cholDiag ll i | i == 0 = sqrt aai
| i > 0 = sqrt $ aai sum (fmap (**2) lrow)
| otherwise = error "cholDiag : index must be nonnegative" where
aai = aa@@(i,i)
lrow = ifilterSV (\j _ -> j < i) (extractRow ll i)
lu :: (Epsilon a, Fractional a, Real a) => SpMatrix a -> (SpMatrix a, SpMatrix a)
lu aa = (lf, ufin) where
(ixf, lf, uf) = execState (modifyUntil q luUpd) luInit
ufin = uUpdSparse (ixf, lf, uf)
q (i, _, _) = i == (nrows aa 1)
n = nrows aa
luInit = (1, l0, u0) where
l0 = insertCol (eye n) ((1/u00) .* extractSubCol aa 0 (1,n 1)) 0
u0 = insertRow (zeroSM n n) (extractRow aa 0) 0
u00 = u0 @@ (0,0)
luUpd (i, l, u) = (i + 1, l', u') where
u' = uUpdSparse (i, l, u)
l' = lUpdSparse (i, l, u')
uUpdSparse (ix, lmat, umat) = insertRow umat (fromListSV n us) ix where
us = onRangeSparse (solveForUij ix) [ix .. n 1]
solveForUij i j = a p where
a = aa @@! (i, j)
p = contractSub lmat umat i j (i 1)
lUpdSparse (ix, lmat, umat) = insertCol lmat (fromListSV n ls) ix where
ls = onRangeSparse (`solveForLij` ix) [ix + 1 .. n 1]
solveForLij i j
| isNz ujj = (a p)/ujj
| otherwise =
error $ unwords ["solveForLij : U",
show (j ,j ),
"is close to 0. Permute rows in order to have a nonzero diagonal of U"]
where
a = aa @@! (i, j)
ujj = umat @@! (j , j)
p = contractSub lmat umat i j (i 1)
onRangeSparse :: Epsilon b => (Int -> b) -> [Int] -> [(Int, b)]
onRangeSparse f ixs = filter (isNz . snd) $ zip ixs $ map f ixs
arnoldi ::
(Epsilon a, Floating a, Eq a) => SpMatrix a -> SpVector a -> Int -> (SpMatrix a, SpMatrix a)
arnoldi aa b kn = (fromCols qvfin, fromListSM (nmax + 1, nmax) hhfin)
where
(qvfin, hhfin, nmax, _) = execState (modifyUntil tf arnoldiStep) arnInit
tf (_, _, ii, fbreak) = ii == kn || fbreak
(m, n) = dim aa
arnInit = (qv1, hh1, 1, False) where
q0 = normalize 2 b
aq0 = aa #> q0
h11 = q0 `dot` aq0
q1nn = (aq0 ^-^ (h11 .* q0))
hh1 = V.fromList [(0, 0, h11), (1, 0, h21)] where
h21 = norm 2 q1nn
q1 = normalize 2 q1nn
qv1 = V.fromList [q0, q1]
arnoldiStep (qv, hh, i, _) = (qv', hh', i + 1, fb') where
qi = V.last qv
aqi = aa #> qi
hhcoli = fmap (`dot` aqi) qv
zv = zeroSV m
qipnn =
aqi ^-^ (V.foldl' (^+^) zv (V.zipWith (.*) hhcoli qv))
qipnorm = norm 2 qipnn
qip = normalize 2 qipnn
hh' = (V.++) hh (indexed2 $ V.snoc hhcoli qipnorm) where
indexed2 v = V.zip3 ii jj v
ii = V.fromList [0 .. n]
jj = V.replicate (n + 1) i
qv' = V.snoc qv qip
fb' | nearZero qipnorm = True
| otherwise = False
diagPartitions :: SpMatrix a -> (SpMatrix a, SpMatrix a, SpMatrix a)
diagPartitions aa = (e,d,f) where
e = extractSubDiag aa
d = extractDiag aa
f = extractSuperDiag aa
ilu0 :: (Epsilon a, Real a, Fractional a) => SpMatrix a -> (SpMatrix a, SpMatrix a)
ilu0 aa = (lh, uh) where
(l, u) = lu aa
lh = sparsifyLU l aa
uh = sparsifyLU u aa
sparsifyLU m m2 = ifilterSM f m where
f i j _ = isJust (lookupSM m2 i j)
mSsor :: Fractional a => SpMatrix a -> a -> (SpMatrix a, SpMatrix a)
mSsor aa omega = (l, r) where
(e, d, f) = diagPartitions aa
n = nrows e
l = (eye n ^-^ scale omega e) ## reciprocal d
r = d ^-^ scale omega f
luSolve ::
(Fractional a, Eq a, Epsilon a) => SpMatrix a -> SpMatrix a -> SpVector a -> SpVector a
luSolve ll uu b
| isLowerTriSM ll && isUpperTriSM uu = triUpperSolve uu (triLowerSolve ll b)
| otherwise = error "luSolve : factors must be triangular matrices"
triLowerSolve :: (Epsilon a, Fractional a) => SpMatrix a -> SpVector a -> SpVector a
triLowerSolve ll b = sparsifySV v where
(v, _) = execState (modifyUntil q lStep) lInit where
q (_, i) = i == dim b
lStep (ww, i) = (ww', i + 1) where
lii = ll @@ (i, i)
bi = b @@ i
wi = (bi r)/lii where
r = extractSubRow ll i (0, i1) `dot` takeSV i ww
ww' = insertSpVector i wi ww
lInit = (ww0, 1) where
l00 = ll @@ (0, 0)
b0 = b @@ 0
w0 = b0 / l00
ww0 = insertSpVector 0 w0 $ zeroSV (dim b)
triUpperSolve :: (Epsilon a, Fractional a) => SpMatrix a -> SpVector a -> SpVector a
triUpperSolve uu w = sparsifySV x where
(x, _) = execState (modifyUntil q uStep) uInit
q (_, i) = i == ( 1)
uStep (xx, i) = (xx', i 1) where
uii = uu @@ (i, i)
wi = w @@ i
xi = (wi r) / uii where
r = extractSubRow_RK uu i (i + 1, dim w 1) `dot` dropSV (i + 1) xx
xx' = insertSpVector i xi xx
uInit = (xx0, i 1) where
i = dim w 1
u00 = uu @@ (i, i)
w0 = w @@ i
x0 = w0 / u00
xx0 = insertSpVector i x0 $ zeroSV (dim w)
gmres :: (Epsilon a, Ord a, Floating a) => SpMatrix a -> SpVector a -> SpVector a
gmres aa b = qa' #> yhat where
m = ncols aa
(qa, ha) = arnoldi aa b m
b' = norm2 b .* (ei mp1 1)
where mp1 = nrows ha
(qh, rh) = qr ha
yhat = triUpperSolve rh' rhs' where
rhs' = takeSV (dim b' 1) (transposeSM qh #> b')
rh' = takeRows (nrows rh 1) rh
qa' = takeCols (ncols qa 1) qa
cgneStep :: SpMatrix Double -> CGNE -> CGNE
cgneStep aa (CGNE x r p) = CGNE x1 r1 p1 where
alphai = r `dot` r / (p `dot` p)
x1 = x ^+^ (alphai .* p)
r1 = r ^-^ (alphai .* (aa #> p))
beta = r1 `dot` r1 / (r `dot` r)
p1 = transposeSM aa #> r ^+^ (beta .* p)
data CGNE =
CGNE {_xCgne , _rCgne, _pCgne :: SpVector Double} deriving Eq
instance Show CGNE where
show (CGNE x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
cgne :: SpMatrix Double -> SpVector Double -> SpVector Double -> CGNE
cgne aa b x0 = execState (untilConverged _xCgne (cgneStep aa)) cgneInit where
r0 = b ^-^ (aa #> x0)
p0 = transposeSM aa #> r0
cgneInit = CGNE x0 r0 p0
tfqmrStep :: SpMatrix Double -> SpVector Double -> TFQMR -> TFQMR
tfqmrStep aa r0hat (TFQMR x w u v d m tau theta eta rho alpha) =
TFQMR x1 w1 u1 v1 d1 (m+1) tau1 theta1 eta1 rho1 alpha1
where
w1 = w ^-^ (alpha .* (aa #> u))
d1 = u ^+^ ((theta**2/alpha*eta) .* d)
theta1 = norm2 w1 / tau
c = recip $ sqrt (1 + theta1**2)
tau1 = tau * theta1 * c
eta1 = c**2 * alpha
x1 = x^+^ (eta1 .* d1)
(alpha1, u1, rho1, v1)
| even m = let
alpha' = rho / (v `dot` r0hat)
u' = u ^-^ (alpha' .* v)
in
(alpha', u', rho, v)
| otherwise = let
rho' = w1 `dot` r0hat
beta = rho'/rho
u' = w1 ^+^ (beta .* u)
v' = (aa #> u') ^+^ (beta .* (aa #> u ^+^ (beta .* v)) )
in (alpha, u', rho', v')
tfqmr :: SpMatrix Double -> SpVector Double -> SpVector Double -> TFQMR
tfqmr aa b x0 = execState (untilConverged _xTfq (tfqmrStep aa r0)) tfqmrInit where
n = dim b
r0 = b ^-^ (aa #> x0)
w0 = r0
u0 = r0
v0 = aa #> u0
d0 = zeroSV n
r0hat = r0
rho0 = r0hat `dot` r0
alpha0 = rho0 / (v0 `dot` r0hat)
m = 0
tau0 = norm2 r0
theta0 = 0
eta0 = 0
tfqmrInit = TFQMR x0 w0 u0 v0 d0 m tau0 theta0 eta0 rho0 alpha0
data TFQMR =
TFQMR { _xTfq, _wTfq, _uTfq, _vTfq, _dTfq :: SpVector Double,
_mTfq :: Int,
_tauTfq, _thetaTfq, _etaTfq, _rhoTfq, _alphaTfq :: Double}
deriving Eq
instance Show TFQMR where
show (TFQMR x _ _ _ _ _ _ _ _ _ _) = "x = " ++ show x ++ "\n"
bcgStep :: SpMatrix Double -> BCG -> BCG
bcgStep aa (BCG x r rhat p phat) = BCG x1 r1 rhat1 p1 phat1 where
aap = aa #> p
alpha = (r `dot` rhat) / (aap `dot` phat)
x1 = x ^+^ (alpha .* p)
r1 = r ^-^ (alpha .* aap)
rhat1 = rhat ^-^ (alpha .* (transposeSM aa #> phat))
beta = (r1 `dot` rhat1) / (r `dot` rhat)
p1 = r1 ^+^ (beta .* p)
phat1 = rhat1 ^+^ (beta .* phat)
data BCG =
BCG { _xBcg, _rBcg, _rHatBcg, _pBcg, _pHatBcg :: SpVector Double } deriving Eq
bcg :: SpMatrix Double -> SpVector Double -> SpVector Double -> BCG
bcg aa b x0 = execState (untilConverged _xBcg (bcgStep aa)) bcgInit where
r0 = b ^-^ (aa #> x0)
r0hat = r0
p0 = r0
p0hat = r0
bcgInit = BCG x0 r0 r0hat p0 p0hat
instance Show BCG where
show (BCG x r rhat p phat) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"r_hat = " ++ show rhat ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"p_hat = " ++ show phat ++ "\n"
cgsStep :: SpMatrix Double -> SpVector Double -> CGS -> CGS
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
where
aap = aa #> p
alphaj = (r `dot` rhat) / (aap `dot` rhat)
q = u ^-^ (alphaj .* aap)
xj1 = x ^+^ (alphaj .* (u ^+^ q))
rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q)))
betaj = (rj1 `dot` rhat) / (r `dot` rhat)
uj1 = rj1 ^+^ (betaj .* q)
pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))
data CGS = CGS { _x, _r, _p, _u :: SpVector Double} deriving Eq
cgs ::
SpMatrix Double ->
SpVector Double ->
SpVector Double ->
SpVector Double ->
CGS
cgs aa b x0 rhat =
execState (untilConverged _x (cgsStep aa rhat)) cgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
u0 = r0
cgsInit = CGS x0 r0 p0 u0
instance Show CGS where
show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"u = " ++ show u ++ "\n"
bicgstabStep :: SpMatrix Double -> SpVector Double -> BICGSTAB -> BICGSTAB
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
aap = aa #> p
alphaj = (r `dot` r0hat) / (aap `dot` r0hat)
sj = r ^-^ (alphaj .* aap)
aasj = aa #> sj
omegaj = (aasj `dot` sj) / (aasj `dot` aasj)
xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)
rj1 = sj ^-^ (omegaj .* aasj)
betaj = (rj1 `dot` r0hat)/(r `dot` r0hat) * alphaj / omegaj
pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))
data BICGSTAB =
BICGSTAB { _xBicgstab, _rBicgstab, _pBicgstab :: SpVector Double} deriving Eq
bicgstab
:: SpMatrix Double
-> SpVector Double
-> SpVector Double
-> SpVector Double
-> BICGSTAB
bicgstab aa b x0 r0hat =
execState (untilConverged _xBicgstab (bicgstabStep aa r0hat)) bicgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
bicgsInit = BICGSTAB x0 r0 p0
instance Show BICGSTAB where
show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
pinv :: SpMatrix Double -> SpVector Double -> SpVector Double
pinv aa b = aa #^# aa <\> atb where
atb = transposeSM aa #> b
data LinSolveMethod = GMRES_ | CGNE_ | TFQMR_ | BCG_ | CGS_ | BICGSTAB_ deriving (Eq, Show)
linSolve ::
LinSolveMethod -> SpMatrix Double -> SpVector Double -> SpVector Double
linSolve method aa b
| n /= nb = error "linSolve : operand dimensions mismatch"
| otherwise = solve aa b where
solve aa' b' | isDiagonalSM aa' = reciprocal aa' #> b'
| otherwise = solveWith aa' b'
solveWith aa' b' = case method of
GMRES_ -> gmres aa' b'
CGNE_ -> _xCgne (cgne aa' b' x0)
TFQMR_ -> _xTfq (tfqmr aa' b' x0)
BCG_ -> _xBcg (bcg aa' b' x0)
BICGSTAB_ -> _xBicgstab (bicgstab aa' b' x0 x0)
CGS_ -> _x (cgs aa' b' x0 x0)
x0 = mkSpVectorD n $ replicate n 0.1
(m, n) = dim aa
nb = dim b
(<\>) :: SpMatrix Double -> SpVector Double -> SpVector Double
(<\>) = linSolve GMRES_
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = do
x <- get
let y = f x
put y
if q y then return y
else modifyUntil q f
loopUntilAcc :: Int -> ([t] -> Bool) -> (t -> t) -> t -> t
loopUntilAcc nitermax q f x = go 0 [] x where
go i ll xx | length ll < 2 = go (i + 1) (y : ll) y
| otherwise = if q ll || i == nitermax
then xx
else go (i + 1) (take 2 $ y:ll) y
where y = f xx
modifyInspectN ::
MonadState s m =>
Int ->
([s] -> Bool) ->
(s -> s) ->
m s
modifyInspectN nitermax q f
| nitermax > 0 = go 0 []
| otherwise = error "modifyInspectN : n must be > 0" where
go i ll = do
x <- get
let y = f x
if length ll < 2
then do put y
go (i + 1) (y : ll)
else if q ll || i == nitermax
then do put y
return y
else do put y
go (i + 1) (take 2 $ y : ll)
meanl :: (Foldable t, Fractional a) => t a -> a
meanl xx = 1/fromIntegral (length xx) * sum xx
norm2l :: (Foldable t, Functor t, Floating a) => t a -> a
norm2l xx = sqrt $ sum (fmap (**2) xx)
diffSqL :: Floating a => [a] -> a
diffSqL xx = (x1 x2)**2 where [x1, x2] = [head xx, xx!!1]
untilConverged :: MonadState a m => (a -> SpVector Double) -> (a -> a) -> m a
untilConverged fproj = modifyInspectN 200 (normDiffConverged fproj)
normDiffConverged :: (Foldable t, Functor t) =>
(a -> SpVector Double) -> t a -> Bool
normDiffConverged fp xx = nearZero $ normSq (foldrMap fp (^-^) (zeroSV 0) xx)
runAppendN :: ([t] -> Bool) -> (t -> t) -> Int -> t -> [t]
runAppendN qq ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go qq ff niter x0 [] where
go q f n z xs =
let x = f z in
if n <= 0 || q xs then xs
else go q f (n1) x (x : xs)
runAppendN' :: (t -> t) -> Int -> t -> [t]
runAppendN' ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go ff niter x0 [] where
go f n z xs =
let x = f z in
if n <= 0 then xs
else go f (n1) x (x : xs)
randArray :: PrimMonad m => Int -> Double -> Double -> m [Double]
randArray n mu sig = do
g <- MWC.create
replicateM n (MWC.normal mu sig g)
randMat :: PrimMonad m => Int -> m (SpMatrix Double)
randMat n = do
g <- MWC.create
aav <- replicateM (n^2) (MWC.normal 0 1 g)
let ii_ = [0 .. n1]
(ix_,iy_) = unzip $ concatMap (zip ii_ . replicate n) ii_
return $ fromListSM (n,n) $ zip3 ix_ iy_ aav
randVec :: PrimMonad m => Int -> m (SpVector Double)
randVec n = do
g <- MWC.create
bv <- replicateM n (MWC.normal 0 1 g)
let ii_ = [0..n1]
return $ fromListSV n $ zip ii_ bv
randSpMat :: Int -> Int -> IO (SpMatrix Double)
randSpMat n nsp | nsp > n^2 = error "randSpMat : nsp must be < n^2 "
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
jj <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSM (n,n) $ zip3 ii jj aav
randSpVec :: Int -> Int -> IO (SpVector Double)
randSpVec n nsp | nsp > n = error "randSpVec : nsp must be < n"
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSV n $ zip ii aav