{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

module Net.Snmp.Client where

import Net.Snmp.Types
import Language.Asn.Types
import Data.Coerce
import Control.Monad.STM
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM.TMVar
import Data.Map (Map)
import Data.Maybe
import Data.Word
import Data.Vector (Vector)
import Data.IntMap (IntMap)
import Control.Monad
import Control.Concurrent (forkIO)
import Control.Concurrent.Chan
import Data.ByteString (ByteString)
import Control.Exception (throwIO,Exception)
import Control.Applicative
import Data.Functor
import Data.Int
import Control.Concurrent
import Debug.Trace
import Text.Printf (printf)
import Data.Bits
import qualified Data.Vector as Vector
import qualified Data.IntMap as IntMap
import qualified Network.Socket as NS
import qualified Data.ByteString as ByteString
import qualified Network.Socket.ByteString as NSB
import qualified Net.Snmp.Decoding as SnmpDecoding
import qualified Net.Snmp.Encoding as SnmpEncoding
import qualified Language.Asn.Decoding as AsnDecoding
import qualified Language.Asn.Encoding as AsnEncoding
import qualified Data.Map as Map
import qualified Data.ByteString.Lazy as LB
import qualified System.Posix.Types

data Session = Session
  { sessionSockets :: !(Chan NS.Socket)
  -- , sessionCredsTimestamps :: !(TVar (Map Word32
  , sessionSocketCount :: !Int
  , sessionRequestId :: !(TVar RequestId)
  , sessionAesSalt :: !(TVar AesSalt)
  , sessionTimeoutMicroseconds :: !Int
  , sessionMaxTries :: !Int
  }

data Config = Config
  { configSocketPoolSize :: !Int
  , configTimeoutMicroseconds :: !Int
  , configRetries :: !Int
  } deriving (Show,Eq)

data Destination = Destination
  { destinationHost :: !(Word8,Word8,Word8,Word8)
  , destinationPort :: !Word16
  } deriving (Show,Eq)

data Credentials
  = CredentialsConstructV2 CredentialsV2
  | CredentialsConstructV3 CredentialsV3
  deriving (Show,Eq)

newtype CredentialsV2 = CredentialsV2
  { credentialsV2CommunityString :: ByteString
  } deriving (Show,Eq)


data CredentialsV3 = CredentialsV3
  { credentialsV3Crypto :: !Crypto
  , credentialsV3ContextName :: !ByteString
  , credentialsV3User :: !ByteString
  } deriving (Show,Eq)

data Context = Context
  { contextSession :: !Session
  , contextDestination :: !Destination
  , contextCredentials :: !Credentials
  }

data PerHostV3 = PerHostV3
  { perHostV3AuthoritativeEngineId :: !EngineId
  , perHostV3ReceiverTime :: !Int32
  , perHostV3ReceiverBoots :: !Int32
  }



-- | Only one connection can be open at a time on a given port.
openSession :: Config -> IO Session
openSession (Config socketPoolSize timeout retries) = do
  addrinfos <- NS.getAddrInfo
    (Just (NS.defaultHints {NS.addrFlags = [NS.AI_PASSIVE]}))
    (Just "0.0.0.0")
    Nothing
  let serveraddr = head addrinfos
  allSockets <- replicateM socketPoolSize $ do
    sock <- NS.socket (NS.addrFamily serveraddr) NS.Datagram NS.defaultProtocol
    NS.bind sock (NS.addrAddress serveraddr)
    return sock
  requestIdVar <- newTVarIO (RequestId 1)
  aesSaltVar <- newTVarIO (AesSalt 1)
  socketChan <- newChan
  writeList2Chan socketChan allSockets
  return (Session socketChan socketPoolSize requestIdVar aesSaltVar timeout retries)

closeSession :: Session -> IO ()
closeSession session = replicateM_ (sessionSocketCount session) $ do
  sock <- readChan (sessionSockets session)
  NS.close sock

generalRequest ::
     (RequestId -> Pdus)
  -> (Pdu -> Either SnmpException a)
  -> Context
  -> IO (Either SnmpException a)
generalRequest pdusFromRequestId fromPdu (Context session (Destination ip port) creds) = do
  sock <- readChan (sessionSockets session)
  case creds of
    CredentialsConstructV2 (CredentialsV2 commStr) -> do
      requestId <- nextRequestId (sessionRequestId session)
      let !bs = id
            $ LB.toStrict
            $ AsnEncoding.der SnmpEncoding.messageV2
            $ MessageV2 commStr
            $ pdusFromRequestId requestId
          !bsLen = ByteString.length bs
          go1 :: Int -> IO (Either SnmpException Pdu)
          go1 !n1 = if n1 > 0
            then do
              when inDebugMode $ putStrLn "Sending:"
              when inDebugMode $ putStrLn (hexByteStringInternal bs)
              bytesSentLen <- NSB.sendTo sock bs (NS.SockAddrInet (fromIntegral port) (NS.tupleToHostAddress ip))
              if bytesSentLen /= bsLen
                then return $ Left $ SnmpExceptionNotAllBytesSent bytesSentLen bsLen
                else do
                  let go2 mperHostV3 = do
                        (isReadyAction,deregister) <- threadWaitReadSTM (mySockFd sock)
                        delay <- registerDelay (sessionTimeoutMicroseconds session)
                        isContentReady <- atomically $ (isReadyAction $> True) <|> (fini delay $> False)
                        deregister
                        if not isContentReady
                          then go1 (n1 - 1)
                          else do
                            bsRecv <- NSB.recv sock 10000
                            when inDebugMode $ putStrLn "Received:"
                            when inDebugMode $ print bsRecv
                            if ByteString.null bsRecv
                              then return (Left SnmpExceptionSocketClosed)
                              else case AsnDecoding.ber SnmpDecoding.messageV2 bsRecv of
                                  Left err -> return (Left $ SnmpExceptionDecoding err)
                                  Right msg -> case messageV2Data msg of
                                    PdusResponse pdu@(Pdu respRequestId _ _ _) ->
                                      case compare requestId respRequestId of
                                        GT -> go2 mperHostV3
                                        EQ -> return (Right pdu)
                                        LT -> return $ Left $ SnmpExceptionMissedResponse requestId respRequestId
                                    _ -> return (Left (SnmpExceptionNonPduResponseV2 msg))
                  go2 Nothing
            else return $ Left SnmpExceptionTimeout
      e <- go1 (sessionMaxTries session)
      writeChan (sessionSockets session) sock
      return (e >>= fromPdu)
    CredentialsConstructV3 (CredentialsV3 crypto contextName user) -> do
      -- setting the reportable flags is very important
      -- for AuthPriv
      let flags = cryptoFlags crypto .|. 0x04
          mkAuthParams :: RequestId -> PerHostV3 -> (ByteString,ScopedPduData) -> ByteString
          mkAuthParams reqId phv3 privPair = case cryptoAuth crypto of
            Nothing -> ByteString.empty
            Just (AuthParameters typ password) ->
              -- figure out a way to cache this
              let key = SnmpEncoding.passwordToKey typ password (perHostV3AuthoritativeEngineId phv3)
                  serializationWithoutAuth = snd (makeBs (ByteString.replicate 12 0x00) reqId privPair phv3)
               in SnmpEncoding.mkSign typ key serializationWithoutAuth
          mkPrivParams :: AesSalt -> RequestId -> PerHostV3 -> (ByteString,ScopedPduData)
          mkPrivParams theSalt reqId phv3 = case crypto of
            AuthPriv (AuthParameters authType authPass) (PrivParameters privType privPass) -> case privType of
              PrivTypeAes ->
                let (encrypted,actualSaltBs) = SnmpEncoding.aesEncrypt
                      key
                      (perHostV3ReceiverBoots phv3)
                      (perHostV3ReceiverTime phv3)
                      theSalt
                      (LB.toStrict (AsnEncoding.der SnmpEncoding.scopedPdu spdu))
                 in (actualSaltBs,ScopedPduDataEncrypted encrypted)
              PrivTypeDes ->
                let (encrypted,actualSaltBs) = SnmpEncoding.desEncrypt
                      key
                      (perHostV3ReceiverBoots phv3)
                      (fromIntegral (getAesSalt theSalt))
                      -- (perHostV3ReceiverTime phv3)
                      (LB.toStrict (AsnEncoding.der SnmpEncoding.scopedPdu spdu))
                 in (actualSaltBs,ScopedPduDataEncrypted encrypted)
              where key = SnmpEncoding.passwordToKey authType privPass (perHostV3AuthoritativeEngineId phv3)
            _ -> (ByteString.empty,ScopedPduDataPlaintext spdu)
            where spdu = ScopedPdu (perHostV3AuthoritativeEngineId phv3) contextName (pdusFromRequestId reqId)
          makeBs :: ByteString -> RequestId -> (ByteString,ScopedPduData) -> PerHostV3 -> (MessageV3,ByteString)
          makeBs activeAuthParams reqId (activePrivParams,spdud) (PerHostV3 authoritativeEngineId receiverTime boots) =
            let myMsg = MessageV3
                  (HeaderData reqId 1500 flags) -- making up a max size
                  (Usm authoritativeEngineId boots receiverTime user activeAuthParams activePrivParams)
                  spdud
                -- myMsg2 = trace ("THE MESSAGE TO SEND: " ++ show myMsg) myMsg
             in (myMsg, LB.toStrict $ AsnEncoding.der SnmpEncoding.messageV3 $ myMsg)
          fullMakeBs :: AesSalt -> RequestId -> PerHostV3 -> (MessageV3, ByteString)
          fullMakeBs theSalt reqId phv3 =
            let privPair = mkPrivParams theSalt reqId phv3
                authParams = mkAuthParams reqId phv3 privPair
                newPair = makeBs authParams reqId privPair phv3
             in newPair
          go1 :: Int -> RequestId -> (MessageV3,ByteString) -> Bool -> IO (Either SnmpException Pdu)
          go1 !n1 !requestId (!sentMsg,!bsSent) !engineIdsAcquired = if n1 > 0
            then do
              when inDebugMode $ putStrLn "Sending:"
              when inDebugMode $ putStrLn (hexByteStringInternal bsSent)
              let bsLen = ByteString.length bsSent
              bytesSentLen <- NSB.sendTo sock bsSent (NS.SockAddrInet (fromIntegral port) (NS.tupleToHostAddress ip))
              if bytesSentLen /= bsLen
                then return $ Left $ SnmpExceptionNotAllBytesSent bytesSentLen bsLen
                else do
                  let go2 :: IO (Either SnmpException Pdu)
                      go2 = do
                        (isReadyAction,deregister) <- threadWaitReadSTM (mySockFd sock)
                        delay <- registerDelay (sessionTimeoutMicroseconds session)
                        isContentReady <- atomically $ (isReadyAction $> True) <|> (fini delay $> False)
                        deregister
                        if not isContentReady
                          then do
                            when inDebugMode $ putStrLn "NO RESPONSE"
                            requestId' <- nextRequestId (sessionRequestId session)
                            go1 (n1 - 1) requestId' (sentMsg,bsSent) engineIdsAcquired
                          else do
                            bsRecv <- NSB.recv sock 10000
                            when inDebugMode $ putStrLn "Received:"
                            when inDebugMode $ putStrLn (hexByteStringInternal bsRecv)
                            if ByteString.null bsRecv
                              then return (Left SnmpExceptionSocketClosed)
                              else case AsnDecoding.ber SnmpDecoding.messageV3 bsRecv of
                                Left err -> return (Left $ SnmpExceptionDecoding err)
                                Right msg -> do
                                  case cryptoAuth crypto of
                                    Nothing -> return ()
                                    Just (AuthParameters typ password) -> do
                                      when inDebugMode $ putStrLn "THE RECEIVED MESSAGE"
                                      when inDebugMode $ print msg
                                      let reencoded = LB.toStrict $ AsnEncoding.der SnmpEncoding.messageV3 msg
                                      when inDebugMode $ putStrLn $ hexByteStringInternal $ reencoded
                                      when (reencoded /= bsRecv) $ do
                                        when inDebugMode $ putStrLn "NOT THE SAME"
                                      let key = SnmpEncoding.passwordToKey typ password (usmAuthoritativeEngineId (messageV3SecurityParameters msg))
                                      case SnmpEncoding.checkSign typ key msg of
                                        Nothing -> return ()
                                        Just (expected,actual) -> do
                                          when (not $ ByteString.null actual) $ do
                                            throwIO $ SnmpExceptionAuthenticationFailure expected actual
                                  let handleSpdu :: ScopedPdu -> IO (Either SnmpException Pdu)
                                      handleSpdu spdu = case scopedPduData spdu of
                                        -- check to make sure that we requested an unencrypted response
                                        -- somehow check the message id in here too
                                        PdusResponse pdu@(Pdu respRequestId _ _ _) ->
                                          case compare requestId respRequestId of
                                            GT -> go2
                                            EQ -> return (Right pdu)
                                            LT -> return $ Left $ SnmpExceptionMissedResponse requestId respRequestId
                                        PdusReport (Pdu respRequestId _ _ _) -> do
                                          when inDebugMode $ putStrLn $ "Expected Request ID: " ++ show requestId
                                          when inDebugMode $ putStrLn $ "Received Request ID: " ++ show respRequestId
                                          if engineIdsAcquired
                                            then return $ Left (SnmpExceptionBadEngineId sentMsg msg)
                                            else do
                                              let usm = messageV3SecurityParameters msg
                                                  phv3 = PerHostV3
                                                    (usmAuthoritativeEngineId usm)
                                                    (usmAuthoritativeEngineTime usm)
                                                    (usmAuthoritativeEngineBoots usm)
                                              theSalt <- atomically $ nextSalt (sessionAesSalt session)
                                              requestId' <- nextRequestId (sessionRequestId session)
                                              -- Notice that n1 is not decremented in this
                                              -- situation. This is intentional.
                                              go1 n1 requestId' (fullMakeBs theSalt requestId' phv3) True
                                        _ -> return (Left (SnmpExceptionNonPduResponseV3 msg))
                                  case messageV3Data msg of
                                    ScopedPduDataEncrypted encrypted -> case crypto of
                                      AuthPriv (AuthParameters authType _) (PrivParameters privType privPass) -> do
                                        let usm = messageV3SecurityParameters msg
                                            key = SnmpEncoding.passwordToKey authType privPass (usmAuthoritativeEngineId usm)
                                            mdecrypted = case privType of
                                              PrivTypeDes -> SnmpEncoding.desDecrypt key (usmPrivacyParameters usm) encrypted
                                              PrivTypeAes -> SnmpEncoding.aesDecrypt key (usmPrivacyParameters usm) (usmAuthoritativeEngineBoots usm) (usmAuthoritativeEngineTime usm) encrypted
                                        case mdecrypted of
                                          Just bs -> case AsnDecoding.ber SnmpDecoding.scopedPdu bs of
                                            Left err -> throwIO (SnmpExceptionDecoding err)
                                            Right spdu -> handleSpdu spdu
                                          Nothing -> throwIO SnmpExceptionDecryptionFailure
                                    ScopedPduDataPlaintext spdu -> handleSpdu spdu
                  go2
            else return $ Left $ SnmpExceptionTimeoutV3 sentMsg
      -- boots and estimated time are made up for this, we could do better
      let originalPhv3 = PerHostV3 (EngineId "initial-engine-id") 0xFFFFFF 0xEEEEEE
      theSalt <- atomically $ nextSalt (sessionAesSalt session)
      requestId' <- nextRequestId (sessionRequestId session)
      e <- go1 (sessionMaxTries session) requestId' (fullMakeBs theSalt requestId' originalPhv3) False
      writeChan (sessionSockets session) sock
      return (e >>= fromPdu)

nextSalt :: TVar AesSalt -> STM AesSalt
nextSalt v = do
  AesSalt w <- readTVar v
  let s = AesSalt (w + 1)
  writeTVar v s
  return s

throwSnmpException :: IO (Either SnmpException a) -> IO a
throwSnmpException = (either throwIO return =<<)

get :: Context -> ObjectIdentifier -> IO ObjectSyntax
get ctx ident = throwSnmpException (get' ctx ident)

getBulkStep :: Context -> Int -> ObjectIdentifier -> IO (Vector (ObjectIdentifier,ObjectSyntax))
getBulkStep ctx maxRep ident = throwSnmpException (getBulkStep' ctx maxRep ident)

getBulkChildren :: Context -> Int -> ObjectIdentifier -> IO (Vector (ObjectIdentifier,ObjectSyntax))
getBulkChildren ctx maxRep oid1 = throwSnmpException (getBulkChildren' ctx maxRep oid1)

get' :: Context -> ObjectIdentifier -> IO (Either SnmpException ObjectSyntax)
get' ctx ident = generalRequest
  (\reqId -> PdusGetRequest (Pdu reqId (ErrorStatus 0) (ErrorIndex 0) (Vector.singleton (VarBind ident BindingResultUnspecified))))
  (singleBindingValue ident <=< onlyBindings)
  ctx

getBulkStep' :: Context -> Int -> ObjectIdentifier -> IO (Either SnmpException (Vector (ObjectIdentifier,ObjectSyntax)))
getBulkStep' ctx maxRep ident = generalRequest
  (\reqId -> PdusGetBulkRequest (BulkPdu reqId 0 (fromIntegral maxRep) (Vector.singleton (VarBind ident BindingResultUnspecified))))
  (fmap multipleBindings . onlyBindings)
  ctx

getBulkChildren' :: Context -> Int -> ObjectIdentifier -> IO (Either SnmpException (Vector (ObjectIdentifier,ObjectSyntax)))
getBulkChildren' ctx maxRep oid1 = go Vector.empty oid1 where
  go prevPairs ident = do
    epairsUnfiltered <- getBulkStep' ctx maxRep ident
    case epairsUnfiltered of
      Left e -> return (Left e)
      Right pairsUnfiltered -> do
        let pairs = Vector.filter (\(oid,_) -> oidIsPrefixOf oid1 oid) pairsUnfiltered
        if Vector.null pairs
          then return (Right prevPairs)
          else go (prevPairs Vector.++ pairs) (fst (Vector.last pairs))

oidIsPrefixOf :: ObjectIdentifier -> ObjectIdentifier -> Bool
oidIsPrefixOf (ObjectIdentifier a) (ObjectIdentifier b) =
  let lenA = Vector.length a in
  (lenA <= Vector.length b) &&
  (a == Vector.take lenA b)

-- There is not a mapMaybe for vector until 0.12.0.0
multipleBindings :: Vector VarBind -> Vector (ObjectIdentifier,ObjectSyntax)
multipleBindings = Vector.fromList . mapMaybe
  ( \(VarBind ident br) -> case br of
       BindingResultValue obj -> Just (ident,obj)
       _ -> Nothing
  ) . Vector.toList

singleBindingValue :: ObjectIdentifier -> Vector VarBind -> Either SnmpException ObjectSyntax
singleBindingValue oid v = if Vector.length v == 1
  then do
    let VarBind name res = v Vector.! 0
    when (name /= oid) $ Left $ SnmpExceptionMismatchedBinding oid name
    case res of
      BindingResultValue obj -> Right obj
      BindingResultUnspecified -> Left SnmpExceptionUnspecified
      BindingResultNoSuchObject -> Left (SnmpExceptionNoSuchObject oid)
      BindingResultNoSuchInstance -> Left (SnmpExceptionNoSuchInstance oid)
      BindingResultEndOfMibView -> Left SnmpExceptionEndOfMibView
  else Left (SnmpExceptionMultipleBindings (Vector.length v))

onlyBindings :: Pdu -> Either SnmpException (Vector VarBind)
onlyBindings (Pdu _ errStatus@(ErrorStatus e) errIndex bindings) =
  if e == 0 then Right bindings else Left (SnmpExceptionPduError errStatus errIndex)

data SnmpException
  = SnmpExceptionNotAllBytesSent !Int !Int
  | SnmpExceptionTimeout
  | SnmpExceptionTimeoutV3 !MessageV3
  | SnmpExceptionPduError !ErrorStatus !ErrorIndex
  | SnmpExceptionMultipleBindings !Int
  | SnmpExceptionMismatchedBinding !ObjectIdentifier !ObjectIdentifier
  | SnmpExceptionUnspecified -- ^ Should not happen
  | SnmpExceptionNoSuchObject !ObjectIdentifier
  | SnmpExceptionNoSuchInstance !ObjectIdentifier
  | SnmpExceptionEndOfMibView
  | SnmpExceptionMissedResponse !RequestId !RequestId
  | SnmpExceptionNonPduResponseV2 !MessageV2
  | SnmpExceptionNonPduResponseV3 !MessageV3
  | SnmpExceptionDecoding !String
  | SnmpExceptionSocketClosed
  | SnmpExceptionAuthenticationFailure !ByteString !ByteString
  | SnmpExceptionBadEngineId !MessageV3 !MessageV3
  | SnmpExceptionDecryptionFailure
  deriving (Show,Eq)

instance Exception SnmpException

readTMVarTimeout :: Int -> TMVar a -> IO (Maybe a)
readTMVarTimeout timeoutAfter pktChannel = do
  delay <- registerDelay timeoutAfter
  atomically $
        Just <$> readTMVar pktChannel
    <|> pure Nothing <* fini delay

fini :: TVar Bool -> STM ()
fini = check <=< readTVar

nextRequestId :: TVar RequestId -> IO RequestId
nextRequestId requestIdVar = atomically $ do
  RequestId i1 <- readTVar requestIdVar
  let !i2 = mod (i1 + 1) 100000000
      !i3 = if i2 == 0 then 1 else i2
  writeTVar requestIdVar (RequestId i3)
  return (RequestId i3)

mySockFd :: NS.Socket -> System.Posix.Types.Fd
mySockFd (NS.MkSocket n _ _ _ _) = System.Posix.Types.Fd n

hexByteStringInternal :: ByteString -> String
hexByteStringInternal = ByteString.foldr (\w xs -> printf "%02X" w ++ xs) []

inDebugMode :: Bool
inDebugMode = False