{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveGeneric         #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators #-}

module Data.Matrix.Static.IO
    ( fromMM
    , withMM
    ) where

import qualified Data.ByteString.Char8 as B
import Conduit
import Control.Monad (when)
import qualified Data.Vector.Generic as G
import           Data.ByteString.Lex.Fractional (readExponential, readSigned)
import           Data.ByteString.Lex.Integral   (readDecimal, readDecimal_)
import Data.Singletons
import Data.Maybe
import Data.Singletons.TypeLits

import qualified Data.Matrix.Static.Sparse as S

data MMElem = MMReal
            | MMComplex
            | MMInteger
            | MMPattern
            deriving (Eq)

class IOElement a where
    decodeElem :: B.ByteString -> a
    elemType :: Proxy a -> MMElem

instance IOElement Int where
    decodeElem x = fst . fromMaybe errMsg . readSigned readDecimal $ x
      where
        errMsg = error $ "readInt: Fail to cast ByteString to Int:" ++ show x
    elemType _ = MMInteger

instance IOElement Double where
    decodeElem x = fst . fromMaybe errMsg . readSigned readExponential $ x
      where
        errMsg = error $ "readDouble: Fail to cast ByteString to Double:" ++ show x
    elemType _ = MMReal

withMM :: forall m v a b. (Monad m, G.Vector v a, IOElement a)
       => ConduitT () B.ByteString m ()
       -> (forall r c. S.SparseMatrix r c v a -> m b)
       -> m b
withMM conduit f = do
    (ty, (r,c,nnz)) <- parseHeader conduit
    when (elemType (Proxy :: Proxy a) /= ty) $ error "Element types do not match"
    withSomeSing (fromIntegral (r :: Int)) $ \(SNat :: Sing r) ->
        withSomeSing (fromIntegral (c :: Int)) $ \(SNat :: Sing c) -> do
            mat@(S.SparseMatrix v _ _) <- S.fromTripletC triplets
            let n = G.length v
            when (n /= nnz) $ error $
                "number of non-zeros do not match: " <> show nnz <> "/=" <> show n
            f (mat :: S.SparseMatrix r c v a)
  where
    triplets = conduit .| linesUnboundedAsciiC .| filterC (not . (=='%') . B.head) .|
        (dropC 1 >> streamTriplet)

fromMM :: forall m r c v a. (Monad m, SingI r, SingI c, G.Vector v a, IOElement a)
       => ConduitT () B.ByteString m () -> m (S.SparseMatrix r c v a)
fromMM conduit = do
    (ty, (r,c,nnz)) <- parseHeader conduit
    mat@(S.SparseMatrix v _ _) <- case () of
        _ | elemType (Proxy :: Proxy a) /= ty -> error "Element types do not match"
          | (r, c) /= (nrow, ncol) -> error $ "Dimensions do not match: " <>
                show (r,c) <> "/=" <> show (nrow,ncol)
          | otherwise -> S.fromTripletC triplets
    let n = G.length v
    if n /= nnz
        then error $ "number of non-zeros do not match: " <> show nnz <> "/=" <> show n
        else return mat
  where
    triplets = conduit .| linesUnboundedAsciiC .| filterC (not . (=='%') . B.head) .|
        (dropC 1 >> streamTriplet)
    nrow = fromIntegral $ fromSing (sing :: Sing r) :: Int
    ncol = fromIntegral $ fromSing (sing :: Sing c) :: Int

parseHeader :: Monad m => ConduitT () B.ByteString m () -> m (MMElem, (Int, Int, Int))
parseHeader conduit = runConduit $
    conduit .| linesUnboundedAsciiC .| headerParser
  where
    parse x
        | "%%MatrixMarket" `B.isPrefixOf` x = case B.words x of
            [_, _, format, ty, form] ->
                let ty' = case ty of
                        "real" -> MMReal
                        "complex" -> MMComplex
                        "integer" -> MMInteger
                        "pattern" -> MMPattern
                        t -> error $ "Unknown type: " <> show t
                in ty'
            _ -> error $ "Cannot parse header: " <> show x
        | otherwise = error $ "Cannot parse header: " <> show x
    headerParser = do
        ty <- headC >>= \case
            Nothing -> error "Empty file"
            Just header -> return $ parse header
        line <- filterC (not . (=='%') . B.head) .| headC
        case line of
            Nothing -> error "Empty file"
            Just x ->
                let [r, c, nnz] = map decodeElem $ B.words x
                in return (ty, (r, c, nnz))
{-# INLINE parseHeader #-}

streamTriplet :: (Monad m, IOElement a) => ConduitT B.ByteString (Int, Int, a) m ()
streamTriplet = mapC (f . B.words)
  where
    f [i,j,x] = (readDecimal_ i, readDecimal_ j, decodeElem x)
    f x = error $ "Formatting error: " <> show x
{-# INLINE streamTriplet #-}