module Numeric.LinearAlgebra.Algorithms (
linearSolve,
inv, pinv,
pinvTol, det, rank, rcond,
svd,
full, economy, --thin,
eig, eigSH,
qr,
chol,
hess,
schur,
lu,
expm,
sqrtm,
matFunc,
nullspacePrec,
nullVector,
Normed(..), NormType(..),
ctrans,
eps, i,
haussholder,
unpackQR, unpackHess,
Field(linearSolveSVD,eigSH',cholSH)
) where
import Data.Packed.Internal hiding (fromComplex, toComplex, comp, conj, (//))
import Data.Packed
import qualified Numeric.GSL.Matrix as GSL
import Numeric.GSL.Vector
import Numeric.LinearAlgebra.LAPACK as LAPACK
import Complex
import Numeric.LinearAlgebra.Linear
import Data.List(foldl1')
import Data.Array
class (Normed (Matrix t), Linear Vector t, Linear Matrix t) => Field t where
svd :: Matrix t -> (Matrix t, Vector Double, Matrix t)
luPacked :: Matrix t -> (Matrix t, [Int])
linearSolve :: Matrix t -> Matrix t -> Matrix t
linearSolveSVD :: Matrix t -> Matrix t -> Matrix t
eig :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
eigSH' :: Matrix t -> (Vector Double, Matrix t)
cholSH :: Matrix t -> Matrix t
qr :: Matrix t -> (Matrix t, Matrix t)
hess :: Matrix t -> (Matrix t, Matrix t)
schur :: Matrix t -> (Matrix t, Matrix t)
ctrans :: Matrix t -> Matrix t
instance Field Double where
svd = svdR
luPacked = luR
linearSolve = linearSolveR
linearSolveSVD = linearSolveSVDR Nothing
ctrans = trans
eig = eigR
eigSH' = eigS
cholSH = cholS
qr = GSL.unpackQR . qrR
hess = unpackHess hessR
schur = schurR
instance Field (Complex Double) where
svd = svdC
luPacked = luC
linearSolve = linearSolveC
linearSolveSVD = linearSolveSVDC Nothing
ctrans = conj . trans
eig = eigC
eigSH' = eigH
cholSH = cholH
qr = unpackQR . qrC
hess = unpackHess hessC
schur = schurC
eigSH :: Field t => Matrix t -> (Vector Double, Matrix t)
eigSH m | m `equal` ctrans m = eigSH' m
| otherwise = error "eigSH requires complex hermitian or real symmetric matrix"
chol :: Field t => Matrix t -> Matrix t
chol m | m `equal` ctrans m = cholSH m
| otherwise = error "chol requires positive definite complex hermitian or real symmetric matrix"
square m = rows m == cols m
det :: Field t => Matrix t -> t
det m | square m = s * (product $ toList $ takeDiag $ lu)
| otherwise = error "det of nonsquare matrix"
where (lu,perm) = luPacked m
s = signlp (rows m) perm
lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t)
lu = luFact . luPacked
inv :: Field t => Matrix t -> Matrix t
inv m | square m = m `linearSolve` ident (rows m)
| otherwise = error "inv of nonsquare matrix"
pinv :: Field t => Matrix t -> Matrix t
pinv m = linearSolveSVD m (ident (rows m))
full :: Element t
=> (Matrix t -> (Matrix t, Vector Double, Matrix t)) -> Matrix t -> (Matrix t, Matrix Double, Matrix t)
full svd' m = (u, d ,v) where
(u,s,v) = svd' m
d = diagRect s r c
r = rows m
c = cols m
economy :: Element t
=> (Matrix t -> (Matrix t, Vector Double, Matrix t)) -> Matrix t -> (Matrix t, Vector Double, Matrix t)
economy svd' m = (u', subVector 0 d s, v') where
(u,s,v) = svd' m
sl@(g:_) = toList s
s' = fromList . filter (>tol) $ sl
t = 1
tol = (fromIntegral (max r c) * g * t * eps)
r = rows m
c = cols m
d = dim s'
u' = takeColumns d u
v' = takeColumns d v
eps :: Double
eps = 2.22044604925031e-16
i :: Complex Double
i = 0:+1
mXm :: (Num t, Field t) => Matrix t -> Matrix t -> Matrix t
mXm = multiply
mXv :: (Num t, Field t) => Matrix t -> Vector t -> Vector t
mXv m v = flatten $ m `mXm` (asColumn v)
vXm :: (Num t, Field t) => Vector t -> Matrix t -> Vector t
vXm v m = flatten $ (asRow v) `mXm` m
norm2 :: Vector Double -> Double
norm2 = toScalarR Norm2
norm1 :: Vector Double -> Double
norm1 = toScalarR AbsSum
data NormType = Infinity | PNorm1 | PNorm2
pnormRV PNorm2 = norm2
pnormRV PNorm1 = norm1
pnormRV Infinity = vectorMax . vectorMapR Abs
pnormCV PNorm2 = norm2 . asReal
pnormCV PNorm1 = norm1 . liftVector magnitude
pnormCV Infinity = vectorMax . liftVector magnitude
pnormRM PNorm2 m = head (toList s) where (_,s,_) = svdR m
pnormRM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (vectorMapR Abs) m
pnormRM Infinity m = vectorMax $ liftMatrix (vectorMapR Abs) m `mXv` constant 1 (cols m)
pnormCM PNorm2 m = head (toList s) where (_,s,_) = svdC m
pnormCM PNorm1 m = vectorMax $ constant 1 (rows m) `vXm` liftMatrix (liftVector magnitude) m
pnormCM Infinity m = vectorMax $ liftMatrix (liftVector magnitude) m `mXv` constant 1 (cols m)
class Normed t where
pnorm :: NormType -> t -> Double
instance Normed (Vector Double) where
pnorm = pnormRV
instance Normed (Vector (Complex Double)) where
pnorm = pnormCV
instance Normed (Matrix Double) where
pnorm = pnormRM
instance Normed (Matrix (Complex Double)) where
pnorm = pnormCM
nullspacePrec :: Field t
=> Double
-> Matrix t
-> [Vector t]
nullspacePrec t m = ns where
(_,s,v) = svd m
sl@(g:_) = toList s
tol = (fromIntegral (max (rows m) (cols m)) * g * t * eps)
rank' = length (filter (> g*tol) sl)
ns = drop rank' $ toRows $ ctrans v
nullVector :: Field t => Matrix t -> Vector t
nullVector = last . nullspacePrec 1
pinvTol t m = v' `mXm` diag s' `mXm` trans u' where
(u,s,v) = svdR m
sl@(g:_) = toList s
s' = fromList . map rec $ sl
rec x = if x < g*tol then 1 else 1/x
tol = (fromIntegral (max r c) * g * t * eps)
r = rows m
c = cols m
d = dim s
u' = takeColumns d u
v' = takeColumns d v
haussholder :: (Field a) => a -> Vector a -> Matrix a
haussholder tau v = ident (dim v) `sub` (tau `scale` (w `mXm` ctrans w))
where w = asColumn v
zh k v = fromList $ replicate (k1) 0 ++ (1:drop k xs)
where xs = toList v
zt 0 v = v
zt k v = join [subVector 0 (dim v k) v, constant 0 k]
unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t)
unpackQR (pq, tau) = (q,r)
where cs = toColumns pq
m = rows pq
n = cols pq
mn = min m n
r = fromColumns $ zipWith zt ([m1, m2 .. 1] ++ repeat 0) cs
vs = zipWith zh [1..mn] cs
hs = zipWith haussholder (toList tau) vs
q = foldl1' mXm hs
unpackHess :: (Field t) => (Matrix t -> (Matrix t,Vector t)) -> Matrix t -> (Matrix t, Matrix t)
unpackHess hf m
| rows m == 1 = ((1><1)[1],m)
| otherwise = (uH . hf) m
uH (pq, tau) = (p,h)
where cs = toColumns pq
m = rows pq
n = cols pq
mn = min m n
h = fromColumns $ zipWith zt ([m2, m3 .. 1] ++ repeat 0) cs
vs = zipWith zh [2..mn] cs
hs = zipWith haussholder (toList tau) vs
p = foldl1' mXm hs
rcond :: Field t => Matrix t -> Double
rcond m = last s / head s
where (_,s',_) = svd m
s = toList s'
rank :: Field t => Matrix t -> Int
rank m | pnorm PNorm1 m < eps = 0
| otherwise = dim s where (_,s,_) = economy svd m
diagonalize m = if rank v == n
then Just (l,v)
else Nothing
where n = rows m
(l,v) = if m `equal` ctrans m
then let (l',v') = eigSH m in (real l', v')
else eig m
matFunc :: Field t => (Complex Double -> Complex Double) -> Matrix t -> Matrix (Complex Double)
matFunc f m = case diagonalize (complex m) of
Just (l,v) -> v `mXm` diag (liftVector f l) `mXm` inv v
Nothing -> error "Sorry, matFunc requires a diagonalizable matrix"
golubeps :: Integer -> Integer -> Double
golubeps p q = a * fromIntegral b / fromIntegral c where
a = 2^^(3pq)
b = fact p * fact q
c = fact (p+q) * fact (p+q+1)
fact n = product [1..n]
epslist = [ (fromIntegral k, golubeps k k) | k <- [1..]]
geps delta = head [ k | (k,g) <- epslist, g<delta]
expGolub m = iterate msq f !! j
where j = max 0 $ floor $ log2 $ pnorm Infinity m
log2 x = log x / log 2
a = m */ fromIntegral ((2::Int)^j)
q = geps eps
eye = ident (rows m)
work (k,c,x,n,d) = (k',c',x',n',d')
where k' = k+1
c' = c * fromIntegral (qk+1) / fromIntegral ((2*qk+1)*k)
x' = a <> x
n' = n |+| (c' .* x')
d' = d |+| (((1)^k * c') .* x')
(_,_,_,nf,df) = iterate work (1,1,eye,eye,eye) !! q
f = linearSolve df nf
msq x = x <> x
(<>) = multiply
v */ x = scale (recip x) v
(.*) = scale
(|+|) = add
expm :: Field t => Matrix t -> Matrix t
expm = expGolub
sqrtm :: Field t => Matrix t -> Matrix t
sqrtm = sqrtmInv
sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x))
where fixedPoint (a:b:rest) | pnorm PNorm1 (fst a |-| fst b) < eps = a
| otherwise = fixedPoint (b:rest)
fixedPoint _ = error "fixedpoint with impossible inputs"
f (y,z) = (0.5 .* (y |+| inv z),
0.5 .* (inv y |+| z))
(.*) = scale
(|+|) = add
(|-|) = sub
signlp r vals = foldl f 1 (zip [0..r1] vals)
where f s (a,b) | a /= b = s
| otherwise = s
swap (arr,s) (a,b) | a /= b = (arr // [(a, arr!b),(b,arr!a)],s)
| otherwise = (arr,s)
fixPerm r vals = (fromColumns $ elems res, sign)
where v = [0..r1]
s = toColumns (ident r)
(res,sign) = foldl swap (listArray (0,r1) s, 1) (zip v vals)
triang r c h v = reshape c $ fromList [el i j | i<-[0..r1], j<-[0..c1]]
where el i j = if ji>=h then v else 1 v
luFact (lu,perm) | r <= c = (l ,u ,p, s)
| otherwise = (l',u',p, s)
where
r = rows lu
c = cols lu
tu = triang r c 0 1
tl = triang r c 0 0
l = takeColumns r (lu |*| tl) |+| diagRect (constant 1 r) r r
u = lu |*| tu
(p,s) = fixPerm r perm
l' = (lu |*| tl) |+| diagRect (constant 1 c) r c
u' = takeRows c (lu |*| tu)
(|+|) = add
(|*|) = mul