-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Crypto.RC4
-- Copyright : (c) Austin Seipp
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- An implementation of RC4 (AKA Rivest Cipher 4 or Alleged RC4/ARC4),
-- using SBV. For information on RC4, see: <http://en.wikipedia.org/wiki/RC4>.
--
-- We make no effort to optimize the code, and instead focus on a clear
-- implementation. In fact, the RC4 algorithm relies on in-place update of
-- its state heavily for efficiency, and is therefore unsuitable for a purely
-- functional implementation.
-----------------------------------------------------------------------------

{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.Crypto.RC4 where

import Data.Char  (ord, chr)
import Data.List  (genericIndex)
import Data.Maybe (fromJust)
import Data.SBV

import Data.SBV.Tools.STree

import Numeric (showHex)

-----------------------------------------------------------------------------
-- * Types
-----------------------------------------------------------------------------

-- | RC4 State contains 256 8-bit values. We use the symbolically accessible
-- full-binary type 'STree' to represent the state, since RC4 needs
-- access to the array via a symbolic index and it's important to minimize access time.
type S = STree Word8 Word8

-- | Construct the fully balanced initial tree, where the leaves are simply the numbers @0@ through @255@.
initS :: S
initS :: S
initS = [SWord8] -> S
forall i e. HasKind i => [SBV e] -> STree i e
mkSTree ((Word8 -> SWord8) -> [Word8] -> [SWord8]
forall a b. (a -> b) -> [a] -> [b]
map Word8 -> SWord8
forall a. SymVal a => a -> SBV a
literal [Word8
0 .. Word8
255])

-- | The key is a stream of 'Word8' values.
type Key = [SWord8]

-- | Represents the current state of the RC4 stream: it is the @S@ array
-- along with the @i@ and @j@ index values used by the PRGA.
type RC4 = (S, SWord8, SWord8)

-----------------------------------------------------------------------------
-- * The PRGA
-----------------------------------------------------------------------------

-- | Swaps two elements in the RC4 array.
swap :: SWord8 -> SWord8 -> S -> S
swap :: SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
st = S -> SWord8 -> SWord8 -> S
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree (S -> SWord8 -> SWord8 -> S
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e -> STree i e
writeSTree S
st SWord8
i SWord8
stj) SWord8
j SWord8
sti
  where sti :: SWord8
sti = S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
i
        stj :: SWord8
stj = S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
j

-- | Implements the PRGA used in RC4. We return the new state and the next key value generated.
prga :: RC4 -> (SWord8, RC4)
prga :: RC4 -> (SWord8, RC4)
prga (S
st', SWord8
i', SWord8
j') = (S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
kInd, (S
st, SWord8
i, SWord8
j))
  where i :: SWord8
i    = SWord8
i' SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ SWord8
1
        j :: SWord8
j    = SWord8
j' SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st' SWord8
i
        st :: S
st   = SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
st'
        kInd :: SWord8
kInd = S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
i SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
st SWord8
j

-----------------------------------------------------------------------------
-- * Key schedule
-----------------------------------------------------------------------------

-- | Constructs the state to be used by the PRGA using the given key.
initRC4 :: Key -> S
initRC4 :: [SWord8] -> S
initRC4 [SWord8]
key
 | Int
keyLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
keyLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
256
 = [Char] -> S
forall a. HasCallStack => [Char] -> a
error ([Char] -> S) -> [Char] -> S
forall a b. (a -> b) -> a -> b
$ [Char]
"RC4 requires a key of length between 1 and 256, received: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
keyLength
 | Bool
True
 = (SWord8, S) -> S
forall a b. (a, b) -> b
snd ((SWord8, S) -> S) -> (SWord8, S) -> S
forall a b. (a -> b) -> a -> b
$ ((SWord8, S) -> SWord8 -> (SWord8, S))
-> (SWord8, S) -> [SWord8] -> (SWord8, S)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (SWord8, S) -> SWord8 -> (SWord8, S)
mix (SWord8
0, S
initS) [SWord8
0..SWord8
255]
 where keyLength :: Int
keyLength = [SWord8] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord8]
key
       mix :: (SWord8, S) -> SWord8 -> (SWord8, S)
       mix :: (SWord8, S) -> SWord8 -> (SWord8, S)
