{-# LANGUAGE FlexibleContexts, FlexibleInstances #-} {-# LANGUAGE RecordWildCards #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Internal.CG( cgSolve, cgSolve', CGState(..), R, V ) where import Internal.Vector import Internal.Matrix import Internal.Numeric import Internal.Element import Internal.IO import Internal.Container import Internal.Sparse import Numeric.Vector() import Internal.Algorithms(linearSolveLS, linearSolve, relativeError, pnorm, NormType(..)) import Control.Arrow((***)) {- import Util.Misc(debug, debugMat) (//) :: Show a => a -> String -> a infix 0 // -- , /// a // b = debug b id a (///) :: V -> String -> V infix 0 /// v /// b = debugMat b 2 asRow v -} type V = Vector R data CGState = CGState { cgp :: Vector R -- ^ conjugate gradient , cgr :: Vector R -- ^ residual , cgr2 :: R -- ^ squared norm of residual , cgx :: Vector R -- ^ current solution , cgdx :: R -- ^ normalized size of correction } cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState cg sym at a (CGState p r r2 x _) = CGState p' r' r'2 x' rdx where ap1 = a p ap | sym = ap1 | otherwise = at ap1 pap | sym = p <.> ap1 | otherwise = norm2 ap1 ** 2 alpha = r2 / pap dx = scale alpha p x' = x + dx r' = r - scale alpha ap r'2 = r' <.> r' beta = r'2 / r2 p' = r' + scale beta p rdx = norm2 dx / max 1 (norm2 x) conjugrad :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState] conjugrad sym a b = solveG sym (tr a !#>) (a !#>) (cg sym) b solveG :: Bool -> (V -> V) -> (V -> V) -> ((V -> V) -> (V -> V) -> CGState -> CGState) -> V -> V -> R -> R -> [CGState] solveG sym mat ma meth rawb x0' ϵb ϵx = takeUntil ok . iterate (meth mat ma) $ CGState p0 r0 r20 x0 1 where a = if sym then ma else mat . ma b = if sym then rawb else mat rawb x0 = if x0' == 0 then konst 0 (dim b) else x0' r0 = b - a x0 r20 = r0 <.> r0 p0 = r0 nb2 = b <.> b ok CGState {..} = cgr2 Bool) -> [a] -> [a] takeUntil q xs = a++ take 1 b where (a,b) = break q xs -- | Solve a sparse linear system using the conjugate gradient method with default parameters. cgSolve :: Bool -- ^ is symmetric -> GMatrix -- ^ coefficient matrix -> Vector R -- ^ right-hand side -> Vector R -- ^ solution cgSolve sym a b = cgx $ last $ cgSolve' sym 1E-4 1E-3 n a b 0 where n = max 10 (round $ sqrt (fromIntegral (dim b) :: Double)) -- | Solve a sparse linear system using the conjugate gradient method with default parameters. cgSolve' :: Bool -- ^ symmetric -> R -- ^ relative tolerance for the residual (e.g. 1E-4) -> R -- ^ relative tolerance for δx (e.g. 1E-3) -> Int -- ^ maximum number of iterations -> GMatrix -- ^ coefficient matrix -> Vector R -- ^ initial solution -> Vector R -- ^ right-hand side -> [CGState] -- ^ solution cgSolve' sym er es n a b x = take n $ conjugrad sym a b x er es -------------------------------------------------------------------------------- instance Testable GMatrix where checkT _ = (ok,info) where sma = convo2 20 3 x1 = vect [1..20] x2 = vect [1..40] sm = mkSparse sma dm = toDense sma s1 = sm !#> x1 d1 = dm #> x1 s2 = tr sm !#> x2 d2 = tr dm #> x2 sdia = mkDiagR 40 20 (vect [1..10]) s3 = sdia !#> x1 s4 = tr sdia !#> x2 ddia = diagRect 0 (vect [1..10]) 40 20 d3 = ddia #> x1 d4 = tr ddia #> x2 v = testb 40 s5 = cgSolve False sm v d5 = denseSolve dm v symassoc = [((0,0),1.0),((1,1),2.0),((0,1),0.5),((1,0),0.5)] b = vect [3,4] d6 = flatten $ linearSolve (toDense symassoc) (asColumn b) s6 = cgSolve True (mkSparse symassoc) b info = do print sm disp (toDense sma) print s1; print d1 print s2; print d2 print s3; print d3 print s4; print d4 print s5; print d5 print $ relativeError (pnorm Infinity) s5 d5 print s6; print d6 print $ relativeError (pnorm Infinity) s6 d6 ok = s1==d1 && s2==d2 && s3==d3 && s4==d4 && relativeError (pnorm Infinity) s5 d5 < 1E-10 && relativeError (pnorm Infinity) s6 d6 < 1E-10 disp = putStr . dispf 2 vect = fromList :: [Double] -> Vector Double convomat :: Int -> Int -> AssocMatrix convomat n k = [ ((i,j `mod` n),1) | i<-[0..n-1], j <- [i..i+k-1]] convo2 :: Int -> Int -> AssocMatrix convo2 n k = m1 ++ m2 where m1 = convomat n k m2 = map (((+n) *** id) *** id) m1 testb n = vect $ take n $ cycle ([0..10]++[9,8..1]) denseSolve a = flatten . linearSolveLS a . asColumn -- mkDiag v = mkDiagR (dim v) (dim v) v