{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE GADTs #-}
module Network.Snmp.Client.Version3 
( clientV3
)
where

import Data.ByteString (ByteString)
import Network.Socket hiding (recv, socket, close)
import qualified Network.Socket as NS
import Network.Socket.ByteString (recv, sendAll)
import Control.Applicative ((<$>))
import Control.Concurrent.Async
import Data.IORef (newIORef, IORef, readIORef, atomicWriteIORef)
import Control.Concurrent (threadDelay)
import Control.Monad (when)
import Data.ASN1.Types hiding (Context) 
import Data.Monoid ((<>), mempty, mconcat)
import Data.Int
import Control.Exception 
import System.Random (randomIO)

import Network.Protocol.Snmp
import Network.Snmp.Client.Types
import Network.Snmp.Client.Internal 

v3 :: Packet 
v3 = initial Version3 

data ST = ST
  { authCache' :: IORef (Maybe ByteString)
  , privCache' :: IORef (Maybe ByteString)
  , engine' :: IORef (Maybe (ByteString, Int32, Int32))
  , ref' :: IORef Int32
  , salt32 :: IORef Int32
  , salt64 :: IORef Int64
  , securityLevel' :: PrivAuth
  , authType' :: AuthType
  , privType' :: PrivType
  , timeout' :: Int
  , socket' :: Socket
  , login' :: Login
  , authPass' :: Password
  , privPass' :: Password
  } 

clientV3 :: Hostname -> Port -> Int -> Login -> Password -> Password -> PrivAuth -> AuthType -> PrivType -> IO Client
clientV3 hostname port timeout sequrityName authPass privPass securityLevel authType privType = do
    socket <- makeSocket hostname port 
    uniqInteger <- uniqID
    ref <- newIORef uniqInteger
    salt32 <- newIORef =<< abs <$> randomIO
    salt64 <- newIORef =<< abs <$> randomIO
    authCache <- newIORef Nothing
    privCache <- newIORef Nothing
    engine <- newIORef Nothing
    let st = ST authCache privCache engine ref salt32 salt64 securityLevel authType privType timeout socket sequrityName authPass privPass
    return $ Client 
      { get = get' st
      , bulkget = bulkget' st
      , getnext = getnext' st
      , walk = \oids -> mconcat <$> mapM (\oi -> walk' st oi oi mempty) oids
      , bulkwalk = \oids -> mconcat <$> mapM (\oi -> bulkwalk' st oi oi mempty) oids
      , set = set' st
      , close = NS.close socket
      }


init' :: ST -> IO (ByteString, Int32, Int32)
init' st = withSocketsDo $ do
    rid <- predCounter (ref' st)
    sendAll (socket' st) $ encode $ packet' rid
    resp <- decode <$> recv (socket' st) 1500 :: IO Packet
    atomicWriteIORef (engine' st) $ Just (getEngineIdP resp, getEngineBootsP resp, getEngineTimeP resp) 
    return (getEngineIdP resp, getEngineBootsP resp, getEngineTimeP resp)
    where
      packet' x = ( setIDP (ID x)
                  . setMaxSizeP (MaxSize 1500)
                  . setReportableP False
                  . setPrivAuthP NoAuthNoPriv
                  . setRid x
                  ) v3


get' :: ST -> OIDS -> IO Suite
get' st oids = withSocketsDo $ do
    full <- packet st (toEmptySuite oids) GetRequest
    sendPacket st full
    packet' <-  decryptPacketWithCache st =<< returnResult st
    checkError packet'
    return $ getSuite packet'

bulkget' :: ST -> OIDS -> IO Suite
bulkget' st oids = withSocketsDo $ do
    full <- packet st (toEmptySuite oids) GetBulk
    sendPacket st full
    packet' <-  decryptPacketWithCache st =<< returnResult st
    checkError packet'
    return $ getSuite packet'

getnext' :: ST -> OIDS -> IO Suite
getnext' st oids = withSocketsDo $ do
    full <- packet st (toEmptySuite oids) GetNextRequest
    sendPacket st full
    packet' <-  decryptPacketWithCache st =<< returnResult st
    checkError packet'
    return $ getSuite packet'

set' :: ST -> Suite -> IO Suite
set' st suite = withSocketsDo $ do
    full <- packet st suite SetRequest
    sendPacket st full
    packet' <-  decryptPacketWithCache st =<< returnResult st
    checkError packet'
    return $ getSuite packet'

walk' :: ST -> OID -> OID -> Suite -> IO Suite
walk' st oids base accumulator  
    | oids == base = do
        first <- get' st [oids]
        next <- getnext' st [oids]
        case (first, next) of
             (Suite [Coupla _ NoSuchObject], Suite [Coupla nextOid _]) -> walk' st nextOid base next
             (Suite [Coupla _ NoSuchInstance], Suite [Coupla nextOid _]) -> walk' st nextOid base next
             (Suite [Coupla _ EndOfMibView], _) -> return accumulator
             (_, Suite [Coupla nextOid _]) -> walk' st nextOid base first
             (_, _) -> throwIO $ ServerException 5
    | otherwise = do
        nextData <- getnext' st [oids]
        let Suite [Coupla next v] = nextData
        case (isUpLevel next base, v) of
             (True, _) -> return accumulator
             (_, NoSuchObject) -> walk' st next base accumulator
             (_, NoSuchInstance) -> walk' st next base accumulator
             (_, EndOfMibView) -> return accumulator
             (_, _) -> walk' st next base (accumulator <> nextData) 

bulkwalk' :: ST -> OID -> OID -> Suite -> IO Suite
bulkwalk' st oids base accumulator = do
       first <- bulkget' st [oids]
       let Coupla next snmpData = lastS first
           filtered (Suite xs) = Suite $ filter (\(Coupla x _) -> not $ isUpLevel x base) xs
       case (isUpLevel next base , snmpData) of
            (_, EndOfMibView) -> return $ accumulator <> filtered first
            (False, _) -> bulkwalk' st next base (accumulator <> first)
            (True, _) -> return $ accumulator <> filtered first
     
checkError :: Packet -> IO ()
checkError p = when (getErrorStatus p /= 0) $ throwIO $ ServerException $ getErrorStatus p

toEmptySuite :: OIDS -> Suite
toEmptySuite = Suite . map (\x -> Coupla x Zero) 

packet :: ST -> Suite -> (RequestId -> ErrorStatus -> ErrorIndex -> Request) -> IO Packet
packet st suite r = do
    rid <- predCounter (ref' st)
    eid' <- readIORef (engine' st)
    (eid, boots, time) <- case eid' of
                                Just x -> return x
                                Nothing -> init' st
    let wrapBulk (GetBulk rid' x _) = GetBulk rid' x 10
        wrapBulk x = x
        full = ( (setReportableP True) 
               . (setPrivAuthP (securityLevel' st)) 
               . (setUserNameP (login' st))  
               . (setEngineIdP eid)
               . (setEngineBootsP boots)
               . (setEngineTimeP time)
               . (setAuthenticationParametersP cleanPass)  
               . (setIDP (ID rid))
               . (setRequest $ wrapBulk (r rid 0 0))
               . (setSuite  suite)
               ) v3
    return full


sendPacket :: ST -> Packet -> IO ()
sendPacket st packet' = sendAll (socket' st) . encode
      =<< signPacketWithCache st 
      =<< encryptPacketWithCache st packet'

returnResult :: ST -> IO Packet
returnResult st = do
    result <- race (threadDelay (timeout' st)) (decode <$> recv (socket' st) 1500 :: IO Packet)
    case result of
         Right resp -> return resp
         Left _ -> throwIO TimeoutException

signPacketWithCache :: ST -> Packet -> IO Packet
signPacketWithCache st packet' = do
    k <- readIORef (authCache' st)
    maybe (newKey packet') (reuseKey packet') k
    where
    newKey packet'' = do
        let key = passwordToKey (authType' st) (authPass' st) (getEngineIdP packet'')
        atomicWriteIORef (authCache' st) (Just key)
        return $ signPacket (authType' st) key packet'
    reuseKey packet'' key = return $ signPacket (authType' st) key packet''

encryptPacketWithCache :: ST -> Packet -> IO Packet
encryptPacketWithCache st packet'
    | (securityLevel' st) == AuthPriv = do
        k <- readIORef (privCache' st)
        maybe (newKey packet') (reuseKey packet') k
    | otherwise = return packet'
        where
          newKey packet'' = do
              let key = passwordToKey (authType' st) (privPass' st) (getEngineIdP packet'')
              atomicWriteIORef (privCache' st) (Just key)
              encryptPacket st key packet''  
          reuseKey packet'' key = encryptPacket st key packet''

decryptPacketWithCache :: ST -> Packet -> IO Packet
decryptPacketWithCache st packet' = do
    k <- readIORef (privCache' st)
    maybe (newKey packet') (reuseKey packet') k
    where
    newKey packet'' = do
        let key = passwordToKey (authType' st) (privPass' st) (getEngineIdP packet'')
        atomicWriteIORef (privCache' st) (Just key)
        return $ decryptPacket st key packet''
    reuseKey packet'' key = return $ decryptPacket st key packet''
 
encryptPacket :: ST -> Key -> Packet -> IO Packet
encryptPacket st key packet'
  | privType' st == DES = do
      s <- succCounter (salt32 st)
      let eib = getEngineBootsP packet'
          (encrypted, salt) = desEncrypt key eib s (encode $ (getPDU packet' :: PDU V3))
      return $ setPrivParametersP salt . setPDU (CryptedPDU encrypted) $ packet'
  | privType' st == AES = do
      s <- succCounter (salt64 st)
      let eib = getEngineBootsP packet'
          t = getEngineTimeP packet'
          (encrypted, salt) = aesEncrypt key eib t s (encode $ (getPDU packet' :: PDU V3))
      return $ setPrivParametersP salt . setPDU (CryptedPDU encrypted) $ packet'
  | otherwise = throwIO $ ServerException 5

decryptPacket :: ST -> Key -> Packet -> Packet
decryptPacket st key packet'
  | privType' st == DES = 
      let pdu = getPDU packet' :: PDU V3
          salt = getPrivParametersP packet'
      in case pdu of
              CryptedPDU x -> setPDU (decode (desDecrypt key salt x) :: PDU V3) packet'
              _ -> packet'
  | privType' st == AES =
      let pdu = getPDU packet' :: PDU V3
          salt = getPrivParametersP packet'
          eib = getEngineBootsP packet'
          t = getEngineTimeP packet'
      in case pdu of
              CryptedPDU x -> setPDU (decode (aesDecrypt key salt eib t x) :: PDU V3) packet'
              _ -> packet'
  | otherwise = throw $ ServerException 5