module Math.Linear.Sparse where
import Math.Linear.Sparse.IntMap
import Control.Monad.Primitive
import Control.Monad (mapM_, forM_, replicateM)
import Control.Monad.State.Strict
import qualified Data.IntMap.Strict as IM
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC
import Data.Monoid
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Data.Maybe
class Functor f => Additive f where
zero :: Num a => f a
(^+^) :: Num a => f a -> f a -> f a
negated :: (Num a, Functor f) => f a -> f a
negated = fmap negate
(^-^) :: (Additive f, Num a) => f a -> f a -> f a
x ^-^ y = x ^+^ negated y
class Additive f => VectorSpace f where
(.*) :: Num a => a -> f a -> f a
lerp :: (VectorSpace f, Num a) => a -> f a -> f a -> f a
lerp a u v = a .* u ^+^ ((1a) .* v)
class VectorSpace f => Hilbert f where
dot :: Num a => f a -> f a -> a
class Hilbert f => Normed f where
norm :: (Floating a, Eq a) => a -> f a -> a
normSq :: (Hilbert f, Num a) => f a -> a
normSq v = v `dot` v
norm1 :: (Foldable t, Num a, Functor t) => t a -> a
norm1 v = sum (fmap abs v)
norm2 :: (Hilbert f, Floating a) => f a -> a
norm2 v = sqrt (normSq v)
normP :: (Foldable t, Functor t, Floating a) => a -> t a -> a
normP p v = sum u**(1/p) where
u = fmap (**p) v
normInfty :: (Foldable t, Ord a) => t a -> a
normInfty = maximum
normalize :: (Normed f, Floating a, Eq a) => a -> f a -> f a
normalize p v = (1 / norm p v) .* v
dotLp :: (Set t, Foldable t, Floating a) => a -> t a -> t a -> a
dotLp p v1 v2 = sum u**(1/p) where
f a b = (a*b)**p
u = liftI2 f v1 v2
reciprocal :: (Functor f, Fractional b) => f b -> f b
reciprocal = fmap recip
scale :: (Num b, Functor f) => b -> f b -> f b
scale n = fmap (* n)
class Additive f => FiniteDim f where
type FDSize f :: *
dim :: f a -> FDSize f
withDim :: (FiniteDim f, Show e) =>
f a
-> (FDSize f -> f a -> Bool)
-> (f a -> c)
-> String
-> (f a -> e)
-> c
withDim x p f e ef | p (dim x) x = f x
| otherwise = error e' where e' = e ++ show (ef x)
withDim2 :: (FiniteDim f, FiniteDim g, Show e) =>
f a
-> g b
-> (FDSize f -> FDSize g -> f a -> g b -> Bool)
-> (f a -> g b -> c)
-> String
-> (f a -> g b -> e)
-> c
withDim2 x y p f e ef | p (dim x) (dim y) x y = f x y
| otherwise = error e' where e' = e ++ show (ef x y)
class Additive f => HasData f a where
type HDData f a :: *
dat :: f a -> HDData f a
class (FiniteDim f, HasData f a) => Sparse f a where
spy :: Fractional b => f a -> b
class Functor f => Set f where
liftU2 :: (a -> a -> a) -> f a -> f a -> f a
liftI2 :: (a -> b -> c) -> f a -> f b -> f c
instance Set IM.IntMap where
liftU2 = IM.unionWith
liftI2 = IM.intersectionWith
instance Additive IM.IntMap where
zero = IM.empty
(^+^) = liftU2 (+)
instance VectorSpace IM.IntMap where
n .* im = IM.map (* n) im
instance Hilbert IM.IntMap where
a `dot` b = sum $ liftI2 (*) a b
instance Normed IM.IntMap where
norm p v | p==1 = norm1 v
| p==2 = norm2 v
| otherwise = normP p v
data SpVector a = SV { svDim :: Int ,
svData :: IM.IntMap a} deriving Eq
spySV :: Fractional b => SpVector a -> b
spySV s = fromIntegral (IM.size (dat s)) / fromIntegral (dim s)
instance Functor SpVector where
fmap f (SV n x) = SV n (fmap f x)
instance Set SpVector where
liftU2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftU2 f2 x1 x2)
liftI2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftI2 f2 x1 x2)
instance Foldable SpVector where
foldr f d v = F.foldr f d (svData v)
instance Additive SpVector where
zero = SV 0 IM.empty
(^+^) = liftU2 (+)
instance VectorSpace SpVector where
n .* v = scale n v
instance FiniteDim SpVector where
type FDSize SpVector = Int
dim = svDim
instance HasData SpVector a where
type HDData SpVector a = IM.IntMap a
dat = svData
instance Sparse SpVector a where
spy = spySV
instance Hilbert SpVector where
a `dot` b | dim a == dim b = dot (dat a) (dat b)
| otherwise =
error $ "dot : sizes must coincide, instead we got " ++
show (dim a, dim b)
instance Normed SpVector where
norm p (SV _ v) = norm p v
zeroSV :: Int -> SpVector a
zeroSV n = SV n IM.empty
singletonSV :: a -> SpVector a
singletonSV x = SV 1 (IM.singleton 0 x)
mkSpVector :: (Num a, Eq a) => Int -> IM.IntMap a -> SpVector a
mkSpVector d im = SV d $ IM.filterWithKey (\k v -> v /= 0 && inBounds0 d k) im
mkSpVectorD :: (Num a, Eq a) => Int -> [a] -> SpVector a
mkSpVectorD d ll = mkSpVector d (IM.fromList $ denseIxArray (take d ll))
mkSpVector1 :: Int -> IM.IntMap a -> SpVector a
mkSpVector1 d ll = SV d $ IM.filterWithKey (\ k _ -> inBounds0 d k) ll
fromListDenseSV :: Int -> [a] -> SpVector a
fromListDenseSV d ll = SV d (IM.fromList $ denseIxArray (take d ll))
onesSV :: Num a => Int -> SpVector a
onesSV d = SV d $ IM.fromList $ denseIxArray $ replicate d 1
zerosSV :: Num a => Int -> SpVector a
zerosSV d = SV d $ IM.fromList $ denseIxArray $ replicate d 0
insertSpVector :: Int -> a -> SpVector a -> SpVector a
insertSpVector i x (SV d xim)
| inBounds0 d i = SV d (IM.insert i x xim)
| otherwise = error "insertSpVector : index out of bounds"
fromListSV :: Int -> [(Int, a)] -> SpVector a
fromListSV d iix = SV d (IM.fromList (filter (inBounds0 d . fst) iix ))
toListSV :: SpVector a -> [(IM.Key, a)]
toListSV sv = IM.toList (dat sv)
toDenseListSV :: Num b => SpVector b -> [b]
toDenseListSV (SV d im) = fmap (\i -> IM.findWithDefault 0 i im) [0 .. d1]
instance Show a => Show (SpVector a) where
show (SV d x) = "SV (" ++ show d ++ ") "++ show (IM.toList x)
lookupDenseSV :: Num a => IM.Key -> SpVector a -> a
lookupDenseSV i (SV _ im) = IM.findWithDefault 0 i im
tailSV :: SpVector a -> SpVector a
tailSV (SV n sv) = SV (n1) ta where
ta = IM.mapKeys (\i -> i 1) $ IM.delete 0 sv
headSV :: Num a => SpVector a -> a
headSV sv = fromMaybe 0 (IM.lookup 0 (dat sv))
concatSV :: SpVector a -> SpVector a -> SpVector a
concatSV (SV n1 s1) (SV n2 s2) = SV (n1+n2) (IM.union s1 s2') where
s2' = IM.mapKeys (+ n1) s2
svToSM :: SpVector a -> SpMatrix a
svToSM (SV n d) = SM (n, 1) $ IM.singleton 0 d
outerProdSV, (><) :: Num a => SpVector a -> SpVector a -> SpMatrix a
outerProdSV v1 v2 = fromListSM (m, n) ixy where
m = dim v1
n = dim v2
ixy = [(i,j, x * y) | (i,x) <- toListSV v1 , (j, y) <- toListSV v2]
(><) = outerProdSV
data SpMatrix a = SM {smDim :: (Rows, Cols),
smData :: IM.IntMap (IM.IntMap a)} deriving Eq
instance Show a => Show (SpMatrix a) where
show sm@(SM _ x) = "SM " ++ sizeStr sm ++ " "++ show (IM.toList x)
instance Functor SpMatrix where
fmap f (SM d md) = SM d ((fmap . fmap) f md)
instance Set SpMatrix where
liftU2 f2 (SM n1 x1) (SM n2 x2) = SM (maxTup n1 n2) ((liftU2.liftU2) f2 x1 x2)
liftI2 f2 (SM n1 x1) (SM n2 x2) = SM (minTup n1 n2) ((liftI2.liftI2) f2 x1 x2)
instance Additive SpMatrix where
zero = SM (0,0) IM.empty
(^+^) = liftU2 (+)
instance FiniteDim SpMatrix where
type FDSize SpMatrix = (Rows, Cols)
dim = smDim
instance HasData SpMatrix a where
type HDData SpMatrix a = IM.IntMap (IM.IntMap a)
dat = smData
instance Sparse SpMatrix a where
spy = spySM
maxTup, minTup :: Ord t => (t, t) -> (t, t) -> (t, t)
maxTup (x1,y1) (x2,y2) = (max x1 x2, max y1 y2)
minTup (x1,y1) (x2,y2) = (min x1 x2, min y1 y2)
emptySpMatrix :: (Int, Int) -> SpMatrix a
emptySpMatrix d = SM d IM.empty
matScale :: Num a => a -> SpMatrix a -> SpMatrix a
matScale a = fmap (*a)
normFrobenius :: SpMatrix Double -> Double
normFrobenius m = sqrt $ foldlSM (+) 0 m' where
m' | nrows m > ncols m = transposeSM m ## m
| otherwise = m ## transposeSM m
type Rows = Int
type Cols = Int
type IxRow = Int
type IxCol = Int
validIxSM :: SpMatrix a -> (Int, Int) -> Bool
validIxSM mm = inBounds02 (dim mm)
isSquareSM :: SpMatrix a -> Bool
isSquareSM m = nrows m == ncols m
isDiagonalSM :: SpMatrix a -> Bool
isDiagonalSM m = IM.size d == nrows m where
d = IM.filterWithKey ff (immSM m)
ff irow row = IM.size row == 1 &&
IM.size (IM.filterWithKey (\j _ -> j == irow) row) == 1
immSM :: SpMatrix t -> IM.IntMap (IM.IntMap t)
immSM (SM _ imm) = imm
dimSM :: SpMatrix t -> (Rows, Cols)
dimSM (SM d _) = d
nelSM :: SpMatrix t -> Int
nelSM (SM (nr,nc) _) = nr*nc
nrows :: SpMatrix a -> Rows
nrows = fst . dim
ncols :: SpMatrix a -> Cols
ncols = snd . dim
data SMInfo = SMInfo { smNz :: Int,
smSpy :: Double} deriving (Eq, Show)
infoSM :: SpMatrix a -> SMInfo
infoSM s = SMInfo (nzSM s) (spySM s)
nzSM :: SpMatrix a -> Int
nzSM s = sum $ fmap IM.size (immSM s)
spySM :: Fractional b => SpMatrix a -> b
spySM s = fromIntegral (nzSM s) / fromIntegral (nelSM s)
nzRow :: SpMatrix a -> IM.Key -> Int
nzRow s i | inBounds0 (nrows s) i = nzRowU s i
| otherwise = error "nzRow : index out of bounds" where
nzRowU :: SpMatrix a -> IM.Key -> Int
nzRowU s i = maybe 0 IM.size (IM.lookup i $ immSM s)
bwMinSM :: SpMatrix a -> Int
bwMinSM = fst . bwBoundsSM
bwMaxSM :: SpMatrix a -> Int
bwMaxSM = snd . bwBoundsSM
bwBoundsSM :: SpMatrix a -> (Int, Int)
bwBoundsSM s =
(snd $ IM.findMin b,
snd $ IM.findMax b)
where
ss = immSM s
fmi = fst . IM.findMin
fma = fst . IM.findMax
b = fmap (\x -> fma x fmi x + 1:: Int) ss
zeroSM :: Int -> Int -> SpMatrix a
zeroSM m n = SM (m,n) IM.empty
insertSpMatrix :: IxRow -> IxCol -> a -> SpMatrix a -> SpMatrix a
insertSpMatrix i j x s
| inBounds02 d (i,j) = SM d $ insertIM2 i j x smd
| otherwise = error "insertSpMatrix : index out of bounds" where
smd = immSM s
d = dim s
fromListSM' :: Foldable t => t (IxRow, IxCol, a) -> SpMatrix a -> SpMatrix a
fromListSM' iix sm = foldl ins sm iix where
ins t (i,j,x) = insertSpMatrix i j x t
fromListSM :: Foldable t => (Int, Int) -> t (IxRow, IxCol, a) -> SpMatrix a
fromListSM (m,n) iix = fromListSM' iix (zeroSM m n)
fromListDenseSM :: Int -> [a] -> SpMatrix a
fromListDenseSM m ll = fromListSM (m, n) $ denseIxArray2 m ll where
n = length ll `div` m
toDenseListSM :: Num t => SpMatrix t -> [(IxRow, IxCol, t)]
toDenseListSM m =
[(i, j, m @@ (i, j)) | i <- [0 .. nrows m 1], j <- [0 .. ncols m 1]]
mkDiagonal :: Int -> [a] -> SpMatrix a
mkDiagonal n = mkSubDiagonal n 0
eye :: Num a => Int -> SpMatrix a
eye n = mkDiagonal n (replicate n 1)
mkSubDiagonal :: Int -> Int -> [a] -> SpMatrix a
mkSubDiagonal n o xx | abs o < n = if o >= 0
then fz ii jj xx
else fz jj ii xx
| otherwise = error "mkSubDiagonal : offset > dimension" where
ii = [0 .. n1]
jj = [abs o .. n 1]
fz a b x = fromListSM (n,n) (zip3 a b x)
extractSubmatrixSM :: SpMatrix a -> (IxRow, IxCol) -> (IxRow, IxCol) -> SpMatrix a
extractSubmatrixSM (SM (r, c) im) (i1, i2) (j1, j2)
| q = SM (m', n') imm'
| otherwise = error $ "extractSubmatrixSM : invalid indexing " ++ show (i1, i2) ++ ", " ++ show (j1, j2) where
imm' = mapKeysIM2 (\i -> i i1) (\j -> j j1) $
IM.filter (not . IM.null) $
ifilterIM2 ff im
ff i j _ = i1 <= i &&
i <= i2 &&
j1 <= j &&
j <= j2
(m', n') = (i2i1 + 1, j2j1 + 1)
q = inBounds0 r i1 &&
inBounds0 r i2 &&
inBounds0 c j1 &&
inBounds0 c j2 &&
i2 >= i1
toSV :: SpMatrix a -> SpVector a
toSV (SM (m,n) im) = SV d $ snd . head $ IM.toList im where
d | m==1 && n==1 = 1
| m==1 && n>1 = n
| n==1 && m>1 = m
| otherwise = error $ "toSV : incompatible dimensions " ++ show (m,n)
extractColSM :: SpMatrix a -> IxCol -> SpMatrix a
extractColSM sm j = extractSubmatrixSM sm (0, nrows sm 1) (j, j)
extractCol :: SpMatrix a -> IxCol -> SpVector a
extractCol m j = toSV $ extractColSM m j
extractRowSM :: SpMatrix a -> IxRow -> SpMatrix a
extractRowSM sm i = extractSubmatrixSM sm (i, i) (0, ncols sm 1)
extractRow :: SpMatrix a -> IxRow -> SpVector a
extractRow m i = toSV $ extractRowSM m i
vertStackSM, (-=-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
vertStackSM mm1 mm2 = SM (m, n) $ IM.union u1 u2 where
nro1 = nrows mm1
m = nro1 + nrows mm2
n = max (ncols mm1) (ncols mm2)
u1 = immSM mm1
u2 = IM.mapKeys (+ nro1) (immSM mm2)
(-=-) = vertStackSM
horizStackSM, (-||-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
horizStackSM mm1 mm2 = t (t mm1 -=- t mm2) where
t = transposeSM
(-||-) = horizStackSM
lookupSM :: SpMatrix a -> IxRow -> IxCol -> Maybe a
lookupSM (SM _ im) i j = IM.lookup i im >>= IM.lookup j
lookupWD_SM, (@@) :: Num a => SpMatrix a -> (IxRow, IxCol) -> a
lookupWD_SM sm (i,j) =
fromMaybe 0 (lookupSM sm i j)
lookupWD_IM :: Num a => IM.IntMap (IM.IntMap a) -> (IxRow, IxCol) -> a
lookupWD_IM im (i,j) = fromMaybe 0 (IM.lookup i im >>= IM.lookup j)
(@@) = lookupWD_SM
foldlSM :: (a -> b -> b) -> b -> SpMatrix a -> b
foldlSM f n (SM _ m)= foldlIM2 f n m
ifoldlSM :: (IM.Key -> IM.Key -> a -> b -> b) -> b -> SpMatrix a -> b
ifoldlSM f n (SM _ m) = ifoldlIM2' f n m
countSubdiagonalNZSM :: SpMatrix a -> Int
countSubdiagonalNZSM (SM _ im) = countSubdiagonalNZ im
extractDiagonalDSM :: Num a => SpMatrix a -> SpVector a
extractDiagonalDSM mm = fromListDenseSV n $ foldr ins [] ll where
ll = [0 .. n 1]
n = nrows mm
ins i acc = mm@@(i,i) : acc
subdiagIndicesSM :: SpMatrix a -> [(IM.Key, IM.Key)]
subdiagIndicesSM (SM _ im) = subdiagIndices im
sparsifyIM2 :: IM.IntMap (IM.IntMap Double) -> IM.IntMap (IM.IntMap Double)
sparsifyIM2 = ifilterIM2 (\_ _ x -> abs x >= eps)
sparsifySM :: SpMatrix Double -> SpMatrix Double
sparsifySM (SM d im) = SM d $ sparsifyIM2 im
roundZeroOneSM :: SpMatrix Double -> SpMatrix Double
roundZeroOneSM (SM d im) = sparsifySM $ SM d $ mapIM2 roundZeroOne im
transposeSM, (#^) :: SpMatrix a -> SpMatrix a
transposeSM (SM (m, n) im) = SM (n, m) (transposeIM2 im)
(#^) = transposeSM
matVec, (#>) :: Num a => SpMatrix a -> SpVector a -> SpVector a
matVec (SM (nr, nc) mdata) (SV n sv)
| nc == n = SV nr $ fmap (`dot` sv) mdata
| otherwise = error $ "matVec : mismatching dimensions " ++ show (nc, n)
(#>) = matVec
vecMat, (<#) :: Num a => SpVector a -> SpMatrix a -> SpVector a
vecMat (SV n sv) (SM (nr, nc) mdata)
| n == nr = SV nc $ fmap (`dot` sv) (transposeIM2 mdata)
| otherwise = error $ "vecMat : mismatching dimensions " ++ show (n, nr)
(<#) = vecMat
matMat, (##) :: Num a => SpMatrix a -> SpMatrix a -> SpMatrix a
matMat m1 m2
| c1 == r2 = matMatU m1 m2
| otherwise = error $ "matMat : incompatible matrix sizes" ++ show (d1, d2) where
d1@(r1, c1) = dim m1
d2@(r2, c2) = dim m2
matMatU :: Num a => SpMatrix a -> SpMatrix a -> SpMatrix a
matMatU m1 m2 =
SM (nrows m1, ncols m2) im where
im = fmap (\vm1 -> (`dot` vm1) <$> transposeIM2 (immSM m2)) (immSM m1)
(##) = matMat
matMatSparsified, (#~#) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
matMatSparsified m1 m2 = sparsifySM $ matMat m1 m2
(#~#) = matMatSparsified
(#^#) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
a #^# b = transposeSM a #~# b
(##^) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
a ##^ b = a #~# transposeSM b
isOrthogonalSM :: SpMatrix Double -> Bool
isOrthogonalSM sm@(SM (_,n) _) = rsm == eye n where
rsm = roundZeroOneSM $ transposeSM sm ## sm
conditionNumberSM :: SpMatrix Double -> Double
conditionNumberSM m | isInfinite kappa = error "Infinite condition number : rank-deficient system"
| otherwise = kappa where
kappa = lmax / lmin
(_, r) = qr m
u = extractDiagonalDSM r
lmax = abs (maximum u)
lmin = abs (minimum u)
hhMat :: Num a => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ scale beta (x >< x) where
n = dim x
hhRefl :: SpVector Double -> SpMatrix Double
hhRefl = hhMat 2.0
hypot :: Floating a => a -> a -> a
hypot x y = abs x * (sqrt (1 + y/x)**2)
sign :: (Ord a, Num a) => a -> a
sign x
| x > 0 = 1
| x == 0 = 0
| otherwise = 1
givensCoef :: (Ord a, Floating a) => a -> a -> (a, a, a)
givensCoef a b
| b==0 = (sign a, 0, abs a)
| a==0 = (0, sign b, abs b)
| abs a > abs b = let t = b/a
u = sign a * abs ( sqrt (1 + t**2))
in (1/u, t/u, a*u)
| otherwise = let t = a/b
u = sign b * abs ( sqrt (1 + t**2))
in (t/u, 1/u, b*u)
givens :: SpMatrix Double -> IxRow -> IxCol -> SpMatrix Double
givens mm i j
| validIxSM mm (i,j) && isSquareSM mm =
sparsifySM $ fromListSM' [(i,i,c),(j,j,c),(j,i,s),(i,j,s)] (eye (nrows mm))
| otherwise = error "givens : indices out of bounds"
where
(c, s, _) = givensCoef a b
i' = head $ fromMaybe (error $ "givens: no compatible rows for entry " ++ show (i,j)) (candidateRows (immSM mm) i j)
a = mm @@ (i', j)
b = mm @@ (i, j)
firstNonZeroColumn :: IM.IntMap a -> IxRow -> Bool
firstNonZeroColumn mm k = isJust (IM.lookup k mm) &&
isNothing (IM.lookupLT k mm)
candidateRows :: IM.IntMap (IM.IntMap a) -> IxRow -> IxCol -> Maybe [IM.Key]
candidateRows mm i j | IM.null u = Nothing
| otherwise = Just (IM.keys u) where
u = IM.filterWithKey (\irow row -> irow /= i &&
firstNonZeroColumn row j) mm
qr :: SpMatrix Double -> (SpMatrix Double, SpMatrix Double)
qr mm = (transposeSM qmatt, rmat) where
qmatt = F.foldl' (#~#) ee $ gmats mm
rmat = qmatt #~# mm
ee = eye (nrows mm)
gmats :: SpMatrix Double -> [SpMatrix Double]
gmats mm = gm mm (subdiagIndicesSM mm) where
gm m ((i,j):is) = let g = givens m i j
in g : gm (g #~# m) is
gm _ [] = []
eigsQR :: Int -> SpMatrix Double -> SpVector Double
eigsQR nitermax m = extractDiagonalDSM $ execState (convergtest eigsStep) m where
eigsStep m = r #~# q where (q, r) = qr m
convergtest g = modifyInspectN nitermax f g where
f [m1, m2] = let dm1 = extractDiagonalDSM m1
dm2 = extractDiagonalDSM m2
in norm2 (dm1 ^-^ dm2) <= eps
rayleighStep ::
SpMatrix Double ->
(SpVector Double, Double) ->
(SpVector Double, Double)
rayleighStep aa (b, mu) = (b', mu') where
ii = eye (nrows aa)
nom = (aa ^-^ (mu `matScale` ii)) <\> b
b' = normalize 2 nom
mu' = b' `dot` (aa #> b') / (b' `dot` b')
eigRayleigh :: Int
-> SpMatrix Double
-> (SpVector Double, Double)
-> (SpVector Double, Double)
eigRayleigh nitermax m = execState (convergtest (rayleighStep m)) where
convergtest g = modifyInspectN nitermax f g where
f [(b1, _), (b2, _)] = norm2 (b2 ^-^ b1) <= eps
hhV :: SpVector Double -> (SpVector Double, Double)
hhV x = (v, beta) where
n = dim x
tx = tailSV x
sigma = tx `dot` tx
vtemp = singletonSV 1 `concatSV` tx
(v, beta) | sigma <= eps = (vtemp, 0)
| otherwise = let mu = sqrt (headSV x**2 + sigma)
xh = headSV x
vh | xh <= 1 = xh mu
| otherwise = sigma / (xh + mu)
vnew = (1 / vh) .* insertSpVector 0 vh vtemp
in (vnew, 2 * xh**2 / (sigma + vh**2))
eps :: Double
eps = 1e-8
residual :: Num a => SpMatrix a -> SpVector a -> SpVector a -> SpVector a
residual aa b x0 = b ^-^ (aa #> x0)
converged :: SpMatrix Double -> SpVector Double -> SpVector Double -> Bool
converged aa b x0 = normSq (residual aa b x0) <= eps
cgsStep :: SpMatrix Double -> SpVector Double -> CGS -> CGS
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
where
aap = aa #> p
alphaj = (r `dot` rhat) / (aap `dot` rhat)
q = u ^-^ (alphaj .* aap)
xj1 = x ^+^ (alphaj .* (u ^+^ q))
rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q)))
betaj = (rj1 `dot` rhat) / (r `dot` rhat)
uj1 = rj1 ^+^ (betaj .* q)
pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))
data CGS = CGS { _x :: SpVector Double,
_r :: SpVector Double,
_p :: SpVector Double,
_u :: SpVector Double } deriving Eq
cgs ::
SpMatrix Double ->
SpVector Double ->
SpVector Double ->
SpVector Double ->
CGS
cgs aa b x0 rhat =
execState (untilConverged _x (cgsStep aa rhat)) cgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
u0 = r0
cgsInit = CGS x0 r0 p0 u0
instance Show CGS where
show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"u = " ++ show u ++ "\n"
bicgstabStep :: SpMatrix Double -> SpVector Double -> BICGSTAB -> BICGSTAB
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
aap = aa #> p
alphaj = (r `dot` r0hat) / (aap `dot` r0hat)
sj = r ^-^ (alphaj .* aap)
aasj = aa #> sj
omegaj = (aasj `dot` sj) / (aasj `dot` aasj)
xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)
rj1 = sj ^-^ (omegaj .* aasj)
betaj = (rj1 `dot` r0hat)/(r `dot` r0hat) * alphaj / omegaj
pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))
data BICGSTAB = BICGSTAB { _xBicgstab :: SpVector Double,
_rBicgstab :: SpVector Double,
_pBicgstab :: SpVector Double} deriving Eq
bicgstab
:: SpMatrix Double
-> SpVector Double
-> SpVector Double
-> SpVector Double
-> BICGSTAB
bicgstab aa b x0 r0hat =
execState (untilConverged _xBicgstab (bicgstabStep aa r0hat)) bicgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
bicgsInit = BICGSTAB x0 r0 p0
instance Show BICGSTAB where
show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
data LinSolveMethod = CGS_ | BICGSTAB_ deriving (Eq, Show)
linSolveM ::
PrimMonad m =>
LinSolveMethod -> SpMatrix Double -> SpVector Double -> m (SpVector Double)
linSolveM method aa b = do
let (m, n) = dim aa
nb = dim b
if n /= nb then error "linSolve : operand dimensions mismatch" else do
x0 <- randVec nb
case method of CGS_ -> return $ _xBicgstab (bicgstab aa b x0 x0)
BICGSTAB_ -> return $ _x (cgs aa b x0 x0)
linSolve ::
LinSolveMethod -> SpMatrix Double -> SpVector Double -> SpVector Double
linSolve method aa b
| n /= nb = error "linSolve : operand dimensions mismatch"
| otherwise = solve aa b where
solve aa' b' | isDiagonalSM aa = (reciprocal aa') #> b'
| otherwise = solveWith aa' b'
solveWith aa' b' = case method of
CGS_ -> _xBicgstab (bicgstab aa' b' x0 x0)
BICGSTAB_ -> _x (cgs aa' b' x0 x0)
x0 = mkSpVectorD n $ replicate n 0.1
(m, n) = dim aa
nb = dim b
(<\>) :: SpMatrix Double -> SpVector Double -> SpVector Double
(<\>) = linSolve BICGSTAB_
sizeStr :: SpMatrix a -> String
sizeStr sm =
unwords ["(",show (nrows sm),"rows,",show (ncols sm),"columns ) ,",show nz,"NZ ( sparsity",show sy,")"] where
(SMInfo nz sy) = infoSM sm
showNonZero :: (Show a, Num a, Eq a) => a -> String
showNonZero x = if x == 0 then " " else show x
toDenseRow :: Num a => SpMatrix a -> IM.Key -> [a]
toDenseRow (SM (_,ncol) im) irow =
fmap (\icol -> im `lookupWD_IM` (irow,icol)) [0..ncol1]
toDenseRowClip :: (Show a, Num a) => SpMatrix a -> IM.Key -> Int -> String
toDenseRowClip sm irow ncomax
| ncols sm > ncomax = unwords (map show h) ++ " ... " ++ show t
| otherwise = show dr
where dr = toDenseRow sm irow
h = take (ncomax 2) dr
t = last dr
newline :: IO ()
newline = putStrLn ""
printDenseSM :: (Show t, Num t) => SpMatrix t -> IO ()
printDenseSM sm = do
newline
putStrLn $ sizeStr sm
newline
printDenseSM' sm 5 5
newline
where
printDenseSM' :: (Show t, Num t) => SpMatrix t -> Int -> Int -> IO ()
printDenseSM' sm'@(SM (nr,_) _) nromax ncomax = mapM_ putStrLn rr_' where
rr_ = map (\i -> toDenseRowClip sm' i ncomax) [0..nr 1]
rr_' | nrows sm > nromax = take (nromax 2) rr_ ++ [" ... "] ++[last rr_]
| otherwise = rr_
toDenseListClip :: (Show a, Num a) => SpVector a -> Int -> String
toDenseListClip sv ncomax
| dim sv > ncomax = unwords (map show h) ++ " ... " ++ show t
| otherwise = show dr
where dr = toDenseListSV sv
h = take (ncomax 2) dr
t = last dr
printDenseSV :: (Show t, Num t) => SpVector t -> IO ()
printDenseSV sv = do
newline
printDenseSV' sv 5
newline where
printDenseSV' v nco = putStrLn rr_' where
rr_ = toDenseListClip v nco :: String
rr_' | dim sv > nco = unwords [take (nco 2) rr_ , " ... " , [last rr_]]
| otherwise = rr_
class PrintDense a where
prd :: a -> IO ()
instance (Show a, Num a) => PrintDense (SpVector a) where
prd = printDenseSV
instance (Show a, Num a) => PrintDense (SpMatrix a) where
prd = printDenseSM
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = do
x <- get
let y = f x
put y
if q y then return y
else modifyUntil q f
loopUntilAcc :: Int -> ([t] -> Bool) -> (t -> t) -> t -> t
loopUntilAcc nitermax q f x = go 0 [] x where
go i ll xx | length ll < 2 = go (i + 1) (y : ll) y
| otherwise = if q ll || i == nitermax
then xx
else go (i + 1) (take 2 $ y:ll) y
where y = f xx
modifyInspectN ::
MonadState s m =>
Int ->
([s] -> Bool) ->
(s -> s) ->
m s
modifyInspectN nitermax q f
| nitermax > 0 = go 0 []
| otherwise = error "modifyInspectN : n must be > 0" where
go i ll = do
x <- get
let y = f x
if length ll < 2
then do put y
go (i + 1) (y : ll)
else if q ll || i == nitermax
then do put y
return y
else do put y
go (i + 1) (take 2 $ y : ll)
meanl :: (Foldable t, Fractional a) => t a -> a
meanl xx = 1/fromIntegral (length xx) * sum xx
norm2l :: (Foldable t, Functor t, Floating a) => t a -> a
norm2l xx = sqrt $ sum (fmap (**2) xx)
diffSqL :: Floating a => [a] -> a
diffSqL xx = (x1 x2)**2 where [x1, x2] = [head xx, xx!!1]
untilConverged :: MonadState a m => (a -> SpVector Double) -> (a -> a) -> m a
untilConverged fproj = modifyInspectN 100 (normDiffConverged fproj)
normDiffConverged :: (Foldable t, Functor t) =>
(a -> SpVector Double) -> t a -> Bool
normDiffConverged fp xx = normSq (foldrMap fp (^-^) (zeroSV 0) xx) <= eps
runAppendN :: ([t] -> Bool) -> (t -> t) -> Int -> t -> [t]
runAppendN qq ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go qq ff niter x0 [] where
go q f n z xs =
let x = f z in
if n <= 0 || q xs then xs
else go q f (n1) x (x : xs)
runAppendN' :: (t -> t) -> Int -> t -> [t]
runAppendN' ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go ff niter x0 [] where
go f n z xs =
let x = f z in
if n <= 0 then xs
else go f (n1) x (x : xs)
almostZero, almostOne :: Double -> Bool
almostZero x = abs x <= eps
almostOne x = x >= (1eps) && x < (1+eps)
withDefault :: (t -> Bool) -> t -> t -> t
withDefault q d x | q x = d
| otherwise = x
roundZero, roundOne :: Double -> Double
roundZero = withDefault almostZero 0
roundOne = withDefault almostOne 1
with2Defaults :: (t -> Bool) -> (t -> Bool) -> t -> t -> t -> t
with2Defaults q1 q2 d1 d2 x | q1 x = d1
| q2 x = d2
| otherwise = x
roundZeroOne :: Double -> Double
roundZeroOne = with2Defaults almostZero almostOne 0 1
randMat :: PrimMonad m => Int -> m (SpMatrix Double)
randMat n = do
g <- MWC.create
aav <- replicateM (n^2) (MWC.normal 0 1 g)
let ii_ = [0 .. n1]
(ix_,iy_) = unzip $ concatMap (zip ii_ . replicate n) ii_
return $ fromListSM (n,n) $ zip3 ix_ iy_ aav
randVec :: PrimMonad m => Int -> m (SpVector Double)
randVec n = do
g <- MWC.create
bv <- replicateM n (MWC.normal 0 1 g)
let ii_ = [0..n1]
return $ fromListSV n $ zip ii_ bv
randSpMat :: Int -> Int -> IO (SpMatrix Double)
randSpMat n nsp | nsp > n^2 = error "randSpMat : nsp must be < n^2 "
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
jj <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSM (n,n) $ zip3 ii jj aav
randSpVec :: Int -> Int -> IO (SpVector Double)
randSpVec n nsp | nsp > n = error "randSpVec : nsp must be < n"
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSV n $ zip ii aav
denseIxArray :: [b] -> [(Int, b)]
denseIxArray xs = zip [0..length xs1] xs
denseIxArray2 :: Int -> [c] -> [(Int, Int, c)]
denseIxArray2 m xs = zip3 (concat $ replicate n ii_) jj_ xs where
ii_ = [0 .. m1]
jj_ = concatMap (replicate m) [0 .. n1]
ln = length xs
n = ln `div` m
foldrMap :: (Foldable t, Functor t) => (a -> b) -> (b -> c -> c) -> c -> t a -> c
foldrMap ff gg x0 = foldr gg x0 . fmap ff
foldlStrict :: (a -> b -> a) -> a -> [b] -> a
foldlStrict f = go
where
go z [] = z
go z (x:xs) = let z' = f z x in z' `seq` go z' xs
ifoldr :: Num i =>
(a -> b -> b) -> b -> (i -> c -> d -> a) -> c -> [d] -> b
ifoldr mjoin mneutral f = go 0 where
go i z (x:xs) = mjoin (f i z x) (go (i+1) z xs)
go _ _ [] = mneutral
type LB = Int
type UB = Int
inBounds :: LB -> UB -> Int -> Bool
inBounds ibl ibu i = i>= ibl && i<ibu
inBounds2 :: (LB, UB) -> (Int, Int) -> Bool
inBounds2 (ibl,ibu) (ix,iy) = inBounds ibl ibu ix && inBounds ibl ibu iy
inBounds0 :: UB -> Int -> Bool
inBounds0 = inBounds 0
inBounds02 :: (UB, UB) -> (Int, Int) -> Bool
inBounds02 (bx,by) (i,j) = inBounds0 bx i && inBounds0 by j