{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts, TypeFamilies, MultiParamTypeClasses, FlexibleInstances  #-}
{-# language ApplicativeDo #-}
-- {-# OPTIONS_GHC -O2 -rtsopts -with-rtsopts=-K32m -prof#-}
{-|

This module exposes the high-level functionality of the library.

-}

module Numeric.LinearAlgebra.Sparse
       (
         -- * Linear solvers
         -- ** Iterative methods
         (<\>),
         -- ** Moore-Penrose pseudoinverse
         pinv,
         -- ** Preconditioners
         jacobiPre, ilu0, mSsor,         
         -- ** Direct methods
         luSolve,
         -- *** Forward substitution
         triLowerSolve,
         -- *** Backward substitution
         triUpperSolve,
         -- * Eigensolvers
         eigsQR,
         -- eigRayleigh,
         eigsArnoldi,
         -- * Matrix factorization algorithms
         -- ** QR
         qr,
         -- ** LU
         lu,
         -- ** Cholesky
         chol, 
         -- ** Arnoldi iteration
         arnoldi,    
         -- * Matrix partitioning
         diagPartitions,
         -- * Utilities
         -- ** Givens' rotation
         givens,
         -- ** Condition number
         conditionNumberSM,
         -- ** Householder reflection
         hhRefl,
         -- -- * Householder bidiagonalization         
         -- -- * Random arrays
         -- randArray,
         -- -- * Random matrices and vectors
         -- randMat, randVec, 
         -- -- ** Sparse "
         -- randSpMat, randSpVec,
         -- * Creation and conversion of sparse data
         -- ** From/to SpVector
         fromListSV, toListSV,
         -- ** From/to SpMatrix
         fromListSM, toListSM,
         -- ** Packing and unpacking, rows or columns of a sparse matrix
         -- *** ", using lists as container
         fromRowsL, toRowsL,
         fromColsL, toColsL,
         -- *** ", using Vector as container
         fromRowsV, fromColsV,
         -- * Operators
         (.*), (./), 
         -- ** Inner product
         (<.>),
         -- ** Matrix-vector product
         (#>), (<#),
         -- ** Matrix-matrix product
         (##), (#^#), (##^),
         -- *** Sparsifying matrix-matrix product
         (#~#), (#~^#), (#~#^),
         -- ** Vector outer product
         (><),
         -- * Common operations
         dim, nnz, spy,
         -- ** Vector-related
         cvx,
         -- *** Norms and normalization
         norm, norm2, norm2', normalize, normalize2, normalize2',
         norm1, hilbertDistSq,
         -- ** Matrix-related
         transpose, normFrobenius,
         -- * Pretty-printing
         prd, prd0,
         -- * Iteration combinators
         untilConvergedG0, untilConvergedG, untilConvergedGM,
         modifyInspectGuarded, modifyInspectGuardedM, IterationConfig (..),
         modifyUntil, modifyUntilM,
         -- * Internal
         linSolve0, LinSolveMethod(..)
       )
       where


import Control.Exception.Common
import Control.Iterative
import Data.Sparse.Common

import Control.Monad.Catch
import Data.Typeable

-- import Control.Applicative ((<|>))

-- import Control.Monad (replicateM)
import Control.Monad.State.Strict
-- import Control.Monad.Writer
-- import Control.Monad.Trans.Class
import qualified Control.Monad.Trans.State  as MTS -- (runStateT)
-- import Control.Monad.Trans.Writer (runWriterT)
import Data.Complex

import Data.VectorSpace hiding (magnitude)

import qualified Data.Sparse.Internal.IntM as I
-- import Data.Utils.StrictFold (foldlStrict) -- hidden in `containers`

-- import qualified System.Random.MWC as MWC
-- import qualified System.Random.MWC.Distributions as MWC

-- import Data.Monoid
-- import qualified Data.Foldable as F
-- import qualified Data.Traversable as T

-- import qualified Data.List as L

import Data.Maybe

import qualified Data.Vector as V



-- | A lumped constraint for numerical types
type Num' x = (Epsilon x, Elt x, Show x, Ord x, Typeable x)





-- * Matrix condition number

-- | Matrix condition number: computes the QR factorization and extracts the extremal eigenvalues from the R factor
conditionNumberSM :: (MonadThrow m, MatrixRing (SpMatrix a), Num' a, Typeable a) =>
     SpMatrix a -> m a
conditionNumberSM m = do
  (_, r) <- qr m
  let
   u = extractDiagDense r 
   lmax = abs (maximum u)
   lmin = abs (minimum u)
   kappa = lmax / lmin                     
  if nearZero lmin
  then throwM (HugeConditionNumber "conditionNumberSM" kappa)
  else return kappa 
                          





-- * Householder transformation

hhMat :: Num a => a -> SpVector a -> SpMatrix a
hhMat beta x = eye n ^-^ beta `scale` (x >< x) where
  n = dim x


-- | Householder reflection: a vector `x` uniquely defines an orthogonal (hyper)plane, i.e. an orthogonal subspace; the Householder operator reflects any point `v` through this subspace: v' = (I - 2 x >< x) v
hhRefl :: Num a => SpVector a -> SpMatrix a
hhRefl = hhMat 2











-- * Givens rotation matrix

-- -- -- | Givens coefficients (using stable algorithm shown in  Anderson, Edward (4 December 2000). "Discontinuous Plane Rotations and the Symmetric Eigenvalue Problem". LAPACK Working Note)
-- -- -- givensCoef0 :: (Ord b, Floating a, Eq a) => (a -> b) -> a -> a -> (a, a, a)
-- -- givensCoef0
-- --   :: (Signum t, Ord a, Floating t, Eq t) =>
-- --      (t -> a) -> t -> t -> (t, t, t)
-- givensCoef0 ff a b  -- returns (c, s, r) where r = norm (a, b)
--   | b==0 = (signum' a, 0, abs a)
--   | a==0 = (0, signum' b, abs b)
--   | ff a > ff b = let t = b/a
--                       u = signum' a * abs ( sqrt (1 + t**2))
--                   in (1/u, - t/u, a*u)
--   | otherwise = let t = a/b
--                     u = signum' b * abs ( sqrt (1 + t**2))
--                 in (t/u, - 1/u, b*u)

-- -- | Givens coefficients, real-valued
-- givensCoef :: (Floating a, Ord a) => a -> a -> (a, a, a)
-- givensCoef = givensCoef0 abs

-- -- | Givens coefficients, complex-valued
-- givensCoefC :: RealFloat a => Complex a -> Complex a -> (Complex a, Complex a, Complex a)
-- givensCoefC = givensCoef0 magnitude



                  
{- |
Givens method, row version: choose other row index i' s.t. i' is :
* below the diagonal
* corresponding element is nonzero

QR.C1 ) To zero out entry A(i, j) we must find row k such that A(k, j) is
non-zero but A has zeros in row k for all columns less than j.

NB: the current version is quite inefficient in that:
1. the Givens' matrix `G_i` is different from Identity only in 4 entries
2. at each iteration `i` we multiply `G_i` by the previous partial result `M`. Since this corresponds to a rotation, and the `givensCoef` function already computes the value of the resulting non-zero component (output `r`), `G_i ## M` can be simplified by just changing two entries of `M` (i.e. zeroing one out and changing the other into `r`).
-}
{-# inline givens #-}
givens :: (Elt a, MonadThrow m) => SpMatrix a -> Int -> Int -> m (SpMatrix a)
givens aa i j 
  | isValidIxSM aa (i,j) && nrows aa >= ncols aa = do
      i' <- candidateRows' (immSM aa) i j
      return $ givensMat aa i i' j
  | otherwise = throwM (OOBIxsError "Givens" [i, j])
  where
  givensMat mm i i' j =
    fromListSM'
      [(i,i, c), (j,j, conj c), (j,i, - conj s), (i,j, s)]
      (eye (nrows mm))
           where
             (c, s, _) = givensCoef a b
             a = mm @@ (i', j)
             b = mm @@ (i, j)   -- element to zero out
  candidateRows' mm i j | null u = throwM (OOBNoCompatRows "Givens" (i,j))
                        | otherwise = return $ head (I.keys u) where
    u = I.filterWithKey (\irow row -> irow /= i &&
                                      firstNZColumn row j) mm
    firstNZColumn m k = isJust (I.lookup k m) &&
                        isNothing (I.lookupLT k m)

rotMat :: Elt e => e -> e -> IxRow -> IxRow -> Int -> SpMatrix e
rotMat a b i j n =
  fromListSM' [(i,i, c), (j,j, conj c), (j,i, - conj s), (i,j, s)] (eye n)
    where
      (c, s, _) = givensCoef a b
    

-- | Givens coefficients and norm of associated vector
givensCoef :: Elt t => t -> t -> (t, t, t)
givensCoef a b = (c0/r, s0/r, r) where
  c0 = conj a
  s0 = - conj b
  r = hypot c0 s0  
  hypot x y = abs x * sqrt (1 + (y/x)**2)

{-

   ( c    s )
G =(        )
   ( -s*  c*)

-}


                          




-- * QR decomposition

-- | Given a matrix A, returns a pair of matrices (Q, R) such that Q R = A, where Q is orthogonal and R is upper triangular. Applies Givens rotation iteratively to zero out sub-diagonal elements.
{-# inline qr #-}
qr :: (Elt a, MatrixRing (SpMatrix a), Epsilon a, MonadThrow m) =>
     SpMatrix a
     -> m (SpMatrix a, SpMatrix a)  -- ^ Q, R
qr mm = do 
     (qt, r, _) <- MTS.execStateT (modifyUntilM haltf qrstepf) gminit
     return (transpose qt, r) 
  where
    gminit = (eye (nrows mm), mm, subdiagIndicesSM mm)
    haltf (_, _, iis) = null iis
    qrstepf (qmatt, m, iis) = do
        let (i, j) = head iis
        g <- givens m i j
        let
          qmatt' = g #~# qmatt  -- update Q'
          m' = g #~# m          -- update R
        return (qmatt', m', tail iis)

  








-- * Eigenvalue algorithms

-- ** QR algorithm

-- | `eigsQR n mm` performs `n` iterations of the QR algorithm on matrix `mm`, and returns a SpVector containing all eigenvalues
eigsQR :: (MonadThrow m, MonadIO m, Num' a, Normed (SpVector a), MatrixRing (SpMatrix a), Typeable (Magnitude (SpVector a))) =>
        Int
     -> Bool           -- ^ Print debug information        
     -> SpMatrix a     -- ^ Operand matrix
     -> m (SpVector a) -- ^ Eigenvalues {λ_i}
eigsQR nitermax debq m = pf <$> untilConvergedGM "eigsQR" c (const True) stepf m
  where
    pf = extractDiagDense
    c = IterConf nitermax debq pf prd
    stepf mm = do
      (q, _) <- qr mm
      return $ q #~^# (m ## q) -- r #~# q



  




-- ** Rayleigh iteration

-- | `eigsRayleigh n mm` performs `n` iterations of the Rayleigh algorithm on matrix `mm` and returns the eigenpair closest to the initialization. It displays cubic-order convergence, but it also requires an educated guess on the initial eigenpair.

eigRayleigh nitermax debq prntf m = untilConvergedGM "eigRayleigh" config (const True) (rayStep m)
  where
    ii = eye (nrows m)
    config = IterConf nitermax debq fst prntf
    rayStep aa (b, mu) = do
      nom <- (m ^-^ (mu `matScale` ii)) <\> b
      let b' = normalize2' nom
          mu' = (b' <.> (aa #> b')) / (b' <.> b')
      return (b', mu')


-- | `eigsArnoldi n aa b` computes at most n iterations of the Arnoldi algorithm to find a Krylov subspace of (A, b), denoted Q, along with a Hessenberg matrix of coefficients H. After that, it computes the QR decomposition of H, denoted (O, R) and the eigenvalues {λ_i} of A are listed on the diagonal of the R factor.
eigsArnoldi :: (Scalar (SpVector t) ~ t, MatrixType (SpVector t) ~ SpMatrix t,
      Elt t, V (SpVector t), MatrixRing (SpMatrix t), Epsilon t, MonadThrow m) =>
     Int
     -> SpMatrix t
     -> SpVector t
     -> m (SpMatrix t, SpMatrix t, SpVector t) -- ^ Q, O, {λ_i}
eigsArnoldi nitermax aa b = do
  (q, h) <- arnoldi aa b nitermax
  (o, r) <- qr h
  return (q, o, extractDiagDense r)
  



-- * Householder vector 

-- (Golub & Van Loan, Alg. 5.1.1, function `house`)
hhV :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t), Epsilon t) =>
     SpVector t -> (SpVector t, t)
hhV x = (v, beta) where
  tx = tailSV x
  sigma = tx <.> tx 
  vtemp = singletonSV 1 `concatSV` tx
  (v, beta) | nearZero sigma = (vtemp, 0)
            | otherwise = let mu = sqrt (headSV x**2 + sigma)
                              xh = headSV x
                              vh | mag xh <= 1 = xh - mu
                                 | otherwise = - sigma / (xh + mu)
                              vnew = (1 / vh) `scale` insertSpVector 0 vh vtemp     
                          in (vnew, 2 * xh**2 / (sigma + vh**2))

                         



-- -- * Bidiagonalization



-- -- * SVD

{- Golub & Van Loan, sec 8.6.2 (p 452 segg.)

SVD of A, Golub-Kahan method

* reduce A to upper bidiagonal form B (Alg. 5.4.2, Householder bidiagonalization)
* compute SVD of B (implicit-shift QR step applied to B^T B, Alg. 8.3.2)

-}








-- * Cholesky factorization

-- | Given a positive semidefinite matrix A, returns a lower-triangular matrix L such that L L^T = A . This is an implementation of the Cholesky–Banachiewicz algorithm, i.e. proceeding row by row from the upper-left corner.
-- | NB: The algorithm throws an exception if some diagonal element of A is zero.
chol :: (Elt a, Epsilon a, MonadThrow m) =>
        SpMatrix a
     -> m (SpMatrix a)  -- ^ L
chol aa = do
  let n = nrows aa
      q (i, _) = i == n
  l0 <- cholUpd aa (0, zeroSM n n)
  (_, lfin) <- MTS.execStateT (modifyUntilM q (cholUpd aa)) l0
  return lfin
  where
   oops i = throwM (NeedsPivoting "chol" (unwords ["L", show (i,i)]) :: MatrixException Double)
   cholUpd aa (i, ll) = do 
     sd <- cholSDRowUpd aa ll i  -- update subdiagonal entries
     ll' <- cholDiagUpd aa sd i  -- update diagonal entries
     return (i + 1, ll')
   cholSDRowUpd aa ll i = do 
     lrs <- fromListSV (i + 1) <$> onRangeSparseM cholSubDiag [0 .. i-1]
     return $ insertRow ll lrs i where
       cholSubDiag j | isNz ljj = return $ 1/ljj*(aij - inn)
                     | otherwise = oops j
        where
          ljj = ll @@! (j, j)
          aij = aa @@! (i, j)
          inn = contractSub ll ll i j (j - 1)
   cholDiagUpd aa ll i = do
     cd <- cholDiag 
     return $ insertSpMatrix i i cd ll where
       cholDiag | i == 0 = sqrt <$> aai
                | otherwise = do
                    a <- aai
                    let l = sum (fmap (**2) lrow)
                    return $ sqrt (a - l)
         where
          lrow = ifilterSV (\j _ -> j < i) (extractRow ll i) -- sub-diagonal elems of L
          aai | isNz aaii = return aaii
              | otherwise = oops i
            where
             aaii = aa @@! (i,i)
                 
















-- * LU factorization

-- | Given a matrix A, returns a pair of matrices (L, U) where L is lower triangular and U is upper triangular such that L U = A . Implements the Doolittle algorithm, which sets the diagonal of the L matrix to ones and expects all diagonal entries of A to be nonzero. Apply pivoting (row or column permutation) to enforce a nonzero diagonal of the A matrix (the algorithm throws an appropriate exception otherwise).
lu :: (Scalar (SpVector t) ~ t, Elt t, VectorSpace (SpVector t), Epsilon t,
        MonadThrow m) =>
     SpMatrix t
     -> m (SpMatrix t, SpMatrix t) -- ^ L, U
lu aa = do
  let oops j = throwM (NeedsPivoting "solveForLij" ("U" ++ show (j, j)) :: MatrixException Double)
      n = nrows aa
      q (i, _, _) = i == n - 1
      luInit | isNz u00 = return (1, l0, u0)
             | otherwise = oops (0 :: Int)
        where
          l0 = insertCol (eye n) ((extractSubCol aa 0 (1, n-1)) ./ u00 ) 0
          u0 = insertRow (zeroSM n n) (extractRow aa 0) 0   -- initial U
          u00 = u0 @@! (0,0)  -- make sure this is non-zero by applying permutation
      luUpd (i, l, u) = do -- (i + 1, l', u') 
        u' <- uUpd aa n (i, l, u)  -- update U
        l' <- lUpd (i, l, u') -- update L
        return (i + 1, l', u')
      uUpd aa n (ix, lmat, umat) = do
        let us = onRangeSparse (solveForUij ix) [ix .. n - 1]
            solveForUij i j = a - p where
              a = aa @@! (i, j)
              p = contractSub lmat umat i j (i - 1)
        return $ insertRow umat (fromListSV n us) ix
      lUpd (ix, lmat, umat) = do -- insertCol lmat (fromListSV n ls) ix
        ls <- lsm
        return $ insertCol lmat (fromListSV n ls) ix
        where
          lsm = onRangeSparseM (`solveForLij` ix) [ix + 1 .. n - 1]
          solveForLij i j
            | isNz ujj = return $ (a - p)/ujj
            | otherwise = oops j
            where
             a = aa @@! (i, j)
             ujj = umat @@! (j , j)   -- NB this must be /= 0
             p = contractSub lmat umat i j (i - 1)    
  s0 <- luInit
  (ixf, lf, uf) <- MTS.execStateT (modifyUntilM q luUpd) s0
  ufin <- uUpd aa n (ixf, lf, uf)   -- final U update
  return (lf, ufin)
    






         







-- -- Produces the permutation matrix necessary to have a nonzero in position (iref, jref). This is used in the LU factorization
-- permutAA :: Num b => IxRow -> IxCol -> SpMatrix a -> Maybe (SpMatrix b)
-- permutAA iref jref (SM (nro,_) mm) 
--   | isJust (lookupIM2 iref jref mm) = Nothing -- eye nro
--   | otherwise = Just $ permutationSM nro [head u] where
--       u = IM.keys (ifilterIM2 ff mm)
--       ff i j _ = i /= iref &&
--                  j == jref
















-- * Arnoldi iteration

-- | Given a matrix A, a vector b and a positive integer `n`, this procedure finds the basis of an order `n` Krylov subspace (as the columns of matrix Q), along with an upper Hessenberg matrix H, such that A = Q^T H Q.
-- At the i`th iteration, it finds (i + 1) coefficients (the i`th column of the Hessenberg matrix H) and the (i + 1)`th Krylov vector.
arnoldi :: (MatrixType (SpVector a) ~ SpMatrix a, V (SpVector a) ,
            Scalar (SpVector a) ~ a, Epsilon a, MonadThrow m) =>
     SpMatrix a                    -- ^ System matrix
     -> SpVector a                 -- ^ r.h.s.
     -> Int                        -- ^ Max. # of iterations
     -> m (SpMatrix a, SpMatrix a) -- ^ Q, H
arnoldi aa b kn | n == nb = return (fromColsV qvfin, fromListSM (nmax + 1, nmax) hhfin)
                | otherwise = throwM (MatVecSizeMismatchException "arnoldi" (m,n) nb)
  where
  (qvfin, hhfin, nmax, _) = execState (modifyUntil tf arnoldiStep) arnInit 
  tf (_, _, ii, fbreak) = ii == kn || fbreak -- termination criterion
  (m, n) = (nrows aa, ncols aa)
  nb = dim b
  arnInit = (qv1, hh1, 1, False) where
      q0 = normalize2 b   -- starting basis vector
      aq0 = aa #> q0       -- A q0
      h11 = q0 `dot` aq0          
      q1nn = aq0 ^-^ (h11 .* q0)
      hh1 = V.fromList [(0, 0, h11), (1, 0, h21)] where        
        h21 = norm2' q1nn
      q1 = normalize2 q1nn       -- q1 `dot` q0 ~ 0
      qv1 = V.fromList [q0, q1]
  arnoldiStep (qv, hh, i, _) = (qv', hh', i + 1, breakf) where
    qi = V.last qv
    aqi = aa #> qi
    hhcoli = fmap (`dot` aqi) qv -- H_{1, i}, H_{2, i}, .. , H_{m + 1, i}
    zv = zeroSV m
    qipnn =
      aqi ^-^ V.foldl' (^+^) zv (V.zipWith (.*) hhcoli qv) -- unnormalized q_{i+1}
    qipnorm = norm2' qipnn      -- normalization factor H_{i+1, i}
    qip = normalize2 qipnn              -- q_{i + 1}
    hh' = (V.++) hh (indexed2 $ V.snoc hhcoli qipnorm) where -- update H
      indexed2 v = V.zip3 ii jj v
      ii = V.fromList [0 .. n]    -- nth col of upper Hessenberg has `n+1` nz
      jj = V.replicate (n + 1) i  -- `n+1` replicas of `i`
    qv' = V.snoc qv qip        -- append q_{i+1} to Krylov basis Q_i
    breakf | nearZero qipnorm = True  -- breakdown condition
           | otherwise = False







-- * Preconditioning

-- | Partition a matrix into strictly subdiagonal, diagonal and strictly superdiagonal parts
diagPartitions :: SpMatrix a
               -> (SpMatrix a, SpMatrix a, SpMatrix a) -- ^ Subdiagonal, diagonal, superdiagonal partitions
diagPartitions aa = (e,d,f) where
  e = extractSubDiag aa
  d = extractDiag aa
  f = extractSuperDiag aa


-- ** Jacobi preconditioner

-- | The Jacobi preconditioner is just the reciprocal of the diagonal 
jacobiPre :: Fractional a => SpMatrix a -> SpMatrix a
jacobiPre x = recip <$> extractDiag x


-- ** Incomplete LU

-- | Used for Incomplete LU : remove entries in the output matrix corresponding to zero entries in the input matrix (this is called ILU(0) in the preconditioner literature)
ilu0 :: (Scalar (SpVector t) ~ t, Elt t, VectorSpace (SpVector t),
      Epsilon t, MonadThrow m) =>
     SpMatrix t
     -> m (SpMatrix t, SpMatrix t) -- ^ L, U (with holes)
ilu0 aa = do
  (l, u) <- lu aa
  let lh = sparsifyLU l aa
      uh = sparsifyLU u aa
      sparsifyLU m m2 = ifilterSM f m where
        f i j _ = isJust (lookupSM m2 i j)
  return (lh, uh)


-- ** SSOR

-- | Symmetric Successive Over-Relaxation. `mSsor aa omega` : if `omega = 1` it returns the symmetric Gauss-Seidel preconditioner. When ω = 1, the SOR reduces to Gauss-Seidel; when ω > 1 and ω < 1, it corresponds to over-relaxation and under-relaxation, respectively.
mSsor :: (MatrixRing (SpMatrix b), Fractional b) =>
     SpMatrix b
     -> b  -- ^ relaxation factor
     -> (SpMatrix b, SpMatrix b) -- ^ Left, right factors
mSsor aa omega = (l, r) where
  (e, d, f) = diagPartitions aa
  n = nrows e
  l = (eye n ^-^ scale omega e) ## reciprocal d
  r = d ^-^ scale omega f 








-- * Linear solver, LU-based

-- | Direct solver based on a triangular factorization of the system matrix.
luSolve :: (Scalar (SpVector t) ~ t, MonadThrow m, Elt t, InnerSpace (SpVector t), Epsilon t) =>
     SpMatrix t    -- ^ Lower triangular
     -> SpMatrix t -- ^ Upper triangular
     -> SpVector t -- ^ r.h.s.
     -> m (SpVector t)
luSolve ll uu b
  | isLowerTriSM ll && isUpperTriSM uu = do
      w <- triLowerSolve ll b
      triUpperSolve uu w
  | otherwise = throwM (NonTriangularException "luSolve")

-- | Forward substitution solver
triLowerSolve :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
      Epsilon t, MonadThrow m) =>
     SpMatrix t    -- ^ Lower triangular
     -> SpVector t -- ^ r.h.s
     -> m (SpVector t)
triLowerSolve ll b = do
  let q (_, i) = i == nb
      nb = svDim b
      oops i = throwM (NeedsPivoting "triLowerSolve" (unwords ["L", show (i, i)]) :: MatrixException Double)
      lStep (ww, i) = do -- (ww', i + 1) where
        let
          lii = ll @@ (i, i)
          bi = b @@ i
          wi = (bi - r)/lii where
            r = extractSubRow ll i (0, i-1) `dot` takeSV i ww
        if isNz lii
          then return (insertSpVector i wi ww, i + 1)
          else oops i
      lInit = do -- (ww0, 1) where
        let
          l00 = ll @@ (0, 0)
          b0 = b @@ 0
          w0 = b0 / l00
        if isNz l00
          then return (insertSpVector 0 w0 $ zeroSV (dim b), 1)
          else oops (0 :: Int)
  l0 <- lInit             
  (v, _) <- MTS.execStateT (modifyUntilM q lStep) l0
  return $ sparsifySV v



-- NB in the computation of `xi` we must rebalance the subrow indices (extractSubRow_RK) because `dropSV` does that too, in order to take the inner product with consistent index pairs
-- | Backward substitution solver
triUpperSolve :: (Scalar (SpVector t) ~ t, Elt t, InnerSpace (SpVector t),
      Epsilon t, MonadThrow m) =>
     SpMatrix t    -- ^ Upper triangular
     -> SpVector t -- ^ r.h.s
     -> m (SpVector t)
triUpperSolve uu w = do 
  let q (_, i) = i == (- 1)
      nw = svDim w
      oops i = throwM (NeedsPivoting "triUpperSolve" (unwords ["U", show (i, i)]) :: MatrixException Double)
      uStep (xx, i) = do
        let uii = uu @@ (i, i) 
            wi = w @@ i
            r = extractSubRow_RK uu i (i + 1, nw - 1) `dot` dropSV (i + 1) xx  
            xi = (wi - r) / uii
        if isNz uii
          then return (insertSpVector i xi xx, i - 1)
          else oops i 
      uInit = do 
        let i = nw - 1
            u00 = uu @@! (i, i)
            w0 = w @@ i
            x0 = w0 / u00
        if isNz u00
          then return (insertSpVector i x0 (zeroSV nw), i - 1)
          else oops (0 :: Int)
  u0 <- uInit             
  (x, _) <- MTS.execStateT (modifyUntilM q uStep) u0
  return $ sparsifySV x      





    





-- * Iterative linear solvers


-- ** GMRES

-- | Given a linear system `A x = b` where `A` is an (m x m) real-valued matrix, the GMRES method finds an approximate solution `xhat` such that the Euclidean norm of the residual `A xhat - b` is minimized. `xhat` is spanned by the order-`n` Krylov subspace of (A, b).
-- In this implementation:
-- 1) the Arnoldi iteration is carried out until numerical breakdown (therefore yielding _at_most_ `m+1` Krylov basis vectors)
-- 2) the resulting Hessenberg matrix H is factorized in QR form (H = Q R)
-- 3) the Krylov-subspace solution `yhat` is found by backsubstitution (since R is upper-triangular)
-- 4) the approximate solution in the original space `xhat` is computed using the Krylov basis, `xhat = Q_n yhat`
--
-- Many optimizations are possible, for example interleaving the QR factorization (and the subsequent triangular solve) with the Arnoldi process (and employing an updating QR factorization which only requires one Givens' rotation at every update). 

gmres :: (Scalar (SpVector t) ~ t, MatrixType (SpVector t) ~ SpMatrix t,
      Elt t, Normed (SpVector t), LinearVectorSpace (SpVector t), Epsilon t,
      MonadThrow m) =>
     SpMatrix t -> SpVector t -> m (SpVector t)
gmres aa b = do
  let m = ncols aa
  (qa, ha) <- arnoldi aa b m   -- at most m steps of Arnoldi (aa, b)
    -- b' = (transposeSe qa) #> b
  let b' = norm2' b .* ei mp1 1  -- b rotated back to canonical basis by Q^T
        where mp1 = nrows ha     -- = 1 + (# Arnoldi iterations)
  (qh, rh) <- qr ha
  let rhs' = takeSV (dim b' - 1) (transpose qh #> b')
      rh' = takeRows (nrows rh - 1) rh -- last row of `rh` is 0
  yhat <- triUpperSolve rh' rhs'
  let qa' = takeCols (ncols qa - 1) qa  -- we don't use last column of Krylov basis   
  return $ qa' #> yhat



  
-- ** CGNE

data CGNE a =
  CGNE {_xCgne , _rCgne, _pCgne :: SpVector a} deriving Eq
instance Show a => Show (CGNE a) where
    show (CGNE x r p) = "x = " ++ show x ++ "\n" ++
                        "r = " ++ show r ++ "\n" ++
                        "p = " ++ show p ++ "\n"

cgneInit :: (MatrixType (SpVector a) ~ SpMatrix a,
      LinearVectorSpace (SpVector a)) =>
     SpMatrix a -> SpVector a -> SpVector a -> CGNE a
cgneInit aa b x0 = CGNE x0 r0 p0 where
  r0 = b ^-^ (aa #> x0)    -- residual of initial guess solution
  p0 = transposeSM aa #> r0

cgneStep :: (MatrixType (SpVector a) ~ SpMatrix a,
      LinearVectorSpace (SpVector a), InnerSpace (SpVector a),
      Fractional (Scalar (SpVector a))) =>
     SpMatrix a -> CGNE a -> CGNE a
cgneStep aa (CGNE x r p) = CGNE x1 r1 p1 where
    alphai = (r `dot` r) / (p `dot` p)
    x1 = x ^+^ (alphai .* p)
    r1 = r ^-^ (alphai .* (aa #> p))
    beta = (r1 `dot` r1) / (r `dot` r)
    p1 = transpose aa #> r ^+^ (beta .* p)





-- ** BCG

data BCG a =
  BCG { _xBcg, _rBcg, _rHatBcg, _pBcg, _pHatBcg :: SpVector a } deriving Eq

bcgInit :: LinearVectorSpace (SpVector a) =>
     MatrixType (SpVector a) -> SpVector a -> SpVector a -> BCG a
bcgInit aa b x0 = BCG x0 r0 r0hat p0 p0hat where
  r0 = b ^-^ (aa #> x0)    
  r0hat = r0
  p0 = r0
  p0hat = r0

bcgStep :: (MatrixType (SpVector a) ~ SpMatrix a,
      LinearVectorSpace (SpVector a), InnerSpace (SpVector a),
      Fractional (Scalar (SpVector a))) =>
     SpMatrix a -> BCG a -> BCG a 
bcgStep aa (BCG x r rhat p phat) = BCG x1 r1 rhat1 p1 phat1 where
    aap = aa #> p
    alpha = (r `dot` rhat) / (aap `dot` phat)
    x1 = x ^+^ (alpha .* p)
    r1 = r ^-^ (alpha .* aap)
    rhat1 = rhat ^-^ (alpha .* (transpose aa #> phat))
    beta = (r1 `dot` rhat1) / (r `dot` rhat)
    p1 = r1 ^+^ (beta .* p)
    phat1 = rhat1 ^+^ (beta .* phat)

instance Show a => Show (BCG a) where
  show (BCG x r rhat p phat) = "x = " ++ show x ++ "\n" ++
                       "r = " ++ show r ++ "\n" ++
                       "r_hat = " ++ show rhat ++ "\n" ++
                       "p = " ++ show p ++ "\n" ++
                       "p_hat = " ++ show phat ++ "\n"


-- ** CGS

data CGS a = CGS { _x, _r, _p, _u :: SpVector a} deriving Eq

cgsInit :: LinearVectorSpace (SpVector a) =>
     MatrixType (SpVector a) -> SpVector a -> SpVector a -> CGS a
cgsInit aa b x0 = CGS x0 r0 r0 r0 where
  r0 = b ^-^ (aa #> x0)    -- residual of initial guess solution

cgsStep :: (V (SpVector a), Fractional (Scalar (SpVector a))) =>
     MatrixType (SpVector a) -> SpVector a -> CGS a -> CGS a
cgsStep aa rhat (CGS x r p u) = CGS xj1 rj1 pj1 uj1
    where
    aap = aa #> p
    alphaj = (r `dot` rhat) / (aap `dot` rhat)
    q = u ^-^ (alphaj .* aap)
    xj1 = x ^+^ (alphaj .* (u ^+^ q))         -- updated solution
    rj1 = r ^-^ (alphaj .* (aa #> (u ^+^ q))) -- updated residual
    betaj = (rj1 `dot` rhat) / (r `dot` rhat)
    uj1 = rj1 ^+^ (betaj .* q)
    pj1 = uj1 ^+^ (betaj .* (q ^+^ (betaj .* p)))


instance (Show a) => Show (CGS a) where
  show (CGS x r p u) = "x = " ++ show x ++ "\n" ++
                                "r = " ++ show r ++ "\n" ++
                                "p = " ++ show p ++ "\n" ++
                                "u = " ++ show u ++ "\n"





-- ** BiCGSTAB

data BICGSTAB a =
  BICGSTAB { _xBicgstab, _rBicgstab, _pBicgstab :: SpVector a} deriving Eq

bicgsInit :: LinearVectorSpace (SpVector a) =>
     MatrixType (SpVector a) -> SpVector a -> SpVector a -> BICGSTAB a
bicgsInit aa b x0 = BICGSTAB x0 r0 r0 where
  r0 = b ^-^ (aa #> x0)   -- residual of initial guess solution

bicgstabStep :: (V (SpVector a), Fractional (Scalar (SpVector a))) =>
     MatrixType (SpVector a) -> SpVector a -> BICGSTAB a -> BICGSTAB a
bicgstabStep aa r0hat (BICGSTAB x r p) = BICGSTAB xj1 rj1 pj1 where
     aap = aa #> p
     alphaj = (r <.> r0hat) / (aap <.> r0hat)
     sj = r ^-^ (alphaj .* aap)
     aasj = aa #> sj
     omegaj = (aasj <.> sj) / (aasj <.> aasj)
     xj1 = x ^+^ (alphaj .* p) ^+^ (omegaj .* sj)    -- updated solution
     rj1 = sj ^-^ (omegaj .* aasj)
     betaj = (rj1 <.> r0hat)/(r <.> r0hat) * alphaj / omegaj
     pj1 = rj1 ^+^ (betaj .* (p ^-^ (omegaj .* aap)))

instance Show a => Show (BICGSTAB a) where
  show (BICGSTAB x r p) = "x = " ++ show x ++ "\n" ++
                          "r = " ++ show r ++ "\n" ++
                          "p = " ++ show p ++ "\n"






-- * Moore-Penrose pseudoinverse
-- | Least-squares approximation of a rectangular system of equations.
pinv :: (MatrixType v ~ SpMatrix a, LinearSystem v, Epsilon a,
         MonadThrow m, MonadIO m) =>
     SpMatrix a -> v -> m v
pinv aa b = aa #~^# aa <\> atb where
  atb = transpose aa #> b




-- | Interface method to individual linear solvers, do not use directly
linSolve0 method aa b x0
  | m /= nb = throwM (MatVecSizeMismatchException "linSolve0" dm nb)
  | otherwise = solve aa b where
     solve aa' b' | isDiagonalSM aa' = return $ reciprocal aa' #> b' -- diagonal solve
                  | otherwise = xHat
     xHat = case method of
       BICGSTAB_ -> solver "BICGSTAB" nits _xBicgstab (bicgstabStep aa r0hat) (bicgsInit aa b x0)
       BCG_ -> solver "BCG" nits _xBcg (bcgStep aa) (bcgInit aa b x0)
       CGS_ -> solver "CGS" nits _x  (cgsStep aa r0hat) (cgsInit aa b x0)
       GMRES_ -> gmres aa b
       CGNE_ -> solver "CGNE" nits _xCgne (cgneStep aa) (cgneInit aa b x0)
     r0hat = b ^-^ (aa #> x0)
     nits = 200
     dm@(m,n) = dim aa
     nb = dim b
     solver fname nitermax fproj stepf initf = do
      xf <- untilConvergedG fname config (const True) stepf initf
      return $ fproj xf
      where
        config = IterConf nitermax True fproj prd0
  

-- * Linear solver interface

-- -- TFQMR is in LinearSolvers.Experimental for now
data LinSolveMethod = GMRES_ | CGNE_ | BCG_ | CGS_ | BICGSTAB_ deriving (Eq, Show) 




-- -- -- | linSolve using the GMRES method as default
instance LinearSystem (SpVector Double) where
  aa <\> b = linSolve0 GMRES_ aa b (mkSpVR n $ replicate n 0.1)
    where n = ncols aa

instance LinearSystem (SpVector (Complex Double)) where
  aa <\> b = linSolve0 GMRES_ aa b (mkSpVC n $ replicate n 0.1)
    where n = ncols aa











-- | TODO : if system is poorly conditioned, is it better to warn the user or just switch solvers (e.g. via the pseudoinverse) ?

-- linSolveQR aa b init f1 stepf
--   | isInfinite k = do
--        tell "linSolveQR : rank-deficient system"
--   | otherwise = do
--        solv aa b init
--     where
--      (q, r) = qr aa
--      k = conditionNumberSM r
--      solv aa b init = execState (untilConverged f1 stepf) init



  





-- test data

-- aa4 : eigenvalues 1 (mult.=2) and -1
aa4 :: SpMatrix Double
aa4 = fromListDenseSM 3 [3,2,-2,2,2,-1,6,5,-4] 

aa4c :: SpMatrix (Complex Double)
aa4c = toC <$> aa4








-- -- * Random arrays

-- randArray :: PrimMonad m => Int -> Double -> Double -> m [Double]
-- randArray n mu sig = do
--   g <- MWC.create
--   replicateM n (MWC.normal mu sig g)
  



-- -- * Random matrices and vectors

-- -- |Dense SpMatrix
-- randMat :: PrimMonad m => Int -> m (SpMatrix Double)
-- randMat n = do
--   g <- MWC.create
--   aav <- replicateM (n^2) (MWC.normal 0 1 g)
--   let ii_ = [0 .. n-1]
--       (ix_,iy_) = unzip $ concatMap (zip ii_ . replicate n) ii_
--   return $ fromListSM (n,n) $ zip3 ix_ iy_ aav

-- -- | Dense SpVector  
-- randVec :: PrimMonad m => Int -> m (SpVector Double)
-- randVec n = do
--   g <- MWC.create
--   bv <- replicateM n (MWC.normal 0 1 g)
--   let ii_ = [0..n-1]
--   return $ fromListSV n $ zip ii_ bv



-- -- | Sparse SpMatrix
-- randSpMat :: Int -> Int -> IO (SpMatrix Double)
-- randSpMat n nsp | nsp > n^2 = error "randSpMat : nsp must be < n^2 "
--                 | otherwise = do
--   g <- MWC.create
--   aav <- replicateM nsp (MWC.normal 0 1 g)
--   ii <- replicateM nsp (MWC.uniformR (0, n-1) g :: IO Int)
--   jj <- replicateM nsp (MWC.uniformR (0, n-1) g :: IO Int)
--   return $ fromListSM (n,n) $ zip3 ii jj aav

-- -- | Sparse SpVector
-- randSpVec :: Int -> Int -> IO (SpVector Double)
-- randSpVec n nsp | nsp > n = error "randSpVec : nsp must be < n"
--                 | otherwise = do
--   g <- MWC.create
--   aav <- replicateM nsp (MWC.normal 0 1 g)
--   ii <- replicateM nsp (MWC.uniformR (0, n-1) g :: IO Int)
--   return $ fromListSV n $ zip ii aav