{-# language TypeFamilies #-}
{-# language TypeOperators, GADTs #-}
{-# language FlexibleInstances, FlexibleContexts #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   :  (C) 2016 Marco Zocca
-- License     :  GPL-3 (see LICENSE)
-- Maintainer  :  zocca.marco gmail
-- Stability   :  provisional
-- Portability :  portable
--
-----------------------------------------------------------------------------
module Data.Sparse.Common
       ( module X,
         insertRowWith, insertRow, insertColWith, insertCol,
         diagonalSM,
         outerProdSV, (><), toSV, svToSM,
         lookupRowSM, 
         extractCol, extractRow,
         extractVectorDenseWith, extractRowDense, extractColDense,
         extractDiagDense,
         extractSubRow, extractSubCol,
         extractSubRow_RK, extractSubCol_RK,
         fromRowsL, fromRowsV, fromColsV, fromColsL, toRowsL, toColsL) where

-- import Control.Exception
-- import Control.Exception.Common
-- import Control.Monad.Catch

import Data.Sparse.Utils as X
import Data.Sparse.PPrint as X
import Data.Sparse.Types as X
import Data.Sparse.Internal.IntMap2 -- as X
import qualified Data.Sparse.Internal.IntM as I
import Data.Sparse.Internal.IntM (IntM(..))
import Data.Sparse.SpMatrix as X
import Data.Sparse.SpVector as X
-- import Data.Sparse.Internal.CSR as X

import Numeric.Eps as X
import Numeric.LinearAlgebra.Class as X

import qualified Data.IntMap.Strict as IM

import GHC.Exts
import Data.Complex

-- import Control.Applicative
-- import Data.Traversable

import Data.Maybe (fromMaybe, maybe)
import qualified Data.Vector as V








-- withBoundsSM m ij e f
--   | isValidIxSM m ij = f m ij
--   | otherwise = error e

-- | Modify the size of a SpVector. Do not use directly
resizeSV :: Int -> SpVector a -> SpVector a
resizeSV d2 (SV _ sv) = SV d2 sv
-- | Remap the keys of a SpVector. Do not use directly
mapKeysSV :: (IM.Key -> IM.Key) -> SpVector a -> SpVector a
mapKeysSV fk (SV d sv) = SV d $ I.mapKeys fk sv

-- * Insert row/column vector in matrix

-- | Insert row , using the provided row index transformation function
insertRowWith :: (IxCol -> IxCol) -> SpMatrix a -> SpVector a -> IM.Key -> SpMatrix a
insertRowWith fj (SM (m,n) im) (SV d sv) i
  | not (inBounds0 m i) = error "insertRowSM : index out of bounds"
  | n >= d = SM (m,n) $ I.insert i (insertOrUnion i sv' im) im
  | otherwise = error $ "insertRowSM : incompatible dimensions " ++ show (n, d)
    where sv' = I.mapKeys fj sv
          insertOrUnion i' sv' im' = maybe sv' (I.union sv') (I.lookup i' im')

-- | Insert row          
insertRow :: SpMatrix a -> SpVector a -> IM.Key -> SpMatrix a
insertRow = insertRowWith id

-- | Insert column, using the provided row index transformation function
insertColWith :: (IxRow -> IxRow) -> SpMatrix a -> SpVector a -> IxCol -> SpMatrix a
insertColWith fi smm sv j
  | not (inBounds0 n j) = error "insertColSM : index out of bounds"
  | m >= mv = insIM2 smm vl j
  | otherwise = error $ "insertColSM : incompatible dimensions " ++ show (m,mv) where
      (m, n) = dim smm
      mv = dim sv
      vl = toListSV sv
      insIM2 im2 ((i,x):xs) j' = insIM2 (insertSpMatrix (fi i) j' x im2) xs j'
      insIM2 im2 [] _ = im2

-- | Insert column
insertCol :: SpMatrix a -> SpVector a -> IxCol -> SpMatrix a
insertCol = insertColWith id



-- * Outer vector product

-- | Outer product
outerProdSV, (><) :: Num a => SpVector a -> SpVector a -> SpMatrix a
outerProdSV v1 v2 = fromListSM (m, n) ixy where
  m = dim v1
  n = dim v2
  ixy = [(i,j, x * y) | (i,x) <- toListSV v1 , (j, y) <- toListSV v2]

(><) = outerProdSV



-- * Diagonal matrix

-- | Fill the diagonal of a SpMatrix with the components of a SpVector
diagonalSM :: SpVector a -> SpMatrix a
diagonalSM sv = ifoldSV iins (zeroSM n n) sv where
  n = dim sv
  iins i = insertSpMatrix i i



-- * Matrix-vector conversions

-- | promote a SV to SM
svToSM :: SpVector a -> SpMatrix a
svToSM (SV n d) = SM (n, 1) $ I.singleton 0 d

-- |Demote (n x 1) or (1 x n) SpMatrix to SpVector
toSV :: SpMatrix a -> SpVector a
toSV (SM (m, n) im) = SV d im' where
  im' | m < n = snd . head . toList $ im
      | otherwise = fmap g im
  g = snd . head . toList 
  d | m==1 && n==1 = 1
    | m==1 && n>1 = n 
    | n==1 && m>1 = m
    | otherwise = error $ "toSV : incompatible matrix dimension " ++ show (m,n)  





-- | Lookup a row in a SpMatrix; returns an SpVector with the row, if this is non-empty
lookupRowSM :: SpMatrix a -> IxRow -> Maybe (SpVector a)
lookupRowSM sm i = SV (ncols sm) <$> I.lookup i (dat sm) 


-- * Extract a SpVector from an SpMatrix
-- ** Sparse extract

-- |Extract ith row
extractRow :: SpMatrix a -> IxRow -> SpVector a
extractRow m i
  | inBounds0 (nrows m) i = fromMaybe (zeroSV (ncols m)) (lookupRowSM m i)
  | otherwise = error $ unwords ["extractRow : index",show i,"out of bounds"]

-- |Extract jth column
extractCol :: SpMatrix a -> IxCol -> SpVector a
extractCol m j = toSV $ extractColSM m j      



-- ** Dense extract (default == 0)
-- | Generic extraction function
extractVectorDenseWith ::
  Num a => (Int -> (IxRow, IxCol)) -> SpMatrix a -> SpVector a
extractVectorDenseWith f mm = fromListDenseSV n $ foldr ins [] ll  where
  ll = [0 .. n - 1]
  (_, n) = dim mm
  ins i acc = mm @@ f i : acc

-- | Extract ith row (dense)
extractRowDense :: Num a => SpMatrix a -> IxRow -> SpVector a
extractRowDense mm iref = extractVectorDenseWith (\j -> (iref, j)) mm

-- | Extract jth column
extractColDense :: Num a => SpMatrix a -> IxCol -> SpVector a
extractColDense mm jref = extractVectorDenseWith (\i -> (i, jref)) mm

-- | Extract the diagonal
extractDiagDense :: Num a => SpMatrix a -> SpVector a
extractDiagDense = extractVectorDenseWith (\i -> (i, i))



-- | extract row interval (all entries between columns j1 and j2, INCLUDED, are returned)
-- extractSubRow :: SpMatrix a -> IxRow -> (IxCol, IxCol) -> SpVector a
-- extractSubRow m i (j1, j2) = case lookupRowSM m i of
--   Nothing -> zeroSV (ncols m)
--   Just rv -> ifilterSV (\j _ -> j >= j1 && j <= j2) rv

-- |", returning in Maybe
-- extractSubRow :: SpMatrix a -> IxRow -> (Int, Int) -> Maybe (SpVector a)
-- extractSubRow m i (j1, j2) =
--   resizeSV (j2 - j1) . ifilterSV (\j _ -> j >= j1 && j <= j2) <$> lookupRowSM m i

-- | Extract an interval of SpVector components, changing accordingly the resulting SpVector size. Keys are _not_ rebalanced, i.e. components are still labeled according with respect to the source matrix.
extractSubRow :: SpMatrix a -> IxRow -> (Int, Int) -> SpVector a
extractSubRow m i (j1, j2) = fromMaybe (zeroSV deltaj) vfilt where
  deltaj = j2 - j1 + 1
  vfilt = resizeSV deltaj .
          ifilterSV (\j _ -> j >= j1 && j <= j2) <$> lookupRowSM m i

-- | extract row interval, rebalance keys by subtracting lowest one
extractSubRow_RK :: SpMatrix a -> IxRow -> (IxCol, IxCol) -> SpVector a
extractSubRow_RK m i (j1, j2) = mapKeysSV (subtract j1) $ extractSubRow m i (j1, j2)

  -- toSV $ extractSubRowSM_RK m i (j1, j2)



-- | extract column interval
extractSubCol :: SpMatrix a -> IxCol -> (IxRow, IxRow) -> SpVector a
extractSubCol m j (i1, i2)  = toSV $ extractSubColSM m j (i1, i2)

-- | extract column interval, rebalance keys by subtracting lowest one
extractSubCol_RK :: SpMatrix a -> IxCol -> (IxRow, IxRow) -> SpVector a
extractSubCol_RK m j (i1, i2)  = toSV $ extractSubColSM_RK m j (i1, i2)




-- ** Matrix action on a vector

{- 
FIXME : matVec is more general than SpVector's :

\m v -> fmap (`dot` v) m
  :: (Normed f1, Num b, Functor f) => f (f1 b) -> f1 b -> f b
-}

instance (InnerSpace t, Scalar t ~ t) => LinearVectorSpace (SpVector t) where
  type MatrixType (SpVector t) = SpMatrix t
  (#>) = matVecSD
  (<#) = vecMatSD

matVecSD :: (InnerSpace t, Scalar t ~ t) => SpMatrix t -> SpVector t -> SpVector t
matVecSD (SM (nr, nc) mdata) (SV n sv)
  | nc == n = SV nr $ fmap (`dot` sv) mdata
  | otherwise = error $ "matVec : mismatched dimensions " ++ show (nc, n)

-- |Vector-on-matrix (FIXME : transposes matrix: more costly than `matVec`, I think)
vecMatSD :: (InnerSpace t, Scalar t ~ t) => SpVector t -> SpMatrix t -> SpVector t
vecMatSD (SV n sv) (SM (nr, nc) mdata)
  | n == nr = SV nc $ fmap (`dot` sv) (transposeIM2 mdata)
  | otherwise = error $ "vecMat : mismatching dimensions " ++ show (n, nr)





-- -- generalized matVec : we require a function `rowsf` that produces a functor of elements of a Hilbert space (the rows of `m`)
-- matVecG :: (Hilbert v, Functor f, f (Scalar v) ~ v) => (m -> f v) -> m -> v -> v
-- matVecG rowsf m v = fmap (`dot` v) (rowsf m)

-- matVecGA
--   :: (Hilbert v, Traversable t, t (Scalar v) ~ v) =>
--      (m -> t v) -> m -> v -> v
-- matVecGA rowsf m v = traverse (<.> v) (rowsf m)

-- -- --  Really, a matrix is just notation for a linear map between two finite-dimensional Hilbert spaces, i.e.
-- matVec :: (Hilbert u, Hilbert v) => (u -> v) -> u -> v
-- which is a specialization of a function application operator like ($) :: (a -> b) -> a -> b




-- -- --  from `vector-space`

-- data a -* b where
--   Dot :: VectorSpace b => b -> (b -* Scalar b)
--   (:&&) :: (a -* c) -> (a -* d) -> (a -* (c, d)) -- a,c,d should be constrained

-- apply :: Hilbert a => (a -* b) -> (a -> b)
-- apply (Dot b) = dot b
-- apply (f :&& g) = apply f &&& apply g
--   where (u &&& v) a = (u a, v a) -- (&&&) from Control.Arrow

-- -- type a :~ b = Scalar a ~ Scalar b






-- | Pack a list of SpVectors as rows of an SpMatrix
fromRowsL :: [SpVector a] -> SpMatrix a
fromRowsL = fromRowsV . V.fromList

-- | Pack a list of SpVectors as columns an SpMatrix
fromColsL :: [SpVector a] -> SpMatrix a
fromColsL = fromColsV . V.fromList

-- | Unpack the rows of an SpMatrix into a list of SpVectors
toRowsL :: SpMatrix a -> [SpVector a]
toRowsL aa = map (extractRow aa) [0 .. m-1] where
  (m,n) = dim aa

-- | Unpack the columns of an SpMatrix into a list of SpVectors
toColsL :: SpMatrix a -> [SpVector a]
toColsL aa = map (extractCol aa) [0 .. n-1] where
  (m,n) = dim aa




-- | Pack a V.Vector of SpVectors as columns of an SpMatrix
fromColsV :: V.Vector (SpVector a) -> SpMatrix a
fromColsV qv = V.ifoldl' ins (zeroSM m n) qv where
  n = V.length qv
  m = dim $ V.head qv
  ins mm i c = insertCol mm c i


-- | Pack a V.Vector of SpVectors as rows of an SpMatrix
fromRowsV :: V.Vector (SpVector a) -> SpMatrix a
fromRowsV qv = V.ifoldl' ins (zeroSM m n) qv where
  n = V.length qv
  m = svDim $ V.head qv
  ins mm i c = insertRow mm c i





-- * Pretty printing


showNz :: (Epsilon a, Show a) => a -> String
showNz x | nearZero x = " _ "
         | otherwise = show x

toDenseRow :: Num a => SpMatrix a -> IM.Key -> [a]
toDenseRow sm irow =
  fmap (\icol -> sm @@ (irow,icol)) [0..ncols sm-1]



prdR, prdC :: PPrintOptions
prdR = PPOpts 4 2 7   -- reals
prdC = PPOpts 4 2 16  -- complex values



-- -- printDenseSM :: (Show t, Num t) => SpMatrix t -> IO ()
-- printDenseSM :: (ScIx c ~ (Int, Int), FDSize c ~ (Int, Int), SpContainer c a, Show a, Epsilon a) => c a -> IO ()
printDenseSM sm = do
  newline
  putStrLn $ sizeStrSM sm
  newline
  prd0 sm



printDenseSM0
  :: (SpMatrix a -> IxRow -> Int -> String) -> SpMatrix a -> IO ()
printDenseSM0 f sm = do
  printDenseSM' sm 5 5
  newline
  where
    printDenseSM' sm' nromax ncomax = mapM_ putStrLn rr_' where
      (nr, _) = (nrows sm, ncols sm)
      rr_ = map (\i -> f sm' i ncomax) [0..nr-1]
      rr_' | nr > nromax = take (nromax - 2) rr_ ++ [" ... "] ++[last rr_]
           | otherwise = rr_ 

printDenseSM0r :: SpMatrix Double -> IO ()
printDenseSM0r sm = printDenseSM0 g sm
  where
    g sm' irow ncolmax = printDN (ncols sm') ncolmax prdR $ toDenseRow sm' irow
printDenseSM0c :: SpMatrix (Complex Double) -> IO ()
printDenseSM0c sm = printDenseSM0 g sm
  where
    g sm' irow ncolmax = printCN (ncols sm') ncolmax prdC $ toDenseRow sm' irow


-- printDenseSV :: (Show t, Epsilon t) => SpVector t -> IO ()
printDenseSV :: PrintDense (SpVector a) => SpVector a -> IO ()
printDenseSV sv = do
  newline
  putStrLn $ sizeStrSV sv
  newline
  prd0 sv

printDenseSV0r :: SpVector Double -> IO ()
printDenseSV0r = printDenseSV0 g where
  g l n = printDN l n prdR
printDenseSV0c :: SpVector (Complex Double) -> IO ()
printDenseSV0c = printDenseSV0 g where
  g l n = printCN l n prdC

printDenseSV0 :: Num a =>
  (Int -> Int -> [a] -> String) -> SpVector a -> IO ()
printDenseSV0 f sv = do
  printDenseSV' (svDim sv) 5
  newline where
    printDenseSV' l n = putStrLn (f l n vd)
    vd = toDenseListSV sv


    
-- ** Pretty printer typeclass

instance PrintDense (SpVector Double) where
  prd = printDenseSV
  prd0 = printDenseSV0r

instance PrintDense (SpVector (Complex Double)) where
  prd = printDenseSV
  prd0 = printDenseSV0c

instance PrintDense (SpMatrix Double) where
  prd = printDenseSM
  prd0 = printDenseSM0r

instance PrintDense (SpMatrix (Complex Double)) where
  prd = printDenseSM
  prd0 = printDenseSM0c





  



-- instance (Elt a, Show a) => PrintDense (CsrMatrix a) where
--   prd = printDenseSM