mix (SWord8
j', S
s) SWord8
i = let j :: SWord8
j = SWord8
j' SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ S -> SWord8 -> SWord8
forall i e.
(SFiniteBits i, SymVal e) =>
STree i e -> SBV i -> SBV e
readSTree S
s SWord8
i SWord8 -> SWord8 -> SWord8
forall a. Num a => a -> a -> a
+ [SWord8] -> Word8 -> SWord8
forall i a. Integral i => [a] -> i -> a
genericIndex [SWord8]
key (Maybe Word8 -> Word8
forall a. HasCallStack => Maybe a -> a
fromJust (SWord8 -> Maybe Word8
forall a. SymVal a => SBV a -> Maybe a
unliteral SWord8
i) Word8 -> Word8 -> Word8
forall a. Integral a => a -> a -> a
`mod` Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
keyLength)
                       in (SWord8
j, SWord8 -> SWord8 -> S -> S
swap SWord8
i SWord8
j S
s)

-- | The key-schedule. Note that this function returns an infinite list.
keySchedule :: Key -> [SWord8]
keySchedule :: [SWord8] -> [SWord8]
keySchedule [SWord8]
key = RC4 -> [SWord8]
genKeys ([SWord8] -> S
initRC4 [SWord8]
key, SWord8
0, SWord8
0)
  where genKeys :: RC4 -> [SWord8]
        genKeys :: RC4 -> [SWord8]
genKeys RC4
st = let (SWord8
k, RC4
st') = RC4 -> (SWord8, RC4)
prga RC4
st in SWord8
k SWord8 -> [SWord8] -> [SWord8]
forall a. a -> [a] -> [a]
: RC4 -> [SWord8]
genKeys RC4
st'

-- | Generate a key-schedule from a given key-string.
keyScheduleString :: String -> [SWord8]
keyScheduleString :: [Char] -> [SWord8]
keyScheduleString = [SWord8] -> [SWord8]
keySchedule ([SWord8] -> [SWord8])
-> ([Char] -> [SWord8]) -> [Char] -> [SWord8]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> SWord8) -> [Char] -> [SWord8]
forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> SWord8
forall a. SymVal a => a -> SBV a
literal (Word8 -> SWord8) -> (Char -> Word8) -> Char -> SWord8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord)

-----------------------------------------------------------------------------
-- * Encryption and Decryption
-----------------------------------------------------------------------------

