-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Crypto.RC4
-- Copyright : (c) Austin Seipp
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- An implementation of RC4 (AKA Rivest Cipher 4 or Alleged RC4/ARC4),
-- using SBV. For information on RC4, see: <http://en.wikipedia.org/wiki/RC4>.
--
-- We make no effort to optimize the code, and instead focus on a clear
-- implementation. In fact, the RC4 algorithm relies on in-place update of
-- its state heavily for efficiency, and is therefore unsuitable for a purely
-- functional implementation.
-----------------------------------------------------------------------------

{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Crypto.RC4 where

import Data.Char  (ord, chr)
import Data.List  (genericIndex)
import Data.Maybe (fromJust)
import Data.SBV

import Data.SBV.Tools.STree

import Numeric (showHex)

-----------------------------------------------------------------------------
-- * Types
-----------------------------------------------------------------------------

-- | RC4 State contains 256 8-bit values. We use the symbolically accessible
-- full-binary type 'STree' to represent the state, since RC4 needs
-- access to the array via a symbolic index and it's important to minimize access time.
type S = STree Word8 Word8

-- | Construct the fully balanced initial tree, where the leaves are simply the numbers @0@ through @255@.
initS :: S
initS :: S
initS = [SBV Word8] -> S
forall i e. HasKind i => [SBV e] -> STree i e
mkSTree ((Word8 -> SBV Word8) -> [Word8] -> [SBV Word8]
forall a b. (a -> b) -> [a] -> [b]
map Word8 -> SBV Word8
forall a. SymVal a => a -> SBV a
literal [Word8
0 .. Word8
255])

-- | The key is a stream of 'Word8' values.
type Key = [SWord8]

-- | Represents the current state of the RC4 stream: it is the @S@ array
-- along with the @i@ and @j@ index values used by the PRGA.
type RC4 = (S, SWord8, SWord8)

-----------------------------------------------------------------------------
-- * The PRGA
-----------------------------------------------------------------------------

-- | Swaps two elements in the RC4 array.
swap :: SWord8 -> SWord8 -> S -> S
swap :: SBV Word8 -> SBV Word8 -> S -> S
swap SBV Word8
i SBV Word8
j S
st = S -> SBV Word8 -> SBV Word8 -> S
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree (S -> SBV Word8 -> SBV Word8 -> S
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree S
st SBV Word8
i SBV Word8
stj) SBV Word8
j SBV Word8
sti
  where sti :: SBV Word8
sti = S -> SBV Word8 -> SBV Word8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SBV Word8
i
        stj :: SBV Word8
stj = S -> SBV Word8 -> SBV Word8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SBV Word8
j

-- | Implements the PRGA used in RC4. We return the new state and the next key value generated.
prga :: RC4 -> (SWord8, RC4)
prga :: RC4 -> (SBV Word8, RC4)
prga (S
st', SBV Word8
i', SBV Word8
j') = (S -> SBV Word8 -> SBV Word8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SBV Word8
kInd, (S
st, SBV Word8
i, SBV Word8
j))
  where i :: SBV Word8
i    = SBV Word8
i' SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Num a => a -> a -> a
+ SBV Word8
1
        j :: SBV Word8
j    = SBV Word8
j' SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Num a => a -> a -> a
+ S -> SBV Word8 -> SBV Word8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st' SBV Word8
i
        st :: S
st   = SBV Word8 -> SBV Word8 -> S -> S
swap SBV Word8
i SBV Word8
j S
st'
        kInd :: SBV Word8
kInd = S -> SBV Word8 -> SBV Word8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SBV Word8
i SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Num a => a -> a -> a
+ S -> SBV Word8 -> SBV Word8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SBV Word8
j

-----------------------------------------------------------------------------
-- * Key schedule
-----------------------------------------------------------------------------

-- | Constructs the state to be used by the PRGA using the given key.
initRC4 :: Key -> S
initRC4 :: [SBV Word8] -> S
initRC4 [SBV Word8]
key
 | Int
keyLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
keyLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
256
 = [Char] -> S
forall a. HasCallStack => [Char] -> a
error ([Char] -> S) -> [Char] -> S
forall a b. (a -> b) -> a -> b
$ [Char]
"RC4 requires a key of length between 1 and 256, received: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
keyLength
 | Bool
True
 = (SBV Word8, S) -> S
forall a b. (a, b) -> b
snd ((SBV Word8, S) -> S) -> (SBV Word8, S) -> S
forall a b. (a -> b) -> a -> b
$ ((SBV Word8, S) -> SBV Word8 -> (SBV Word8, S))
-> (SBV Word8, S) -> [SBV Word8] -> (SBV Word8, S)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (SBV Word8, S) -> SBV Word8 -> (SBV Word8, S)
mix (SBV Word8
0, S
initS) [SBV Word8
0..SBV Word8
255]
 where keyLength :: Int
keyLength = [SBV Word8] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SBV Word8]
key
       mix :: (SWord8, S) -> SWord8 -> (SWord8, S)
       mix :: (SBV Word8, S) -> SBV Word8 -> (SBV Word8, S)
