{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE GADTs #-} module Network.Snmp.Client.Version3 ( clientV3 ) where import Data.ByteString (ByteString) import Data.ByteString.Lazy (toStrict, fromStrict) import Network.Socket hiding (recv, socket, close) import qualified Network.Socket as NS import Network.Socket.ByteString.Lazy (recv, sendAll) import Control.Applicative ((<$>)) import Control.Concurrent.Async import Data.IORef (newIORef, IORef, readIORef, atomicWriteIORef) import Control.Concurrent (threadDelay) import Control.Monad (when) import Data.ASN1.Types hiding (Context) import Data.Monoid ((<>), mempty, mconcat) import Data.Int import Control.Exception import System.Random (randomIO) import Data.Binary import Network.Protocol.Snmp import Network.Snmp.Client.Types import Network.Snmp.Client.Internal v3 :: Packet v3 = initial Version3 data ST = ST { authCache' :: IORef (Maybe ByteString) , privCache' :: IORef (Maybe ByteString) , engine' :: IORef (Maybe (ByteString, Int32, Int32)) , ref' :: IORef Int32 , salt32 :: IORef Int32 , salt64 :: IORef Int64 , securityLevel' :: PrivAuth , authType' :: AuthType , privType' :: PrivType , timeout' :: Int , socket' :: Socket , login' :: Login , authPass' :: Password , privPass' :: Password } clientV3 :: Hostname -> Port -> Int -> Login -> Password -> Password -> PrivAuth -> AuthType -> PrivType -> IO Client clientV3 hostname port timeout sequrityName authPass privPass securityLevel authType privType = do socket <- makeSocket hostname port uniqInteger <- uniqID ref <- newIORef uniqInteger salt32 <- newIORef =<< abs <$> randomIO salt64 <- newIORef =<< abs <$> randomIO authCache <- newIORef Nothing privCache <- newIORef Nothing engine <- newIORef Nothing let st = ST authCache privCache engine ref salt32 salt64 securityLevel authType privType timeout socket sequrityName authPass privPass return $ Client { get = get' st , bulkget = bulkget' st , getnext = getnext' st , walk = \oids -> mconcat <$> mapM (\oi -> walk' st oi oi mempty) oids , bulkwalk = \oids -> mconcat <$> mapM (\oi -> bulkwalk' st oi oi mempty) oids , set = set' st , close = NS.close socket } init' :: ST -> IO (ByteString, Int32, Int32) init' st = withSocketsDo $ do rid <- predCounter (ref' st) sendAll (socket' st) $ encode $ packet' rid result <- race (threadDelay (timeout' st)) (decode <$> recv (socket' st) 1500 :: IO Packet) case result of Left _ -> throwIO TimeoutException Right resp -> do atomicWriteIORef (engine' st) $ Just (getEngineIdP resp, getEngineBootsP resp, getEngineTimeP resp) return (getEngineIdP resp, getEngineBootsP resp, getEngineTimeP resp) where packet' x = ( setIDP (ID x) . setMaxSizeP (MaxSize 1500) . setReportableP False . setPrivAuthP NoAuthNoPriv . setRid x ) v3 get' :: ST -> OIDS -> IO Suite get' st oids = withSocketsDo $ do full <- packet st (toEmptySuite oids) GetRequest sendPacket st full packet' <- decryptPacketWithCache st =<< returnResult st checkError packet' return $ getSuite packet' bulkget' :: ST -> OIDS -> IO Suite bulkget' st oids = withSocketsDo $ do full <- packet st (toEmptySuite oids) GetBulk sendPacket st full packet' <- decryptPacketWithCache st =<< returnResult st checkError packet' return $ getSuite packet' getnext' :: ST -> OIDS -> IO Suite getnext' st oids = withSocketsDo $ do full <- packet st (toEmptySuite oids) GetNextRequest sendPacket st full packet' <- decryptPacketWithCache st =<< returnResult st checkError packet' return $ getSuite packet' set' :: ST -> Suite -> IO Suite set' st suite = withSocketsDo $ do full <- packet st suite SetRequest sendPacket st full packet' <- decryptPacketWithCache st =<< returnResult st checkError packet' return $ getSuite packet' walk' :: ST -> OID -> OID -> Suite -> IO Suite walk' st oids base accumulator | oids == base = do first <- get' st [oids] next <- getnext' st [oids] case (first, next) of (Suite [Coupla _ NoSuchObject], Suite [Coupla nextOid _]) -> walk' st nextOid base next (Suite [Coupla _ NoSuchInstance], Suite [Coupla nextOid _]) -> walk' st nextOid base next (Suite [Coupla _ EndOfMibView], _) -> return accumulator (_, Suite [Coupla nextOid _]) -> walk' st nextOid base first (_, _) -> throwIO $ ServerException 5 | otherwise = do nextData <- getnext' st [oids] let Suite [Coupla next v] = nextData case (isUpLevel next base, v) of (True, _) -> return accumulator (_, NoSuchObject) -> walk' st next base accumulator (_, NoSuchInstance) -> walk' st next base accumulator (_, EndOfMibView) -> return accumulator (_, _) -> walk' st next base (accumulator <> nextData) bulkwalk' :: ST -> OID -> OID -> Suite -> IO Suite bulkwalk' st oids base accumulator = do first <- bulkget' st [oids] let Coupla next snmpData = lastS first filtered (Suite xs) = Suite $ filter (\(Coupla x _) -> not $ isUpLevel x base) xs case (isUpLevel next base , snmpData) of (_, EndOfMibView) -> return $ accumulator <> filtered first (False, _) -> bulkwalk' st next base (accumulator <> first) (True, _) -> return $ accumulator <> filtered first checkError :: Packet -> IO () checkError p = when (getErrorStatus p /= 0) $ throwIO $ ServerException $ getErrorStatus p toEmptySuite :: OIDS -> Suite toEmptySuite = Suite . map (\x -> Coupla x Zero) packet :: ST -> Suite -> (RequestId -> ErrorStatus -> ErrorIndex -> Request) -> IO Packet packet st suite r = do rid <- predCounter (ref' st) eid' <- readIORef (engine' st) (eid, boots, time) <- case eid' of Just x -> return x Nothing -> init' st let wrapBulk (GetBulk rid' x _) = GetBulk rid' x 10 wrapBulk x = x full = ( (setReportableP True) . (setPrivAuthP (securityLevel' st)) . (setUserNameP (login' st)) . (setEngineIdP eid) . (setEngineBootsP boots) . (setEngineTimeP time) . (setAuthenticationParametersP cleanPass) . (setIDP (ID rid)) . (setRequest $ wrapBulk (r rid 0 0)) . (setSuite suite) ) v3 return full sendPacket :: ST -> Packet -> IO () sendPacket st packet' = sendAll (socket' st) . encode =<< signPacketWithCache st =<< encryptPacketWithCache st packet' returnResult :: ST -> IO Packet returnResult st = do result <- race (threadDelay (timeout' st)) (decode <$> recv (socket' st) 1500 :: IO Packet) case result of Right resp -> return resp Left _ -> throwIO TimeoutException signPacketWithCache :: ST -> Packet -> IO Packet signPacketWithCache st packet' = do k <- readIORef (authCache' st) maybe (newKey packet') (reuseKey packet') k where newKey packet'' = do let key = passwordToKey (authType' st) (authPass' st) (getEngineIdP packet'') atomicWriteIORef (authCache' st) (Just key) return $ signPacket (authType' st) key packet' reuseKey packet'' key = return $ signPacket (authType' st) key packet'' encryptPacketWithCache :: ST -> Packet -> IO Packet encryptPacketWithCache st packet' | (securityLevel' st) == AuthPriv = do k <- readIORef (privCache' st) maybe (newKey packet') (reuseKey packet') k | otherwise = return packet' where newKey packet'' = do let key = passwordToKey (authType' st) (privPass' st) (getEngineIdP packet'') atomicWriteIORef (privCache' st) (Just key) encryptPacket st key packet'' reuseKey packet'' key = encryptPacket st key packet'' decryptPacketWithCache :: ST -> Packet -> IO Packet decryptPacketWithCache st packet' = do k <- readIORef (privCache' st) maybe (newKey packet') (reuseKey packet') k where newKey packet'' = do let key = passwordToKey (authType' st) (privPass' st) (getEngineIdP packet'') atomicWriteIORef (privCache' st) (Just key) return $ decryptPacket st key packet'' reuseKey packet'' key = return $ decryptPacket st key packet'' encryptPacket :: ST -> Key -> Packet -> IO Packet encryptPacket st key packet' | privType' st == DES = do s <- succCounter (salt32 st) let eib = getEngineBootsP packet' (encrypted, salt) = desEncrypt key eib s (toStrict $ encode $ (getPDU packet' :: PDU V3)) return $ setPrivParametersP salt . setPDU (CryptedPDU encrypted) $ packet' | privType' st == AES = do s <- succCounter (salt64 st) let eib = getEngineBootsP packet' t = getEngineTimeP packet' (encrypted, salt) = aesEncrypt key eib t s (toStrict $ encode $ (getPDU packet' :: PDU V3)) return $ setPrivParametersP salt . setPDU (CryptedPDU encrypted) $ packet' | otherwise = throwIO $ ServerException 5 decryptPacket :: ST -> Key -> Packet -> Packet decryptPacket st key packet' | privType' st == DES = let pdu = getPDU packet' :: PDU V3 salt = getPrivParametersP packet' in case pdu of CryptedPDU x -> setPDU (decode (fromStrict $ desDecrypt key salt x) :: PDU V3) packet' _ -> packet' | privType' st == AES = let pdu = getPDU packet' :: PDU V3 salt = getPrivParametersP packet' eib = getEngineBootsP packet' t = getEngineTimeP packet' in case pdu of CryptedPDU x -> setPDU (decode (fromStrict $ aesDecrypt key salt eib t x) :: PDU V3) packet' _ -> packet' | otherwise = throw $ ServerException 5