{-# LANGUAGE FlexibleContexts #-}
{-# language TypeFamilies, FlexibleInstances, DeriveFunctor #-}
{-# language DeriveFoldable, DeriveTraversable #-}
module Data.Sparse.Internal.TriMatrix where

-- import qualified Data.Map.Strict as M
import qualified Data.IntMap.Strict as IM
import Data.IntMap.Strict ((!))
-- import qualified Data.Set as S
-- import qualified Data.Vector as V

import Data.Foldable (foldrM)
import Data.List (sort)
import Data.Maybe (fromMaybe)
-- import Data.Monoid
import Data.Complex
import qualified Data.Graph as G
import qualified Data.Tree as T

import Numeric.Eps
import Data.Sparse.Types
-- import Data.Sparse.Utils
import Data.Sparse.Internal.SList

import Numeric.LinearAlgebra.Class
-- import Data.Sparse.Internal.CSC
import Data.Sparse.Internal.SVector
import qualified Data.Sparse.Internal.SVector.Mutable as SMV
import Data.Sparse.SpMatrix (fromListSM, fromListDenseSM, insertSpMatrix, zeroSM, transposeSM, sparsifySM)
import Data.Sparse.Common (prd, prd0, (@@!), nrows, ncols, lookupSM, extractRow, extractCol, SpVector, SpMatrix, foldlWithKeySV, (##), (#~#))
import Control.Iterative (IterationConfig(IterConf), modifyUntilM, modifyUntilM')
import Data.Sparse.PPrint

import Control.Monad.Catch (MonadThrow, throwM)
import Control.Exception.Common

import Control.Monad (when)
import Control.Monad.IO.Class
import Control.Monad.Trans.State -- (execStateT, get, put, modify)
import Control.Monad.Primitive
-- import Control.Monad.State (MonadState())

import qualified Data.Vector as V (Vector, freeze, thaw, fromList, toList, zip)
import qualified Data.Vector.Mutable as VM (MVector, new, set, write, read, clone)



flattenForest :: T.Forest a -> [a]
flattenForest = concatMap T.flatten

-- | Given a lower triangular system L x = b, finds the nonzero entries of the solution vector x as the set of reachable nodes from the r.h.s. via the graph G(L^T). Node indices are _sorted_ afterwards, for a total complexity of O(N)
reachableFromRHS :: G.Graph -> V.Vector Int -> V.Vector Int
reachableFromRHS g vs = V.fromList . sort . flattenForest $ G.dfs (G.transposeG g) (V.toList vs)



-- triLowerSolve ll b = do
--   initializeSoln xinz b
--   where
--   -- xinz : nonzeros of solution vector x obtained from reachable nodes of b via G(L^T)
--   xinz = reachableFromRHS (cscToGraph ll) (svIx b)

-- -- tlUpdateColumn lldiag llsubdiag x j = undefined where
-- --   xj' = xj / lldiag

-- | Build the initial solution vector `x` for the triangular solve, given a vector of nonzero indices `ixnz` and the right hand side of the linear system `bb`:
-- -- 1. Initialize an empty mutable vector of length `length ixnz` to 0
-- -- 2. Copy the entries from `bb` to `x`.
--
-- Note: this assumes that the index set of `bb` is strictly contained within that of `x` (which is true for the case of the triangular solve, see `reachableFromRHS`)
initializeSoln ::
  (PrimMonad m, Num a) => V.Vector Int -> SVector a -> m (SMV.SMVector m a)
initializeSoln ixnz (SV n ixb b) = do
  let nnzx = length ixnz 
  xm <- VM.new nnzx
  ixnzm <- V.thaw ixnz
  VM.set xm 0
  SMV.fromListOverwrite ixnzm xm n $ V.toList (V.zip ixb b)




tlUpdateSubdiag :: VectorSpace v => v -> Scalar v -> v -> v
tlUpdateSubdiag lsubdiag xj x = x ^-^ (xj .* lsubdiag)
   


  



g0 = G.buildG (0,2) [(0,0),(2,0),(1,1),(2,2)]

t0 :: T.Forest Int
t0 = G.dfs (G.transposeG g0) [0]

g1 = G.buildG (0, 4) [(0,0), (1,0), (1,1), (2,0), (2,2), (3,1), (3,2), (3,3), (4,0), (4,4)]



{- | triangular sparse matrix, row-major order

Intmap-of-sparse lists
* fast random access of rows
* fast consing of row elements
-}

newtype TriMatrix a = TM { unTM :: IM.IntMap (SList a) } deriving (Show, Functor)

emptyIMSL :: Int -> IM.IntMap (SList a)
emptyIMSL n = IM.fromList [(i, emptySL) | i <- [0 .. n-1]]

emptyTM :: Int -> TriMatrix a
emptyTM n = TM (emptyIMSL n)

-- | `appendIM i x im` appends an element `x` to the i'th SList in an IntMap-of-SLists structure
appendIM :: IM.Key -> (Int, a) -> IM.IntMap (SList a) -> IM.IntMap (SList a)
appendIM i x im = IM.insert i (x `consSL` e) im where
  e = fromMaybe emptySL (IM.lookup i im)


-- | Nested lookup with default value = 0
lookupWD :: Num a =>
     (irow -> mat -> Maybe row)    -- ^ row lookup
     -> (jcol -> row -> Maybe a)   -- ^ in-row lookup
     -> mat                
     -> irow
     -> jcol
     -> a
lookupWD rlu clu aa i j = fromMaybe 0 (rlu i aa >>= clu j)

 







{- | LU factorization : store L and U^T in TriMatrix format -}


lu :: (Scalar (SpVector t) ~ t, Elt t, AdditiveGroup t, VectorSpace (SpVector t),
      MonadThrow m, MonadIO m, PrintDense (SpMatrix t),
      Epsilon t) =>
     SpMatrix t -> m (SpMatrix t, SpMatrix t) -- ^ L, U
lu amat = do
  let d@(m,n) = (nrows amat, ncols amat)
      q (_, _, i) = i == m    -- stopping criterion
      luInit = (lmat0, umat0, 1) where
         urow0 = extractRow amat 0                 -- first row of U
         lcol0 = extractCol amat 0 ./ (urow0 @@ 0) -- first col of L, div by U00
         umat0 = foldlWithKeySV ins (emptyIMSL n) urow0 -- populate umat0
         lmat0 = IM.insert 0 (SL [(0, 1)]) l0 where     -- populate lmat0
           l0 = foldlWithKeySV ins (emptyIMSL m) lcol0 
         ins acc i x = appendIM i (0, x) acc
      luStep (lmat, umat, i) = do
          let (umat', uii) = uStep amat lmat umat i  -- new U
          when (nearZero uii) $
             throwM (NeedsPivoting "LU" (unwords ["U", show (i,i)]) :: MatrixException Double)
          let lmat' = lStep amat lmat umat' uii i  -- new L
          return (lmat', umat', i + 1)             
  -- (lfin, ufin, _) <- execStateT (modifyUntilM q luStep) luInit
  (lfin, ufin, _) <- modifyUntilM' (luConfig d) q luStep luInit  
  let uu = fillSM d True ufin
      ll = fillSM d False lfin
  return (ll, uu)


luConfig d = IterConf 0 True vf prf where
        vf (l, u, _) = (l, u)
        prf (l, u) = do
          prd0 $ fillSM d False l
          prd0 $ fillSM d True u




uStep :: (Elt a, Epsilon a, AdditiveGroup a) =>
     SpMatrix a
     -> IM.IntMap (SList a)
     -> IM.IntMap (SList a)
     -> IM.Key
     -> (IM.IntMap (SList a), a)   -- ^ updated U, i'th diagonal element Uii
uStep amat lmat umat i = (foldr ins umat [i .. n-1], udiag) where
  n = ncols amat
  udiag = amat@@!(i,i) - (li <.> (umat ! i)) -- i'th diag element of U
  li = lmat ! i                            -- i'th row of L
  ins j acc
      | i == j   = appendIM j (i, udiag) acc
      | isNz uij = appendIM j (i, uij) acc
      | otherwise = acc where
    uij = aij - li <.> uj 
    aij = amat @@! (i,j)
    uj = umat ! j
  

lStep :: (Elt a, Epsilon a, AdditiveGroup a) =>
     SpMatrix a
     -> IM.IntMap (SList a)
     -> IM.IntMap (SList a)
     -> a                   -- ^ diagonal element of U (must be nonzero)
     -> IM.Key
     -> IM.IntMap (SList a) -- ^ updated L
lStep amat lmat umat udiag j = foldr ins lmat [j .. m-1] where
  m = nrows amat
  uj = umat ! j
  ins i acc
    | i == j   = appendIM i (j, 1) acc  -- write 1 on the diagonal 
    | isNz lij = appendIM i (j, lij) acc
    | otherwise = acc where
    lij = (aij - li <.> uj)/udiag
    aij = amat @@! (i,j)
    li = lmat ! i




fillSM :: (Rows, Cols) -> Bool -> IM.IntMap (SList a) -> SpMatrix a
fillSM (m,n) transpq tm = IM.foldlWithKey rowIns (zeroSM m n) tm where
  rowIns accRow i row = foldr ins accRow (unSL row) where
    ins (j, x) acc | transpq = insertSpMatrix j i x acc   -- transposed fill
                   | otherwise = insertSpMatrix i j x acc



-- test data

test mm = do
  (l, u) <- lu mm
  prd l
  prd u
  prd mm
  prd $ l #~# u



tm0, tm2, tm4, tm9 :: SpMatrix Double

tm0 = fromListDenseSM 2 [1,3,2,4]

tm2 = fromListDenseSM 3 [12, 6, -4, -51, 167, 24, 4, -68, -41]
tm4 = sparsifySM $ fromListDenseSM 4 [1,0,0,0,2,5,0,10,3,6,8,11,4,7,9,12]
tm9 = fromListSM (4, 4) [(0,0,pi), (1,1, 3), (3, 0, 23), (1,3, 45), (2,2,4), (3,2, 1), (3,1, 5), (3,3, exp 1)]

-- -- complex
tmc4 :: SpMatrix (Complex Double)
tmc4 = fromListDenseSM 3 [3:+1, 4:+(-1), (-5):+3, 2:+2, 3:+(-2), 5:+0.2, 7:+(-2), 9:+(-1), 2:+3]




-- λ> test tmc4

-- 1.00            , _               , _               
-- 1.10 - 0.70i    , 1.00            , _               
-- -1.20 + 1.40i   , -0.56 + 1.04i   , 1.00            

-- 3.00 + 1.00i    , 2.00 + 2.00i    , 7.00 - 2.00i    
-- _               , 2.20 - 5.60i    , -0.10 - 3.70i   
-- _               , _               , 16.99 + 8.24i   







-- playground

data T x
  = Leaf
  | Branch (T x) x (T x)
  deriving (Functor, Foldable, Traversable, Show)

next :: Monad m => StateT Int m Int
next = do x <- get ; modify (+ 1) ; return x

next' f = do x <- get;  modify f; return x

labelElt :: Monad m => x -> StateT Int m (Int, x)
labelElt x = (,) <$> next <*> pure x