{- Math.Clustering.Hierarchical.Spectral.Load
Gregory W. Schwartz

Collects the functions pertaining to loading a matrix.
-}

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

module Math.Clustering.Hierarchical.Spectral.Load
    ( readDenseAdjMatrix
    , readSparseAdjMatrix
    -- , readEigenSparseAdjMatrix
    ) where

-- Remote
import Control.Monad.Except (runExceptT, ExceptT (..))
import Control.Monad.Managed (with, liftIO, Managed (..))
import Data.Maybe (fromMaybe, catMaybes)
import System.IO (Handle (..))
import qualified Data.ByteString.Streaming.Char8 as BS
import qualified Data.Csv as CSV
-- import qualified Data.Eigen.SparseMatrix as E
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import qualified Data.Sparse.Common as SH
import qualified Data.Text as T
import qualified Data.Vector as V
import qualified Numeric.LinearAlgebra as H
import qualified Streaming as S
import qualified Streaming.Cassava as S
import qualified Streaming.Prelude as S
import qualified Streaming.With.Lifted as SW

-- Local
import Math.Clustering.Hierarchical.Spectral.Types

-- | Generic error message.
errorMsg = error "Not correct format (requires row,column,value)"

-- | Parse a row of a label index file.
parseRow :: (T.Text, T.Text, Double) -> ((T.Text, T.Text), Double)
parseRow (i, j, v) = ((i, j), v)

-- | Ignore the disconnected vertices, not used (rather use very small weight).
ignoreDisconnected :: V.Vector T.Text
                   -> H.Matrix Double
                   -> (V.Vector T.Text, H.Matrix Double)
ignoreDisconnected items mat = (newItems, newMat)
  where
    newItems = V.fromList $ fmap ((V.!) items) valid
    newMat = mat H.?? (H.Pos $ H.idxs valid, H.Pos $ H.idxs valid)
    valid = catMaybes
          . zipWith (\x xs -> if sum xs > 0 then Just x else Nothing) [0..]
          . H.toLists
          $ mat

-- | Ensure symmetry.
symmetric :: [((Int, Int), Double)] -> [((Int, Int), Double)]
symmetric = concatMap (\((!i, !j), v) -> [((i, j), v), ((j, i), v)])

-- | Ensure zeros on diagonal.
zeroDiag :: [((Int, Int), Double)] -> [((Int, Int), Double)]
zeroDiag = filter (\((!i, !j), _) -> i /= j)

-- | Get the translated matrix indices.
getNewIndices
    -- :: (Eq a, Ord a)
    -- => [((a, a), Double)] -> [((Int, Int), Double)]
    :: [((T.Text, T.Text), Double)] -> [((Int, Int), Double)]
getNewIndices xs =
    fmap
        (\((!i,!j),!v) ->
              ( ( Map.findWithDefault eMsg i idxMap
                , Map.findWithDefault eMsg j idxMap
                )
              , v
              )
        )
        xs
  where
    eMsg     = error "Index not found during index conversion."
    indices  = getAllIndices xs
    idxMap   = Map.fromList $ zip indices [0 ..]

-- | Get the list of all indices.
getAllIndices :: (Eq a, Ord a) => [((a, a), Double)] -> [a]
getAllIndices xs = Set.toAscList . Set.union (getSet fst) $ getSet snd
  where
    getSet f = Set.fromList . fmap (f . fst) $ xs

-- | Get a dense adjacency matrix from a handle.
readDenseAdjMatrix :: CSV.DecodeOptions
                   -> Handle
                   -> IO (V.Vector T.Text, H.Matrix Double)
readDenseAdjMatrix decodeOpt handle = flip with return $ do
    let getAssocList = S.toList_ . S.map parseRow

    assocList <-
        fmap (either (error . show) id)
            . runExceptT
            . getAssocList
            . S.decodeWith decodeOpt S.NoHeader
            $ (BS.hGetContents handle :: BS.ByteString (ExceptT S.CsvParseException Managed) ())

    let items = V.fromList $ getAllIndices assocList
        mat   = H.assoc (V.length items, V.length items) 0
              . Set.toList
              . Set.fromList -- Ensure no duplicates.
              . symmetric -- Ensure symmetry.
              . zeroDiag -- Ensure zeros on diagonal.
              . getNewIndices -- Only look at present rows by converting indices.
              $ assocList

    return (items, mat)

-- | Get a sparse adjacency matrix from a handle.
readSparseAdjMatrix :: CSV.DecodeOptions
                    -> Handle
                    -> IO (V.Vector T.Text, SH.SpMatrix Double)
readSparseAdjMatrix decodeOpt handle = flip with return $ do
    let getAssocList = S.toList_ . S.map parseRow

    assocList <-
        fmap (either (error . show) id)
            . runExceptT
            . getAssocList
            . S.decodeWith decodeOpt S.NoHeader
            $ (BS.hGetContents handle :: BS.ByteString (ExceptT S.CsvParseException Managed) ())

    let items = V.fromList $ getAllIndices assocList
        mat   = SH.fromListSM (V.length items, V.length items)
              . Set.toList
              . Set.fromList -- Ensure no duplicates.
              . fmap (\((i, j), v) -> (i, j, v))
              . symmetric -- Ensure symmetry.
              . zeroDiag -- Ensure zeros on diagonal.
              . getNewIndices -- Only look at present rows by converting indices.
              $ assocList

    return (items, mat)

-- -- | Get a sparse adjacency matrix from a handle.
-- readEigenSparseAdjMatrix :: CSV.DecodeOptions
--                     -> Handle
--                     -> IO (V.Vector T.Text, E.SparseMatrixXd)
-- readEigenSparseAdjMatrix decodeOpt handle = flip with return $ do
--     let getAssocList = S.toList_ . S.map parseRow

--     assocList <-
--         fmap (either (error . show) id)
--             . runExceptT
--             . getAssocList
--             . S.decodeWith decodeOpt S.NoHeader
--             $ (BS.hGetContents handle :: BS.ByteString (ExceptT S.CsvParseException Managed) ())

--     let items = V.fromList $ getAllIndices assocList
--         mat   = E.fromList (V.length items) (V.length items)
--               . Set.toList
--               . Set.fromList -- Ensure no duplicates.
--               . fmap (\((i, j), v) -> (i, j, v))
--               . symmetric -- Ensure symmetry.
--               . zeroDiag -- Ensure zeros on diagonal.
--               . getNewIndices -- Only look at present rows by converting indices.
--               $ assocList

--     return (items, mat)