mix (SBV Word8
j', S
s) SBV Word8
i = let j :: SBV Word8
j = SBV Word8
j' SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Num a => a -> a -> a
+ S -> SBV Word8 -> SBV Word8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
s SBV Word8
i SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Num a => a -> a -> a
+ [SBV Word8] -> Word8 -> SBV Word8
forall i a. Integral i => [a] -> i -> a
genericIndex [SBV Word8]
key (Maybe Word8 -> Word8
forall a. HasCallStack => Maybe a -> a
fromJust (SBV Word8 -> Maybe Word8
forall a. SymVal a => SBV a -> Maybe a
unliteral SBV Word8
i) Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`mod` Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
keyLength)
                       in (SBV Word8
j, SBV Word8 -> SBV Word8 -> S -> S
swap SBV Word8
i SBV Word8
j S
s)

-- | The key-schedule. Note that this function returns an infinite list.
keySchedule :: Key -> [SWord8]
keySchedule :: [SBV Word8] -> [SBV Word8]
keySchedule [SBV Word8]
key = RC4 -> [SBV Word8]
genKeys ([SBV Word8] -> S
initRC4 [SBV Word8]
key, SBV Word8
0, SBV Word8
0)
  where genKeys :: RC4 -> [SWord8]
        genKeys :: RC4 -> [SBV Word8]
genKeys RC4
st = let (SBV Word8
k, RC4
st') = RC4 -> (SBV Word8, RC4)
prga RC4
st in SBV Word8
k SBV Word8 -> [SBV Word8] -> [SBV Word8]
forall a. a -> [a] -> [a]
: RC4 -> [SBV Word8]
genKeys RC4
st'

-- | Generate a key-schedule from a given key-string.
keyScheduleString :: String -> [SWord8]
keyScheduleString :: [Char] -> [SBV Word8]
keyScheduleString = [SBV Word8] -> [SBV Word8]
keySchedule ([SBV Word8] -> [SBV Word8])
-> ([Char] -> [SBV Word8]) -> [Char] -> [SBV Word8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> SBV Word8) -> [Char] -> [SBV Word8]
forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> SBV Word8
forall a. SymVal a => a -> SBV a
literal (Word8 -> SBV Word8) -> (Char -> Word8) -> Char -> SBV Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord)

-----------------------------------------------------------------------------
-- * Encryption and Decryption
-----------------------------------------------------------------------------