-- | RC4 encryption. We generate key-words and xor it with the input. The
-- following test-vectors are from Wikipedia <http://en.wikipedia.org/wiki/RC4>:
--
-- >>> concatMap hex2 $ encrypt "Key" "Plaintext"
-- "bbf316e8d940af0ad3"
--
-- >>> concatMap hex2 $ encrypt "Wiki" "pedia"
-- "1021bf0420"
--
-- >>> concatMap hex2 $ encrypt "Secret" "Attack at dawn"
-- "45a01f645fc35b383552544b9bf5"
encrypt :: String -> String -> [SWord8]
encrypt :: [Char] -> [Char] -> [SWord8]
encrypt [Char]
key [Char]
pt = (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor ([Char] -> [SWord8]
keyScheduleString [Char]
key) ((Char -> SWord8) -> [Char] -> [SWord8]
forall a b. (a -> b) -> [a] -> [b]
map Char -> SWord8
cvt [Char]
pt)
  where cvt :: Char -> SWord8
cvt = Word8 -> SWord8
forall a. SymVal a => a -> SBV a
literal (Word8 -> SWord8) -> (Char -> Word8) -> Char -> SWord8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> (Char -> Int) -> Char -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Int
ord

-- | RC4 decryption. Essentially the same as decryption. For the above test vectors we have:
--
-- >>> decrypt "Key" [0xbb, 0xf3, 0x16, 0xe8, 0xd9, 0x40, 0xaf, 0x0a, 0xd3]
-- "Plaintext"
--
-- >>> decrypt "Wiki" [0x10, 0x21, 0xbf, 0x04, 0x20]
-- "pedia"
--
-- >>> decrypt "Secret" [0x45, 0xa0, 0x1f, 0x64, 0x5f, 0xc3, 0x5b, 0x38, 0x35, 0x52, 0x54, 0x4b, 0x9b, 0xf5]
-- "Attack at dawn"
decrypt :: String -> [SWord8] -> String
decrypt :: [Char] -> [SWord8] -> [Char]
decrypt [Char]
key [SWord8]
ct = (SWord8 -> Char) -> [SWord8] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map SWord8 -> Char
cvt ([SWord8] -> [Char]) -> [SWord8] -> [Char]
forall a b. (a -> b) -> a -> b
$ (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor ([Char] -> [SWord8]
keyScheduleString [Char]
key) [SWord8]
ct
  where cvt :: SWord8 -> Char
cvt = Int -> Char
chr (Int -> Char) -> (SWord8 -> Int) -> SWord8 -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> (SWord8 -> Word8) -> SWord8 -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Word8 -> Word8
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Word8 -> Word8)
-> (SWord8 -> Maybe Word8) -> SWord8 -> Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord8 -> Maybe Word8
forall a. SymVal a => SBV a -> Maybe a
unliteral

-----------------------------------------------------------------------------
-- * Verification
-----------------------------------------------------------------------------

-- | Prove that round-trip encryption/decryption leaves the plain-text unchanged.
-- The theorem is stated parametrically over key and plain-text sizes. The expression
-- performs the proof for a 40-bit key (5 bytes) and 40-bit plaintext (again 5 bytes).
--
-- Note that this theorem is trivial to prove, since it is essentially establishing
-- xor'in the same value twice leaves a word unchanged (i.e., @x `xor` y `xor` y = x@).
-- However, the proof takes quite a while to complete, as it gives rise to a fairly
-- large symbolic trace.
rc4IsCorrect :: IO ThmResult
rc4IsCorrect :: IO ThmResult
rc4IsCorrect = SymbolicT IO SBool -> IO ThmResult
forall a. Provable a => a -> IO ThmResult
prove (SymbolicT IO SBool -> IO ThmResult)
-> SymbolicT IO SBool -> IO ThmResult
forall a b. (a -> b) -> a -> b
$ do
        [SWord8]
key <- Int -> Symbolic [SWord8]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkFreeVars Int
5
        [SWord8]
pt  <- Int -> Symbolic [SWord8]
forall a. SymVal a => Int -> Symbolic [SBV a]
mkFreeVars Int
5
        let ks :: [SWord8]
ks  = [SWord8] -> [SWord8]
keySchedule [SWord8]
key
            ct :: [SWord8]
ct  = (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor [SWord8]
ks [SWord8]
pt
            pt' :: [SWord8]
pt' = (SWord8 -> SWord8 -> SWord8) -> [SWord8] -> [SWord8] -> [SWord8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord8 -> SWord8 -> SWord8
forall a. Bits a => a -> a -> a
xor [SWord8]
ks [SWord8]
ct
        SBool -> SymbolicT IO SBool
forall a. a -> SymbolicT IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (SBool -> SymbolicT IO SBool) -> SBool -> SymbolicT IO SBool
forall a b. (a -> b) -> a -> b
$ [SWord8]
pt [SWord8] -> [SWord8] -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [SWord8]
pt'

--------------------------------------------------------------------------------------------
-- | For doctest purposes only
hex2 :: (SymVal a, Show a, Integral a) => SBV a -> String
hex2 :: forall a. (SymVal a, Show a, Integral a) => SBV a -> [Char]
hex2 SBV a
v = Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Char] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
s) Char
'0' [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
s
  where s :: [Char]
s = (a -> [Char] -> [Char]) -> [Char] -> a -> [Char]
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> [Char] -> [Char]
forall a. Integral a => a -> [Char] -> [Char]
showHex [Char]
"" (a -> [Char]) -> (SBV a -> a) -> SBV a -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> (SBV a -> Maybe a) -> SBV a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SBV a -> Maybe a
forall a. SymVal a => SBV a -> Maybe a
unliteral (SBV a -> [Char]) -> SBV a -> [Char]
forall a b. (a -> b) -> a -> b
$ SBV a
v