{-# language BangPatterns #-}
{-# language OverloadedStrings #-}
module Snmp.Client
( Session(..)
, Config(..)
, Destination(..)
, Credentials(..)
, CredentialsV2(..)
, CredentialsV3(..)
, Context(..)
, PerHostV3(..)
, SnmpException(..)
, openSession
, closeSession
, get
, get'
, getBulkStep
, getBulkStep'
, getBulkChildren
, getBulkChildren'
) where
import Control.Applicative ((<|>))
import Control.Concurrent (Chan,newChan,writeList2Chan,readChan,threadWaitReadSTM,writeChan)
import Control.Concurrent.STM.TVar (TVar,newTVarIO,registerDelay,readTVar,writeTVar)
import Control.Exception (throwIO,Exception)
import Control.Monad (replicateM,replicateM_,when,(<=<))
import Control.Monad.STM (STM,atomically,check)
import Data.Bits ((.|.))
import Data.ByteString (ByteString)
import Data.Functor (($>))
import Data.IORef (IORef,readIORef,writeIORef,newIORef)
import Data.Int (Int32)
import Data.Map (Map)
import Data.Maybe (mapMaybe)
import Data.Vector (Vector)
import Data.Word (Word16)
import Language.Asn.Types
import Snmp.Types
import Net.Types (IPv4)
import Text.Printf (printf)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as LB
import qualified Data.Map as Map
import qualified Data.Vector as Vector
import qualified Language.Asn.Decoding as AsnDecoding
import qualified Language.Asn.Encoding as AsnEncoding
import qualified Language.Asn.ObjectIdentifier as OID
import qualified Net.IPv4 as IPv4
import qualified Snmp.Decoding as SnmpDecoding
import qualified Snmp.Encoding as SnmpEncoding
import qualified Network.Socket as NS
import qualified Network.Socket.ByteString as NSB
import qualified System.Posix.Types
data Session = Session
{ sessionSockets :: !(Chan NS.Socket)
, sessionSocketCount :: !Int
, sessionRequestId :: !(TVar RequestId)
, sessionAesSalt :: !(TVar AesSalt)
, sessionTimeoutMicroseconds :: !Int
, sessionMaxTries :: !Int
, sessionKeyCache :: !(IORef (Map (AuthType,ByteString,EngineId) ByteString))
}
data Config = Config
{ configSocketPoolSize :: !Int
, configTimeoutMicroseconds :: !Int
, configRetries :: !Int
} deriving (Show,Eq)
data Destination = Destination
{ destinationHost :: !IPv4
, destinationPort :: !Word16
} deriving (Show,Eq)
data Credentials
= CredentialsConstructV2 CredentialsV2
| CredentialsConstructV3 CredentialsV3
deriving (Show,Eq)
newtype CredentialsV2 = CredentialsV2
{ credentialsV2CommunityString :: ByteString
} deriving (Show,Eq)
data CredentialsV3 = CredentialsV3
{ credentialsV3Crypto :: !Crypto
, credentialsV3ContextName :: !ByteString
, credentialsV3User :: !ByteString
} deriving (Show,Eq)
data Context = Context
{ contextSession :: !Session
, contextDestination :: !Destination
, contextCredentials :: !Credentials
}
data PerHostV3 = PerHostV3
{ perHostV3AuthoritativeEngineId :: !EngineId
, perHostV3ReceiverTime :: !Int32
, perHostV3ReceiverBoots :: !Int32
}
openSession :: Config -> IO Session
openSession (Config socketPoolSize timeout retries) = do
addrinfos <- NS.getAddrInfo
(Just (NS.defaultHints {NS.addrFlags = [NS.AI_PASSIVE]}))
(Just "0.0.0.0")
Nothing
let serveraddr = head addrinfos
allSockets <- replicateM socketPoolSize $ do
sock <- NS.socket (NS.addrFamily serveraddr) NS.Datagram NS.defaultProtocol
NS.bind sock (NS.addrAddress serveraddr)
return sock
requestIdVar <- newTVarIO (RequestId 1)
aesSaltVar <- newTVarIO (AesSalt 1)
socketChan <- newChan
writeList2Chan socketChan allSockets
keyCache <- newIORef Map.empty
return (Session socketChan socketPoolSize requestIdVar aesSaltVar timeout retries keyCache)
closeSession :: Session -> IO ()
closeSession session = replicateM_ (sessionSocketCount session) $ do
sock <- readChan (sessionSockets session)
NS.close sock
generalRequest ::
(RequestId -> Pdus)
-> (Pdu -> Either SnmpException a)
-> Context
-> IO (Either SnmpException a)
generalRequest pdusFromRequestId fromPdu (Context session (Destination ip port) creds) = do
sock <- readChan (sessionSockets session)
case creds of
CredentialsConstructV2 (CredentialsV2 commStr) -> do
requestId <- nextRequestId (sessionRequestId session)
let !bs = id
$ LB.toStrict
$ AsnEncoding.der SnmpEncoding.messageV2
$ MessageV2 commStr
$ pdusFromRequestId requestId
!bsLen = ByteString.length bs
go1 :: Int -> IO (Either SnmpException Pdu)
go1 !n1 = if n1 > 0
then do
when inDebugMode $ putStrLn "Sending:"
when inDebugMode $ putStrLn (hexByteStringInternal bs)
bytesSentLen <- NSB.sendTo sock bs (NS.SockAddrInet (fromIntegral port) (NS.tupleToHostAddress (IPv4.toOctets ip)))
if bytesSentLen /= bsLen
then return $ Left $ SnmpExceptionNotAllBytesSent bytesSentLen bsLen
else do
let go2 mperHostV3 = do
(isReadyAction,deregister) <- threadWaitReadSTM =<< mySockFd sock
delay <- registerDelay (sessionTimeoutMicroseconds session)
isContentReady <- atomically $ (isReadyAction $> True) <|> (fini delay $> False)
deregister
if not isContentReady
then go1 (n1 - 1)
else do
bsRecv <- NSB.recv sock 10000
when inDebugMode $ putStrLn "Received:"
when inDebugMode $ print bsRecv
if ByteString.null bsRecv
then return (Left SnmpExceptionSocketClosed)
else case AsnDecoding.ber SnmpDecoding.messageV2 bsRecv of
Left err -> return (Left $ SnmpExceptionDecoding err)
Right msg -> case messageV2Data msg of
PdusResponse pdu@(Pdu respRequestId _ _ _) ->
case compare requestId respRequestId of
GT -> go2 mperHostV3
EQ -> return (Right pdu)
LT -> return $ Left $ SnmpExceptionMissedResponse requestId respRequestId
_ -> return (Left (SnmpExceptionNonPduResponseV2 msg))
go2 Nothing
else return $ Left SnmpExceptionTimeout
e <- go1 (sessionMaxTries session)
writeChan (sessionSockets session) sock
return (e >>= fromPdu)
CredentialsConstructV3 (CredentialsV3 crypto contextName user) -> do
let passwordToKeyCached typ password eng = do
keyCache <- readIORef (sessionKeyCache session)
let triple = (typ,password,eng)
case Map.lookup triple keyCache of
Nothing -> do
let key = SnmpEncoding.passwordToKey typ password eng
writeIORef (sessionKeyCache session) (Map.insert triple key keyCache)
return key
Just key -> return key
let flags = cryptoFlags crypto .|. 0x04
mkAuthParams :: RequestId -> PerHostV3 -> (ByteString,ScopedPduData) -> IO ByteString
mkAuthParams reqId phv3 privPair = case cryptoAuth crypto of
Nothing -> return ByteString.empty
Just (AuthParameters typ password) -> do
key <- passwordToKeyCached typ password (perHostV3AuthoritativeEngineId phv3)
let serializationWithoutAuth = snd (makeBs (ByteString.replicate 12 0x00) reqId privPair phv3)
pure (SnmpEncoding.mkSign typ key serializationWithoutAuth)
mkPrivParams :: AesSalt -> RequestId -> PerHostV3 -> IO (ByteString,ScopedPduData)
mkPrivParams theSalt reqId phv3 = case crypto of
AuthPriv (AuthParameters authType _) (PrivParameters privType privPass) -> do
key <- passwordToKeyCached authType privPass (perHostV3AuthoritativeEngineId phv3)
case privType of
PrivTypeAes ->
let (encrypted,actualSaltBs) = SnmpEncoding.aesEncrypt
key
(perHostV3ReceiverBoots phv3)
(perHostV3ReceiverTime phv3)
theSalt
(LB.toStrict (AsnEncoding.der SnmpEncoding.scopedPdu spdu))
in pure (actualSaltBs,ScopedPduDataEncrypted encrypted)
PrivTypeDes ->
let (encrypted,actualSaltBs) = SnmpEncoding.desEncrypt
key
(perHostV3ReceiverBoots phv3)
(fromIntegral (getAesSalt theSalt))
(LB.toStrict (AsnEncoding.der SnmpEncoding.scopedPdu spdu))
in pure (actualSaltBs,ScopedPduDataEncrypted encrypted)
_ -> pure (ByteString.empty,ScopedPduDataPlaintext spdu)
where spdu = ScopedPdu (perHostV3AuthoritativeEngineId phv3) contextName (pdusFromRequestId reqId)
makeBs :: ByteString -> RequestId -> (ByteString,ScopedPduData) -> PerHostV3 -> (MessageV3,ByteString)
makeBs activeAuthParams reqId (activePrivParams,spdud) (PerHostV3 authoritativeEngineId receiverTime boots) =
let myMsg = MessageV3
(HeaderData reqId 1500 flags)
(Usm authoritativeEngineId boots receiverTime user activeAuthParams activePrivParams)
spdud
in (myMsg, LB.toStrict $ AsnEncoding.der SnmpEncoding.messageV3 $ myMsg)
fullMakeBs :: AesSalt -> RequestId -> PerHostV3 -> IO (MessageV3, ByteString)
fullMakeBs theSalt reqId phv3 = do
privPair <- mkPrivParams theSalt reqId phv3
authParams <- mkAuthParams reqId phv3 privPair
return (makeBs authParams reqId privPair phv3)
go1 :: Int -> RequestId -> (MessageV3,ByteString) -> Bool -> IO (Either SnmpException Pdu)
go1 !n1 !requestId (!sentMsg,!bsSent) !engineIdsAcquired = if n1 > 0
then do
when inDebugMode $ putStrLn "Sending:"
when inDebugMode $ putStrLn (hexByteStringInternal bsSent)
let bsLen = ByteString.length bsSent
bytesSentLen <- NSB.sendTo sock bsSent (NS.SockAddrInet (fromIntegral port) (NS.tupleToHostAddress (IPv4.toOctets ip)))
if bytesSentLen /= bsLen
then return $ Left $ SnmpExceptionNotAllBytesSent bytesSentLen bsLen
else do
let go2 :: IO (Either SnmpException Pdu)
go2 = do
(isReadyAction,deregister) <- threadWaitReadSTM =<< mySockFd sock
delay <- registerDelay (sessionTimeoutMicroseconds session)
isContentReady <- atomically $ (isReadyAction $> True) <|> (fini delay $> False)
deregister
if not isContentReady
then do
when inDebugMode $ putStrLn "NO RESPONSE"
requestId' <- nextRequestId (sessionRequestId session)
go1 (n1 - 1) requestId' (sentMsg,bsSent) engineIdsAcquired
else do
bsRecv <- NSB.recv sock 10000
when inDebugMode $ putStrLn "Received:"
when inDebugMode $ putStrLn (hexByteStringInternal bsRecv)
if ByteString.null bsRecv
then return (Left SnmpExceptionSocketClosed)
else case AsnDecoding.ber SnmpDecoding.messageV3 bsRecv of
Left err -> return (Left $ SnmpExceptionDecoding err)
Right msg -> do
case cryptoAuth crypto of
Nothing -> return ()
Just (AuthParameters typ password) -> do
when inDebugMode $ putStrLn "THE RECEIVED MESSAGE"
when inDebugMode $ print msg
let reencoded = LB.toStrict $ AsnEncoding.der SnmpEncoding.messageV3 msg
when inDebugMode $ putStrLn $ hexByteStringInternal $ reencoded
when (reencoded /= bsRecv) $ do
when inDebugMode $ putStrLn "NOT THE SAME"
key <- passwordToKeyCached typ password (usmAuthoritativeEngineId (messageV3SecurityParameters msg))
case SnmpEncoding.checkSign typ key msg of
Nothing -> return ()
Just (expected,actual) -> do
when (not $ ByteString.null actual) $ do
throwIO $ SnmpExceptionAuthenticationFailure expected actual
let handleSpdu :: ScopedPdu -> IO (Either SnmpException Pdu)
handleSpdu spdu = case scopedPduData spdu of
PdusResponse pdu@(Pdu respRequestId _ _ _) ->
case compare requestId respRequestId of
GT -> go2
EQ -> return (Right pdu)
LT -> return $ Left $ SnmpExceptionMissedResponse requestId respRequestId
PdusReport (Pdu respRequestId _ _ _) -> do
when inDebugMode $ putStrLn $ "Expected Request ID: " ++ show requestId
when inDebugMode $ putStrLn $ "Received Request ID: " ++ show respRequestId
if engineIdsAcquired
then return $ Left (SnmpExceptionBadEngineId sentMsg msg)
else do
let usm = messageV3SecurityParameters msg
phv3 = PerHostV3
(usmAuthoritativeEngineId usm)
(usmAuthoritativeEngineTime usm)
(usmAuthoritativeEngineBoots usm)
theSalt <- atomically $ nextSalt (sessionAesSalt session)
requestId' <- nextRequestId (sessionRequestId session)
internalFragment <- fullMakeBs theSalt requestId' phv3
go1 n1 requestId' internalFragment True
_ -> return (Left (SnmpExceptionNonPduResponseV3 msg))
case messageV3Data msg of
ScopedPduDataEncrypted encrypted -> case crypto of
NoAuthNoPriv -> error "internal library error: messageV3Data NoAuthPriv"
AuthNoPriv _ -> error "internal library error: messageV3Data (AuthNoPriv _)"
AuthPriv (AuthParameters authType _) (PrivParameters privType privPass) -> do
let usm = messageV3SecurityParameters msg
key <- passwordToKeyCached authType privPass (usmAuthoritativeEngineId usm)
let mdecrypted = case privType of
PrivTypeDes -> SnmpEncoding.desDecrypt key (usmPrivacyParameters usm) encrypted
PrivTypeAes -> SnmpEncoding.aesDecrypt key (usmPrivacyParameters usm) (usmAuthoritativeEngineBoots usm) (usmAuthoritativeEngineTime usm) encrypted
case mdecrypted of
Just bs -> case AsnDecoding.ber SnmpDecoding.scopedPdu bs of
Left err -> throwIO (SnmpExceptionDecoding err)
Right spdu -> handleSpdu spdu
Nothing -> throwIO SnmpExceptionDecryptionFailure
ScopedPduDataPlaintext spdu -> handleSpdu spdu
go2
else return $ Left $ SnmpExceptionTimeoutV3 sentMsg
let originalPhv3 = PerHostV3 (EngineId "initial-engine-id") 0xFFFFFF 0xEEEEEE
theSalt <- atomically $ nextSalt (sessionAesSalt session)
requestId' <- nextRequestId (sessionRequestId session)
theFragment <- fullMakeBs theSalt requestId' originalPhv3
e <- go1 (sessionMaxTries session) requestId' theFragment False
writeChan (sessionSockets session) sock
return (e >>= fromPdu)
nextSalt :: TVar AesSalt -> STM AesSalt
nextSalt v = do
AesSalt w <- readTVar v
let s = AesSalt (w + 1)
writeTVar v s
return s
throwSnmpException :: IO (Either SnmpException a) -> IO a
throwSnmpException = (either throwIO return =<<)
get :: Context -> ObjectIdentifier -> IO ObjectSyntax
get ctx ident = throwSnmpException (get' ctx ident)
getBulkStep :: Context -> Int -> ObjectIdentifier -> IO (Vector (ObjectIdentifier,ObjectSyntax))
getBulkStep ctx maxRep ident = throwSnmpException (getBulkStep' ctx maxRep ident)
getBulkChildren :: Context -> Int -> ObjectIdentifier -> IO (Vector (ObjectIdentifier,ObjectSyntax))
getBulkChildren ctx maxRep oid1 = throwSnmpException (getBulkChildren' ctx maxRep oid1)
get' :: Context -> ObjectIdentifier -> IO (Either SnmpException ObjectSyntax)
get' ctx ident = generalRequest
(\reqId -> PdusGetRequest (Pdu reqId (ErrorStatus 0) (ErrorIndex 0) (Vector.singleton (VarBind ident BindingResultUnspecified))))
(singleBindingValue ident <=< onlyBindings)
ctx
getBulkStep' :: Context -> Int -> ObjectIdentifier -> IO (Either SnmpException (Vector (ObjectIdentifier,ObjectSyntax)))
getBulkStep' ctx maxRep ident = generalRequest
(\reqId -> PdusGetBulkRequest (BulkPdu reqId 0 (fromIntegral maxRep) (Vector.singleton (VarBind ident BindingResultUnspecified))))
(fmap multipleBindings . onlyBindings)
ctx
getBulkChildren' :: Context -> Int -> ObjectIdentifier -> IO (Either SnmpException (Vector (ObjectIdentifier,ObjectSyntax)))
getBulkChildren' ctx maxRep oid1 = go Vector.empty oid1 where
go prevPairs ident = do
epairsUnfiltered <- getBulkStep' ctx maxRep ident
case epairsUnfiltered of
Left e -> return (Left e)
Right pairsUnfiltered -> do
let pairs = Vector.filter (\(oid,_) -> oidIsPrefixOf oid1 oid) pairsUnfiltered
if Vector.null pairs
then return (Right prevPairs)
else go (prevPairs Vector.++ pairs) (fst (Vector.last pairs))
oidIsPrefixOf :: ObjectIdentifier -> ObjectIdentifier -> Bool
oidIsPrefixOf = OID.isPrefixOf
multipleBindings :: Vector VarBind -> Vector (ObjectIdentifier,ObjectSyntax)
multipleBindings = Vector.fromList . mapMaybe
( \(VarBind ident br) -> case br of
BindingResultValue obj -> Just (ident,obj)
_ -> Nothing
) . Vector.toList
singleBindingValue :: ObjectIdentifier -> Vector VarBind -> Either SnmpException ObjectSyntax
singleBindingValue oid v = if Vector.length v == 1
then do
let VarBind name res = v Vector.! 0
when (name /= oid) $ Left $ SnmpExceptionMismatchedBinding oid name
case res of
BindingResultValue obj -> Right obj
BindingResultUnspecified -> Left SnmpExceptionUnspecified
BindingResultNoSuchObject -> Left (SnmpExceptionNoSuchObject oid)
BindingResultNoSuchInstance -> Left (SnmpExceptionNoSuchInstance oid)
BindingResultEndOfMibView -> Left SnmpExceptionEndOfMibView
else Left (SnmpExceptionMultipleBindings (Vector.length v))
onlyBindings :: Pdu -> Either SnmpException (Vector VarBind)
onlyBindings (Pdu _ errStatus@(ErrorStatus e) errIndex bindings) =
if e == 0 then Right bindings else Left (SnmpExceptionPduError errStatus errIndex)
data SnmpException
= SnmpExceptionNotAllBytesSent !Int !Int
| SnmpExceptionTimeout
| SnmpExceptionTimeoutV3 !MessageV3
| SnmpExceptionPduError !ErrorStatus !ErrorIndex
| SnmpExceptionMultipleBindings !Int
| SnmpExceptionMismatchedBinding !ObjectIdentifier !ObjectIdentifier
| SnmpExceptionUnspecified
| SnmpExceptionNoSuchObject !ObjectIdentifier
| SnmpExceptionNoSuchInstance !ObjectIdentifier
| SnmpExceptionEndOfMibView
| SnmpExceptionMissedResponse !RequestId !RequestId
| SnmpExceptionNonPduResponseV2 !MessageV2
| SnmpExceptionNonPduResponseV3 !MessageV3
| SnmpExceptionDecoding !String
| SnmpExceptionSocketClosed
| SnmpExceptionAuthenticationFailure !ByteString !ByteString
| SnmpExceptionBadEngineId !MessageV3 !MessageV3
| SnmpExceptionDecryptionFailure
deriving (Show,Eq)
instance Exception SnmpException
fini :: TVar Bool -> STM ()
fini = check <=< readTVar
nextRequestId :: TVar RequestId -> IO RequestId
nextRequestId requestIdVar = atomically $ do
RequestId i1 <- readTVar requestIdVar
let !i2 = mod (i1 + 1) 100000000
!i3 = if i2 == 0 then 1 else i2
writeTVar requestIdVar (RequestId i3)
return (RequestId i3)
mySockFd :: NS.Socket -> IO System.Posix.Types.Fd
mySockFd s = fmap System.Posix.Types.Fd (NS.fdSocket s)
hexByteStringInternal :: ByteString -> String
hexByteStringInternal = ByteString.foldr (\w xs -> printf "%02X" w ++ xs) []
inDebugMode :: Bool
inDebugMode = False