module Numeric.LinearAlgebra.Sparse where
import Data.Sparse.Common
import Control.Monad.Primitive
import Control.Monad (mapM_, forM_, replicateM)
import Control.Monad.State.Strict
import qualified Data.IntMap.Strict as IM
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC
import Data.Monoid
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.Maybe
sparsifySV :: SpVector Double -> SpVector Double
sparsifySV (SV d im) = SV d $ IM.filter (\x -> abs x >= eps) im
conditionNumberSM :: SpMatrix Double -> Double
conditionNumberSM m | isInfinite kappa = error "Infinite condition number : rank-deficient system"
| otherwise = kappa where
kappa = lmax / lmin
(_, r) = qr m
u = extractDiagonalDSM r
lmax = abs (maximum u)
lmin = abs (minimum u)
hhMat :: Num a => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ scale beta (x >< x) where
n = dim x
hhRefl :: SpVector Double -> SpMatrix Double
hhRefl = hhMat 2.0
hypot :: Floating a => a -> a -> a
hypot x y = abs x * (sqrt (1 + y/x)**2)
sign :: (Ord a, Num a) => a -> a
sign x
| x > 0 = 1
| x == 0 = 0
| otherwise = 1
givensCoef :: (Ord a, Floating a) => a -> a -> (a, a, a)
givensCoef a b
| b==0 = (sign a, 0, abs a)
| a==0 = (0, sign b, abs b)
| abs a > abs b = let t = b/a
u = sign a * abs ( sqrt (1 + t**2))
in (1/u, t/u, a*u)
| otherwise = let t = a/b
u = sign b * abs ( sqrt (1 + t**2))
in (t/u, 1/u, b*u)
givens :: SpMatrix Double -> IxRow -> IxCol -> SpMatrix Double
givens mm i j
| validIxSM mm (i,j) && isSquareSM mm =
sparsifySM $ fromListSM' [(i,i,c),(j,j,c),(j,i,s),(i,j,s)] (eye (nrows mm))
| otherwise = error "givens : indices out of bounds"
where
(c, s, _) = givensCoef a b
i' = head $ fromMaybe (error $ "givens: no compatible rows for entry " ++ show (i,j)) (candidateRows (immSM mm) i j)
a = mm @@ (i', j)
b = mm @@ (i, j)
firstNonZeroColumn :: IM.IntMap a -> IxRow -> Bool
firstNonZeroColumn mm k = isJust (IM.lookup k mm) &&
isNothing (IM.lookupLT k mm)
candidateRows :: IM.IntMap (IM.IntMap a) -> IxRow -> IxCol -> Maybe [IM.Key]
candidateRows mm i j | IM.null u = Nothing
| otherwise = Just (IM.keys u) where
u = IM.filterWithKey (\irow row -> irow /= i &&
firstNonZeroColumn row j) mm
qr :: SpMatrix Double -> (SpMatrix Double, SpMatrix Double)
qr mm = (transposeSM qmatt, rmat) where
qmatt = F.foldl' (#~#) ee $ gmats mm
rmat = qmatt #~# mm
ee = eye (nrows mm)
gmats :: SpMatrix Double -> [SpMatrix Double]
gmats mm = gm mm (subdiagIndicesSM mm) where
gm m ((i,j):is) = let g = givens m i j
in g : gm (g #~# m) is
gm _ [] = []
eigsQR :: Int -> SpMatrix Double -> SpVector Double
eigsQR nitermax m = extractDiagonalDSM $ execState (convergtest eigsStep) m where
eigsStep m = r #~# q where (q, r) = qr m
convergtest g = modifyInspectN nitermax f g where
f [m1, m2] = let dm1 = extractDiagonalDSM m1
dm2 = extractDiagonalDSM m2
in norm2 (dm1 ^-^ dm2) <= eps
eigRayleigh :: Int
-> SpMatrix Double
-> (SpVector Double, Double)
-> (SpVector Double, Double)
eigRayleigh nitermax m = execState (convergtest (rayleighStep m)) where
convergtest g = modifyInspectN nitermax f g where
f [(b1, _), (b2, _)] = norm2 (b2 ^-^ b1) <= eps
rayleighStep aa (b, mu) = (b', mu') where
ii = eye (nrows aa)
nom = (aa ^-^ (mu `matScale` ii)) <\> b
b' = normalize 2 nom
mu' = b' `dot` (aa #> b') / (b' `dot` b')
hhV :: SpVector Double -> (SpVector Double, Double)
hhV x = (v, beta) where
n = dim x
tx = tailSV x
sigma = tx `dot` tx
vtemp = singletonSV 1 `concatSV` tx
(v, beta) | sigma <= eps = (vtemp, 0)
| otherwise = let mu = sqrt (headSV x**2 + sigma)
xh = headSV x
vh | xh <= 1 = xh mu
| otherwise = sigma / (xh + mu)
vnew = (1 / vh) .* insertSpVector 0 vh vtemp
in (vnew, 2 * xh**2 / (sigma + vh**2))
permutAA :: Num b => SpMatrix a -> IxRow -> IxCol -> Maybe (SpMatrix b)
permutAA (SM (nro,_) mm) iref jref
| isJust (lookupIM2 iref jref mm) = Nothing
| otherwise = Just $ permutationSM nro [head u] where
u = IM.keys (ifilterIM2 ff mm)
ff i j _ = i /= iref &&
j == jref
bcgStep :: SpMatrix Double -> BCG -> BCG
bcgStep aa (BCG x r rhat p phat) = BCG x1 r1 rhat1 p1 phat1 where
aap = aa #> p
alpha = (r `dot` rhat) / (aap `dot` phat)
x1 = x ^+^ (alpha .* p)
r1 = r ^-^ (alpha .* aap)
rhat1 = rhat ^-^ (alpha .* (transposeSM aa #> phat))
beta = (r1 `dot` rhat1) / (r `dot` rhat)
p1 = r1 ^+^ (beta .* p)
phat1 = rhat1 ^+^ (beta .* phat)
data BCG =
BCG { _xBcg, _rBcg, _rHatBcg, _pBcg, _pHatBcg :: SpVector Double } deriving Eq
bcg :: SpMatrix Double -> SpVector Double -> SpVector Double -> BCG
bcg aa b x0 = execState (untilConverged _xBcg (bcgStep aa)) bcgInit where
r0 = b ^-^ (aa #> x0)
r0hat = r0
p0 = r0
p0hat = r0
bcgInit = BCG x0 r0 r0hat p0 p0hat
instance Show BCG where
show (BCG x r rhat p phat) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"r_hat = " ++ show rhat ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"p_hat = " ++ show phat ++ "\n"
cgsStep :: SpMatrix Double -> SpVector Double -> CGS -> CGS
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
where
aap = aa #> p
alphaj = (r `dot` rhat) / (aap `dot` rhat)
q = u ^-^ (alphaj .* aap)
xj1 = x ^+^ (alphaj .* (u ^+^ q))
rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q)))
betaj = (rj1 `dot` rhat) / (r `dot` rhat)
uj1 = rj1 ^+^ (betaj .* q)
pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))
data CGS = CGS { _x, _r, _p, _u :: SpVector Double} deriving Eq
cgs ::
SpMatrix Double ->
SpVector Double ->
SpVector Double ->
SpVector Double ->
CGS
cgs aa b x0 rhat =
execState (untilConverged _x (cgsStep aa rhat)) cgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
u0 = r0
cgsInit = CGS x0 r0 p0 u0
instance Show CGS where
show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"u = " ++ show u ++ "\n"
bicgstabStep :: SpMatrix Double -> SpVector Double -> BICGSTAB -> BICGSTAB
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
aap = aa #> p
alphaj = (r `dot` r0hat) / (aap `dot` r0hat)
sj = r ^-^ (alphaj .* aap)
aasj = aa #> sj
omegaj = (aasj `dot` sj) / (aasj `dot` aasj)
xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)
rj1 = sj ^-^ (omegaj .* aasj)
betaj = (rj1 `dot` r0hat)/(r `dot` r0hat) * alphaj / omegaj
pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))
data BICGSTAB =
BICGSTAB { _xBicgstab, _rBicgstab, _pBicgstab :: SpVector Double} deriving Eq
bicgstab
:: SpMatrix Double
-> SpVector Double
-> SpVector Double
-> SpVector Double
-> BICGSTAB
bicgstab aa b x0 r0hat =
execState (untilConverged _xBicgstab (bicgstabStep aa r0hat)) bicgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
bicgsInit = BICGSTAB x0 r0 p0
instance Show BICGSTAB where
show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
pinv :: SpMatrix Double -> SpVector Double -> SpVector Double
pinv aa b = aa #^# aa <\> atb where
atb = transposeSM aa #> b
data LinSolveMethod = BCG_ | CGS_ | BICGSTAB_ deriving (Eq, Show)
linSolve ::
LinSolveMethod -> SpMatrix Double -> SpVector Double -> SpVector Double
linSolve method aa b
| n /= nb = error "linSolve : operand dimensions mismatch"
| otherwise = solve aa b where
solve aa' b' | isDiagonalSM aa' = reciprocal aa' #> b'
| otherwise = solveWith aa' b'
solveWith aa' b' = case method of
BCG_ -> _xBcg (bcg aa' b' x0)
CGS_ -> _xBicgstab (bicgstab aa' b' x0 x0)
BICGSTAB_ -> _x (cgs aa' b' x0 x0)
x0 = mkSpVectorD n $ replicate n 0.1
(m, n) = dim aa
nb = dim b
(<\>) :: SpMatrix Double -> SpVector Double -> SpVector Double
(<\>) = linSolve BICGSTAB_
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = do
x <- get
let y = f x
put y
if q y then return y
else modifyUntil q f
loopUntilAcc :: Int -> ([t] -> Bool) -> (t -> t) -> t -> t
loopUntilAcc nitermax q f x = go 0 [] x where
go i ll xx | length ll < 2 = go (i + 1) (y : ll) y
| otherwise = if q ll || i == nitermax
then xx
else go (i + 1) (take 2 $ y:ll) y
where y = f xx
modifyInspectN ::
MonadState s m =>
Int ->
([s] -> Bool) ->
(s -> s) ->
m s
modifyInspectN nitermax q f
| nitermax > 0 = go 0 []
| otherwise = error "modifyInspectN : n must be > 0" where
go i ll = do
x <- get
let y = f x
if length ll < 2
then do put y
go (i + 1) (y : ll)
else if q ll || i == nitermax
then do put y
return y
else do put y
go (i + 1) (take 2 $ y : ll)
meanl :: (Foldable t, Fractional a) => t a -> a
meanl xx = 1/fromIntegral (length xx) * sum xx
norm2l :: (Foldable t, Functor t, Floating a) => t a -> a
norm2l xx = sqrt $ sum (fmap (**2) xx)
diffSqL :: Floating a => [a] -> a
diffSqL xx = (x1 x2)**2 where [x1, x2] = [head xx, xx!!1]
untilConverged :: MonadState a m => (a -> SpVector Double) -> (a -> a) -> m a
untilConverged fproj = modifyInspectN 100 (normDiffConverged fproj)
normDiffConverged :: (Foldable t, Functor t) =>
(a -> SpVector Double) -> t a -> Bool
normDiffConverged fp xx = normSq (foldrMap fp (^-^) (zeroSV 0) xx) <= eps
runAppendN :: ([t] -> Bool) -> (t -> t) -> Int -> t -> [t]
runAppendN qq ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go qq ff niter x0 [] where
go q f n z xs =
let x = f z in
if n <= 0 || q xs then xs
else go q f (n1) x (x : xs)
runAppendN' :: (t -> t) -> Int -> t -> [t]
runAppendN' ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go ff niter x0 [] where
go f n z xs =
let x = f z in
if n <= 0 then xs
else go f (n1) x (x : xs)
randMat :: PrimMonad m => Int -> m (SpMatrix Double)
randMat n = do
g <- MWC.create
aav <- replicateM (n^2) (MWC.normal 0 1 g)
let ii_ = [0 .. n1]
(ix_,iy_) = unzip $ concatMap (zip ii_ . replicate n) ii_
return $ fromListSM (n,n) $ zip3 ix_ iy_ aav
randVec :: PrimMonad m => Int -> m (SpVector Double)
randVec n = do
g <- MWC.create
bv <- replicateM n (MWC.normal 0 1 g)
let ii_ = [0..n1]
return $ fromListSV n $ zip ii_ bv
randSpMat :: Int -> Int -> IO (SpMatrix Double)
randSpMat n nsp | nsp > n^2 = error "randSpMat : nsp must be < n^2 "
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
jj <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSM (n,n) $ zip3 ii jj aav
randSpVec :: Int -> Int -> IO (SpVector Double)
randSpVec n nsp | nsp > n = error "randSpVec : nsp must be < n"
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSV n $ zip ii aav
showNonZero :: (Show a, Num a, Eq a) => a -> String
showNonZero x = if x == 0 then " " else show x
toDenseRow :: Num a => SpMatrix a -> IM.Key -> [a]
toDenseRow (SM (_,ncol) im) irow =
fmap (\icol -> im `lookupWD_IM` (irow,icol)) [0..ncol1]
toDenseRowClip :: (Show a, Num a) => SpMatrix a -> IM.Key -> Int -> String
toDenseRowClip sm irow ncomax
| ncols sm > ncomax = unwords (map show h) ++ " ... " ++ show t
| otherwise = show dr
where dr = toDenseRow sm irow
h = take (ncomax 2) dr
t = last dr
newline :: IO ()
newline = putStrLn ""
printDenseSM :: (Show t, Num t) => SpMatrix t -> IO ()
printDenseSM sm = do
newline
putStrLn $ sizeStr sm
newline
printDenseSM' sm 5 5
newline
where
printDenseSM' :: (Show t, Num t) => SpMatrix t -> Int -> Int -> IO ()
printDenseSM' sm'@(SM (nr,_) _) nromax ncomax = mapM_ putStrLn rr_' where
rr_ = map (\i -> toDenseRowClip sm' i ncomax) [0..nr 1]
rr_' | nrows sm > nromax = take (nromax 2) rr_ ++ [" ... "] ++[last rr_]
| otherwise = rr_
toDenseListClip :: (Show a, Num a) => SpVector a -> Int -> String
toDenseListClip sv ncomax
| dim sv > ncomax = unwords (map show h) ++ " ... " ++ show t
| otherwise = show dr
where dr = toDenseListSV sv
h = take (ncomax 2) dr
t = last dr
printDenseSV :: (Show t, Num t) => SpVector t -> IO ()
printDenseSV sv = do
newline
printDenseSV' sv 5
newline where
printDenseSV' v nco = putStrLn rr_' where
rr_ = toDenseListClip v nco :: String
rr_' | dim sv > nco = unwords [take (nco 2) rr_ , " ... " , [last rr_]]
| otherwise = rr_
class PrintDense a where
prd :: a -> IO ()
instance (Show a, Num a) => PrintDense (SpVector a) where
prd = printDenseSV
instance (Show a, Num a) => PrintDense (SpMatrix a) where
prd = printDenseSM