---------------------------------------------------------------------------------
-- |
-- Module      :  Math.ConjugateGradient
-- Copyright   :  (c) Levent Erkok
-- License     :  BSD3
-- Maintainer  :  erkokl@gmail.com
-- Stability   :  stable
--
-- (The linear equation solver library is hosted at <http://github.com/LeventErkok/conjugateGradient>.
-- Comments, bug reports, and patches are always welcome.)
--
-- Sparse matrix linear-equation solver, using the conjugate gradient algorithm. Note that the technique only
-- applies to matrices that are symmetric and positive definite. See <http://en.wikipedia.org/wiki/Conjugate_gradient_method>
-- for details.
--
--  The conjugate gradient method can handle very large sparse matrices, where direct
--  methods (such as LU decomposition) are way too expensive to be useful in practice.
--  Such large sparse matrices arise naturally in many engineering problems, such as
--  in ASIC placement algorithms and when solving partial differential equations.
--
-- Here's an example usage, for the simple system:
--
-- @
--       4x +  y = 1
--        x + 3y = 2
-- @
--
-- >>> import Data.IntMap.Strict
-- >>> import System.Random
-- >>> import Math.ConjugateGradient
-- >>> let a = SM (2, fromList [(0, SV (fromList [(0, 4), (1, 1)])), (1, SV (fromList [(0, 1), (1, 3)]))]) :: SM Double
-- >>> let b = SV (fromList [(0, 1), (1, 2)]) :: SV Double
-- >>> let g = mkStdGen 12345
-- >>> let x = solveCG g a b
-- >>> putStrLn $ showSolution 4 a b x
--       A       |   x    =   b   
-- --------------+----------------
-- 4.0000 1.0000 | 0.0909 = 1.0000
-- 1.0000 3.0000 | 0.6364 = 2.0000
---------------------------------------------------------------------------------

module Math.ConjugateGradient(
          -- * Types
            SV(..), SM(..)
          -- * Sparse operations
          , lookupSV, lookupSM, addSV, subSV, dotSV, normSV, sMulSV, sMulSM, mulSMV
          -- * Conjugate-Gradient solver
          , solveCG
          -- * Displaying solutions
          , showSolution
        ) where

import Data.List                          (intercalate)
import Data.Maybe                         (fromMaybe)
import qualified Data.IntMap        as IM (fold)
import qualified Data.IntMap.Strict as IM (IntMap, lookup, map, unionWith, intersectionWith, fromList)
import System.Random                      (Random, RandomGen, randomRs)
import Numeric                            (showFFloat)

-- | A sparse vector containing elements of type 'a'. Only the indices that contain non-@0@ elements should be given
-- for efficiency purposes. (Nothing will break if you put in elements that are @0@'s, it's just not as efficient.)
newtype SV a = SV (IM.IntMap a)

-- | A sparse matrix is essentially an int-map containing sparse row-vectors:
--
--     * The first element, @n@, is the number of rows in the matrix, including those with all @0@ elements.
--
--     * The matrix is implicitly assumed to be @nxn@, indexed by keys @(0, 0)@ to @(n-1, n-1)@.
--
--     * When constructing a sparse-matrix, only put in rows that have a non-@0@ element in them for efficiency.
--
--     * Note that you have to give all the non-0 elements: Even though the matrix must be symmetric for the algorithm
--       to work, the matrix should contain all the non-@0@ elements, not just the upper (or the lower)-triangle.
--
--     * Make sure the keys of the int-map is a subset of @[0 .. n-1]@, both for the row-indices and the indices of the vectors representing the sparse-rows.
newtype SM a = SM (Int, IM.IntMap (SV a))

---------------------------------------------------------------------------------
-- Sparse vector/matrix operations
---------------------------------------------------------------------------------

-- | Look-up a value in a sparse-vector.
lookupSV :: Num a => SV a -> Int -> a
lookupSV (SV v) k = fromMaybe 0 (k `IM.lookup` v)

-- | Look-up a value in a sparse-matrix.
lookupSM :: Num a => SM a -> (Int, Int) -> a
lookupSM (SM (_, m)) (i, j) = maybe 0 (`lookupSV` j) (i `IM.lookup` m)

-- | Multiply a sparse-vector by a scalar.
sMulSV :: Num a => a -> SV a -> SV a
sMulSV s (SV v) = SV (IM.map (s *) v)

-- | Multiply a sparse-matrix by a scalar.
sMulSM :: Num a => a -> SM a -> SM a
sMulSM s (SM (n, m)) = SM (n, IM.map (s `sMulSV`) m)

-- | Add two sparse vectors.
addSV :: Num a => SV a -> SV a -> SV a
addSV (SV v1) (SV v2) = SV (IM.unionWith (+) v1 v2)

-- | Subtract two sparse vectors.
subSV :: Num a => SV a -> SV a -> SV a
subSV v1 (SV v2) = addSV v1 (SV (IM.map ((-1)*) v2))

-- | Dot product of two sparse vectors.
dotSV :: Num a => SV a -> SV a -> a
dotSV (SV v1) (SV v2) = IM.fold (+) 0 $ IM.intersectionWith (*) v1 v2

