module RL_Glue.TaskSpec (
  -- * Task Spec data types
  TaskSpec(TaskSpec), ProblemType(Episodic, Continuing, OtherProblemType),
  DiscountFactor, AbsDataType(AbsDataType), IntsBounds, DoublesBounds,
  NumChars, DataBounds, RewardBounds,
  -- ** Bounds
  LowBound(LowBound, NegInf, LBUnspec), UpBound(UpBound, PosInf, UBUnspec),
  -- * Parsing functions
  toTaskSpec, toTaskSpecOrDie, parseTaskSpec
  ) where

import Control.Monad
import qualified Data.ByteString as BS
import Text.Parsec
import Text.Parsec.ByteString
import System.Exit

-- Datatype definitions
data TaskSpec = TaskSpec ProblemType DiscountFactor AbsDataType AbsDataType RewardBounds String
  deriving (Show)

data ProblemType = Episodic | Continuing | OtherProblemType String
  deriving (Show)

type DiscountFactor = Double

data AbsDataType = AbsDataType IntsBounds DoublesBounds NumChars
  deriving (Show)
type IntsBounds = [DataBounds Int]
type DoublesBounds = [DataBounds Double]
type NumChars = Int
type DataBounds a = (LowBound a, UpBound a)
data LowBound a = LowBound a | NegInf | LBUnspec
  deriving (Show)
data UpBound a = UpBound a | PosInf | UBUnspec
  deriving (Show)

type RewardBounds = DataBounds Double

-- Parsing functions
toTaskSpec :: BS.ByteString -> Either ParseError TaskSpec
toTaskSpec = parse parseTaskSpec "(network)"

toTaskSpecOrDie :: BS.ByteString -> IO TaskSpec
toTaskSpecOrDie str = either (\x -> print x >> exitFailure) return (toTaskSpec str)

parseTaskSpec = do
  parseVersion
  spaces
  probType <- parseProblemType
  spaces
  discountFactor <- parseDiscountFactor
  spaces
  obs <- parseObservations
  spaces
  act <- parseActions
  spaces
  reward <- parseRewards
  spaces
  extra <- parseExtra
  return $ TaskSpec probType discountFactor obs act reward extra

parseVersion = do
  -- We currently only parse version 3.0
  string "VERSION"
  spaces
  string "RL-Glue-3.0"

parseProblemType = do
  string "PROBLEMTYPE"
  spaces
  probTypeStr <- many (letter <|> digit)
  return $ case probTypeStr of
    "episodic" -> Episodic
    "continuing" -> Continuing
    _ -> OtherProblemType probTypeStr

parseDiscountFactor = do
  string "DISCOUNTFACTOR"
  spaces
  numStr <- many1 (digit <|> char '.')
  return $ read numStr

parseObservations = do
  string "OBSERVATIONS"
  spaces
  parseAbsDataType

parseActions = do
  string "ACTIONS"
  spaces
  parseAbsDataType

parseRewards = do
  string "REWARDS"
  spaces
  char '('
  lower <- parseLB
  spaces
  upper <- parseUB
  char ')'
  return (lower, upper)

parseExtra = do
  string "EXTRA"
  spaces
  many anyChar

parseAbsDataType :: Parsec BS.ByteString () AbsDataType
parseAbsDataType = do
  intObsType <- parseIntsBounds
  doubleObsType <- parseDoublesBounds
  charObsType <- parseNumChars
  return $ AbsDataType intObsType doubleObsType charObsType

parseRepeatable :: Parsec BS.ByteString () a -> Parsec BS.ByteString () [a]
parseRepeatable parser =
  try (do
    times <- liftM read $ many1 digit
    spaces
    x <- parser
    return $ replicate times x)
  <|> liftM (: []) parser

parseIntsBounds :: Parsec BS.ByteString () IntsBounds
parseIntsBounds = try (do
  string "INTS"
  parseInnerBounds) <|> return []

parseDoublesBounds :: Parsec BS.ByteString () DoublesBounds
parseDoublesBounds = try (do
  string "DOUBLES"
  parseInnerBounds) <|> return []

parseInnerBounds :: Read a => Parsec BS.ByteString () [DataBounds a]
parseInnerBounds = do
  spaces
  xs <- many (do
    x <- parseAbsDataTypeTuple
    spaces
    return x)
  return $ concat xs

parseAbsDataTypeTuple :: Read a => Parsec BS.ByteString () [DataBounds a]
parseAbsDataTypeTuple = do
  char '('
  x <- parseRepeatable parseBoundsTuple
  char ')'
  return x

parseBoundsTuple :: Read a => Parsec BS.ByteString () (DataBounds a)
parseBoundsTuple = do
  lb <- parseLB
  spaces
  ub <- parseUB
  return (lb, ub)

parseLB :: Read a => Parsec BS.ByteString () (LowBound a)
parseLB = 
  liftM (LowBound . read) (many1 $ char '-' <|> char '.' <|> digit) <|> 
  (string "NEGINF" >> return NegInf) <|>
  (string "UNSPEC" >> return LBUnspec)

parseUB :: Read a => Parsec BS.ByteString () (UpBound a)
parseUB = 
  liftM (UpBound . read) (many1 $ char '-' <|> char '.' <|> digit) <|> 
  (string "POSINF" >> return PosInf) <|>
  (string "UNSPEC" >> return UBUnspec)

parseNumChars :: Parsec BS.ByteString () NumChars
parseNumChars = try (do
  string "CHARCOUNT"
  spaces
  liftM read (many1 digit)) <|> return 0