-- | RC4 encryption. We generate key-words and xor it with the input. The
-- following test-vectors are from Wikipedia <http://en.wikipedia.org/wiki/RC4>:
--
-- >>> concatMap hex2 $ encrypt "Key" "Plaintext"
-- "bbf316e8d940af0ad3"
--
-- >>> concatMap hex2 $ encrypt "Wiki" "pedia"
-- "1021bf0420"
--
-- >>> concatMap hex2 $ encrypt "Secret" "Attack at dawn"
-- "45a01f645fc35b383552544b9bf5"
encrypt :: String -> String -> [SWord8]
encrypt :: [Char] -> [Char] -> [SBV Word8]
encrypt [Char]
key [Char]
pt = (SBV Word8 -> SBV Word8 -> SBV Word8)
-> [SBV Word8] -> [SBV Word8] -> [SBV Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Bits a => a -> a -> a
xor ([Char] -> [SBV Word8]
keyScheduleString [Char]
key) ((Char -> SBV Word8) -> [Char] -> [SBV Word8]
forall a b. (a -> b) -> [a] -> [b]
map Char -> SBV Word8
cvt [Char]
pt)
  where cvt :: Char -> SBV Word8
cvt = Word8 -> SBV Word8
forall a. SymVal a => a -> SBV a
literal (Word8 -> SBV Word8) -> (Char -> Word8) -> Char -> SBV Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord

-- | RC4 decryption. Essentially the same as decryption. For the above test vectors we have:
--
-- >>> decrypt "Key" [0xbb, 0xf3, 0x16, 0xe8, 0xd9, 0x40, 0xaf, 0x0a, 0xd3]
-- "Plaintext"
--
-- >>> decrypt "Wiki" [0x10, 0x21, 0xbf, 0x04, 0x20]
-- "pedia"
--
-- >>> decrypt "Secret" [0x45, 0xa0, 0x1f, 0x64, 0x5f, 0xc3, 0x5b, 0x38, 0x35, 0x52, 0x54, 0x4b, 0x9b, 0xf5]
-- "Attack at dawn"
decrypt :: String -> [SWord8] -> String
decrypt :: [Char] -> [SBV Word8] -> [Char]
decrypt [Char]
key [SBV Word8]
ct = (SBV Word8 -> Char) -> [SBV Word8] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map SBV Word8 -> Char
cvt ([SBV Word8] -> [Char]) -> [SBV Word8] -> [Char]
forall a b. (a -> b) -> a -> b
$ (SBV Word8 -> SBV Word8 -> SBV Word8)
-> [SBV Word8] -> [SBV Word8] -> [SBV Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Bits a => a -> a -> a
xor ([Char] -> [SBV Word8]
keyScheduleString [Char]
key) [SBV Word8]
ct
  where cvt :: SBV Word8 -> Char
cvt = Int -> Char
chr (Int -> Char) -> (SBV Word8 -> Int) -> SBV Word8 -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> (SBV Word8 -> Word8) -> SBV Word8 -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Word8 -> Word8
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Word8 -> Word8)
-> (SBV Word8 -> Maybe Word8) -> SBV Word8 -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SBV Word8 -> Maybe Word8
forall a. SymVal a => SBV a -> Maybe a
unliteral

-----------------------------------------------------------------------------
-- * Verification
-----------------------------------------------------------------------------

-- | Prove that round-trip encryption/decryption leaves the plain-text unchanged.
-- The theorem is stated parametrically over key and plain-text sizes. The expression
-- performs the proof for a 40-bit key (5 bytes) and 40-bit plaintext (again 5 bytes).
--
-- Note that this theorem is trivial to prove, since it is essentially establishing
-- xor'in the same value twice leaves a word unchanged (i.e., @x `xor` y `xor` y = x@).
-- However, the proof takes quite a while to complete, as it gives rise to a fairly
-- large symbolic trace.
rc4IsCorrect :: IO ThmResult
rc4IsCorrect :: IO ThmResult
rc4IsCorrect = SymbolicT IO SBool -> IO ThmResult
forall a. Provable a => a -> IO ThmResult
prove (SymbolicT IO SBool -> IO ThmResult)
-> SymbolicT IO SBool -> IO ThmResult
forall a b. (a -> b) -> a -> b
$ do
        [SBV Word8]
key <- Int -> Symbolic [SBV Word8]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkForallVars Int
5
        [SBV Word8]
pt  <- Int -> Symbolic [SBV Word8]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkForallVars Int
5
        let ks :: [SBV Word8]
ks  = [SBV Word8] -> [SBV Word8]
keySchedule [SBV Word8]
key
            ct :: [SBV Word8]
ct  = (SBV Word8 -> SBV Word8 -> SBV Word8)
-> [SBV Word8] -> [SBV Word8] -> [SBV Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Bits a => a -> a -> a
xor [SBV Word8]
ks [SBV Word8]
pt
            pt' :: [SBV Word8]
pt' = (SBV Word8 -> SBV Word8 -> SBV Word8)
-> [SBV Word8] -> [SBV Word8] -> [SBV Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SBV Word8 -> SBV Word8 -> SBV Word8
forall a. Bits a => a -> a -> a
xor [SBV Word8]
ks [SBV Word8]
ct
        SBool -> SymbolicT IO SBool
forall (m :: * -> *) a. Monad m => a -> m a
return (SBool -> SymbolicT IO SBool) -> SBool -> SymbolicT IO SBool
forall a b. (a -> b) -> a -> b
$ [SBV Word8]
pt [SBV Word8] -> [SBV Word8] -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [SBV Word8]
pt'

--------------------------------------------------------------------------------------------
-- | For doctest purposes only
hex2 :: (SymVal a, Show a, Integral a) => SBV a -> String
hex2 :: SBV a -> [Char]
hex2 SBV a
v = Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Char] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
s) Char
'0' [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
s
  where s :: [Char]
s = (a -> [Char] -> [Char]) -> [Char] -> a -> [Char]
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> [Char] -> [Char]
forall a. (Integral a, Show a) => a -> [Char] -> [Char]
showHex [Char]
"" (a -> [Char]) -> (SBV a -> a) -> SBV a -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> (SBV a -> Maybe a) -> SBV a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SBV a -> Maybe a
forall a. SymVal a => SBV a -> Maybe a
unliteral (SBV a -> [Char]) -> SBV a -> [Char]
forall a b. (a -> b) -> a -> b
$ SBV a
v