module Network.Stun
(
bindRequest
, stunRequest
, stunRequest'
, Message(..)
, MessageClass(..)
, TransactionID(..)
, Attribute(..)
, findAttribute
, IsAttribute(..)
, findMappedAddress
, MappedAddress
, XorMappedAddress
, fromXorMappedAddress
, xorMappedAddress
, Username(..)
, Credentials(..)
, withMessageIntegrity
, checkMessageIntegrity
, StunError(..)
, ErrorAttribute(..)
, errTryAlternate
, errBadRequest
, errUnauthorized
, errUnknownAttribute
, errStaleNonce
, errServerError
) where
import Control.Applicative
import Control.Concurrent.Timeout
import Control.Monad.IO.Class
import Control.Monad.Trans.Error
import Data.Serialize
import qualified Network.BSD as Net
import qualified Network.Socket as S
import qualified Network.Socket.ByteString as SocketBS
import Network.Stun.Base
import Network.Stun.Credentials
import Network.Stun.Error
import Network.Stun.MappedAddress
import System.Random
bindRequest :: IO Message
bindRequest = do
tid <- TID <$> randomIO <*> randomIO <*> randomIO
return $ Message { messageMethod = 1
, messageClass = Request
, transactionID = tid
, messageAttributes = []
, fingerprint = True
}
data StunError = TimeOut
| ProtocolError
| ErrorMsg !Message
| WrongMessageType !Message
deriving (Show, Eq)
instance Error StunError
stunRequest
:: S.SockAddr
-> Net.PortNumber
-> [Integer]
-> Message
-> IO (Either StunError Message)
stunRequest host localPort timeOuts msg = runErrorT $ do
(r, s) <- ErrorT $ stunRequest' host localPort timeOuts msg
liftIO $ S.close s
return r
stunRequest'
:: S.SockAddr
-> Net.PortNumber
-> [Integer]
-> Message
-> IO (Either StunError (Message, S.Socket))
stunRequest' host' _localPort timeOuts msg = runErrorT $ do
let host = setHostPort host'
s <- liftIO $ case host of
S.SockAddrInet _hostPort _ha -> do
s <- S.socket S.AF_INET S.Datagram S.defaultProtocol
return s
S.SockAddrInet6 _hostPort _fi _ha _sid -> do
s <- S.socket S.AF_INET6 S.Datagram S.defaultProtocol
S.setSocketOption s S.IPv6Only 1
return s
_ -> error $ "stunRequest': SockAddrUnix not implemented"
liftIO $ S.connect s host
let go [] = liftIO (S.close s) >> throwError TimeOut
go (to:tos) = do
_ <- liftIO $ SocketBS.send s (encode msg)
r <- liftIO . timeout to $ SocketBS.recv s 1024
case r of
Nothing -> go tos
Just answer -> return answer
answer <- go $ if null timeOuts then [500000, 1000000, 2000000] else timeOuts
case decode answer of
Left _ -> throwError $ ProtocolError
Right msg' -> do
case messageClass msg' of
Failure -> throwError $ ErrorMsg msg'
Success -> return (msg', s)
_ -> throwError $ WrongMessageType msg'
where
setHostPort (S.SockAddrInet pn ha) = S.SockAddrInet
(if pn == 0 then 3478 else pn) ha
setHostPort (S.SockAddrInet6 pn fl ha si) = S.SockAddrInet6
(if pn == 0 then 3478 else pn)
fl ha si
setHostPort s = s
findMappedAddress :: S.SockAddr
-> Net.PortNumber
-> [Integer]
-> IO (Either StunError (S.SockAddr, S.SockAddr))
findMappedAddress host localPort timeOuts = runErrorT $ do
br <- liftIO $ bindRequest
(msg, s) <- ErrorT $ stunRequest' host localPort timeOuts br
xma <- case findAttribute $ messageAttributes msg of
Right [xma] -> return . Just
$! fromXorMappedAddress (transactionID msg) xma
Right [] -> return Nothing
_ -> throwError $ ProtocolError
ma <- case findAttribute $ messageAttributes msg of
Right [ma] -> return . Just $! unMA ma
Right [] -> return Nothing
_ -> throwError $ ProtocolError
m <- case (xma <|> ma) of
Just m' -> return m'
Nothing -> throwError $ ProtocolError
local <- liftIO $ S.getSocketName s
liftIO $ S.sClose s
return $ (m, local)