module Water.SWMM where

import Control.Applicative ((<$>), (<*>))
import Control.Monad       (replicateM)
import Data.Binary.Get     (Get, getWord32le, getByteString)
import Data.Binary.IEEE754 (getFloat32le, getFloat64le)
import Data.List.Split     (chunksOf)
import Pipes.Binary        (decodeGet)
import Pipes.Parse         (Parser, evalStateT)

import qualified Data.ByteString.Char8 as BSC (ByteString)
import qualified Data.ByteString.Lazy  as BL  (ByteString, readFile, drop, take, length)
import qualified Pipes.ByteString      as PB  (ByteString, fromLazy)

data SWMMParams = SWMMParams { header        :: Header
                             , ids           :: ObjectIds
                             , properties    :: ObjectProperties
                             , variables     :: ReportingVariables
                             , intervals     :: ReportingInterval
                             , closingRecord :: ClosingRecord
                             } deriving (Show, Eq)

data Header = Header { headerIdNumber        :: Int
                     , versionNumber         :: Int
                     , codeNumber            :: Int
                     , numberOfSubcatchments :: Int
                     , numberOfNodes         :: Int
                     , numberOfLinks         :: Int
                     , numberOfPollutants    :: Int
                     } deriving (Show, Eq)

data ObjectIds = ObjectIds { subcatchmentIds  :: [BSC.ByteString]
                           , nodeIds          :: [BSC.ByteString]
                           , linkIds          :: [BSC.ByteString]
                           , pollutantIds     :: [BSC.ByteString]
                           , concentrationIds :: [Int]
                           } deriving (Show, Eq)

data ObjectProperties = ObjectProperties { subcatchmentProperties :: Properties
                                         , nodeProperties         :: Properties
                                         , linkProperties         :: Properties
                                         } deriving (Show, Eq)

data ReportingVariables = ReportingVariables { subcatchmentVariables :: Variables
                                             , nodeVariables         :: Variables
                                             , linkVariables         :: Variables
                                             , systemVariables       :: Variables
                                             } deriving (Show, Eq)

data ReportingInterval = ReportingInterval { startDateTime :: Double
                                           , timeIntervals :: Int
                                           } deriving (Show, Eq)

data ValuesForOneDateTime = ValuesForOneDateTime { dateTimeValue     :: Double
                                                 , subcatchmentValue :: [[Float]]
                                                 , nodeValue         :: [[Float]]
                                                 , linkValue         :: [[Float]]
                                                 , systemValue       :: [Float]
                                                 } deriving (Show)

data ClosingRecord = ClosingRecord { idBytePosition         :: Int
                                   , propertiesBytePosition :: Int
                                   , resultBytePosition     :: Int
                                   , numberOfPeriods        :: Int
                                   , errorCode              :: Int
                                   , closingIdNumber        :: Int
                                   } deriving (Show, Eq)

data Properties = Properties { numberOfProperties   :: Int
                             , codeNumberProperties :: [Int]
                             , valueProperties      :: [Float]
                             } deriving (Show, Eq)

data Variables = Variables { numberOfVariables   :: Int
                           , codeNumberVariables :: [Int]
                           } deriving (Show, Eq)

instance Ord ValuesForOneDateTime where
    a <= b = dateTimeValue a <= dateTimeValue b

instance Eq ValuesForOneDateTime where
    a == b = dateTimeValue a == dateTimeValue b


-- | Get a 4 byte integer
getInt :: Get Int
getInt = fromIntegral <$> getWord32le

-- | Get a string
getString :: Get BSC.ByteString
getString = getInt >>= getByteString

getHeader :: Get Header
getHeader = Header <$> getInt
                   <*> getInt
                   <*> getInt
                   <*> getInt
                   <*> getInt
                   <*> getInt
                   <*> getInt

getObjectIds :: Header -> Get ObjectIds
getObjectIds h = do
  s <- replicateM (numberOfSubcatchments h) getString
  n <- replicateM (numberOfNodes         h) getString
  l <- replicateM (numberOfLinks         h) getString
  p <- replicateM (numberOfPollutants    h) getString
  c <- replicateM (numberOfPollutants    h) getInt
  return $ ObjectIds s n l p c

getProperties :: Int -> Get Properties
getProperties n = do
  size <- getInt
  codeNumbers <- replicateM size getInt
  values <- replicateM (size * n) getFloat32le
  return $ Properties size codeNumbers values

getObjectProperties :: Header -> Get ObjectProperties
getObjectProperties h =
  ObjectProperties <$> getProperties (numberOfSubcatchments h)
                   <*> getProperties (numberOfNodes h)
                   <*> getProperties (numberOfLinks h)

