{-# LANGUAGE FlexibleInstances, ScopedTypeVariables #-}

-- | Pure and composable Matrix Market reader and writer.
--
-- Usage example:
--
-- @
-- rm <- 'readMM' \`liftM\` readFile \"my-real-matrix.mtx\" :: IO ('ReadMatrix' Double)
-- case rm of
--   Right m -> -- Do something with the matrix m
--   Left err -> -- Report error
-- @
module Data.MatrixMarket
  ( -- * Data types
    Matrix(..)
  , MatrixData(..)
  , MValue(..)
  , MField(..)
  , Symmetry(..)
  , CM(..)
  , AM(..)
  , ReadError(..)
  , ReadMatrix
    -- * Read and write Matrix Market
  , readMM
  , dumpMM
    -- * Utility functions
  , mm'rows, mm'cols, mm'shape
  , toList
  , toCompleteList
  , toArray
  , toArrayM
  , at
    -- * Re-exports
  , Complex(..)
  )
  where

import Data.Array.IArray (IArray, array)
import Data.Array.MArray (MArray, newArray, writeArray)
import Data.Char (toLower)
import Data.Complex (Complex(..), conjugate)
import qualified Data.Map as Map
import Data.Maybe (listToMaybe)
import Control.Applicative ((<$>), (<*>))
import Control.Monad (join, forM_)

-- | Matrix Market format representation.
data (MValue a) => Matrix a
  = MM
  { mm'data :: MatrixData a
  , mm'field :: MField
  , mm'symmetry :: Symmetry
  , mm'comments :: [String]
  } deriving (Show, Eq)

-- | Matrix' data block.
data (MValue a) => MatrixData a
                = CoordinateM { coords'm :: CM a }
                | ArrayM { array'm :: AM a }
  deriving (Show, Eq)

-- | Number of rows in the matrix.
mm'rows :: MValue a => Matrix a -> Int
mm'rows m =
    case mm'data m of
      (CoordinateM cm) -> cm'rows cm
      (ArrayM am) -> am'rows am

-- | Number of columns in the matrix.
mm'cols :: MValue a => Matrix a -> Int
mm'cols m =
    case mm'data m of
      (CoordinateM cm) -> cm'cols cm
      (ArrayM am) -> am'cols am

-- | Dimensions of the matrix: (number of rows, number of columns).
mm'shape :: MValue a => Matrix a -> (Int, Int)
mm'shape m = (mm'rows m, mm'cols m)

-- | Coordinate format (sparse matrix).
data (MValue a) => CM a = CM
  { cm'rows :: Int
  , cm'cols :: Int
  , cm'size :: Int
  , cm'values :: [((Int, Int), a)]
  } deriving (Show, Eq)

-- | Array format (dense matrix).
data (MValue a) => AM a = AM
  { am'rows :: Int
  , am'cols :: Int
  , am'values :: [a]
  } deriving (Show, Eq)

-- | Field of the matrix.
data MField = MInt | MReal | MComplex | MPattern deriving (Show, Eq)

fieldname MInt = "integer"
fieldname MReal = "real"
fieldname MComplex = "complex"
fieldname MPattern = "pattern"

numColumns :: MField -> Int
numColumns MInt = 1
numColumns MReal = 1
numColumns MComplex = 2
numColumns MPattern = 0

-- | Values allowed in the Matrix Market files.
class (Num a, Show a) => MValue a where
    typename :: a -> String
    readval  :: [String] -> Maybe a
    showval  :: a -> String
    conj     :: a -> a

instance MValue Int where
  typename _ = "integer"
  readval [s] = maybeRead s
  readval _   = Nothing
  showval     = show
  conj        = id

instance MValue Double where
  typename _ = "real"
  readval [s] = maybeRead s
  readval _   = Nothing
  showval     = show
  conj        = id

instance MValue (Complex Double) where
  typename _  = "complex"
  readval [re,im] = (:+) <$> (maybeRead re) <*> (maybeRead im)
  readval _       = Nothing
  showval (re :+ im) = unwords [show re, show im]
  conj            = conjugate

maybeRead :: (Read a) => String -> Maybe a
maybeRead s = listToMaybe (fst <$> reads s)

-- | Symmetry class of the matrix.
data Symmetry = General | Symmetric | SkewSymmetric | Hermitian
  deriving (Show, Eq)

-- | Parsing errors.
data ReadError
  = NotAMatrixMarketFormat
  | InvalidHeader String
  | UnknownFormat String
  | UnexpectedField String
  | UnknownSymmetry String
  | NoParse
  deriving (Show, Eq)

-- | Construct a list of non-zero entries (without symmetric entries).
toList :: (MValue a) => Matrix a -> [((Int, Int), a)]
toList (MM (CoordinateM cm) _ _ _) = cm'values cm
toList (MM (ArrayM am) _ General _) =  -- all elements
    let nrs = am'rows am
        ncs = am'cols am
    in  zip [ (r,c) | c<-[1..ncs], r<-[1..nrs] ] $ am'values am
toList (MM (ArrayM am) _ _ _) =  -- only elements below diagonal
    let nrs = am'rows am
        ncs = am'cols am
    in  zip [ (r,c) | c<-[1..ncs], r<-[c..nrs] ] $ am'values am


-- | Construct a list of all non-zero entries (including symmetric entries).
toCompleteList :: (MValue a) => Matrix a -> [((Int, Int), a)]
toCompleteList m@(MM _ _ General _) = toList m
toCompleteList m@(MM _ _ Symmetric _) = concatMap (insertSymmetric) (toList m)
toCompleteList m@(MM _ _ SkewSymmetric _) = concatMap (insertSkew) (toList m)
toCompleteList m@(MM _ _ Hermitian _) = concatMap (insertHermitian) (toList m)

-- | Convert to an immutable dense array.
toArray :: (IArray arr a, MValue a)
        => Int      -- ^ array starting index, usually 0 or 1
        -> Matrix a -- ^ matrix to convert
        -> arr (Int,Int) a
toArray idx0 m =
  let rows = mm'rows m
      cols = mm'cols m
      nonzeros = map (\((i,j),v) -> ((idx0+i-1,idx0+j-1),v)) $ toCompleteList m
      mat = Map.fromList nonzeros
  in  array ((idx0,idx0), (rows-1+idx0,cols-1+idx0)) $
            [ ((i,j), Map.findWithDefault 0 (i,j) mat) |
              i<-[idx0..(rows-1+idx0)], j<-[idx0..(cols-1+idx0)] ]

-- | Convert to a mutable dense array.
toArrayM :: (MArray arr a m, MValue a)
         => Int       -- ^ array starting index, usually 0 or 1
         -> Matrix a  -- ^ matrix to convert
         -> m (arr (Int,Int) a)
toArrayM idx0 mat =
  let rows = mm'rows mat
      cols = mm'cols mat
      rng = ((idx0, idx0), (rows-1+idx0, cols-1+idx0))
      nonzeros = map (\((i,j),v) -> ((idx0+i-1,idx0+j-1),v))$toCompleteList mat
  in  do
      newarr <- newArray rng 0
      forM_ nonzeros $ \(ij,v) -> writeArray newarr ij v
      return newarr

insertSymmetric p@((i,j),v)
    | i /= j     = [p, ((j,i),v)]
    | otherwise  = [p]
insertSkew p@((i,j),v)
    | i /= j     = [p, ((j,i),(-v))]
    | otherwise  = [p]
insertHermitian p@((i,j),v)
    | i /= j     = [p, ((j,i),conj v)]
    | otherwise  = [p]

-- | Get an element of the matrix at the specified position.
-- Warning: This operation is slow, use 'toArray' or 'toArrayM'
-- to convert to an array first.
at :: MValue a => Matrix a -> (Int, Int) -> a
m `at` ij = maybe 0 id $ lookup ij $ toCompleteList m

-- | Write Matrix Market format.
dumpMM :: MValue a => Matrix a -> String
dumpMM (MM md fld sy coms) = unlines $ header : (map ('%':) coms) ++ body
  where
    header =
        let fmt = case md of
                    (CoordinateM _) -> "coordinate"
                    (ArrayM _) -> "array"
            sym = case sy of
                      General -> "general"
                      Symmetric -> "symmetric"
                      SkewSymmetric -> "skew-symmetric"
                      Hermitian -> "hermitian"
        in  "%%MatrixMarket matrix " ++ unwords [fmt, fieldname fld, sym]
    body = case md of
      (CoordinateM cm) -> dumpCM cm
      (ArrayM am) -> dumpAM am
    dumpCM (CM rows cols size vals) =
        unwords [show rows, show cols, show size] :
        map (\((i,j), v) -> unwords [show i, show j, showval v]) vals
    dumpAM (AM rows cols vals) = unwords [show rows, show cols] : map show vals

-- | Use this type synonym to specify the type of 'readMM' when calling.
type ReadMatrix a = Either ReadError (Matrix a)

-- | Parse Matrix Market format.
readMM :: (MValue a) => String -> ReadMatrix a
readMM mtx
  | mtx `startsWith` "%%MatrixMarket" = readMM' mtx
  | otherwise = Left NotAMatrixMarketFormat
  where
  readMM' s = r
    where
    r =
      let hdr  = safeHead $ lines s
          (clins,lins) = span (`startsWith` "%") $ lines s
          coms = map (drop 1) (drop 1 clins)  -- comments
          toks = concatMap words lins
      in  case words . map toLower $ hdr of
          ("%%matrixmarket":"matrix":fmt:field:sym:_) ->
            let p = lookup fmt parsers
                aval = (undefined::Either ReadError (Matrix a) -> a) r
                fi = if field == typename aval
                     then lookup field fields :: Maybe MField
                     else Nothing
                ncols = numColumns <$> fi    :: Maybe Int
                sy = lookup sym symmetries   :: Maybe Symmetry
                d = join $ p <*> ncols <*> (Just toks)
                m = MM <$> d <*> fi <*> sy <*> (Just coms)
            in  case m of
                Just m' -> Right m'
                Nothing -> Left $
                  case (p,fi,sy) of
                    (Nothing,_,_) -> UnknownFormat fmt
                    (_,Nothing,_) -> UnexpectedField field
                    (_,_,Nothing) -> UnknownSymmetry sym
                    _             -> NoParse
          _ -> Left $ InvalidHeader hdr
  --
  parsers = [ ("coordinate", readCoords)
            , ("array",      readArray)]
  fields =  [ ("real", MReal)
            , ("integer", MInt)
            , ("complex", MComplex)
            , ("pattern", MPattern)]
  symmetries = [ ("general",        General)
               , ("symmetric",      Symmetric)
               , ("skew-symmetric", SkewSymmetric)
               , ("hermitian",      Hermitian)]

-- | Read matrix in coordinate format.
readCoords :: forall a . MValue a => Int -> [String] -> Maybe (MatrixData a)
readCoords n (n1:n2:n3:toks) =
  let [nrows,ncols,nsize] = map maybeRead [n1,n2,n3]
      pts = ngroup (2+n) toks :: [[String]]
      vals = sequence $ map (readpt n) pts :: Maybe [((Int,Int),a)]
      cm = CM <$> nrows <*> ncols <*> nsize <*> vals :: Maybe (CM a)
  in  CoordinateM <$> cm
readCoords _ _ = Nothing

-- | Read matrix in array format.
readArray :: forall a . MValue a => Int -> [String] -> Maybe (MatrixData a)
readArray n (n1:n2:toks) =
    let nrows = maybeRead n1 :: Maybe Int
        ncols = maybeRead n2 :: Maybe Int
        vals = sequence $ map (readval . (:[])) toks :: Maybe [a]
        am = AM <$> nrows <*> ncols <*> vals :: Maybe (AM a)
    in  ArrayM <$> am
readArray _ _ = Nothing

readpt :: forall a . MValue a => Int -> [String] -> Maybe ((Int,Int),a)
readpt n (si:sj:rest)
    | length rest /= n = Nothing
    | otherwise =
        let i = maybeRead si :: Maybe Int
            j = maybeRead sj :: Maybe Int
            v = readval rest :: Maybe a
            coords = (,) <$> i <*> j :: Maybe (Int,Int)
        in  (,) <$> coords <*> v     :: Maybe ((Int,Int), a)

startsWith :: (Eq a) => [a] -> [a] -> Bool
s `startsWith` prefix = all id $ zipWith (==) s prefix

ngroup :: Int -> [a] -> [[a]]
ngroup _ [] = []
ngroup n xs = take n xs : ngroup n (drop n xs)

safeHead :: [[a]] -> [a]
safeHead [] = []
safeHead (x:_) = x