module Network.TFTP.UDPIO ( UDPIO(..)
, Address
, udpIO) where
import Control.Exception
import Foreign.Marshal(mallocArray, peekArray, pokeArray, free)
import Foreign.Ptr(Ptr(..))
import qualified Network.Socket as Sock
import Network.TFTP.Types
import Prelude hiding (catch)
import System.IO.Error(ioeGetErrorType, ioeGetHandle, ioeGetLocation)
import System.Timeout(timeout)
type Address = Sock.SockAddr
data UDPIOSt = UDPIOSt { udpReader :: Reader
, udpWriter :: Writer
, udpMyAddress :: Address
}
newtype UDPIO a = UDPIO { runUDPIO :: StateT UDPIOSt IO a }
deriving (Functor, Monad, MonadIO, MonadState UDPIOSt, Applicative)
instance MessageIO UDPIO Address where
sendTo to msg = do
w <- udpWriter <$> get
liftIO (w to msg)
receiveFrom timeout = do
r <- udpReader <$> get
liftIO (r timeout)
localAddress = udpMyAddress <$> get
udpIO :: Maybe String
-> Maybe String
-> UDPIO a
-> IO a
udpIO host port action =
do
(addr, sock) <- bindUDPSocket host port
readBuf <- mallocArray bufferSize
let reader = makeReader sock readBuf
writeBuf <- mallocArray bufferSize
let writer = makeWriter sock writeBuf
res <- evalStateT (runUDPIO action) (UDPIOSt reader writer addr)
free readBuf
free writeBuf
Sock.sClose sock
return res
bufferSize = 4096
bindUDPSocket :: Maybe String -> Maybe String -> IO (Address, Sock.Socket)
bindUDPSocket hostname port = do
let myHints = Sock.defaultHints { Sock.addrFlags = [Sock.AI_PASSIVE] }
(serverAddr:_) <- Sock.getAddrInfo (Just myHints) hostname port
sock <- Sock.socket (Sock.addrFamily serverAddr) Sock.Datagram Sock.defaultProtocol
let addr = Sock.addrAddress serverAddr
Sock.bindSocket sock addr
boundAddr <- Sock.getSocketName sock
logInfo (printf "Bound socket at address %s" (show boundAddr))
return (boundAddr, sock)
makeReader :: Sock.Socket -> Ptr Word8 -> Reader
makeReader sock buffer maybeTimeoutSecs = do
let timeoutMicros = maybe (1) (*1000000) maybeTimeoutSecs
mResult <- timeout timeoutMicros (Sock.recvBufFrom sock buffer bufferSize)
case mResult of
Just (bytesRead, from) -> do
logInfo ("Read " ++ show bytesRead ++ " bytes from " ++ show from)
res <- peekArray bytesRead buffer
return (Just (from, pack res))
Nothing -> do
logWarn "Receive timeout!"
return Nothing
makeWriter :: Sock.Socket -> Ptr Word8 -> Writer
makeWriter sock buffer destAddr toSend =
catch (writeLoop $ unpack toSend) handleIOError
where
writeLoop [] = return True
writeLoop toSend = do
let bytesToSend = min bufferSize (length toSend)
chunkToSend = take bytesToSend toSend
pokeArray buffer chunkToSend
bytesSent <- Sock.sendBufTo sock buffer bytesToSend destAddr
let rest = drop bytesSent toSend
logInfo $ "Sent " ++ show bytesSent ++ " " ++ show (length rest) ++ " left"
writeLoop rest
handleIOError e = do
logWarn $ printf "Caught IO Exception: \'%s\' handle: \'%s\' at: \'%s\'."
(show $ ioeGetErrorType e)
(show $ ioeGetHandle e)
(ioeGetLocation e)
return False
logInfo = debugM "TFTP.UDPIO"
logWarn = warningM "TFTP.UDPIO"
logError = errorM "TFTP.UDPIO"
type Reader = (Maybe Int) -> IO (Maybe (Address, ByteString))
type Writer = Address -> ByteString -> IO Bool