{-|
 - Maintainer: Thomas.DuBuisson@gmail.com
 - Stability: beta
 - Portability: portable 
 -
 - PKCS5 (RFC 1423) and IPSec ESP (RFC 4303) padding methods are implemented both
 - as trivial functions operating on bytestrings and as 'Put' routines usable from
 - the "Data.Serialize" module.  These methods do not work for algorithms or pad
 - sizes in excess of 255 bytes (2040 bits, so extremely large as far as cipher needs
 - are concerned).
 -}

module Crypto.Padding
	(
	-- * PKCS5 (RFC 1423) based [un]padding routines
	  padPKCS5
	, padBlockSize
	, putPaddedPKCS5
	, unpadPKCS5safe
	, unpadPKCS5
	-- * ESP (RFC 4303) [un]padding routines
	, padESP, unpadESP
	, padESPBlockSize
	, putPadESPBlockSize, putPadESP
	) where

import Data.Serialize.Put
import Crypto.Classes
import Crypto.Types
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L

-- |PKCS5 (aka RFC1423) padding method.
-- This method will not work properly for pad modulos > 256
padPKCS5 :: ByteLength -> B.ByteString -> B.ByteString
padPKCS5 len bs = runPut $ putPaddedPKCS5 len bs

-- |
-- @
--     putPaddedPKCS5 m bs
-- @
--
-- Will pad out @bs@ to a byte multiple
-- of @m@ and put both the bytestring and it's padding via 'Put'
-- (this saves on copying if you are already using Cereal).
putPaddedPKCS5 :: ByteLength -> B.ByteString -> Put
putPaddedPKCS5 0 bs = putByteString bs >> putWord8 1
putPaddedPKCS5 len bs = putByteString bs >> putByteString pad
  where
  pad = B.replicate padLen padValue
  r   = len - (B.length bs `rem` len)
  padLen = if r == 0 then len else r
  padValue = fromIntegral padLen

-- |PKCS5 (aka RFC1423) padding method using the BlockCipher instance
-- to determine the pad size.
padBlockSize :: BlockCipher k => k -> B.ByteString -> B.ByteString
padBlockSize k = runPut . putPaddedBlockSize k

-- |Leverages 'putPaddedPKCS5' to put the bytestring and padding
-- of sufficient length for use by the specified block cipher.
putPaddedBlockSize :: BlockCipher k => k -> B.ByteString -> Put
putPaddedBlockSize k bs = putPaddedPKCS5 (blockSizeBytes `for` k) bs

-- | unpad a strict bytestring padded in the typical PKCS5 manner.
-- This routine verifies all pad bytes and pad length match correctly.
unpadPKCS5safe :: B.ByteString -> Maybe B.ByteString
unpadPKCS5safe bs
	| bsLen > 0 && B.all (== padLen) pad && B.length pad == pLen = Just msg
	| otherwise = Nothing
  where
  bsLen = B.length bs
  padLen = B.last bs
  pLen = fromIntegral padLen
  (msg,pad) = B.splitAt (bsLen - pLen) bs

-- unpad a strict bytestring without checking the pad bytes and length any more than necessary.
unpadPKCS5 :: B.ByteString -> B.ByteString
unpadPKCS5 bs = if bsLen == 0 then bs else msg
  where
  bsLen = B.length bs
  padLen = B.last bs
  pLen = fromIntegral padLen
  (msg,_) = B.splitAt (bsLen - pLen) bs

-- | Pad a bytestring to the IPSEC esp specification
--
-- > padESP m payload
--
-- is equivilent to:
-- 
-- @
--               (msg)       (padding)       (length field)
--     B.concat [payload, B.pack [1,2,3,4..], B.pack [padLen]]
-- @
--
-- Where:
--
-- * the msg is any payload, including TFC.
-- 
-- * the padding is <= 255
-- 
-- * the length field is one byte.
--
--  Notice the result bytesting length remainder `r` equals zero.  The lack
--  of a \"next header\" field means this function is not directly useable for
--  an IPSec implementation (copy/paste the 4 line function and add in a
--  \"next header\" field if you are making IPSec ESP).
padESP :: Int -> B.ByteString -> B.ByteString
padESP i bs = runPut (putPadESP i bs)

-- | Like padESP but use the BlockCipher instance to determine padding size
padESPBlockSize :: BlockCipher k => k -> B.ByteString -> B.ByteString
padESPBlockSize k bs = runPut (putPadESPBlockSize k bs)

-- | Like putPadESP but using the BlockCipher instance to determine padding size
putPadESPBlockSize :: BlockCipher k => k -> B.ByteString -> Put
putPadESPBlockSize k bs = putPadESP (blockSizeBytes `for` k) bs

-- | Pad a bytestring to the IPSEC ESP specification using 'Put'.
-- This can reduce copying if you are already using 'Put'.
putPadESP :: Int -> B.ByteString -> Put
putPadESP 0 bs = putByteString bs >> putWord8 0
putPadESP l bs = do
	putByteString bs
	putByteString pad
	putWord8 pLen
  where
  pad = B.take padLen espPad
  padLen = l - ((B.length bs + 1) `rem` l)
  pLen = fromIntegral padLen

-- | A static espPad allows reuse of a single B.pack'ed pad for all calls to padESP
espPad = B.pack [1..255]

-- | unpad and return the padded message (Nothing is returned if the padding is invalid)
unpadESP :: B.ByteString -> Maybe B.ByteString
unpadESP bs =
	if bsLen == 0 || B.take pLen pad /= B.take pLen espPad
		then Nothing
		else Just msg
  where
  bsLen = B.length bs
  padLen = B.last bs
  pLen = fromIntegral padLen
  (msg,pad) = B.splitAt (bsLen - (pLen + 1)) bs