-- |
-- Module      : Crypto.Data.Padding
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Various cryptographic padding commonly used for block ciphers
-- or asymmetric systems.
--
module Crypto.Data.Padding
    ( Format(..)
    , pad
    , unpad
    ) where

import           Data.ByteArray (ByteArray, Bytes)
import qualified Data.ByteArray as B

-- | Format of padding
data Format =
      PKCS5     -- ^ PKCS5: PKCS7 with hardcoded size of 8
    | PKCS7 Int -- ^ PKCS7 with padding size between 1 and 255
    | ZERO Int  -- ^ zero padding with block size
    deriving (Int -> Format -> ShowS
[Format] -> ShowS
Format -> String
(Int -> Format -> ShowS)
-> (Format -> String) -> ([Format] -> ShowS) -> Show Format
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Format] -> ShowS
$cshowList :: [Format] -> ShowS
show :: Format -> String
$cshow :: Format -> String
showsPrec :: Int -> Format -> ShowS
$cshowsPrec :: Int -> Format -> ShowS
Show, Format -> Format -> Bool
(Format -> Format -> Bool)
-> (Format -> Format -> Bool) -> Eq Format
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Format -> Format -> Bool
$c/= :: Format -> Format -> Bool
== :: Format -> Format -> Bool
$c== :: Format -> Format -> Bool
Eq)

-- | Apply some pad to a bytearray
pad :: ByteArray byteArray => Format -> byteArray -> byteArray
pad :: Format -> byteArray -> byteArray
pad  Format
PKCS5     byteArray
bin = Format -> byteArray -> byteArray
forall byteArray.
ByteArray byteArray =>
Format -> byteArray -> byteArray
pad (Int -> Format
PKCS7 Int
8) byteArray
bin
pad (PKCS7 Int
sz) byteArray
bin = byteArray
bin byteArray -> byteArray -> byteArray
forall bs. ByteArray bs => bs -> bs -> bs
`B.append` byteArray
paddingString
  where
    paddingString :: byteArray
paddingString = Int -> Word8 -> byteArray
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate Int
paddingByte (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
paddingByte)
    paddingByte :: Int
paddingByte   = Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- (byteArray -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length byteArray
bin Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
sz)
pad (ZERO Int
sz)  byteArray
bin = byteArray
bin byteArray -> byteArray -> byteArray
forall bs. ByteArray bs => bs -> bs -> bs
`B.append` byteArray
paddingString
  where
    paddingString :: byteArray
paddingString = Int -> Word8 -> byteArray
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate Int
paddingSz Word8
0
    paddingSz :: Int
paddingSz
      | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0   =  Int
sz
      | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0     =  Int
0
      | Bool
otherwise  =  Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m
    m :: Int
m = Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
sz
    len :: Int
len = byteArray -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length byteArray
bin

-- | Try to remove some padding from a bytearray.
unpad :: ByteArray byteArray => Format -> byteArray -> Maybe byteArray
unpad :: Format -> byteArray -> Maybe byteArray
unpad  Format
PKCS5     byteArray
bin = Format -> byteArray -> Maybe byteArray
forall byteArray.
ByteArray byteArray =>
Format -> byteArray -> Maybe byteArray
unpad (Int -> Format
PKCS7 Int
8) byteArray
bin
unpad (PKCS7 Int
sz) byteArray
bin
    | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0                           = Maybe byteArray
forall a. Maybe a
Nothing
    | (Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
sz) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0                = Maybe byteArray
forall a. Maybe a
Nothing
    | Int
paddingSz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
paddingSz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
len   = Maybe byteArray
forall a. Maybe a
Nothing
    | Bytes
paddingWitness Bytes -> byteArray -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
`B.constEq` byteArray
padding = byteArray -> Maybe byteArray
forall a. a -> Maybe a
Just byteArray
content
    | Bool
otherwise                          = Maybe byteArray
forall a. Maybe a
Nothing
  where
    len :: Int
len         = byteArray -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length byteArray
bin
    paddingByte :: Word8
paddingByte = byteArray -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
B.index byteArray
bin (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    paddingSz :: Int
paddingSz   = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
paddingByte
    (byteArray
content, byteArray
padding) = Int -> byteArray -> (byteArray, byteArray)
forall bs. ByteArray bs => Int -> bs -> (bs, bs)
B.splitAt (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
paddingSz) byteArray
bin
    paddingWitness :: Bytes
paddingWitness     = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate Int
paddingSz Word8
paddingByte :: Bytes
unpad (ZERO Int
sz)  byteArray
bin
    | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0                           = Maybe byteArray
forall a. Maybe a
Nothing
    | (Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
sz) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0                = Maybe byteArray
forall a. Maybe a
Nothing
    | byteArray -> Int -> Word8
forall a. ByteArrayAccess a => a -> Int -> Word8
B.index byteArray
bin (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0         = byteArray -> Maybe byteArray
forall a. a -> Maybe a
Just byteArray
bin
    | Bool
otherwise                          = Maybe byteArray
forall a. Maybe a
Nothing
  where
    len :: Int
len         = byteArray -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length byteArray
bin