{-# language BangPatterns #-}
{-# language OverloadedStrings #-}

module Snmp.Client
  ( Session(..)
  , Config(..)
  , Destination(..)
  , Credentials(..)
  , CredentialsV2(..)
  , CredentialsV3(..)
  , Context(..)
  , PerHostV3(..)
  , SnmpException(..)
  , openSession
  , closeSession
  , get
  , get'
  , getBulkStep
  , getBulkStep'
  , getBulkChildren
  , getBulkChildren'
  ) where

import Control.Applicative ((<|>))
import Control.Concurrent (Chan,newChan,writeList2Chan,readChan,threadWaitReadSTM,writeChan)
import Control.Concurrent.STM.TVar (TVar,newTVarIO,registerDelay,readTVar,writeTVar)
import Control.Exception (throwIO,Exception)
import Control.Monad (replicateM,replicateM_,when,(<=<))
import Control.Monad.STM (STM,atomically,check)
import Data.Bits ((.|.))
import Data.ByteString (ByteString)
import Data.Functor (($>))
import Data.IORef (IORef,readIORef,writeIORef,newIORef)
import Data.Int (Int32)
import Data.Map (Map)
import Data.Maybe (mapMaybe)
import Data.Vector (Vector)
import Data.Word (Word16)
import Language.Asn.Types
import Snmp.Types
import Net.Types (IPv4)
import Text.Printf (printf)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as LB
import qualified Data.Map as Map
import qualified Data.Vector as Vector
import qualified Language.Asn.Decoding as AsnDecoding
import qualified Language.Asn.Encoding as AsnEncoding
import qualified Language.Asn.ObjectIdentifier as OID
import qualified Net.IPv4 as IPv4
import qualified Snmp.Decoding as SnmpDecoding
import qualified Snmp.Encoding as SnmpEncoding
import qualified Network.Socket as NS
import qualified Network.Socket.ByteString as NSB
import qualified System.Posix.Types

data Session = Session
  { sessionSockets :: !(Chan NS.Socket)
  -- , sessionCredsTimestamps :: !(TVar (Map Word32
  , sessionSocketCount :: !Int
  , sessionRequestId :: !(TVar RequestId)
  , sessionAesSalt :: !(TVar AesSalt)
  , sessionTimeoutMicroseconds :: !Int
  , sessionMaxTries :: !Int
  , sessionKeyCache :: !(IORef (Map (AuthType,ByteString,EngineId) ByteString))
  }

data Config = Config
  { configSocketPoolSize :: !Int
  , configTimeoutMicroseconds :: !Int
  , configRetries :: !Int
  } deriving (Show,Eq)

data Destination = Destination
  { destinationHost :: !IPv4
  , destinationPort :: !Word16
  } deriving (Show,Eq)

data Credentials
  = CredentialsConstructV2 CredentialsV2
  | CredentialsConstructV3 CredentialsV3
  deriving (Show,Eq)

newtype CredentialsV2 = CredentialsV2
  { credentialsV2CommunityString :: ByteString
  } deriving (Show,Eq)


data CredentialsV3 = CredentialsV3
  { credentialsV3Crypto :: !Crypto
  , credentialsV3ContextName :: !ByteString
  , credentialsV3User :: !ByteString
  } deriving (Show,Eq)

data Context = Context
  { contextSession :: !Session
  , contextDestination :: !Destination
  , contextCredentials :: !Credentials
  }

data PerHostV3 = PerHostV3
  { perHostV3AuthoritativeEngineId :: !EngineId
  , perHostV3ReceiverTime :: !Int32
  , perHostV3ReceiverBoots :: !Int32
  }

-- | Only one connection can be open at a time on a given port.
openSession :: Config -> IO Session
openSession (Config socketPoolSize timeout retries) = do
  addrinfos <- NS.getAddrInfo
    (Just (NS.defaultHints {NS.addrFlags = [NS.AI_PASSIVE]}))
    (Just "0.0.0.0")
    Nothing
  let serveraddr = head addrinfos
  allSockets <- replicateM socketPoolSize $ do
    sock <- NS.socket (NS.addrFamily serveraddr) NS.Datagram NS.defaultProtocol
    NS.bind sock (NS.addrAddress serveraddr)
    return sock
  requestIdVar <- newTVarIO (RequestId 1)
  aesSaltVar <- newTVarIO (AesSalt 1)
  socketChan <- newChan
  writeList2Chan socketChan allSockets
  keyCache <- newIORef Map.empty
  return (Session socketChan socketPoolSize requestIdVar aesSaltVar timeout retries keyCache)

closeSession :: Session -> IO ()
closeSession session = replicateM_ (sessionSocketCount session) $ do
  sock <- readChan (sessionSockets session)
  NS.close sock

generalRequest ::
     (RequestId -> Pdus)
  -> (Pdu -> Either SnmpException a)
  -> Context
  -> IO (Either SnmpException a)
generalRequest pdusFromRequestId fromPdu (Context session (Destination ip port) creds) = do
  sock <- readChan (sessionSockets session)
  case creds of
    CredentialsConstructV2 (CredentialsV2 commStr) -> do
      requestId <- nextRequestId (sessionRequestId session)
      let !bs = id
            $ LB.toStrict
            $ AsnEncoding.der SnmpEncoding.messageV2
            $ MessageV2 commStr
            $ pdusFromRequestId requestId
          !bsLen = ByteString.length bs
          go1 :: Int -> IO (Either SnmpException Pdu)
          go1 !n1 = if n1 > 0
            then do
              when inDebugMode $ putStrLn "Sending:"
              when inDebugMode $ putStrLn (hexByteStringInternal bs)
              bytesSentLen <- NSB.sendTo sock bs (NS.SockAddrInet (fromIntegral port) (NS.tupleToHostAddress (IPv4.toOctets ip)))
              if bytesSentLen /= bsLen
                then return $ Left $ SnmpExceptionNotAllBytesSent bytesSentLen bsLen
                else do
                  let go2 mperHostV3 = do
                        (isReadyAction,deregister) <- threadWaitReadSTM =<< mySockFd sock
                        delay <- registerDelay (sessionTimeoutMicroseconds session)
                        isContentReady <- atomically $ (isReadyAction $> True) <|> (fini delay $> False)
                        deregister
                        if not isContentReady
                          then go1 (n1 - 1)
                          else do
                            bsRecv <- NSB.recv sock 10000
                            when inDebugMode $ putStrLn "Received:"
                            when inDebugMode $ print bsRecv
                            if ByteString.null bsRecv
                              then return (Left SnmpExceptionSocketClosed)
                              else case AsnDecoding.ber SnmpDecoding.messageV2 bsRecv of
                                  Left err -> return (Left $ SnmpExceptionDecoding err)
                                  Right msg -> case messageV2Data msg of
                                    PdusResponse pdu@(Pdu respRequestId _ _ _) ->
                                      case compare requestId respRequestId of
                                        GT -> go2 mperHostV3
                                        EQ -> return (Right pdu)
                                        LT -> return $ Left $ SnmpExceptionMissedResponse requestId respRequestId
                                    _ -> return (Left (SnmpExceptionNonPduResponseV2 msg))
                  go2 Nothing
            else return $ Left SnmpExceptionTimeout
      e <- go1 (sessionMaxTries session)
      writeChan (sessionSockets session) sock
      return (e >>= fromPdu)
    CredentialsConstructV3 (CredentialsV3 crypto contextName user) -> do
      let passwordToKeyCached typ password eng = do
            keyCache <- readIORef (sessionKeyCache session)
            let triple = (typ,password,eng)
            case Map.lookup triple keyCache of
              Nothing -> do
                let key = SnmpEncoding.passwordToKey typ password eng
                writeIORef (sessionKeyCache session) (Map.insert triple key keyCache)
                return key
              Just key -> return key
      -- setting the reportable flags is very important
      -- for AuthPriv
      let flags = cryptoFlags crypto .|. 0x04
          mkAuthParams :: RequestId -> PerHostV3 -> (ByteString,ScopedPduData) -> IO ByteString
          mkAuthParams reqId phv3 privPair = case cryptoAuth crypto of
            Nothing -> return ByteString.empty
            Just (AuthParameters typ password) -> do
              key <- passwordToKeyCached typ password (perHostV3AuthoritativeEngineId phv3)
              let serializationWithoutAuth = snd (makeBs (ByteString.replicate 12 0x00) reqId privPair phv3)
              pure (SnmpEncoding.mkSign typ key serializationWithoutAuth)
          mkPrivParams :: AesSalt -> RequestId -> PerHostV3 -> IO (ByteString,ScopedPduData)
          mkPrivParams theSalt reqId phv3 = case crypto of
            AuthPriv (AuthParameters authType _) (PrivParameters privType privPass) -> do
              key <- passwordToKeyCached authType privPass (perHostV3AuthoritativeEngineId phv3)
              case privType of
                PrivTypeAes ->
                  let (encrypted,actualSaltBs) = SnmpEncoding.aesEncrypt
                        key
                        (perHostV3ReceiverBoots phv3)
                        (perHostV3ReceiverTime phv3)
                        theSalt
                        (LB.toStrict (AsnEncoding.der SnmpEncoding.scopedPdu spdu))
                   in pure (actualSaltBs,ScopedPduDataEncrypted encrypted)
                PrivTypeDes ->
                  let (encrypted,actualSaltBs) = SnmpEncoding.desEncrypt
                        key
                        (perHostV3ReceiverBoots phv3)
                        (fromIntegral (getAesSalt theSalt))
                        -- (perHostV3ReceiverTime phv3)
                        (LB.toStrict (AsnEncoding.der SnmpEncoding.scopedPdu spdu))
                   in pure (actualSaltBs,ScopedPduDataEncrypted encrypted)
            _ -> pure (ByteString.empty,ScopedPduDataPlaintext spdu)
            where spdu = ScopedPdu (perHostV3AuthoritativeEngineId phv3) contextName (pdusFromRequestId reqId)
          makeBs :: ByteString -> RequestId -> (ByteString,ScopedPduData) -> PerHostV3 -> (MessageV3,ByteString)
          makeBs activeAuthParams reqId (activePrivParams,spdud) (PerHostV3 authoritativeEngineId receiverTime boots) =
            let myMsg = MessageV3
                  (HeaderData reqId 1500 flags) -- making up a max size
                  (Usm authoritativeEngineId boots receiverTime user activeAuthParams activePrivParams)
                  spdud
                -- myMsg2 = trace ("THE MESSAGE TO SEND: " ++ show myMsg) myMsg
             in (myMsg, LB.toStrict $ AsnEncoding.der SnmpEncoding.messageV3 $ myMsg)
          fullMakeBs :: AesSalt -> RequestId -> PerHostV3 -> IO (MessageV3, ByteString)
          fullMakeBs theSalt reqId phv3 = do
            privPair <- mkPrivParams theSalt reqId phv3
            authParams <- mkAuthParams reqId phv3 privPair
            return (makeBs authParams reqId privPair phv3)
          go1 :: Int -> RequestId -> (MessageV3,ByteString) -> Bool -> IO (Either SnmpException Pdu)
          go1 !n1 !requestId (!sentMsg,!bsSent) !engineIdsAcquired = if n1 > 0
            then do
              when inDebugMode $ putStrLn "Sending:"
              when inDebugMode $ putStrLn (hexByteStringInternal bsSent)
              let bsLen = ByteString.length bsSent
              bytesSentLen <- NSB.sendTo sock bsSent (NS.SockAddrInet (fromIntegral port) (NS.tupleToHostAddress (IPv4.toOctets ip)))
              if bytesSentLen /= bsLen
                then return $ Left $ SnmpExceptionNotAllBytesSent bytesSentLen bsLen
                else do
                  let go2 :: IO (Either SnmpException Pdu)
                      go2 = do
                        (isReadyAction,deregister) <- threadWaitReadSTM =<< mySockFd sock
                        delay <- registerDelay (sessionTimeoutMicroseconds session)
                        isContentReady <- atomically $ (isReadyAction $> True) <|> (fini delay $> False)
                        deregister
                        if not isContentReady
                          then do
                            when inDebugMode $ putStrLn "NO RESPONSE"
                            requestId' <- nextRequestId (sessionRequestId session)
                            go1 (n1 - 1) requestId' (sentMsg,bsSent) engineIdsAcquired
                          else do
                            bsRecv <- NSB.recv sock 10000
                            when inDebugMode $ putStrLn "Received:"
                            when inDebugMode $ putStrLn (hexByteStringInternal bsRecv)
                            if ByteString.null bsRecv
                              then return (Left SnmpExceptionSocketClosed)
                              else case AsnDecoding.ber SnmpDecoding.messageV3 bsRecv of
                                Left err -> return (Left $ SnmpExceptionDecoding err)
                                Right msg -> do
                                  case cryptoAuth crypto of
                                    Nothing -> return ()
                                    Just (AuthParameters typ password) -> do
                                      when inDebugMode $ putStrLn "THE RECEIVED MESSAGE"
                                      when inDebugMode $ print msg
                                      let reencoded = LB.toStrict $ AsnEncoding.der SnmpEncoding.messageV3 msg
                                      when inDebugMode $ putStrLn $ hexByteStringInternal $ reencoded
                                      when (reencoded /= bsRecv) $ do
                                        when inDebugMode $ putStrLn "NOT THE SAME"
                                      key <- passwordToKeyCached typ password (usmAuthoritativeEngineId (messageV3SecurityParameters msg))
                                      case SnmpEncoding.checkSign typ key msg of
                                        Nothing -> return ()
                                        Just (expected,actual) -> do
                                          when (not $ ByteString.null actual) $ do
                                            throwIO $ SnmpExceptionAuthenticationFailure expected actual
                                  let handleSpdu :: ScopedPdu -> IO (Either SnmpException Pdu)
                                      handleSpdu spdu = case scopedPduData spdu of
                                        -- check to make sure that we requested an unencrypted response
                                        -- somehow check the message id in here too
                                        PdusResponse pdu@(Pdu respRequestId _ _ _) ->
                                          case compare requestId respRequestId of
                                            GT -> go2
                                            EQ -> return (Right pdu)
                                            LT -> return $ Left $ SnmpExceptionMissedResponse requestId respRequestId
                                        PdusReport (Pdu respRequestId _ _ _) -> do
                                          when inDebugMode $ putStrLn $ "Expected Request ID: " ++ show requestId
                                          when inDebugMode $ putStrLn $ "Received Request ID: " ++ show respRequestId
                                          if engineIdsAcquired
                                            then return $ Left (SnmpExceptionBadEngineId sentMsg msg)
                                            else do
                                              let usm = messageV3SecurityParameters msg
                                                  phv3 = PerHostV3
                                                    (usmAuthoritativeEngineId usm)
                                                    (usmAuthoritativeEngineTime usm)
                                                    (usmAuthoritativeEngineBoots usm)
                                              theSalt <- atomically $ nextSalt (sessionAesSalt session)
                                              requestId' <- nextRequestId (sessionRequestId session)
                                              -- Notice that n1 is not decremented in this
                                              -- situation. This is intentional.
                                              internalFragment <- fullMakeBs theSalt requestId' phv3
                                              go1 n1 requestId' internalFragment True
                                        _ -> return (Left (SnmpExceptionNonPduResponseV3 msg))
                                  case messageV3Data msg of
                                    ScopedPduDataEncrypted encrypted -> case crypto of
                                     NoAuthNoPriv -> error "internal library error: messageV3Data NoAuthPriv"
                                     AuthNoPriv _ -> error "internal library error: messageV3Data (AuthNoPriv _)"
                                     AuthPriv (AuthParameters authType _) (PrivParameters privType privPass) -> do
                                        let usm = messageV3SecurityParameters msg
                                        key <- passwordToKeyCached authType privPass (usmAuthoritativeEngineId usm)
                                        let mdecrypted = case privType of
                                              PrivTypeDes -> SnmpEncoding.desDecrypt key (usmPrivacyParameters usm) encrypted
                                              PrivTypeAes -> SnmpEncoding.aesDecrypt key (usmPrivacyParameters usm) (usmAuthoritativeEngineBoots usm) (usmAuthoritativeEngineTime usm) encrypted
                                        case mdecrypted of
                                          Just bs -> case AsnDecoding.ber SnmpDecoding.scopedPdu bs of
                                            Left err -> throwIO (SnmpExceptionDecoding err)
                                            Right spdu -> handleSpdu spdu
                                          Nothing -> throwIO SnmpExceptionDecryptionFailure
                                    ScopedPduDataPlaintext spdu -> handleSpdu spdu
                  go2
            else return $ Left $ SnmpExceptionTimeoutV3 sentMsg
      -- boots and estimated time are made up for this, we could do better
      let originalPhv3 = PerHostV3 (EngineId "initial-engine-id") 0xFFFFFF 0xEEEEEE
      theSalt <- atomically $ nextSalt (sessionAesSalt session)
      requestId' <- nextRequestId (sessionRequestId session)
      theFragment <- fullMakeBs theSalt requestId' originalPhv3
      e <- go1 (sessionMaxTries session) requestId' theFragment False
      writeChan (sessionSockets session) sock
      return (e >>= fromPdu)

nextSalt :: TVar AesSalt -> STM AesSalt
nextSalt v = do
  AesSalt w <- readTVar v
  let s = AesSalt (w + 1)
  writeTVar v s
  return s

throwSnmpException :: IO (Either SnmpException a) -> IO a
throwSnmpException = (either throwIO return =<<)

get :: Context -> ObjectIdentifier -> IO ObjectSyntax
get ctx ident = throwSnmpException (get' ctx ident)

getBulkStep :: Context -> Int -> ObjectIdentifier -> IO (Vector (ObjectIdentifier,ObjectSyntax))
getBulkStep ctx maxRep ident = throwSnmpException (getBulkStep' ctx maxRep ident)

getBulkChildren :: Context -> Int -> ObjectIdentifier -> IO (Vector (ObjectIdentifier,ObjectSyntax))
getBulkChildren ctx maxRep oid1 = throwSnmpException (getBulkChildren' ctx maxRep oid1)

get' :: Context -> ObjectIdentifier -> IO (Either SnmpException ObjectSyntax)
get' ctx ident = generalRequest
  (\reqId -> PdusGetRequest (Pdu reqId (ErrorStatus 0) (ErrorIndex 0) (Vector.singleton (VarBind ident BindingResultUnspecified))))
  (singleBindingValue ident <=< onlyBindings)
  ctx

getBulkStep' :: Context -> Int -> ObjectIdentifier -> IO (Either SnmpException (Vector (ObjectIdentifier,ObjectSyntax)))
getBulkStep' ctx maxRep ident = generalRequest
  (\reqId -> PdusGetBulkRequest (BulkPdu reqId 0 (fromIntegral maxRep) (Vector.singleton (VarBind ident BindingResultUnspecified))))
  (fmap multipleBindings . onlyBindings)
  ctx

getBulkChildren' :: Context -> Int -> ObjectIdentifier -> IO (Either SnmpException (Vector (ObjectIdentifier,ObjectSyntax)))
getBulkChildren' ctx maxRep oid1 = go Vector.empty oid1 where
  go prevPairs ident = do
    epairsUnfiltered <- getBulkStep' ctx maxRep ident
    case epairsUnfiltered of
      Left e -> return (Left e)
      Right pairsUnfiltered -> do
        let pairs = Vector.filter (\(oid,_) -> oidIsPrefixOf oid1 oid) pairsUnfiltered
        if Vector.null pairs
          then return (Right prevPairs)
          else go (prevPairs Vector.++ pairs) (fst (Vector.last pairs))

oidIsPrefixOf :: ObjectIdentifier -> ObjectIdentifier -> Bool
oidIsPrefixOf = OID.isPrefixOf

-- There is not a mapMaybe for vector until 0.12.0.0
multipleBindings :: Vector VarBind -> Vector (ObjectIdentifier,ObjectSyntax)
multipleBindings = Vector.fromList . mapMaybe
  ( \(VarBind ident br) -> case br of
       BindingResultValue obj -> Just (ident,obj)
       _ -> Nothing
  ) . Vector.toList

singleBindingValue :: ObjectIdentifier -> Vector VarBind -> Either SnmpException ObjectSyntax
singleBindingValue oid v = if Vector.length v == 1
  then do
    let VarBind name res = v Vector.! 0
    when (name /= oid) $ Left $ SnmpExceptionMismatchedBinding oid name
    case res of
      BindingResultValue obj -> Right obj
      BindingResultUnspecified -> Left SnmpExceptionUnspecified
      BindingResultNoSuchObject -> Left (SnmpExceptionNoSuchObject oid)
      BindingResultNoSuchInstance -> Left (SnmpExceptionNoSuchInstance oid)
      BindingResultEndOfMibView -> Left SnmpExceptionEndOfMibView
  else Left (SnmpExceptionMultipleBindings (Vector.length v))

onlyBindings :: Pdu -> Either SnmpException (Vector VarBind)
onlyBindings (Pdu _ errStatus@(ErrorStatus e) errIndex bindings) =
  if e == 0 then Right bindings else Left (SnmpExceptionPduError errStatus errIndex)

data SnmpException
  = SnmpExceptionNotAllBytesSent !Int !Int
  | SnmpExceptionTimeout
  | SnmpExceptionTimeoutV3 !MessageV3
  | SnmpExceptionPduError !ErrorStatus !ErrorIndex
  | SnmpExceptionMultipleBindings !Int
  | SnmpExceptionMismatchedBinding !ObjectIdentifier !ObjectIdentifier
  | SnmpExceptionUnspecified -- ^ Should not happen
  | SnmpExceptionNoSuchObject !ObjectIdentifier
  | SnmpExceptionNoSuchInstance !ObjectIdentifier
  | SnmpExceptionEndOfMibView
  | SnmpExceptionMissedResponse !RequestId !RequestId
  | SnmpExceptionNonPduResponseV2 !MessageV2
  | SnmpExceptionNonPduResponseV3 !MessageV3
  | SnmpExceptionDecoding !String
  | SnmpExceptionSocketClosed
  | SnmpExceptionAuthenticationFailure !ByteString !ByteString
  | SnmpExceptionBadEngineId !MessageV3 !MessageV3
  | SnmpExceptionDecryptionFailure
  deriving (Show,Eq)

instance Exception SnmpException

fini :: TVar Bool -> STM ()
fini = check <=< readTVar

nextRequestId :: TVar RequestId -> IO RequestId
nextRequestId requestIdVar = atomically $ do
  RequestId i1 <- readTVar requestIdVar
  let !i2 = mod (i1 + 1) 100000000
      !i3 = if i2 == 0 then 1 else i2
  writeTVar requestIdVar (RequestId i3)
  return (RequestId i3)

mySockFd :: NS.Socket -> IO System.Posix.Types.Fd
mySockFd s = fmap System.Posix.Types.Fd (NS.fdSocket s)

hexByteStringInternal :: ByteString -> String
hexByteStringInternal = ByteString.foldr (\w xs -> printf "%02X" w ++ xs) []

inDebugMode :: Bool
inDebugMode = False