{-# LANGUAGE NamedFieldPuns, RecordWildCards, OverloadedStrings, ScopedTypeVariables #-}
{-|
Module      : Network.EAP.Authentication
Description : Collection of EAP authentication methods and related utility functions
Copyright   : (c) Erick Gonzalez, 2017
License     : BSD3
Maintainer  : erick@codemonkeylabs.de
Stability   : experimental
Portability : POSIX

This module provides functions implements EAP authentication schemes. New authentication types
will be added as needed and contributions are very welcome.

-}
module Network.EAP.Authentication (authenticateMSCHAPv2,
                                   generateNTResponse) where

import Prelude hiding (concatMap)
import Data.Bits                 ((.|.), (.&.), complement, shiftL, shiftR, xor)
import Data.ByteString.Lazy      (ByteString, cons, concatMap, toStrict)
import Data.ByteArray            (convert)
import Control.Monad.Except      (ExceptT(..), Except, throwError)
import Control.Monad.State.Lazy  (execState, modify)
import Crypto.Cipher.DES         (DES)
import Crypto.Cipher.Types       (cipherInit, ecbEncrypt)
import Crypto.Hash.Algorithms    (MD4, SHA1(..))
import Crypto.Hash               (Digest, hashFinalize, hashInitWith, hashUpdate, hashlazy)
import Crypto.Error              (CryptoError, CryptoFailable(..))
import Network.EAP.Types
import qualified Data.ByteString     as SB

-- | Authenticate the MSCHAPv2 response data to a given challenge request, using the supplied
-- cleartext password.
authenticateMSCHAPv2
    :: MSCHAPv2Data -- ^ Decoded data from the MSCHAPv2 response
     -> ByteString   -- ^ Authenticator challenge sent to the peer on a previous request
     -> ByteString   -- ^ Authenticating user password
     -> Except CryptoError Bool -- ^ Returns either an error from one of the encryption
                               -- routines or a boolean indicating whether the user
                               -- response matches the expected value
authenticateMSCHAPv2
  MSCHAPv2ResponseData{ getMSCHAPv2ResponseData = MSCHAPv2ResponseDataField{..}, .. }
  challenge
  password = do
  let peerChallenge = getMSCHAPv2ResponsePeerChallenge
      username      = getMSCHAPv2ResponseName
  r <- generateNTResponse challenge peerChallenge username password
  return $ r == toStrict getMSCHAPv2ResponseNTResponse

authenticateMSCHAPv2 msCHAPv2Data _ _ =
  error $ "Invalid authentication attempt of " ++ show msCHAPv2Data

-- | Calculate the NT Response as per [RFC2759], Section 8.1
generateNTResponse :: ByteString -- ^ Authenticator challenge sent to the peer on a previous
                                 -- request
                   ->  ByteString -- ^ Challenge sent back by authenticating peer
                   ->  ByteString -- ^ MSCHAP username
                   ->  ByteString -- ^ Cleartext user password
                   ->  Except CryptoError SB.ByteString -- ^ Returns either an error from one of
                                                       -- the encryption routines or the
                                                       -- calculated NT response
generateNTResponse authenticatorChallenge peerChallenge username password = do
    let challenge        = challengeHash
        passwordHash     = ntPasswordHash password
        zPasswordHash    = passwordHash `SB.append` SB.replicate 5 0 -- pad to 21 octets
        (pHash0, rest)   = SB.splitAt 7 zPasswordHash
        (pHash1, pHash2) = SB.splitAt 7 rest
    r0 <- encryptDES pHash0 challenge
    r1 <- encryptDES pHash1 challenge
    r2 <- encryptDES pHash2 challenge
    return $ r0 `SB.append` r1 `SB.append` r2
        where challengeHash  = SB.take 8 . convert . hashFinalize . flip execState ctx0 $ do
                                 hash peerChallenge
                                 hash authenticatorChallenge
                                 hash username
              hash           = modify . flip hashUpdate . toStrict
              ctx0           = hashInitWith SHA1
              ntPasswordHash = convert . (hashlazy :: ByteString -> Digest MD4) . concatMap with0s
              with0s         = flip cons "\NUL"

-- | Used internally to encrypt a message using a DES cipher in ECB mode
encryptDES :: SB.ByteString -> SB.ByteString -> Except CryptoError SB.ByteString
encryptDES key msg = do
  (cipher :: DES) <- ExceptT . return $ initCipher
  ExceptT . return . Right $ ecbEncrypt cipher msg
    where initCipher = case cipherInit $ addParity key of
                         CryptoFailed e -> throwError e
                         CryptoPassed c -> Right c

-- | Used internally to add the parity bits to a 56 bit (7 octet key) thus becoming an 8
-- octet key
addParity :: SB.ByteString -> SB.ByteString
addParity = expand . SB.foldl f ((0, 0), SB.empty)
    where f ((i, carry), acc) word =
              let v      = carry  .|. (word `shiftR` i)
                  carry' = word `shiftL` (7 - i)
                  v'     = v .&. 0xfe
                  v''    = v' .|. (complement $ parity v') .&. 1
                  acc'   = acc `SB.snoc` v''
              in ((i+1, carry'), acc')
          expand ((_, carry), str) = str `SB.snoc` carry
          parity x0 = foldl (\x i -> x `xor` (x `shiftR` i)) x0 [1, 2, 4, 8, 16] .&. 1