{-# LANGUAGE GADTs, OverloadedStrings, DeriveFunctor #-}
module Data.Matrix.MatrixMarket.Internal
(readMatrix, readMatrix', readArray, readArray',
writeMatrix, writeMatrix', writeArray, writeArray',
Matrix(..), Array(..),
Format (Coordinate, Array), Structure (General, Symmetric, Hermitian, Skew),
nnz, dim, numDat,
dimArr, numDatArr,
ImportError(..), ExportError(..)) where
import Control.Applicative hiding ( many )
import Data.Functor (($>))
import Data.Int
import qualified Data.Char as C
import Data.Complex
import qualified Data.Scientific as S
import Data.Attoparsec.ByteString.Char8 hiding (I)
import qualified Data.ByteString.Lazy.Char8 as B
import qualified Data.Attoparsec.Lazy as L
import qualified Data.ByteString.Lazy as LBS
import Control.Monad.Catch
import Control.Exception.Common (ImportError(..), ExportError(..))
data Format = Coordinate | Array
deriving (Eq, Show)
data Field = R | C | I | P
deriving (Eq, Show)
data Structure = General | Symmetric | Hermitian | Skew
deriving (Eq, Show)
data Matrix a = RMatrix (Int, Int) Int Structure [(Int, Int, a)]
| CMatrix (Int, Int) Int Structure [(Int, Int, Complex a)]
| PatternMatrix (Int,Int) Int Structure [(Int32,Int32)]
| IntMatrix (Int,Int) Int Structure [(Int32,Int32,Int)]
deriving (Eq, Show)
data Array a = RArray (Int, Int) Structure [a]
| CArray (Int, Int) Structure [Complex a]
deriving (Eq, Show)
comment :: Parser ()
comment = char '%' *> skipWhile (not . eol) *> endOfLine
where
eol w = w `elem` ("\n\r" :: String)
floating :: Parser S.Scientific
floating = skipSpace' *> scientific
integral :: Integral a => Parser a
integral = skipSpace' *> decimal
format :: Parser Format
format = string "coordinate" $> Coordinate
<|> string "array" $> Array
<?> "matrix format"
field :: Parser Field
field = string "real" $> R
<|> string "complex" $> C
<|> string "integer" $> I
<|> string "pattern" $> P
<?> "matrix field"
structure :: Parser Structure
structure = string "general" $> General
<|> string "symmetric" $> Symmetric
<|> string "hermitian" $> Hermitian
<|> string "skew-symmetric" $> Skew
<?> "matrix structure"
header :: Parser (Format,Field,Structure)
header = string "%%MatrixMarket matrix"
>> (,,) <$> (skipSpace' *> format)
<*> (skipSpace' *> field)
<*> (skipSpace' *> structure)
<* endOfLine
<?> "MatrixMarket header"
extentMatrix :: Parser (Int,Int,Int)
extentMatrix = do
[m,n,l] <- skipSpace' *> count 3 integral <* endOfLine
return (m,n,l)
extentArray :: Parser (Int,Int)
extentArray = do
[m,n] <- skipSpace' *> count 2 integral <* endOfLine
return (m,n)
line3 :: Integral i => Parser a -> Parser (i,i,a)
line3 f = (,,) <$> integral
<*> integral
<*> f
<* (endOfLine <|> L.endOfInput)
skipSpace' :: Parser String
skipSpace' = many' space
matrix :: Parser (Matrix S.Scientific)
matrix = do
(f, t, s) <- header
(m, n, l) <- skipMany comment *> extentMatrix
if f /= Coordinate
then fail "matrix is not in Coordinate format"
else
case t of
R -> RMatrix (m,n) l s <$> many1 (line3 floating)
C -> CMatrix (m,n) l s <$> many1 (line3 ((:+) <$> floating <*> floating))
I -> IntMatrix (m,n) l s <$> many1 (line3 integral)
P -> PatternMatrix (m,n) l s <$> many1 ((,) <$> integral <*> integral)
array :: Parser (Array S.Scientific)
array = do
(f, t, s) <- header
(m, n) <- skipMany comment *> extentArray
if f /= Array
then fail "array is not in Array format"
else
case t of
R -> RArray (m,n) s <$> many1 floating
C -> CArray (m,n) s <$> many1 ((:+) <$> floating <*> floating)
_ -> fail "integer and pattern cases not relevant for the dense case"
readMatrix :: FilePath -> IO (Matrix S.Scientific)
readMatrix file = LBS.readFile file >>= readMatrix'
readMatrix' :: MonadThrow m => LBS.ByteString -> m (Matrix S.Scientific)
readMatrix' chunks =
case L.parse matrix chunks of
L.Fail _ _ msg -> throwM (FileParseError "readMatrix" msg)
L.Done _ mtx -> return mtx
readArray :: FilePath -> IO (Array S.Scientific)
readArray file = do
chunks <- LBS.readFile file
readArray' chunks
readArray' :: MonadThrow m => LBS.ByteString -> m (Array S.Scientific)
readArray' chunks =
case L.parse array chunks of
L.Fail _ _ msg -> throwM (FileParseError "readArray" msg)
L.Done _ mtx -> return mtx
showFormat :: Format -> String
showFormat = map C.toLower <$> show
showField :: Field -> String
showField f = case f of R -> "real"
C -> "complex"
I -> "integer"
P -> "pattern"
showStruct :: Structure -> String
showStruct = map C.toLower <$> show
headerStr :: Format -> Field -> Structure -> LBS.ByteString
headerStr f t s =
B.pack $ unwords ["%%MatrixMarket matrix",
showFormat f, showField t, showStruct s]
nl :: LBS.ByteString
nl = toLBS "\n"
showLines :: (a -> String) -> [a] -> LBS.ByteString
showLines showf d = LBS.concat (LBS.pack . withNewline . showf <$> d) where
withNewline x = toEnum . C.ord <$> x ++ "\n"
writeMatrix :: Show b => FilePath -> Matrix b -> IO ()
writeMatrix fp mat = do
mbs <- writeMatrix' mat
LBS.writeFile fp mbs
writeMatrix' :: (MonadThrow m, Show b) => Matrix b -> m LBS.ByteString
writeMatrix' mat =
case mat of (RMatrix d nz s dat) ->
pure $ matrixByteString d nz R s dat
(CMatrix d nz s dat) ->
pure $ matrixByteString d nz C s dat
(IntMatrix d nz s dat) ->
pure $ matrixByteString d nz I s dat
_ -> throwM (FormatExportNotSupported "writeMatrix" "PatternMatrix not implemented yet")
where
matrixByteString di nz t s d =
LBS.concat [headerStr Coordinate t s,
nl,
headerSzMatrix di nz,
nl,
showLines sf3 d]
where
sf3 (i,j,x) = unwords [show i, show j, show x]
headerSzMatrix (m,n) numz = B.pack $ unwords [show m, show n, show numz]
writeArray :: Show a => FilePath -> Array a -> IO ()
writeArray file arr =
case arr of (RArray d s dat) -> LBS.writeFile file (arrayByteString d R s dat)
(CArray d s dat) -> LBS.writeFile file (arrayByteString d C s dat)
where
arrayByteString di t s d =
LBS.concat [headerStr Array t s,
nl,
headerSzArray di,
nl,
showLines show d] where
headerSzArray (m,n) = B.pack $ unwords [show m, show n]
writeArray' :: Show a => Array a -> LBS.ByteString
writeArray' arr =
case arr of (RArray d s dat) -> arrayByteString d R s dat
(CArray d s dat) -> arrayByteString d C s dat
where
arrayByteString di t s d =
LBS.concat [headerStr Array t s,
nl,
headerSzArray di,
nl,
showLines show d]
headerSzArray (m,n) = B.pack $ unwords [show m, show n]
nnz :: Matrix t -> Int
nnz m = case m of (RMatrix _ nz _ _) -> nz
(CMatrix _ nz _ _) -> nz
(PatternMatrix _ nz _ _) -> nz
(IntMatrix _ nz _ _) -> nz
dim :: Matrix t -> (Int, Int)
dim m = case m of (RMatrix d _ _ _) -> d
(CMatrix d _ _ _) -> d
(PatternMatrix d _ _ _) -> d
(IntMatrix d _ _ _) -> d
numDat :: Matrix t -> Int
numDat m = case m of (RMatrix _ _ _ d) -> length d
(CMatrix _ _ _ d) -> length d
(PatternMatrix _ _ _ d) -> length d
(IntMatrix _ _ _ d) -> length d
dimArr :: Array t -> (Int, Int)
dimArr a = case a of (RArray d _ _) -> d
(CArray d _ _) -> d
numDatArr :: Array a -> Int
numDatArr a = case a of (RArray _ _ ll) -> length ll
(CArray _ _ ll) -> length ll
toLBS :: String -> LBS.ByteString
toLBS x = LBS.pack $ (toEnum . C.ord) <$> x