{-# LANGUAGE ScopedTypeVariables #-}

module RL_Glue.Network (
  -- * Magic numbers
  -- ** Connection types
  kExperimentConnection, kAgentConnection, kEnvironmentConnection,
  -- ** Agent
  kAgentInit, kAgentStart, kAgentStep, kAgentEnd, kAgentCleanup, kAgentMessage,
  -- ** Environment
  kEnvInit, kEnvStart, kEnvStep, kEnvCleanup, kEnvMessage,
  -- ** Experiment
  kRLInit, kRLStart, kRLStep, kRLCleanup, kRLReturn, kRLNumSteps,
  kRLNumEpisodes, kRLEpisode, kRLAgentMessage, kRLEnvMessage,
  -- ** Other
  kRLTerm,
  -- * Other constants
  kLocalHost, kDefaultPort, kRetryTimeout,
  kDefaultBufferSize, kIntSize, kDoubleSize, kCharSize,
  -- * Data types
  RLAbstractType(RLAbstractType), Action(Action), Observation(Observation),
  Reward, Terminal,
  -- * sizeOf functions
  sizeOfObs, sizeOfAction, sizeOfRewardObs,
  -- * Get network objects
  getObservation, getObservationOrDie, getAction, getRewardOrDie,
  getRewardObservation, getRewardObservationOrDie,
  getInt, getDouble, getString, getStringOrDie,
  -- * Put network objects
  putObservation, putAction, putTerminalRewardObs, putString,
  -- * Network functions
  glueConnect, doCallWithNoParams, confirmState,
  sendAgentMessage, sendAgentMessageStr,
  sendEnvMessage, sendEnvMessageStr
  ) where

import Control.Exception
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Data.Binary.Get
import Data.Binary.IEEE754
import Data.Binary.Put
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC
import qualified Data.ByteString.Lazy as LBS
import Data.Word
import Network.Simple.TCP
import Network.Socket (socketPort)
import System.Environment
import System.Exit
import System.IO.Error

-- Connection types
kExperimentConnection  = 1 :: Word32
kAgentConnection       = 2 :: Word32
kEnvironmentConnection = 3 :: Word32

-- Agent
kAgentInit      = 4 :: Word32
kAgentStart     = 5 :: Word32
kAgentStep      = 6 :: Word32
kAgentEnd       = 7 :: Word32
kAgentCleanup   = 8 :: Word32
kAgentMessage   = 10 :: Word32

-- Environment
kEnvInit        = 11 :: Word32
kEnvStart       = 12 :: Word32
kEnvStep        = 13 :: Word32
kEnvCleanup     = 14 :: Word32
kEnvMessage     = 19 :: Word32

-- Experiment
kRLInit         = 20 :: Word32
kRLStart        = 21 :: Word32
kRLStep         = 22 :: Word32
kRLCleanup      = 23 :: Word32
kRLReturn       = 24 :: Word32
kRLNumSteps     = 25 :: Word32
kRLNumEpisodes  = 26 :: Word32
kRLEpisode      = 27 :: Word32
kRLAgentMessage = 33 :: Word32
kRLEnvMessage   = 34 :: Word32

-- Other
kRLTerm         = 35 :: Word32

kLocalHost = "127.0.0.1"
kDefaultPort = "4096"
kRetryTimeout = 2 -- Currently unused

kDefaultBufferSize = 4096
kIntSize = 4
kDoubleSize = 8
kCharSize = 1

-- Data types
data RLAbstractType = RLAbstractType [Int] [Double] BS.ByteString
  deriving Show
newtype Action = Action RLAbstractType deriving Show
newtype Observation = Observation RLAbstractType deriving Show
type Reward = Double
type Terminal = Int

-- Abstract type functions
orDie :: (a -> MaybeT IO c) -> String -> a -> IO c
orDie f err x = do
  maybeY <- runMaybeT (f x)
  case maybeY of
    Nothing -> do
      putStrLn err
      exitWith (ExitFailure 1)
    Just y -> return y

sizeOfType :: RLAbstractType -> Int
sizeOfType (RLAbstractType ints doubles bs) =
  kIntSize * (3 + length ints) + kDoubleSize * length doubles + kCharSize * BS.length bs

sizeOfObs :: Observation -> Int
sizeOfObs (Observation absType) = sizeOfType absType

sizeOfAction :: Action -> Int
sizeOfAction (Action absType) = sizeOfType absType

sizeOfRewardObs :: (Terminal, Reward, Observation) -> Int
sizeOfRewardObs (_, _, obs) =
  kIntSize + kDoubleSize + sizeOfObs obs

getAbstractType :: Socket -> MaybeT IO RLAbstractType
getAbstractType sock =
  do
    bs1 <- recvExactly sock (3*4)
    let (numInts, numDoubles, numChars) = runGet parseBytes1 (LBS.fromStrict bs1)
    let size = numInts*kIntSize + numDoubles*kDoubleSize + numChars*kCharSize
    bs2 <- recvExactly sock size
    return $ runGet (parseBytes2 numInts numDoubles numChars) (LBS.fromStrict bs2)
    where
      parseBytes1 = do
        numInts <- getWord32be
        numDoubles <- getWord32be
        numChars <- getWord32be
        return (fromIntegral numInts, fromIntegral numDoubles, fromIntegral numChars)
      parseBytes2 numInts numDoubles numChars = do
        ints <- replicateM (fromIntegral numInts) getWord32be
        doubles <- replicateM (fromIntegral numDoubles) getFloat64be
        chars <- getByteString (fromIntegral numChars)
        return (RLAbstractType (map fromIntegral ints) doubles chars)

getObservation :: Socket -> MaybeT IO Observation
getObservation sock = do
  absType <- getAbstractType sock
  return $ Observation absType

getObservationOrDie :: Socket -> IO Observation
getObservationOrDie = orDie getObservation "Error: Could not get observation"

getAction :: Socket -> MaybeT IO Action
getAction sock = do
  absType <- getAbstractType sock
  return $ Action absType

getRewardOrDie :: Socket -> IO Reward
getRewardOrDie = orDie getDouble "Error: Could not get reward"

getRewardObservation :: Socket -> MaybeT IO (Reward, Observation)
getRewardObservation sock = do
  reward <- getDouble sock
  obs <- getObservation sock
  return (reward, obs)

getRewardObservationOrDie :: Socket -> IO (Reward, Observation)
getRewardObservationOrDie = orDie getRewardObservation "Error: Could not get reward and observation"

putAbstractType :: RLAbstractType -> Put
putAbstractType (RLAbstractType ints doubles bs) = do
  let numInts = fromIntegral $ length ints
  let numDoubles = fromIntegral $ length doubles
  let numChars = fromIntegral $ BS.length bs
  putWord32be numInts
  putWord32be numDoubles
  putWord32be numChars
  mapM_ (putWord32be . fromIntegral) ints
  mapM_ putFloat64be doubles
  putByteString bs

putObservation :: Observation -> Put
putObservation (Observation absType) = putAbstractType absType

putAction :: Action -> Put
putAction (Action absType) = putAbstractType absType

putTerminalRewardObs :: (Terminal, Reward, Observation) -> Put
putTerminalRewardObs (terminal, reward, Observation absType) = do
  putWord32be (fromIntegral terminal)
  putFloat64be reward
  putAbstractType absType

-- Actually connect
glueConnect :: forall r. ((Socket, SockAddr) -> IO r) -> IO r
glueConnect func =
  do
    host <- catchJust
      (\e -> if isDoesNotExistError e then Just () else Nothing)
      (getEnv "RLGLUE_HOST")
      (\_ -> return kLocalHost)
    port <- catchJust
      (\e -> if isDoesNotExistError e then Just () else Nothing)
      (getEnv "RLGLUE_PORT")
      (\_ -> return kDefaultPort)
    let func' :: (Socket, SockAddr) -> IO r
        func' (sock, addr) = do
          clientPort <- socketPort sock
          putStrLn $ "Connecting to " ++ show addr ++ " on port " ++ show clientPort ++ "..."
          x <- func (sock, addr)
          putStrLn ("Disconnecting from " ++ show addr ++ "...")
          return x
    connect host port func'

-- Send/Recv helper functions
doCallWithNoParams :: Socket -> Word32 -> IO ()
doCallWithNoParams sock x =
  do
    let bs = runPut (putWord32be x >> putWord32be (0 :: Word32))
    sendLazy sock bs
    confirmState sock x

doStandardRecv :: Socket -> MaybeT IO (Word32, Word32)
doStandardRecv sock =
  do
    bs <- recvExactly sock (2*4)
    return $ runGet parseBytes (LBS.fromStrict bs)
    where
      parseBytes = do
        glueState <- getWord32be
        dataSize <- getWord32be
        return (glueState, dataSize)

getInt :: Socket -> MaybeT IO Int
getInt sock =
  do
    bs <- recvExactly sock kIntSize
    return . fromIntegral $ runGet getWord32be (LBS.fromStrict bs)

getDouble :: Socket -> MaybeT IO Double
getDouble sock =
  do
    bs <- recvExactly sock kDoubleSize
    return  $ runGet getFloat64be (LBS.fromStrict bs)

getString :: Socket -> MaybeT IO BS.ByteString
getString sock =
  do
    length <- getInt sock
    recvExactly sock (length * kCharSize)

getStringOrDie :: String -> Socket -> IO BS.ByteString
getStringOrDie = orDie getString

putString :: BS.ByteString -> Put
putString bs = do
  putWord32be (fromIntegral (BS.length bs))
  putByteString bs

recvExactly :: Socket -> Int -> MaybeT IO BS.ByteString
recvExactly sock nBytes = do
  maybeBs <- MaybeT $ recv sock nBytes
  let len = BS.length maybeBs
  if len == nBytes 
    then return maybeBs
    else do
      remainingBs <- recvExactly sock (nBytes - len)
      return $ BS.append maybeBs remainingBs

-- Other functions
confirmState :: Socket -> Word32 -> IO ()
confirmState sock exptState =
  do
    x <- runMaybeT (doStandardRecv sock)
    case x of
      Nothing -> do
        putStrLn "Failed to receive state. Exiting..."
        exitWith (ExitFailure 1)
      Just (state, size) -> unless (state == exptState) $ do
        putStrLn $ "State " ++ show state ++ " doesn't match expected state " ++
          show exptState ++ ". Exiting..."
        exitWith (ExitFailure 1)

sendMessage :: Word32 -> Socket -> BS.ByteString -> IO BS.ByteString
sendMessage selByte sock msg =
  do
    let 
      packedMsg = 
        runPut (
          putWord32be selByte >>
          putWord32be (fromIntegral (4 + BS.length msg)) >>
          putWord32be (fromIntegral (BS.length msg)) >>
          putByteString msg)
    sendLazy sock packedMsg
    confirmState sock selByte 
    resp <- runMaybeT (getString sock)
    case resp of
      Nothing -> do
        putStrLn "Error: Could not read response from agent message"
        exitWith (ExitFailure 1)
      Just x -> return x

sendAgentMessage :: Socket -> BS.ByteString -> IO BS.ByteString
sendAgentMessage = sendMessage kRLAgentMessage

sendAgentMessageStr :: Socket -> String -> IO BS.ByteString
sendAgentMessageStr sock msg = sendAgentMessage sock (BSC.pack msg)

sendEnvMessage :: Socket -> BS.ByteString -> IO BS.ByteString
sendEnvMessage = sendMessage kRLEnvMessage

sendEnvMessageStr :: Socket -> String -> IO BS.ByteString
sendEnvMessageStr sock msg = sendEnvMessage sock (BSC.pack msg)