module RL_Glue.Network (
kExperimentConnection, kAgentConnection, kEnvironmentConnection,
kAgentInit, kAgentStart, kAgentStep, kAgentEnd, kAgentCleanup, kAgentMessage,
kEnvInit, kEnvStart, kEnvStep, kEnvCleanup, kEnvMessage,
kRLInit, kRLStart, kRLStep, kRLCleanup, kRLReturn, kRLNumSteps,
kRLNumEpisodes, kRLEpisode, kRLAgentMessage, kRLEnvMessage,
kRLTerm,
kLocalHost, kDefaultPort, kRetryTimeout,
kDefaultBufferSize, kIntSize, kDoubleSize, kCharSize,
RLAbstractType(RLAbstractType), Action(Action), Observation(Observation),
Reward, Terminal,
sizeOfObs, sizeOfAction, sizeOfRewardObs,
getObservation, getObservationOrDie, getAction, getRewardOrDie,
getRewardObservation, getRewardObservationOrDie,
getInt, getDouble, getString, getStringOrDie,
putObservation, putAction, putTerminalRewardObs, putString,
glueConnect, doCallWithNoParams, confirmState,
sendAgentMessage, sendAgentMessageStr,
sendEnvMessage, sendEnvMessageStr
) where
import Control.Exception
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Data.Binary.Get
import Data.Binary.IEEE754
import Data.Binary.Put
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC
import qualified Data.ByteString.Lazy as LBS
import Data.Word
import Network.Simple.TCP
import Network.Socket (socketPort)
import System.Environment
import System.Exit
import System.IO.Error
kExperimentConnection = 1 :: Word32
kAgentConnection = 2 :: Word32
kEnvironmentConnection = 3 :: Word32
kAgentInit = 4 :: Word32
kAgentStart = 5 :: Word32
kAgentStep = 6 :: Word32
kAgentEnd = 7 :: Word32
kAgentCleanup = 8 :: Word32
kAgentMessage = 10 :: Word32
kEnvInit = 11 :: Word32
kEnvStart = 12 :: Word32
kEnvStep = 13 :: Word32
kEnvCleanup = 14 :: Word32
kEnvMessage = 19 :: Word32
kRLInit = 20 :: Word32
kRLStart = 21 :: Word32
kRLStep = 22 :: Word32
kRLCleanup = 23 :: Word32
kRLReturn = 24 :: Word32
kRLNumSteps = 25 :: Word32
kRLNumEpisodes = 26 :: Word32
kRLEpisode = 27 :: Word32
kRLAgentMessage = 33 :: Word32
kRLEnvMessage = 34 :: Word32
kRLTerm = 35 :: Word32
kLocalHost = "127.0.0.1"
kDefaultPort = "4096"
kRetryTimeout = 2
kDefaultBufferSize = 4096
kIntSize = 4
kDoubleSize = 8
kCharSize = 1
data RLAbstractType = RLAbstractType [Int] [Double] BS.ByteString
deriving Show
newtype Action = Action RLAbstractType deriving Show
newtype Observation = Observation RLAbstractType deriving Show
type Reward = Double
type Terminal = Int
orDie :: (a -> MaybeT IO c) -> String -> a -> IO c
orDie f err x = do
maybeY <- runMaybeT (f x)
case maybeY of
Nothing -> do
putStrLn err
exitWith (ExitFailure 1)
Just y -> return y
sizeOfType :: RLAbstractType -> Int
sizeOfType (RLAbstractType ints doubles bs) =
kIntSize * (3 + length ints) + kDoubleSize * length doubles + kCharSize * BS.length bs
sizeOfObs :: Observation -> Int
sizeOfObs (Observation absType) = sizeOfType absType
sizeOfAction :: Action -> Int
sizeOfAction (Action absType) = sizeOfType absType
sizeOfRewardObs :: (Terminal, Reward, Observation) -> Int
sizeOfRewardObs (_, _, obs) =
kIntSize + kDoubleSize + sizeOfObs obs
getAbstractType :: Socket -> MaybeT IO RLAbstractType
getAbstractType sock =
do
bs1 <- recvExactly sock (3*4)
let (numInts, numDoubles, numChars) = runGet parseBytes1 (LBS.fromStrict bs1)
let size = numInts*kIntSize + numDoubles*kDoubleSize + numChars*kCharSize
bs2 <- recvExactly sock size
return $ runGet (parseBytes2 numInts numDoubles numChars) (LBS.fromStrict bs2)
where
parseBytes1 = do
numInts <- getWord32be
numDoubles <- getWord32be
numChars <- getWord32be
return (fromIntegral numInts, fromIntegral numDoubles, fromIntegral numChars)
parseBytes2 numInts numDoubles numChars = do
ints <- replicateM (fromIntegral numInts) getWord32be
doubles <- replicateM (fromIntegral numDoubles) getFloat64be
chars <- getByteString (fromIntegral numChars)
return (RLAbstractType (map fromIntegral ints) doubles chars)
getObservation :: Socket -> MaybeT IO Observation
getObservation sock = do
absType <- getAbstractType sock
return $ Observation absType
getObservationOrDie :: Socket -> IO Observation
getObservationOrDie = orDie getObservation "Error: Could not get observation"
getAction :: Socket -> MaybeT IO Action
getAction sock = do
absType <- getAbstractType sock
return $ Action absType
getRewardOrDie :: Socket -> IO Reward
getRewardOrDie = orDie getDouble "Error: Could not get reward"
getRewardObservation :: Socket -> MaybeT IO (Reward, Observation)
getRewardObservation sock = do
reward <- getDouble sock
obs <- getObservation sock
return (reward, obs)
getRewardObservationOrDie :: Socket -> IO (Reward, Observation)
getRewardObservationOrDie = orDie getRewardObservation "Error: Could not get reward and observation"
putAbstractType :: RLAbstractType -> Put
putAbstractType (RLAbstractType ints doubles bs) = do
let numInts = fromIntegral $ length ints
let numDoubles = fromIntegral $ length doubles
let numChars = fromIntegral $ BS.length bs
putWord32be numInts
putWord32be numDoubles
putWord32be numChars
mapM_ (putWord32be . fromIntegral) ints
mapM_ putFloat64be doubles
putByteString bs
putObservation :: Observation -> Put
putObservation (Observation absType) = putAbstractType absType
putAction :: Action -> Put
putAction (Action absType) = putAbstractType absType
putTerminalRewardObs :: (Terminal, Reward, Observation) -> Put
putTerminalRewardObs (terminal, reward, Observation absType) = do
putWord32be (fromIntegral terminal)
putFloat64be reward
putAbstractType absType
glueConnect :: forall r. ((Socket, SockAddr) -> IO r) -> IO r
glueConnect func =
do
host <- catchJust
(\e -> if isDoesNotExistError e then Just () else Nothing)
(getEnv "RLGLUE_HOST")
(\_ -> return kLocalHost)
port <- catchJust
(\e -> if isDoesNotExistError e then Just () else Nothing)
(getEnv "RLGLUE_PORT")
(\_ -> return kDefaultPort)
let func' :: (Socket, SockAddr) -> IO r
func' (sock, addr) = do
clientPort <- socketPort sock
putStrLn $ "Connecting to " ++ show addr ++ " on port " ++ show clientPort ++ "..."
x <- func (sock, addr)
putStrLn ("Disconnecting from " ++ show addr ++ "...")
return x
connect host port func'
doCallWithNoParams :: Socket -> Word32 -> IO ()
doCallWithNoParams sock x =
do
let bs = runPut (putWord32be x >> putWord32be (0 :: Word32))
sendLazy sock bs
confirmState sock x
doStandardRecv :: Socket -> MaybeT IO (Word32, Word32)
doStandardRecv sock =
do
bs <- recvExactly sock (2*4)
return $ runGet parseBytes (LBS.fromStrict bs)
where
parseBytes = do
glueState <- getWord32be
dataSize <- getWord32be
return (glueState, dataSize)
getInt :: Socket -> MaybeT IO Int
getInt sock =
do
bs <- recvExactly sock kIntSize
return . fromIntegral $ runGet getWord32be (LBS.fromStrict bs)
getDouble :: Socket -> MaybeT IO Double
getDouble sock =
do
bs <- recvExactly sock kDoubleSize
return $ runGet getFloat64be (LBS.fromStrict bs)
getString :: Socket -> MaybeT IO BS.ByteString
getString sock =
do
length <- getInt sock
recvExactly sock (length * kCharSize)
getStringOrDie :: String -> Socket -> IO BS.ByteString
getStringOrDie = orDie getString
putString :: BS.ByteString -> Put
putString bs = do
putWord32be (fromIntegral (BS.length bs))
putByteString bs
recvExactly :: Socket -> Int -> MaybeT IO BS.ByteString
recvExactly sock nBytes = do
maybeBs <- MaybeT $ recv sock nBytes
let len = BS.length maybeBs
if len == nBytes
then return maybeBs
else do
remainingBs <- recvExactly sock (nBytes len)
return $ BS.append maybeBs remainingBs
confirmState :: Socket -> Word32 -> IO ()
confirmState sock exptState =
do
x <- runMaybeT (doStandardRecv sock)
case x of
Nothing -> do
putStrLn "Failed to receive state. Exiting..."
exitWith (ExitFailure 1)
Just (state, size) -> unless (state == exptState) $ do
putStrLn $ "State " ++ show state ++ " doesn't match expected state " ++
show exptState ++ ". Exiting..."
exitWith (ExitFailure 1)
sendMessage :: Word32 -> Socket -> BS.ByteString -> IO BS.ByteString
sendMessage selByte sock msg =
do
let
packedMsg =
runPut (
putWord32be selByte >>
putWord32be (fromIntegral (4 + BS.length msg)) >>
putWord32be (fromIntegral (BS.length msg)) >>
putByteString msg)
sendLazy sock packedMsg
confirmState sock selByte
resp <- runMaybeT (getString sock)
case resp of
Nothing -> do
putStrLn "Error: Could not read response from agent message"
exitWith (ExitFailure 1)
Just x -> return x
sendAgentMessage :: Socket -> BS.ByteString -> IO BS.ByteString
sendAgentMessage = sendMessage kRLAgentMessage
sendAgentMessageStr :: Socket -> String -> IO BS.ByteString
sendAgentMessageStr sock msg = sendAgentMessage sock (BSC.pack msg)
sendEnvMessage :: Socket -> BS.ByteString -> IO BS.ByteString
sendEnvMessage = sendMessage kRLEnvMessage
sendEnvMessageStr :: Socket -> String -> IO BS.ByteString
sendEnvMessageStr sock msg = sendEnvMessage sock (BSC.pack msg)