getVariables :: Get Variables
getVariables = do
    size <- getInt
    codeNumbers <- replicateM size getInt
    return $ Variables size codeNumbers

getReportingVariables :: Get ReportingVariables
getReportingVariables = ReportingVariables <$> getVariables
                                           <*> getVariables
                                           <*> getVariables
                                           <*> getVariables

getReportingInterval :: Get ReportingInterval
getReportingInterval = ReportingInterval <$> getFloat64le
                                         <*> getInt

getSWMMTopParams :: Get ( Header, ObjectIds, ObjectProperties
                        , ReportingVariables, ReportingInterval )
getSWMMTopParams = do
  h <- getHeader
  oi <- getObjectIds h
  op <- getObjectProperties h
  rv <- getReportingVariables
  ri <- getReportingInterval
  return (h, oi, op, rv, ri)

getClosingRecord :: Get ClosingRecord
getClosingRecord = ClosingRecord <$> getInt
                                 <*> getInt
                                 <*> getInt
                                 <*> getInt
                                 <*> getInt
                                 <*> getInt

-- | Extract the SWMM parameters from the file, and return
--   the remainder of the binary file (containing only the
--   result data now) as a lazy bytestring. If the bytestring
--   is small, it can be forced and held in memory, otherwise
--   it needs to be streamed strictly.
extractSWMMParams :: FilePath -> IO (SWMMParams, BL.ByteString)
extractSWMMParams f = do
  -- Grab the Top SWMM Parameters
  s <- PB.fromLazy <$> BL.readFile f
  (h,oi,op,rv,ri) <- evalStateT decodeTopSWMMParameters s

  -- Grab the Closing Records
  e <- PB.fromLazy . grabClosingRecord <$> BL.readFile f
  c <- evalStateT decodeClosingRecord e

  -- Grab the actual data
  results <- dropUptoResults c . takeUptoClosingRecord <$> BL.readFile f

  return $ (SWMMParams h oi op rv ri c, results)
    where
      -- | Grab all parameters aligned before the data in an unsafe manner.
      decodeTopSWMMParameters
        :: Parser PB.ByteString IO ( Header, ObjectIds, ObjectProperties
                                   , ReportingVariables, ReportingInterval )
      decodeTopSWMMParameters = do
        Right top <- decodeGet getSWMMTopParams
        return top

      -- | Grab the closing record in an unsafe manner.
      decodeClosingRecord :: Parser PB.ByteString IO ClosingRecord
      decodeClosingRecord = do
        Right closing <- decodeGet getClosingRecord
        return closing

      -- | The number of bytes spanned by the closing record.
      closingRecordSize :: Num a => a
      closingRecordSize = 4 * 6

      -- | Drop the initial bytestring upto the closing record.
      --   Allows grabbing the closing record in O(1) time.
      grabClosingRecord :: BL.ByteString -> BL.ByteString
      grabClosingRecord = (\f -> BL.drop (BL.length f - closingRecordSize) f) 

      -- | Take the initial bytestring upto the closing record.
      --   Allows grabbing the closing record in O(1) time.
      takeUptoClosingRecord :: BL.ByteString -> BL.ByteString
      takeUptoClosingRecord = (\f -> BL.take (BL.length f - closingRecordSize) f) 

      -- | Drop upto the actual recorded data.
      --   Allows grabbing the results in O(1) time.
      dropUptoResults :: ClosingRecord -> BL.ByteString -> BL.ByteString
      dropUptoResults c = BL.drop (fromIntegral $ resultBytePosition c)

getSplitValues :: Int -> Int -> Get [[Float]]
getSplitValues n nv =
  chunksOf n <$> replicateM nv getFloat32le

getValues :: Int -> Get [Float]
getValues n = replicateM n getFloat32le

getResults :: Header -> ReportingVariables -> Get ValuesForOneDateTime
getResults header report =
  ValuesForOneDateTime <$> getFloat64le
                       <*> getSub
                       <*> getNode
                       <*> getLink
                       <*> getVar
    where nsub  = (numberOfVariables . subcatchmentVariables) report
          nnode = (numberOfVariables . nodeVariables        ) report
          nlink = (numberOfVariables . linkVariables        ) report
          nvar  = (numberOfVariables . systemVariables      ) report

          getSub  = getSplitValues nsub  (numberOfSubcatchments header * nsub )
          getNode = getSplitValues nnode (numberOfNodes         header * nnode)
          getLink = getSplitValues nlink (numberOfLinks         header * nlink)
          getVar  = getValues nvar