module Math.Linear.Sparse where
import Math.Linear.Sparse.IntMap
import Control.Monad.Primitive
import Control.Monad (mapM_, forM_, replicateM)
import Control.Monad.Loops
import Control.Monad.Cont
import Control.Monad.State.Strict
import Control.Monad.Writer
import Control.Monad.Trans
import Control.Monad.Trans.State (runStateT)
import Control.Monad.Trans.Writer (runWriterT)
import qualified Data.IntMap.Strict as IM
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC
import Data.Monoid
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import qualified Data.List as L
import Data.Maybe
class Functor f => Additive f where
zero :: Num a => f a
(^+^) :: Num a => f a -> f a -> f a
(^-^) :: Num a => f a -> f a -> f a
negated :: (Num a, Functor f) => f a -> f a
negated = fmap negate
x `minus` y = x ^+^ negated y
class Additive f => VectorSpace f where
(.*) :: Num a => a -> f a -> f a
lerp :: (VectorSpace f, Num a) => a -> f a -> f a -> f a
lerp a u v = a .* u ^+^ ((1a) .* v)
class VectorSpace f => Hilbert f where
dot :: Num a => f a -> f a -> a
class Hilbert f => Normed f where
norm :: (Floating a, Eq a) => a -> f a -> a
normSq :: (Hilbert f, Num a) => f a -> a
normSq v = v `dot` v
norm1 :: (Foldable t, Num a, Functor t) => t a -> a
norm1 v = sum (fmap abs v)
norm2 :: (Hilbert f, Floating a) => f a -> a
norm2 v = sqrt (normSq v)
normP :: (Foldable t, Functor t, Floating a) => a -> t a -> a
normP p v = sum u**(1/p) where
u = fmap (**p) v
normInfty :: (Foldable t, Ord a) => t a -> a
normInfty = maximum
normalize :: (Normed f, Floating a, Eq a) => a -> f a -> f a
normalize n v = (1 / norm n v) .* v
dotLp :: (Set t, Foldable t, Floating a) => a -> t a -> t a -> a
dotLp p v1 v2 = sum u**(1/p) where
f a b = (a*b)**p
u = liftI2 f v1 v2
reciprocal :: (Functor f, Fractional b) => f b -> f b
reciprocal = fmap recip
scale :: (Num b, Functor f) => b -> f b -> f b
scale n = fmap (* n)
class Additive f => FiniteDim f where
type FDSize f :: *
dim :: f a -> FDSize f
instance FiniteDim SpVector where
type FDSize SpVector = Int
dim = svDim
instance FiniteDim SpMatrix where
type FDSize SpMatrix = (Rows, Cols)
dim = smDim
withDim :: (FiniteDim f, Show e) =>
f a
-> (FDSize f -> f a -> Bool)
-> (f a -> c)
-> String
-> (f a -> e)
-> c
withDim x p f e ef | p (dim x) x = f x
| otherwise = error e' where e' = e ++ show (ef x)
withDim2 :: (FiniteDim f, FiniteDim g, Show e) =>
f a
-> g b
-> (FDSize f -> FDSize g -> f a -> g b -> Bool)
-> (f a -> g b -> c)
-> String
-> (f a -> g b -> e)
-> c
withDim2 x y p f e ef | p (dim x) (dim y) x y = f x y
| otherwise = error e' where e' = e ++ show (ef x y)
class Additive f => HasData f a where
type HDData f a :: *
dat :: f a -> HDData f a
instance HasData SpVector a where
type HDData SpVector a = IM.IntMap a
dat = svData
instance HasData SpMatrix a where
type HDData SpMatrix a = IM.IntMap (IM.IntMap a)
dat = smData
class (FiniteDim f, HasData f a) => Sparse f a where
spy :: Fractional b => f a -> b
instance Sparse SpVector a where
spy = spySV
instance Sparse SpMatrix a where
spy = spySM
class Functor f => Set f where
liftU2 :: (a -> a -> a) -> f a -> f a -> f a
liftI2 :: (a -> b -> c) -> f a -> f b -> f c
instance Set IM.IntMap where
liftU2 = IM.unionWith
liftI2 = IM.intersectionWith
instance Additive IM.IntMap where
zero = IM.empty
(^+^) = liftU2 (+)
x ^-^ y = x ^+^ negated y
instance VectorSpace IM.IntMap where
n .* im = IM.map (* n) im
instance Hilbert IM.IntMap where
a `dot` b = sum $ liftI2 (*) a b
instance Normed IM.IntMap where
norm p v | p==1 = norm1 v
| p==2 = norm2 v
| otherwise = normP p v
data SpVector a = SV { svDim :: Int ,
svData :: IM.IntMap a} deriving Eq
dimSV :: SpVector a -> Int
dimSV = svDim
spySV :: Fractional b => SpVector a -> b
spySV s = fromIntegral (IM.size (dat s)) / fromIntegral (svDim s)
imSV :: SpVector a -> IM.IntMap a
imSV = svData
instance Functor SpVector where
fmap f (SV n x) = SV n (fmap f x)
instance Set SpVector where
liftU2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftU2 f2 x1 x2)
liftI2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftI2 f2 x1 x2)
instance Foldable SpVector where
foldr f d v = F.foldr f d (svData v)
instance Additive SpVector where
zero = SV 0 IM.empty
(^+^) = liftU2 (+)
(^-^) = liftU2 ()
instance VectorSpace SpVector where
n .* v = scale n v
instance Hilbert SpVector where
a `dot` b | dim a == dim b = dot (dat a) (dat b)
| otherwise =
error $ "dot : sizes must coincide, instead we got " ++
show (dim a, dim b)
instance Normed SpVector where
norm p (SV _ v) = norm p v
zeroSV :: Int -> SpVector a
zeroSV n = SV n IM.empty
singletonSV :: a -> SpVector a
singletonSV x = SV 1 (IM.singleton 0 x)
mkSpVector :: (Num a, Eq a) => Int -> IM.IntMap a -> SpVector a
mkSpVector d im = SV d $ IM.filterWithKey (\k v -> v /= 0 && inBounds0 d k) im
mkSpVectorD :: (Num a, Eq a) => Int -> [a] -> SpVector a
mkSpVectorD d ll = mkSpVector d (IM.fromList $ denseIxArray (take d ll))
mkSpVector1 :: Int -> IM.IntMap a -> SpVector a
mkSpVector1 d ll = SV d $ IM.filterWithKey (\ k _ -> inBounds0 d k) ll
mkSpVector1D :: Int -> [a] -> SpVector a
mkSpVector1D d ll = mkSpVector1 d (IM.fromList $ denseIxArray (take d ll))
onesSV :: Num a => Int -> SpVector a
onesSV d = SV d $ IM.fromList $ denseIxArray $ replicate d 1
zerosSV :: Num a => Int -> SpVector a
zerosSV d = SV d $ IM.fromList $ denseIxArray $ replicate d 0
insertSpVector :: Int -> a -> SpVector a -> SpVector a
insertSpVector i x (SV d xim)
| inBounds0 d i = SV d (IM.insert i x xim)
| otherwise = error "insertSpVector : index out of bounds"
fromListSV :: Int -> [(Int, a)] -> SpVector a
fromListSV d iix = SV d (IM.fromList (filter (inBounds0 d . fst) iix ))
toListSV :: SpVector a -> [(IM.Key, a)]
toListSV sv = IM.toList (imSV sv)
toDenseListSV :: Num b => SpVector b -> [b]
toDenseListSV (SV d im) = fmap (\i -> IM.findWithDefault 0 i im) [0 .. d1]
instance Show a => Show (SpVector a) where
show (SV d x) = "SV (" ++ show d ++ ") "++ show (IM.toList x)
lookupDenseSV :: Num a => IM.Key -> SpVector a -> a
lookupDenseSV i (SV _ im) = IM.findWithDefault 0 i im
findWithDefault0IM :: Num a => IM.Key -> IM.IntMap a -> a
findWithDefault0IM = IM.findWithDefault 0
tailSV :: SpVector a -> SpVector a
tailSV (SV n sv) = SV (n1) ta where
ta = IM.mapKeys (\i -> i 1) $ IM.delete 0 sv
headSV :: Num a => SpVector a -> a
headSV sv = fromMaybe 0 (IM.lookup 0 (imSV sv))
concatSV :: SpVector a -> SpVector a -> SpVector a
concatSV (SV n1 s1) (SV n2 s2) = SV (n1+n2) (IM.union s1 s2') where
s2' = IM.mapKeys (+ n1) s2
svToSM :: SpVector a -> SpMatrix a
svToSM (SV n d) = SM (n, 1) $ IM.singleton 0 d
outerProdSV, (><) :: Num a => SpVector a -> SpVector a -> SpMatrix a
outerProdSV v1 v2 = fromListSM (m, n) ixy where
m = dim v1
n = dim v2
ixy = [(i,j, x * y) | (i,x) <- toListSV v1 , (j, y) <- toListSV v2]
(><) = outerProdSV
data SpMatrix a = SM {smDim :: (Rows, Cols),
smData :: IM.IntMap (IM.IntMap a)} deriving Eq
instance Show a => Show (SpMatrix a) where
show sm@(SM _ x) = "SM " ++ sizeStr sm ++ " "++ show (IM.toList x)
instance Functor SpMatrix where
fmap f (SM d md) = SM d ((fmap . fmap) f md)
instance Set SpMatrix where
liftU2 f2 (SM n1 x1) (SM n2 x2) = SM (maxTup n1 n2) ((liftU2.liftU2) f2 x1 x2)
liftI2 f2 (SM n1 x1) (SM n2 x2) = SM (minTup n1 n2) ((liftI2.liftI2) f2 x1 x2)
instance Additive SpMatrix where
zero = SM (0,0) IM.empty
(^+^) = liftU2 (+)
(^-^) = liftU2 ()
maxTup, minTup :: Ord t => (t, t) -> (t, t) -> (t, t)
maxTup (x1,y1) (x2,y2) = (max x1 x2, max y1 y2)
minTup (x1,y1) (x2,y2) = (min x1 x2, min y1 y2)
emptySpMatrix :: (Int, Int) -> SpMatrix a
emptySpMatrix d = SM d IM.empty
matScale :: Num a => a -> SpMatrix a -> SpMatrix a
matScale a = fmap (*a)
normFrobenius :: SpMatrix Double -> Double
normFrobenius m = sqrt $ foldlSM (+) 0 m' where
m' | nrows m > ncols m = transposeSM m ## m
| otherwise = m ## transposeSM m
type Rows = Int
type Cols = Int
type IxRow = Int
type IxCol = Int
validIxSM :: SpMatrix a -> (Int, Int) -> Bool
validIxSM mm = inBounds02 (dim mm)
isSquareSM :: SpMatrix a -> Bool
isSquareSM m = nrows m == ncols m
isDiagonalSM :: SpMatrix a -> Bool
isDiagonalSM m = IM.size d == nrows m where
d = IM.filterWithKey ff (immSM m)
ff irow row = IM.size row == 1 &&
IM.size (IM.filterWithKey (\j _ -> j == irow) row) == 1
immSM :: SpMatrix t -> IM.IntMap (IM.IntMap t)
immSM (SM _ imm) = imm
dimSM :: SpMatrix t -> (Rows, Cols)
dimSM (SM d _) = d
nelSM :: SpMatrix t -> Int
nelSM (SM (nr,nc) _) = nr*nc
nrows, ncols :: SpMatrix a -> Int
nrows = fst . dim
ncols = snd . dim
data SMInfo = SMInfo { smNz :: Int,
smSpy :: Double} deriving (Eq, Show)
infoSM :: SpMatrix a -> SMInfo
infoSM s = SMInfo (nzSM s) (spySM s)
nzSM :: SpMatrix a -> Int
nzSM s = sum $ fmap IM.size (immSM s)
spySM :: Fractional b => SpMatrix a -> b
spySM s = fromIntegral (nzSM s) / fromIntegral (nelSM s)
nzRowU :: SpMatrix a -> IM.Key -> Int
nzRowU s i = maybe 0 IM.size (IM.lookup i $ immSM s)
nzRow :: SpMatrix a -> IM.Key -> Int
nzRow s i | inBounds0 (nrows s) i = nzRowU s i
| otherwise = error "nzRow : index out of bounds"
bwMinSM :: SpMatrix a -> Int
bwMinSM = fst . bwBoundsSM
bwMaxSM :: SpMatrix a -> Int
bwMaxSM = snd . bwBoundsSM
bwBoundsSM :: SpMatrix a -> (Int, Int)
bwBoundsSM s =
(snd $ IM.findMin b,
snd $ IM.findMax b)
where
ss = immSM s
fmi = fst . IM.findMin
fma = fst . IM.findMax
b = fmap (\x -> fma x fmi x + 1:: Int) ss
zeroSM :: Int -> Int -> SpMatrix a
zeroSM m n = SM (m,n) IM.empty
insertSpMatrix :: IxRow -> IxCol -> a -> SpMatrix a -> SpMatrix a
insertSpMatrix i j x s
| inBounds02 d (i,j) = SM d $ insertIM2 i j x smd
| otherwise = error "insertSpMatrix : index out of bounds" where
smd = immSM s
d = dim s
fromListSM' :: Foldable t => t (IxRow, IxCol, a) -> SpMatrix a -> SpMatrix a
fromListSM' iix sm = foldl ins sm iix where
ins t (i,j,x) = insertSpMatrix i j x t
fromListSM :: Foldable t => (Int, Int) -> t (IxRow, IxCol, a) -> SpMatrix a
fromListSM (m,n) iix = fromListSM' iix (zeroSM m n)
fromListDenseSM :: Int -> [a] -> SpMatrix a
fromListDenseSM m ll = fromListSM (m, n) $ denseIxArray2 m ll where
n = length ll `div` m
toDenseListSM :: Num t => SpMatrix t -> [(IxRow, IxCol, t)]
toDenseListSM m =
[(i, j, m @@ (i, j)) | i <- [0 .. nrows m 1], j <- [0 .. ncols m 1]]
mkDiagonal :: Int -> [a] -> SpMatrix a
mkDiagonal n = mkSubDiagonal n 0
eye :: Num a => Int -> SpMatrix a
eye n = mkDiagonal n (ones n)
ones :: Num a => Int -> [a]
ones n = replicate n 1
mkSubDiagonal :: Int -> Int -> [a] -> SpMatrix a
mkSubDiagonal n o xx | abs o < n = if o >= 0
then fz ii jj xx
else fz jj ii xx
| otherwise = error "mkSubDiagonal : offset > dimension" where
ii = [0 .. n1]
jj = [abs o .. n 1]
fz a b x = fromListSM (n,n) (zip3 a b x)
encode :: (Int, Int) -> (Rows, Cols) -> Int
encode (nr,_) (i,j) = i + (j * nr)
decode :: (Int, Int) -> Int -> (Rows, Cols)
decode (nr, _) ci = (r, c) where (c,r ) = quotRem ci nr
extractSubmatrixSM :: SpMatrix a -> (Int, Int) -> (Int, Int) -> SpMatrix a
extractSubmatrixSM (SM (r, c) im) (i1, i2) (j1, j2)
| q = SM (m', n') imm'
| otherwise = error $ "extractSubmatrixSM : invalid indexing " ++ show (i1, i2) ++ ", " ++ show (j1, j2) where
imm' = mapKeysIM2 (\i -> i i1) (\j -> j j1) $
IM.filter (not . IM.null) $
ifilterIM2 ff im
ff i j _ = i1 <= i &&
i <= i2 &&
j1 <= j &&
j <= j2
(m', n') = (i2i1 + 1, j2j1 + 1)
q = inBounds0 r i1 &&
inBounds0 r i2 &&
inBounds0 c j1 &&
inBounds0 c j2 &&
i2 >= i1
extractRowSM :: SpMatrix a -> Int -> SpMatrix a
extractRowSM sm i = extractSubmatrixSM sm (i, i) (0, ncols sm 1)
extractColSM :: SpMatrix a -> Int -> SpMatrix a
extractColSM sm j = extractSubmatrixSM sm (0, nrows sm 1) (j, j)
toSV :: SpMatrix a -> SpVector a
toSV (SM (m,n) im) = SV d $ snd . head $ IM.toList im where
d | m==1 && n==1 = 1
| m==1 && n>1 = n
| n==1 && m>1 = m
| otherwise = error $ "toSV : incompatible dimensions " ++ show (m,n)
extractCol :: SpMatrix a -> Int -> SpVector a
extractCol m i = toSV $ extractColSM m i
extractRow :: SpMatrix a -> Int -> SpVector a
extractRow m j = toSV $ extractRowSM m j
vertStackSM, (-=-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
vertStackSM mm1 mm2 = SM (m, n) $ IM.union u1 u2 where
nro1 = nrows mm1
m = nro1 + nrows mm2
n = max (ncols mm1) (ncols mm2)
u1 = immSM mm1
u2 = IM.mapKeys (+ nro1) (immSM mm2)
(-=-) = vertStackSM
horizStackSM, (-||-) :: SpMatrix a -> SpMatrix a -> SpMatrix a
horizStackSM mm1 mm2 = t (t mm1 -=- t mm2) where
t = transposeSM
(-||-) = horizStackSM
lookupSM :: SpMatrix a -> IM.Key -> IM.Key -> Maybe a
lookupSM (SM _ im) i j = IM.lookup i im >>= IM.lookup j
lookupWD_SM, (@@) :: Num a => SpMatrix a -> (IM.Key, IM.Key) -> a
lookupWD_SM sm (i,j) =
fromMaybe 0 (lookupSM sm i j)
lookupWD_IM :: Num a => IM.IntMap (IM.IntMap a) -> (IM.Key, IM.Key) -> a
lookupWD_IM im (i,j) = fromMaybe 0 (IM.lookup i im >>= IM.lookup j)
(@@) = lookupWD_SM
foldlSM :: (a -> b -> b) -> b -> SpMatrix a -> b
foldlSM f n (SM _ m)= foldlIM2 f n m
ifoldlSM :: (IM.Key -> IM.Key -> a -> b -> b) -> b -> SpMatrix a -> b
ifoldlSM f n (SM _ m) = ifoldlIM2' f n m
countSubdiagonalNZSM :: SpMatrix a -> Int
countSubdiagonalNZSM (SM _ im) = countSubdiagonalNZ im
extractDiagonalDSM :: Num a => SpMatrix a -> SpVector a
extractDiagonalDSM mm = mkSpVector1D n $ foldr ins [] ll where
ll = [0 .. n 1]
n = nrows mm
ins i acc = mm@@(i,i) : acc
subdiagIndicesSM :: SpMatrix a -> [(IM.Key, IM.Key)]
subdiagIndicesSM (SM _ im) = subdiagIndices im
sparsifyIM2 :: IM.IntMap (IM.IntMap Double) -> IM.IntMap (IM.IntMap Double)
sparsifyIM2 = ifilterIM2 (\_ _ x -> abs x >= eps)
sparsifySM :: SpMatrix Double -> SpMatrix Double
sparsifySM (SM d im) = SM d $ sparsifyIM2 im
roundZeroOneSM :: SpMatrix Double -> SpMatrix Double
roundZeroOneSM (SM d im) = sparsifySM $ SM d $ mapIM2 roundZeroOne im
transposeSM, (#^) :: SpMatrix a -> SpMatrix a
transposeSM (SM (m, n) im) = SM (n, m) (transposeIM2 im)
(#^) = transposeSM
(#^#) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
a #^# b = transposeSM a #~# b
(##^) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
a ##^ b = a #~# transposeSM b
matVec, (#>) :: Num a => SpMatrix a -> SpVector a -> SpVector a
matVec (SM (nr, nc) mdata) (SV n sv)
| nc == n = SV nr $ fmap (`dot` sv) mdata
| otherwise = error $ "matVec : mismatching dimensions " ++ show (nc, n)
(#>) = matVec
vecMat, (<#) :: Num a => SpVector a -> SpMatrix a -> SpVector a
vecMat (SV n sv) (SM (nr, nc) mdata)
| n == nr = SV nc $ fmap (`dot` sv) (transposeIM2 mdata)
| otherwise = error $ "vecMat : mismatching dimensions " ++ show (n, nr)
(<#) = vecMat
matMatU :: Num a => SpMatrix a -> SpMatrix a -> SpMatrix a
matMatU m1 m2 =
SM (nrows m1, ncols m2) im where
im = fmap (\vm1 -> (`dot` vm1) <$> transposeIM2 (immSM m2)) (immSM m1)
matMat, (##) :: Num a => SpMatrix a -> SpMatrix a -> SpMatrix a
matMat m1 m2
| c1 == r2 = matMatU m1 m2
| otherwise = error $ "matMat : incompatible matrix sizes" ++ show (d1, d2) where
d1@(r1, c1) = dim m1
d2@(r2, c2) = dim m2
(##) = matMat
matMatSparsified, (#~#) :: SpMatrix Double -> SpMatrix Double -> SpMatrix Double
matMatSparsified m1 m2 = sparsifySM $ matMat m1 m2
(#~#) = matMatSparsified
isOrthogonalSM :: SpMatrix Double -> Bool
isOrthogonalSM sm@(SM (_,n) _) = rsm == eye n where
rsm = roundZeroOneSM $ transposeSM sm ## sm
conditionNumberSM :: SpMatrix Double -> Double
conditionNumberSM m | isInfinite kappa = error "Infinite condition number : rank-deficient system"
| otherwise = kappa where
kappa = lmax / lmin
(_, r) = qr m
u = extractDiagonalDSM r
lmax = abs (maximum u)
lmin = abs (minimum u)
hhMat :: Num a => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ scale beta (x >< x) where
n = dim x
hhRefl :: SpVector Double -> SpMatrix Double
hhRefl = hhMat 2.0
hypot :: Floating a => a -> a -> a
hypot x y = abs x * (sqrt (1 + y/x)**2)
sign :: (Ord a, Num a) => a -> a
sign x
| x > 0 = 1
| x == 0 = 0
| otherwise = 1
givensCoef :: (Ord a, Floating a) => a -> a -> (a, a, a)
givensCoef a b
| b==0 = (sign a, 0, abs a)
| a==0 = (0, sign b, abs b)
| abs a > abs b = let t = b/a
u = sign a * abs ( sqrt (1 + t**2))
in (1/u, t/u, a*u)
| otherwise = let t = a/b
u = sign b * abs ( sqrt (1 + t**2))
in (t/u, 1/u, b*u)
givens :: SpMatrix Double -> Int -> Int -> SpMatrix Double
givens mm i j
| validIxSM mm (i,j) && isSquareSM mm =
sparsifySM $ fromListSM' [(i,i,c),(j,j,c),(j,i,s),(i,j,s)] (eye (nrows mm))
| otherwise = error "givens : indices out of bounds"
where
(c, s, _) = givensCoef a b
i' = head $ fromMaybe (error $ "givens: no compatible rows for entry " ++ show (i,j)) (candidateRows (immSM mm) i j)
a = mm @@ (i', j)
b = mm @@ (i, j)
firstNonZeroColumn :: IM.IntMap a -> IM.Key -> Bool
firstNonZeroColumn mm k = isJust (IM.lookup k mm) &&
isNothing (IM.lookupLT k mm)
candidateRows :: IM.IntMap (IM.IntMap a) -> IM.Key -> IM.Key -> Maybe [IM.Key]
candidateRows mm i j | IM.null u = Nothing
| otherwise = Just (IM.keys u) where
u = IM.filterWithKey (\irow row -> irow /= i &&
firstNonZeroColumn row j) mm
qr :: SpMatrix Double -> (SpMatrix Double, SpMatrix Double)
qr mm = (transposeSM qmatt, rmat) where
qmatt = F.foldl' (#~#) ee $ gmats mm
rmat = qmatt #~# mm
ee = eye (nrows mm)
gmats :: SpMatrix Double -> [SpMatrix Double]
gmats mm = gm mm (subdiagIndicesSM mm) where
gm m ((i,j):is) = let g = givens m i j
in g : gm (g #~# m) is
gm _ [] = []
eigsQR :: Int -> SpMatrix Double -> SpVector Double
eigsQR nitermax m = extractDiagonalDSM $ execState (convergtest eigsStep) m where
eigsStep m = r #~# q where (q, r) = qr m
convergtest g = modifyInspectN nitermax f g where
f [m1, m2] = let dm1 = extractDiagonalDSM m1
dm2 = extractDiagonalDSM m2
in norm2 (dm1 ^-^ dm2) <= eps
rayleighStep ::
SpMatrix Double ->
(SpVector Double, Double) ->
(SpVector Double, Double)
rayleighStep aa (b, mu) = (b', mu') where
ii = eye (nrows aa)
nom = (aa ^-^ (mu `matScale` ii)) <\> b
b' = normalize 2 nom
mu' = b' `dot` (aa #> b') / (b' `dot` b')
eigRayleigh :: Int
-> SpMatrix Double
-> (SpVector Double, Double)
-> (SpVector Double, Double)
eigRayleigh nitermax m = execState (convergtest (rayleighStep m)) where
convergtest g = modifyInspectN nitermax f g where
f [(b1, _), (b2, _)] = norm2 (b2 ^-^ b1) <= eps
hhV :: SpVector Double -> (SpVector Double, Double)
hhV x = (v, beta) where
n = dim x
tx = tailSV x
sigma = tx `dot` tx
vtemp = singletonSV 1 `concatSV` tx
(v, beta) | sigma <= eps = (vtemp, 0)
| otherwise = let mu = sqrt (headSV x**2 + sigma)
xh = headSV x
vh | xh <= 1 = xh mu
| otherwise = sigma / (xh + mu)
vnew = (1 / vh) .* insertSpVector 0 vh vtemp
in (vnew, 2 * xh**2 / (sigma + vh**2))
eps :: Double
eps = 1e-8
residual :: Num a => SpMatrix a -> SpVector a -> SpVector a -> SpVector a
residual aa b x0 = b ^-^ (aa #> x0)
converged :: SpMatrix Double -> SpVector Double -> SpVector Double -> Bool
converged aa b x0 = normSq (residual aa b x0) <= eps
cgsStep :: SpMatrix Double -> SpVector Double -> CGS -> CGS
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
where
aap = aa #> p
alphaj = (r `dot` rhat) / (aap `dot` rhat)
q = u ^-^ (alphaj .* aap)
xj1 = x ^+^ (alphaj .* (u ^+^ q))
rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q)))
betaj = (rj1 `dot` rhat) / (r `dot` rhat)
uj1 = rj1 ^+^ (betaj .* q)
pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))
data CGS = CGS { _x :: SpVector Double,
_r :: SpVector Double,
_p :: SpVector Double,
_u :: SpVector Double } deriving Eq
cgs ::
SpMatrix Double ->
SpVector Double ->
SpVector Double ->
SpVector Double ->
CGS
cgs aa b x0 rhat =
execState (untilConverged _x (cgsStep aa rhat)) cgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
u0 = r0
cgsInit = CGS x0 r0 p0 u0
instance Show CGS where
show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n" ++
"u = " ++ show u ++ "\n"
bicgstabStep :: SpMatrix Double -> SpVector Double -> BICGSTAB -> BICGSTAB
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
aap = aa #> p
alphaj = (r `dot` r0hat) / (aap `dot` r0hat)
sj = r ^-^ (alphaj .* aap)
aasj = aa #> sj
omegaj = (aasj `dot` sj) / (aasj `dot` aasj)
xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)
rj1 = sj ^-^ (omegaj .* aasj)
betaj = (rj1 `dot` r0hat)/(r `dot` r0hat) * alphaj / omegaj
pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))
data BICGSTAB = BICGSTAB { _xBicgstab :: SpVector Double,
_rBicgstab :: SpVector Double,
_pBicgstab :: SpVector Double} deriving Eq
bicgstab
:: SpMatrix Double
-> SpVector Double
-> SpVector Double
-> SpVector Double
-> BICGSTAB
bicgstab aa b x0 r0hat =
execState (untilConverged _xBicgstab (bicgstabStep aa r0hat)) bicgsInit where
r0 = b ^-^ (aa #> x0)
p0 = r0
bicgsInit = BICGSTAB x0 r0 p0
instance Show BICGSTAB where
show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
"r = " ++ show r ++ "\n" ++
"p = " ++ show p ++ "\n"
data LinSolveMethod = CGS_ | BICGSTAB_ deriving (Eq, Show)
linSolveM ::
PrimMonad m =>
LinSolveMethod -> SpMatrix Double -> SpVector Double -> m (SpVector Double)
linSolveM method aa b = do
let (m, n) = dim aa
nb = dim b
if n /= nb then error "linSolve : operand dimensions mismatch" else do
x0 <- randVec nb
case method of CGS_ -> return $ _xBicgstab (bicgstab aa b x0 x0)
BICGSTAB_ -> return $ _x (cgs aa b x0 x0)
linSolve ::
LinSolveMethod -> SpMatrix Double -> SpVector Double -> SpVector Double
linSolve method aa b
| n /= nb = error "linSolve : operand dimensions mismatch"
| otherwise = solve aa b where
solve aa' b' | isDiagonalSM aa = (reciprocal aa') #> b'
| otherwise = solveWith aa' b'
solveWith aa' b' = case method of
CGS_ -> _xBicgstab (bicgstab aa' b' x0 x0)
BICGSTAB_ -> _x (cgs aa' b' x0 x0)
x0 = mkSpVectorD n $ replicate n 0.1
(m, n) = dim aa
nb = dim b
(<\>) :: SpMatrix Double -> SpVector Double -> SpVector Double
(<\>) = linSolve BICGSTAB_
sizeStr :: SpMatrix a -> String
sizeStr sm =
unwords ["(",show (nrows sm),"rows,",show (ncols sm),"columns ) ,",show nz,"NZ ( sparsity",show sy,")"] where
(SMInfo nz sy) = infoSM sm
showNonZero :: (Show a, Num a, Eq a) => a -> String
showNonZero x = if x == 0 then " " else show x
toDenseRow :: Num a => SpMatrix a -> IM.Key -> [a]
toDenseRow (SM (_,ncol) im) irow =
fmap (\icol -> im `lookupWD_IM` (irow,icol)) [0..ncol1]
toDenseRowClip :: (Show a, Num a) => SpMatrix a -> IM.Key -> Int -> String
toDenseRowClip sm irow ncomax
| ncols sm > ncomax = unwords (map show h) ++ " ... " ++ show t
| otherwise = show dr
where dr = toDenseRow sm irow
h = take (ncomax 2) dr
t = last dr
newline :: IO ()
newline = putStrLn ""
printDenseSM :: (Show t, Num t) => SpMatrix t -> IO ()
printDenseSM sm = do
newline
putStrLn $ sizeStr sm
newline
printDenseSM' sm 5 5
newline
where
printDenseSM' :: (Show t, Num t) => SpMatrix t -> Int -> Int -> IO ()
printDenseSM' sm'@(SM (nr,_) _) nromax ncomax = mapM_ putStrLn rr_' where
rr_ = map (\i -> toDenseRowClip sm' i ncomax) [0..nr 1]
rr_' | nrows sm > nromax = take (nromax 2) rr_ ++ [" ... "] ++[last rr_]
| otherwise = rr_
toDenseListClip :: (Show a, Num a) => SpVector a -> Int -> String
toDenseListClip sv ncomax
| dim sv > ncomax = unwords (map show h) ++ " ... " ++ show t
| otherwise = show dr
where dr = toDenseListSV sv
h = take (ncomax 2) dr
t = last dr
printDenseSV :: (Show t, Num t) => SpVector t -> IO ()
printDenseSV sv = do
newline
printDenseSV' sv 5
newline where
printDenseSV' v nco = putStrLn rr_' where
rr_ = toDenseListClip v nco :: String
rr_' | dim sv > nco = unwords [take (nco 2) rr_ , " ... " , [last rr_]]
| otherwise = rr_
class PrintDense a where
prd :: a -> IO ()
instance (Show a, Num a) => PrintDense (SpVector a) where
prd = printDenseSV
instance (Show a, Num a) => PrintDense (SpMatrix a) where
prd = printDenseSM
almostZero, almostOne :: Double -> Bool
almostZero x = abs x <= eps
almostOne x = x >= (1eps) && x < (1+eps)
withDefault :: (t -> Bool) -> t -> t -> t
withDefault q d x | q x = d
| otherwise = x
roundZero, roundOne :: Double -> Double
roundZero = withDefault almostZero 0
roundOne = withDefault almostOne 1
with2Defaults :: (t -> Bool) -> (t -> Bool) -> t -> t -> t -> t
with2Defaults q1 q2 d1 d2 x | q1 x = d1
| q2 x = d2
| otherwise = x
roundZeroOne :: Double -> Double
roundZeroOne = with2Defaults almostZero almostOne 0 1
modifyUntil :: MonadState s m => (s -> Bool) -> (s -> s) -> m s
modifyUntil q f = do
x <- get
let y = f x
put y
if q y then return y
else modifyUntil q f
loopUntilAcc :: Int -> ([t] -> Bool) -> (t -> t) -> t -> t
loopUntilAcc nitermax q f x = go 0 [] x where
go i ll xx | length ll < 2 = go (i + 1) (y : ll) y
| otherwise = if q ll || i == nitermax
then xx
else go (i + 1) (take 2 $ y:ll) y
where y = f xx
modifyInspectN ::
MonadState s m => Int -> ([s] -> Bool) -> (s -> s) -> m s
modifyInspectN nitermax q f
| nitermax > 0 = go 0 []
| otherwise = error "modifyInspectN : n must be > 0" where
go i ll = do
x <- get
let y = f x
if length ll < 2
then do put y
go (i + 1) (y : ll)
else if q ll || i == nitermax
then do put y
return y
else do put y
go (i + 1) (take 2 $ y : ll)
meanl :: (Foldable t, Fractional a) => t a -> a
meanl xx = 1/fromIntegral (length xx) * sum xx
norm2l :: (Foldable t, Functor t, Floating a) => t a -> a
norm2l xx = sqrt $ sum (fmap (**2) xx)
diffSqL xx = (x1 x2)**2 where [x1, x2] = [head xx, xx!!1]
untilConverged :: MonadState a m => (a -> SpVector Double) -> (a -> a) -> m a
untilConverged fproj = modifyInspectN 100 (normDiffConverged fproj)
normDiffConverged :: (Foldable t, Functor t) =>
(a -> SpVector Double) -> t a -> Bool
normDiffConverged fp xx = normSq (foldrMap fp (^-^) (zeroSV 0) xx) <= eps
runAppendN :: ([t] -> Bool) -> (t -> t) -> Int -> t -> [t]
runAppendN qq ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go qq ff niter x0 [] where
go q f n z xs =
let x = f z in
if n <= 0 || q xs then xs
else go q f (n1) x (x : xs)
runAppendN' :: (t -> t) -> Int -> t -> [t]
runAppendN' ff niter x0 | niter<0 = error "runAppendN : niter must be > 0"
| otherwise = go ff niter x0 [] where
go f n z xs =
let x = f z in
if n <= 0 then xs
else go f (n1) x (x : xs)
randMat :: PrimMonad m => Int -> m (SpMatrix Double)
randMat n = do
g <- MWC.create
aav <- replicateM (n^2) (MWC.normal 0 1 g)
let ii_ = [0 .. n1]
(ix_,iy_) = unzip $ concatMap (zip ii_ . replicate n) ii_
return $ fromListSM (n,n) $ zip3 ix_ iy_ aav
randVec :: PrimMonad m => Int -> m (SpVector Double)
randVec n = do
g <- MWC.create
bv <- replicateM n (MWC.normal 0 1 g)
let ii_ = [0..n1]
return $ fromListSV n $ zip ii_ bv
randSpMat :: Int -> Int -> IO (SpMatrix Double)
randSpMat n nsp | nsp > n^2 = error "randSpMat : nsp must be < n^2 "
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
jj <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSM (n,n) $ zip3 ii jj aav
randSpVec :: Int -> Int -> IO (SpVector Double)
randSpVec n nsp | nsp > n = error "randSpVec : nsp must be < n"
| otherwise = do
g <- MWC.create
aav <- replicateM nsp (MWC.normal 0 1 g)
ii <- replicateM nsp (MWC.uniformR (0, n1) g :: IO Int)
return $ fromListSV n $ zip ii aav
denseIxArray :: [b] -> [(Int, b)]
denseIxArray xs = zip [0..length xs1] xs
denseIxArray2 :: Int -> [c] -> [(Int, Int, c)]
denseIxArray2 m xs = zip3 (concat $ replicate n ii_) jj_ xs where
ii_ = [0 .. m1]
jj_ = concatMap (replicate m) [0 .. n1]
ln = length xs
n = ln `div` m
foldrMap :: (Foldable t, Functor t) => (a -> b) -> (b -> c -> c) -> c -> t a -> c
foldrMap ff gg x0 = foldr gg x0 . fmap ff
foldlStrict :: (a -> b -> a) -> a -> [b] -> a
foldlStrict f = go
where
go z [] = z
go z (x:xs) = let z' = f z x in z' `seq` go z' xs
ifoldr :: Num i =>
(a -> b -> b) -> b -> (i -> c -> d -> a) -> c -> [d] -> b
ifoldr mjoin mneutral f = go 0 where
go i z (x:xs) = mjoin (f i z x) (go (i+1) z xs)
go _ _ [] = mneutral
type LB = Int
type UB = Int
inBounds :: LB -> UB -> Int -> Bool
inBounds ibl ibu i = i>= ibl && i<ibu
inBounds2 :: (LB, UB) -> (Int, Int) -> Bool
inBounds2 (ibl,ibu) (ix,iy) = inBounds ibl ibu ix && inBounds ibl ibu iy
inBounds0 :: UB -> Int -> Bool
inBounds0 = inBounds 0
inBounds02 :: (UB, UB) -> (Int, Int) -> Bool
inBounds02 (bx,by) (i,j) = inBounds0 bx i && inBounds0 by j
tm0, tm1, tm2, tm3, tm4 :: SpMatrix Double
tm0 = fromListSM (2,2) [(0,0,pi), (1,0,sqrt 2), (0,1, exp 1), (1,1,sqrt 5)]
tv0, tv1 :: SpVector Double
tv0 = mkSpVectorD 2 [5, 6]
tv1 = SV 2 $ IM.singleton 0 1
tm1 = sparsifySM $ fromListDenseSM 3 [6,5,0,5,1,4,0,4,3]
tm1g1 = givens tm1 1 0
tm1a2 = tm1g1 ## tm1
tm1g2 = givens tm1a2 2 1
tm1a3 = tm1g2 ## tm1a2
tm1q = transposeSM (tm1g2 ## tm1g1)
tm2 = fromListDenseSM 3 [12, 6, 4, 51, 167, 24, 4, 68, 41]
tm3 = transposeSM $ fromListDenseSM 3 [1 .. 9]
tm3g1 = fromListDenseSM 3 [1, 0,0, 0,c,s, 0, s, c]
where c= 0.4961
s = 0.8682
tm4 = sparsifySM $ fromListDenseSM 4 [1,0,0,0,2,5,0,10,3,6,8,11,4,7,9,12]
untilC :: (a -> Bool) -> Int -> (a -> a) -> a -> a
untilC p n f = go n
where
go m x | p x || m <= 0 = x
| otherwise = True `seq` go (m1) (f x)