-- | Multiply a sparse matrix (nxn) with a sparse vector (nx1), obtaining a sparse vector (nx1).
mulSMV :: Num a => SM a -> SV a -> SV a
mulSMV (SM (_, m)) v = SV (IM.map (`dotSV` v) m)

-- | Norm of a sparse vector. (Square-root of its dot-product with itself.)
normSV :: RealFloat a => SV a -> a
normSV (SV v) = sqrt . IM.fold (\e s -> e*e + s) 0 $ v

-- | Conjugate Gradient Solver for the system @Ax=b@. See: <http://en.wikipedia.org/wiki/Conjugate_gradient_method>.
--
-- NB. Assumptions on the input:
--
--    * The @A@ matrix is symmetric and positive definite.
--
--    * All non-@0@ rows are present. (Even if the input is assumed symmetric, all rows must be present.)
--
--    * The indices start from @0@ and go consecutively up-to @n-1@. (Only non-@0@ value/row
--      indices has to be present, of course.)
--
-- For efficiency reasons, we do not check that these properties hold of the input. (If these assumptions are
-- violated, the algorithm will still produce a result, but not the one you expected!)
--
-- We perform either @10^6@ iterations of the Conjugate-Gradient algorithm, or until the error
-- factor is less than @1e-10@. The error factor is defined as the difference of the norm of
-- the current solution from the last one, as we go through the iterative solver. See
-- <http://en.wikipedia.org/wiki/Conjugate_gradient_method#Convergence_properties_of_the_conjugate_gradient_method>
-- for a discussion on the convergence properties of this algorithm.
--
-- The solver can throw an error if it does not converge by @10^6@ iterations. This is typically an indication
-- that the input matrix is not well formed, i.e., not symmetric positive-definite.
solveCG :: (RandomGen g, RealFloat a, Random a)
        => g          -- ^ The seed for the random-number generator.
        -> SM a       -- ^ The @A@ sparse matrix (@nxn@).
        -> SV a       -- ^ The @b@ sparse vector (@nx1@).
        -> SV a       -- ^ The @x@ sparse matrix (@nx1@), such that @Ax = b@.
solveCG g a@(SM (n, _)) b = cg a b x0
  where rs = take n (randomRs (0, 1) g)
        x0 = SV $ IM.fromList [p | p@(_, j) <- zip [0..] rs, j /= 0]

-- | The Conjugate-gradient algorithm. Our implementation closely follows the
-- one given here: <http://en.wikipedia.org/wiki/Conjugate_gradient_method#Example_code_in_Matlab>
cg :: RealFloat a => SM a -> SV a -> SV a -> SV a
cg a b x0 = cgIter (1000000 :: Int) (norm r0) r0 r0 x0
 where r0 = b `subSV` (a `mulSMV` x0)
       cgIter 0 _   _ _ _ = error "Conjugate Gradient: No convergence after 10^6 iterations. Make sure the input matrix is symmetric positive-definite!"
       cgIter i eps p r x
        -- Stop if the square of the error is less than 1e-20, i.e.,
        -- if the error itself is less than 1e-10.
        | eps' < 1e-20 = x'
        | True         = cgIter (i-1) eps' p' r' x'
        where ap    = a `mulSMV` p
              alpha = eps / (ap `dotSV` p)
              x'    = x `addSV` (alpha `sMulSV` p)
              r'    = r `subSV` (alpha `sMulSV` ap)
              eps'  = norm r'
              p'    = r' `addSV` ((eps' / eps) `sMulSV` p)
       norm (SV v) = IM.fold (\e s -> e*e + s) 0 v -- square of normSV, but no need for expensive square-root

-- | Display a solution in a human-readable form. Needless to say, only use this
-- method when the system is small enough to fit nicely on the screen.
showSolution :: RealFloat a
             => Int   -- ^ Precision: Use this many digits after the decimal point.
             -> SM a  -- ^ The @A@ matrix, @nxn@
             -> SV a  -- ^ The @b@ matrix, @nx1@
             -> SV a  -- ^ The @x@ matrix, @nx1@, as returned by 'solveCG', for instance.
             -> String
showSolution prec ma@(SM (n, _)) vb vx = intercalate "\n" $ header ++ res
  where res   = zipWith3 row a x b
        range = [0..n-1]
        sf d = showFFloat (Just prec) d ""
        a = [[sf (ma `lookupSM` (i, j)) | j <- range] | i <- range]
        x = [sf (vx `lookupSV` i) | i <- range]
        b = [sf (vb `lookupSV` i) | i <- range]
        cellWidth = maximum (0 : map length (concat a ++ x ++ b))
        row as xv bv = unwords (map pad as) ++ " | " ++ pad xv ++ " = " ++ pad bv
        pad s  = reverse $ take (length s `max` cellWidth) $ reverse s ++ repeat ' '
        center l s = let extra         = l - length s
                         (left, right) = (extra `div` 2, extra - left)
                     in  replicate left ' ' ++ s ++ replicate right ' '
        header = case res of
                   []    -> ["Empty matrix"]
                   (r:_) -> let l = length (takeWhile (/= '|') r)
                                h =  center (l-1) "A"  ++ " | "
                                  ++ center cellWidth "x" ++ " = " ++ center cellWidth "b"
                                s = replicate l '-' ++ "+" ++ replicate (length r - l - 1) '-'
                            in [h, s]