module Data.Radius.Implements ( signPacket, signedPacket, AuthenticatorError (..), checkSignedRequest, checkSignedResponse, ) where import Control.Monad (unless) import Data.Monoid ((<>)) import Data.Word (Word16) import Data.ByteString (ByteString) import qualified Data.ByteString as BS import Data.Serialize.Put (Put, runPut) import qualified Data.ByteArray as BA import Crypto.Hash (Digest, hash, MD5) import Crypto.MAC.HMAC (HMAC, hmac, hmacGetDigest) import Data.Radius.Packet (Code (..), Header (..), Packet (..)) import Data.Radius.Scalar (AtString (..), Bin128, mayBin128, fromBin128, bin128Zero) import Data.Radius.Attribute (Number (MessageAuthenticator), messageAuthenticator, NumberAbstract (Standard), Attribute' (Attribute'), TypedNumberSets, ) import Data.Radius.StreamGet (Attributes) import qualified Data.Radius.StreamGet as Get import qualified Data.Radius.StreamPut as Put hmacMD5 :: ByteString -> ByteString -> Bin128 hmacMD5 rsk bs = maybe (error "hmacMD5: BUG? Invalid result length") id . mayBin128 . BA.convert $ hmacGetDigest (hmac rsk bs :: HMAC MD5) md5 :: ByteString -> Bin128 md5 bs = maybe (error "md5: BUG? Invalid result length") id . mayBin128 $ BA.convert (hash bs :: Digest MD5) -- | Make signatures for response packet. -- When you don't want to use message authenticator attribute, -- pass a function to make attributes which doesn't use message authenticator argument. signPacket :: (a -> ByteString -> Put) -- ^ Printer for vendor specific attribute -> ByteString -- ^ Radius secret key -> Bin128 -- ^ Request authenticator -> (Word16 -> Bin128 -> Header) -- ^ Function to make header -> (Bin128 -> [Attribute' a]) -- ^ Function to make attributes from message authenticator -> (Word16, Bin128, Bin128) -- ^ Packet length, message authenticator and response authenticator signPacket va rsk auth mkH mkA = (len, msgAuth, respAuth) where asMsgAuth0 = mkA bin128Zero pput = runPut . Put.upacket va len = fromIntegral . BS.length . pput $ Packet { header = mkH 0 auth, attributes = asMsgAuth0 } msgAuth = hmacMD5 rsk . pput $ Packet { header = mkH len auth, attributes = asMsgAuth0 } respAuth = md5 $ (pput $ Packet { header = mkH len auth, attributes = mkA msgAuth }) <> rsk signedPacket :: (a -> ByteString -> Put) -- ^ Printer for vendor specific attribute -> ByteString -- ^ Radius secret key -> Bin128 -- ^ Request authenticator -> (Word16 -> Bin128 -> Header) -- ^ Function to make header -> (Bin128 -> [Attribute' a]) -- ^ Function to make attributes from message authenticator -> Packet [Attribute' a] -- ^ Signed packet signedPacket va rsk auth mkH mkA = case code $ mkH len auth of AccessAccept -> response AccessReject -> response AccessChallenge -> response AccessRequest -> other Other _ -> other where (len, msgAuth, respAuth) = signPacket va rsk auth mkH mkA response = Packet { header = mkH len respAuth, attributes = mkA msgAuth } other = Packet { header = mkH len auth , attributes = mkA msgAuth } data AuthenticatorError v = NoMessageAuthenticator (Attributes v) -- ^ No Message-Authenticator attribute | BadMessageAuthenticator -- ^ Message-Authenticator attribute is not matched | MoreThanOneMessageAuthenticator -- ^ More than one Message-Authenticator attribute pairs found | BadAuthenticator -- ^ Radius packet authenticator is not matched | AttributesDecodeError String -- ^ Fail to decode attributes, attribute type error etc. | NotRequestPacket Code -- ^ Not request packet is passed to function to check request packet | NotResponsePacket Code -- ^ Not response packet is passed to function to check response packet instance Show (AuthenticatorError v) where show = d where d (NoMessageAuthenticator _) = "no messageAuthenticator found" d BadMessageAuthenticator = "bad messageAuthenticator" d MoreThanOneMessageAuthenticator = "more than one messageAuthenticator found" d BadAuthenticator = "bad radius packet authenticator" d (AttributesDecodeError s) = "fail to decode attributes: " ++ s d (NotRequestPacket c) = "not request packet: code: " ++ show c d (NotResponsePacket c) = "not response packet: code: " ++ show c checkSignedRequest :: (TypedNumberSets a, Ord a) => (a -> ByteString -> Put) -- ^ Printer for vendor specific attribute -> ByteString -> Packet [Attribute' a] -> Either (AuthenticatorError a) (Attributes a) checkSignedRequest va rsk upkt = case code $ header upkt of c@AccessAccept -> notRequestCode c c@AccessReject -> notRequestCode c c@AccessChallenge -> notRequestCode c AccessRequest -> check Other _ -> check where notRequestCode = Left . NotRequestPacket check = checkMA calcMsgAuth $ attrs attrs = attributes upkt calcMsgAuth = hmacMD5 rsk . runPut . Put.upacket va $ upkt { attributes = replace0MA attrs } checkSignedResponse :: (TypedNumberSets a, Ord a) => (a -> ByteString -> Put) -- ^ Printer for vendor specific attribute -> ByteString -> Bin128 -> Packet [Attribute' a] -> Either (AuthenticatorError a) (Attributes a) checkSignedResponse va rsk reqAuth upkt = case code $ header upkt of AccessAccept -> check AccessReject -> check AccessChallenge -> check c@AccessRequest -> notResponseCode c c@(Other _) -> notResponseCode c where notResponseCode = Left . NotResponsePacket check = do unless (authenticator (header upkt) == calcRespAuth) $ Left BadAuthenticator checkMA calcMsgAuth attrs attrs = attributes upkt calcRespAuth = md5 $ (runPut . Put.upacket va $ upkt { header = (header upkt) { authenticator = reqAuth } }) <> rsk calcMsgAuth = hmacMD5 rsk . runPut . Put.upacket va $ upkt { header = (header upkt) { authenticator = reqAuth } , attributes = replace0MA attrs } checkMA :: (TypedNumberSets a, Ord a) => Bin128 -> [Attribute' a] -> Either (AuthenticatorError a) (Attributes a) checkMA calcMsgAuth attrs = do ta <- either (Left . AttributesDecodeError) return . Get.extractAttributes $ mapM Get.tellT attrs case Get.takeTyped ta messageAuthenticator of [] -> Left $ NoMessageAuthenticator ta [AtString bs] -> do unless (bs == fromBin128 calcMsgAuth) $ Left BadMessageAuthenticator return ta _:_:_ -> Left MoreThanOneMessageAuthenticator replace0MA :: [Attribute' a] -> [Attribute' a] replace0MA = rec' where rec' [] = [] rec' (Attribute' n@(Standard MessageAuthenticator) _ : xs) = Attribute' n (fromBin128 bin128Zero) : rec' xs rec' (x : xs) = x : rec' xs