-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.Crypto.AES
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- An implementation of AES (Advanced Encryption Standard), using SBV.
-- For details on AES, see <http://en.wikipedia.org/wiki/Advanced_Encryption_Standard>.
--
-- We do a T-box implementation, which leads to good C code as we can take
-- advantage of look-up tables. Note that we make virtually no attempt to
-- optimize our Haskell code. The concern here is not with getting Haskell running
-- fast at all. The idea is to program the T-Box implementation as naturally and clearly
-- as possible in Haskell, and have SBV's code-generator generate fast C code automatically.
-- Therefore, we merely use ordinary Haskell lists as our data-structures, and do not
-- bother with any unboxing or strictness annotations. Thus, we achieve the separation
-- of concerns: Correctness via clarity and simplicity and proofs on the Haskell side,
-- performance by relying on SBV's code generator. If necessary, the generated code
-- can be FFI'd back into Haskell to complete the loop.
--
-- All 3 valid key sizes (128, 192, and 256) as required by the FIPS-197 standard
-- are supported.
-----------------------------------------------------------------------------

{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE ParallelListComp #-}

{-# OPTIONS_GHC -Wall -Werror -Wno-incomplete-uni-patterns #-}

module Documentation.SBV.Examples.Crypto.AES where

import Control.Monad (void, when)

import Data.SBV
import Data.SBV.Tools.CodeGen
import Data.SBV.Tools.Polynomial

import Data.List (transpose)
import Data.Maybe (fromJust)

import Numeric (showHex)

import Test.QuickCheck hiding (verbose)

-- $setup
-- >>> -- For doctest purposes only:
-- >>> import Data.SBV

-----------------------------------------------------------------------------
-- * Formalizing GF(2^8)
-----------------------------------------------------------------------------

-- | An element of the Galois Field 2^8, which are essentially polynomials with
-- maximum degree 7. They are conveniently represented as values between 0 and 255.
type GF28 = SWord 8

-- | Multiplication in GF(2^8). This is simple polynomial multiplication, followed
-- by the irreducible polynomial @x^8+x^4+x^3+x^1+1@. We simply use the 'pMult'
-- function exported by SBV to do the operation.
gf28Mult :: GF28 -> GF28 -> GF28
gf28Mult :: GF28 -> GF28 -> GF28
gf28Mult GF28
x GF28
y = (GF28, GF28, [Int]) -> GF28
forall a. Polynomial a => (a, a, [Int]) -> a
pMult (GF28
x, GF28
y, [Int
8, Int
4, Int
3, Int
1, Int
0])

-- | Exponentiation by a constant in GF(2^8). The implementation uses the usual
-- square-and-multiply trick to speed up the computation.
gf28Pow :: GF28 -> Int -> GF28
gf28Pow :: GF28 -> Int -> GF28
gf28Pow GF28
n = Int -> GF28
forall {t}. (Integral t, Bits t) => t -> GF28
pow
  where sq :: GF28 -> GF28
sq GF28
x  = GF28
x GF28 -> GF28 -> GF28
`gf28Mult` GF28
x
        pow :: t -> GF28
pow t
0    = GF28
1
        pow t
i
         | t -> Bool
forall a. Integral a => a -> Bool
odd t
i = GF28
n GF28 -> GF28 -> GF28
`gf28Mult` GF28 -> GF28
sq (t -> GF28
pow (t
i t -> Int -> t
forall a. Bits a => a -> Int -> a
`shiftR` Int
1))
         | Bool
True  = GF28 -> GF28
sq (t -> GF28
pow (t
i t -> Int -> t
forall a. Bits a => a -> Int -> a
`shiftR` Int
1))

-- | Computing inverses in GF(2^8). By the mathematical properties of GF(2^8)
-- and the particular irreducible polynomial used @x^8+x^5+x^3+x^1+1@, it
-- turns out that raising to the 254 power gives us the multiplicative inverse.
-- Of course, we can prove this using SBV:
--
-- >>> prove $ \x -> x ./= 0 .=> x `gf28Mult` gf28Inverse x .== 1
-- Q.E.D.
--
-- Note that we exclude @0@ in our theorem, as it does not have a
-- multiplicative inverse.
gf28Inverse :: GF28 -> GF28
gf28Inverse :: GF28 -> GF28
gf28Inverse GF28
x = GF28
x GF28 -> Int -> GF28
`gf28Pow` Int
254

-----------------------------------------------------------------------------
-- * Implementing AES
-----------------------------------------------------------------------------

-----------------------------------------------------------------------------
-- ** Types and basic operations
-----------------------------------------------------------------------------
-- | AES state. The state consists of four 32-bit words, each of which is in turn treated
-- as four GF28's, i.e., 4 bytes. The T-Box implementation keeps the four-bytes together
-- for efficient representation.
type State = [SWord 32]

-- | The key, which can be 128, 192, or 256 bits. Represented as a sequence of 32-bit words.
type Key = [SWord 32]

-- | The key schedule. AES executes in rounds, and it treats first and last round keys slightly
-- differently than the middle ones. We reflect that choice by being explicit about it in our type.
-- The length of the middle list of keys depends on the key-size, which in turn determines
-- the number of rounds.
type KS = (Key, [Key], Key)

-- | Rotating a state row by a fixed amount to the right.
rotR :: [GF28] -> Int -> [GF28]
rotR :: [GF28] -> Int -> [GF28]
rotR [GF28
a, GF28
b, GF28
c, GF28
d] Int
1 = [GF28
d, GF28
a, GF28
b, GF28
c]
rotR [GF28
a, GF28
b, GF28
c, GF28
d] Int
2 = [GF28
c, GF28
d, GF28
a, GF28
b]
rotR [GF28
a, GF28
b, GF28
c, GF28
d] Int
3 = [GF28
b, GF28
c, GF28
d, GF28
a]
rotR [GF28]
xs           Int
i = [Char] -> [GF28]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [GF28]) -> [Char] -> [GF28]
forall a b. (a -> b) -> a -> b
$ [Char]
"rotR: Unexpected input: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ([GF28], Int) -> [Char]
forall a. Show a => a -> [Char]
show ([GF28]
xs, Int
i)

-----------------------------------------------------------------------------
-- ** The key schedule
-----------------------------------------------------------------------------

-- | Definition of round-constants, as specified in Section 5.2 of the AES standard.
roundConstants :: [GF28]
roundConstants :: [GF28]
roundConstants = GF28
0 GF28 -> [GF28] -> [GF28]
forall a. a -> [a] -> [a]
: [ GF28 -> Int -> GF28
gf28Pow GF28
2 (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) | Int
k <- [Int
1 .. ] ]

-- | The @InvMixColumns@ transformation, as described in Section 5.3.3 of the standard. Note
-- that this transformation is only used explicitly during key-expansion in the T-Box implementation
-- of AES.
invMixColumns :: State -> State
invMixColumns :: [SWord 32] -> [SWord 32]
invMixColumns [SWord 32]
state = ([GF28] -> SWord 32) -> [[GF28]] -> [SWord 32]
forall a b. (a -> b) -> [a] -> [b]
map [GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes ([[GF28]] -> [SWord 32]) -> [[GF28]] -> [SWord 32]
forall a b. (a -> b) -> a -> b
$ [[GF28]] -> [[GF28]]
forall a. [[a]] -> [[a]]
transpose ([[GF28]] -> [[GF28]]) -> [[GF28]] -> [[GF28]]
forall a b. (a -> b) -> a -> b
$ [[GF28]] -> [[GF28]]
mmult ((SWord 32 -> [GF28]) -> [SWord 32] -> [[GF28]]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> [GF28]
forall a. ByteConverter a => a -> [GF28]
toBytes [SWord 32]
state)
 where dot :: [b -> c] -> [b] -> c
dot [b -> c]
f   = (c -> c -> c) -> [c] -> c
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 c -> c -> c
forall a. Bits a => a -> a -> a
xor ([c] -> c) -> ([b] -> [c]) -> [b] -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((b -> c) -> b -> c) -> [b -> c] -> [b] -> [c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (b -> c) -> b -> c
forall a b. (a -> b) -> a -> b
($) [b -> c]
f
       mmult :: [[GF28]] -> [[GF28]]
mmult [[GF28]]
n = [([GF28] -> GF28) -> [[GF28]] -> [GF28]
forall a b. (a -> b) -> [a] -> [b]
map ([GF28 -> GF28] -> [GF28] -> GF28
forall {c} {b}. Bits c => [b -> c] -> [b] -> c
dot [GF28 -> GF28]
r) [[GF28]]
n | [GF28 -> GF28]
r <- [ [GF28 -> GF28
mE, GF28 -> GF28
mB, GF28 -> GF28
mD, GF28 -> GF28
m9]
                                       , [GF28 -> GF28
m9, GF28 -> GF28
mE, GF28 -> GF28
mB, GF28 -> GF28
mD]
                                       , [GF28 -> GF28
mD, GF28 -> GF28
m9, GF28 -> GF28
mE, GF28 -> GF28
mB]
                                       , [GF28 -> GF28
mB, GF28 -> GF28
mD, GF28 -> GF28
m9, GF28 -> GF28
mE]
                                       ]]
       -- table-lookup versions of gf28Mult with the constants used in invMixColumns
       mE :: GF28 -> GF28
mE = [GF28] -> GF28 -> GF28 -> GF28
forall b.
(Ord b, SymVal b, Num b) =>
[GF28] -> GF28 -> SBV b -> GF28
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [GF28]
mETable GF28
0
       mB :: GF28 -> GF28
mB = [GF28] -> GF28 -> GF28 -> GF28
forall b.
(Ord b, SymVal b, Num b) =>
[GF28] -> GF28 -> SBV b -> GF28
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [GF28]
mBTable GF28
0
       mD :: GF28 -> GF28
mD = [GF28] -> GF28 -> GF28 -> GF28
forall b.
(Ord b, SymVal b, Num b) =>
[GF28] -> GF28 -> SBV b -> GF28
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [GF28]
mDTable GF28
0
       m9 :: GF28 -> GF28
m9 = [GF28] -> GF28 -> GF28 -> GF28
forall b.
(Ord b, SymVal b, Num b) =>
[GF28] -> GF28 -> SBV b -> GF28
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [GF28]
m9Table GF28
0
       mETable :: [GF28]
mETable = (GF28 -> GF28) -> [GF28] -> [GF28]
forall a b. (a -> b) -> [a] -> [b]
map (GF28 -> GF28 -> GF28
gf28Mult GF28
0xE) [GF28
0..GF28
255]
       mBTable :: [GF28]
mBTable = (GF28 -> GF28) -> [GF28] -> [GF28]
forall a b. (a -> b) -> [a] -> [b]
map (GF28 -> GF28 -> GF28
gf28Mult GF28
0xB) [GF28
0..GF28
255]
       mDTable :: [GF28]
mDTable = (GF28 -> GF28) -> [GF28] -> [GF28]
forall a b. (a -> b) -> [a] -> [b]
map (GF28 -> GF28 -> GF28
gf28Mult GF28
0xD) [GF28
0..GF28
255]
       m9Table :: [GF28]
m9Table = (GF28 -> GF28) -> [GF28] -> [GF28]
forall a b. (a -> b) -> [a] -> [b]
map (GF28 -> GF28 -> GF28
gf28Mult GF28
0x9) [GF28
0..GF28
255]

-- | Key expansion. Starting with the given key, returns an infinite sequence of
-- words, as described by the AES standard, Section 5.2, Figure 11.
keyExpansion :: Int -> Key -> [Key]
keyExpansion :: Int -> [SWord 32] -> [[SWord 32]]
keyExpansion Int
nk [SWord 32]
key = [SWord 32] -> [[SWord 32]]
forall a. [a] -> [[a]]
chop4 [SWord 32]
keys
   where keys :: [SWord 32]
         keys :: [SWord 32]
keys = [SWord 32]
key [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [Int -> SWord 32 -> SWord 32 -> SWord 32
nextWord Int
i SWord 32
prev SWord 32
old | Int
i <- [Int
nk ..] | SWord 32
prev <- Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
drop (Int
nkInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [SWord 32]
keys | SWord 32
old <- [SWord 32]
keys]

         nextWord :: Int -> SWord 32 -> SWord 32 -> SWord 32
         nextWord :: Int -> SWord 32 -> SWord 32 -> SWord 32
nextWord Int
i SWord 32
prev SWord 32
old
           | Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
nk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0           = SWord 32
old SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32 -> GF28 -> SWord 32
subWordRcon (SWord 32
prev SWord 32 -> Int -> SWord 32
forall a. Bits a => a -> Int -> a
`rotateL` Int
8) ([GF28]
roundConstants [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! (Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
nk))
           | Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
nk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4 Bool -> Bool -> Bool
&& Int
nk Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
6 = SWord 32
old SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32 -> GF28 -> SWord 32
subWordRcon SWord 32
prev GF28
0
           | Bool
True                      = SWord 32
old SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
prev

         subWordRcon :: SWord 32 -> GF28 -> SWord 32
         subWordRcon :: SWord 32 -> GF28 -> SWord 32
subWordRcon SWord 32
w GF28
rc = [GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes [GF28
a GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
`xor` GF28
rc, GF28
b, GF28
c, GF28
d]
            where [GF28
a, GF28
b, GF28
c, GF28
d] = (GF28 -> GF28) -> [GF28] -> [GF28]
forall a b. (a -> b) -> [a] -> [b]
map GF28 -> GF28
sbox ([GF28] -> [GF28]) -> [GF28] -> [GF28]
forall a b. (a -> b) -> a -> b
$ SWord 32 -> [GF28]
forall a. ByteConverter a => a -> [GF28]
toBytes SWord 32
w

-----------------------------------------------------------------------------
-- ** The S-box transformation
-----------------------------------------------------------------------------

-- | The values of the AES S-box table. Note that we describe the S-box programmatically
-- using the mathematical construction given in Section 5.1.1 of the standard. However,
-- the code-generation will turn this into a mere look-up table, as it is just a
-- constant table, all computation being done at \"compile-time\".
sboxTable :: [GF28]
sboxTable :: [GF28]
sboxTable = [GF28 -> GF28
xformByte (GF28 -> GF28
gf28Inverse GF28
b) | GF28
b <- [GF28
0 .. GF28
255]]
  where xformByte :: GF28 -> GF28
        xformByte :: GF28 -> GF28
xformByte GF28
b = (GF28 -> GF28 -> GF28) -> GF28 -> [GF28] -> GF28
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
xor GF28
0x63 [GF28
b GF28 -> Int -> GF28
forall a. Bits a => a -> Int -> a
`rotateR` Int
i | Int
i <- [Int
0, Int
4, Int
5, Int
6, Int
7]]

-- | The sbox transformation. We simply select from the sbox table. Note that we
-- are obliged to give a default value (here @0@) to be used if the index is out-of-bounds
-- as required by SBV's 'select' function. However, that will never happen since
-- the table has all 256 elements in it.
sbox :: GF28 -> GF28
sbox :: GF28 -> GF28
sbox = [GF28] -> GF28 -> GF28 -> GF28
forall b.
(Ord b, SymVal b, Num b) =>
[GF28] -> GF28 -> SBV b -> GF28
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [GF28]
sboxTable GF28
0

-----------------------------------------------------------------------------
-- ** The inverse S-box transformation
-----------------------------------------------------------------------------

-- | The values of the inverse S-box table. Again, the construction is programmatic.
unSBoxTable :: [GF28]
unSBoxTable :: [GF28]
unSBoxTable = [GF28 -> GF28
gf28Inverse (GF28 -> GF28
xformByte GF28
b) | GF28
b <- [GF28
0 .. GF28
255]]
  where xformByte :: GF28 -> GF28
        xformByte :: GF28 -> GF28
xformByte GF28
b = (GF28 -> GF28 -> GF28) -> GF28 -> [GF28] -> GF28
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
xor GF28
0x05 [GF28
b GF28 -> Int -> GF28
forall a. Bits a => a -> Int -> a
`rotateR` Int
i | Int
i <- [Int
2, Int
5, Int
7]]

-- | The inverse s-box transformation.
unSBox :: GF28 -> GF28
unSBox :: GF28 -> GF28
unSBox = [GF28] -> GF28 -> GF28 -> GF28
forall b.
(Ord b, SymVal b, Num b) =>
[GF28] -> GF28 -> SBV b -> GF28
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [GF28]
unSBoxTable GF28
0

-- | Prove that the 'sbox' and 'unSBox' are inverses. We have:
--
-- >>> prove sboxInverseCorrect
-- Q.E.D.
--
sboxInverseCorrect :: GF28 -> SBool
sboxInverseCorrect :: GF28 -> SBool
sboxInverseCorrect GF28
x = GF28 -> GF28
unSBox (GF28 -> GF28
sbox GF28
x) GF28 -> GF28 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== GF28
x SBool -> SBool -> SBool
.&& GF28 -> GF28
sbox (GF28 -> GF28
unSBox GF28
x) GF28 -> GF28 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== GF28
x

-----------------------------------------------------------------------------
-- ** AddRoundKey transformation
-----------------------------------------------------------------------------

-- | Adding the round-key to the current state. We simply exploit the fact
-- that addition is just xor in implementing this transformation.
addRoundKey :: Key -> State -> State
addRoundKey :: [SWord 32] -> [SWord 32] -> [SWord 32]
addRoundKey = (SWord 32 -> SWord 32 -> SWord 32)
-> [SWord 32] -> [SWord 32] -> [SWord 32]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
xor

-----------------------------------------------------------------------------
-- ** Tables for T-Box encryption
-----------------------------------------------------------------------------

-- | T-box table generation function for encryption
t0Func :: GF28 -> [GF28]
t0Func :: GF28 -> [GF28]
t0Func GF28
a = [GF28
s GF28 -> GF28 -> GF28
`gf28Mult` GF28
2, GF28
s, GF28
s, GF28
s GF28 -> GF28 -> GF28
`gf28Mult` GF28
3] where s :: GF28
s = GF28 -> GF28
sbox GF28
a

-- | First look-up table used in encryption
t0 :: GF28 -> SWord 32
t0 :: GF28 -> SWord 32
t0 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t0Table SWord 32
0 where t0Table :: [SWord 32]
t0Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
t0Func GF28
a)          | GF28
a <- [GF28
0..GF28
255]]

-- | Second look-up table used in encryption
t1 :: GF28 -> SWord 32
t1 :: GF28 -> SWord 32
t1 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t1Table SWord 32
0 where t1Table :: [SWord 32]
t1Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
t0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
1) | GF28
a <- [GF28
0..GF28
255]]

-- | Third look-up table used in encryption
t2 :: GF28 -> SWord 32
t2 :: GF28 -> SWord 32
t2 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t2Table SWord 32
0 where t2Table :: [SWord 32]
t2Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
t0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
2) | GF28
a <- [GF28
0..GF28
255]]

-- | Fourth look-up table used in encryption
t3 :: GF28 -> SWord 32
t3 :: GF28 -> SWord 32
t3 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t3Table SWord 32
0 where t3Table :: [SWord 32]
t3Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
t0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
3) | GF28
a <- [GF28
0..GF28
255]]

-----------------------------------------------------------------------------
-- ** Tables for T-Box decryption
-----------------------------------------------------------------------------

-- | T-box table generating function for decryption
u0Func :: GF28 -> [GF28]
u0Func :: GF28 -> [GF28]
u0Func GF28
a = [GF28
s GF28 -> GF28 -> GF28
`gf28Mult` GF28
0xE, GF28
s GF28 -> GF28 -> GF28
`gf28Mult` GF28
0x9, GF28
s GF28 -> GF28 -> GF28
`gf28Mult` GF28
0xD, GF28
s GF28 -> GF28 -> GF28
`gf28Mult` GF28
0xB] where s :: GF28
s = GF28 -> GF28
unSBox GF28
a

-- | First look-up table used in decryption
u0 :: GF28 -> SWord 32
u0 :: GF28 -> SWord 32
u0 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t0Table SWord 32
0 where t0Table :: [SWord 32]
t0Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
u0Func GF28
a)          | GF28
a <- [GF28
0..GF28
255]]

-- | Second look-up table used in decryption
u1 :: GF28 -> SWord 32
u1 :: GF28 -> SWord 32
u1 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t1Table SWord 32
0 where t1Table :: [SWord 32]
t1Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
u0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
1) | GF28
a <- [GF28
0..GF28
255]]

-- | Third look-up table used in decryption
u2 :: GF28 -> SWord 32
u2 :: GF28 -> SWord 32
u2 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t2Table SWord 32
0 where t2Table :: [SWord 32]
t2Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
u0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
2) | GF28
a <- [GF28
0..GF28
255]]

-- | Fourth look-up table used in decryption
u3 :: GF28 -> SWord 32
u3 :: GF28 -> SWord 32
u3 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t3Table SWord 32
0 where t3Table :: [SWord 32]
t3Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
u0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
3) | GF28
a <- [GF28
0..GF28
255]]

-----------------------------------------------------------------------------
-- ** AES rounds
-----------------------------------------------------------------------------

-- | Generic round function. Given the function to perform one round, a key-schedule,
-- and a starting state, it performs the AES rounds.
doRounds :: (Bool -> State -> Key -> State) -> KS -> State -> State
doRounds :: (Bool -> [SWord 32] -> [SWord 32] -> [SWord 32])
-> KS -> [SWord 32] -> [SWord 32]
doRounds Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
rnd ([SWord 32]
ikey, [[SWord 32]]
rkeys, [SWord 32]
fkey) [SWord 32]
sIn = Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
rnd Bool
True ([[SWord 32]] -> [SWord 32]
forall a. HasCallStack => [a] -> a
last [[SWord 32]]
rs) [SWord 32]
fkey
  where s0 :: [SWord 32]
s0 = [SWord 32]
ikey [SWord 32] -> [SWord 32] -> [SWord 32]
`addRoundKey` [SWord 32]
sIn
        rs :: [[SWord 32]]
rs = [SWord 32]
s0 [SWord 32] -> [[SWord 32]] -> [[SWord 32]]
forall a. a -> [a] -> [a]
: [Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
rnd Bool
False [SWord 32]
s [SWord 32]
k | [SWord 32]
s <- [[SWord 32]]
rs | [SWord 32]
k <- [[SWord 32]]
rkeys ]

-- | One encryption round. The first argument indicates whether this is the final round
-- or not, in which case the construction is slightly different.
aesRound :: Bool -> State -> Key -> State
aesRound :: Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
aesRound Bool
isFinal [SWord 32]
s [SWord 32]
key = [SWord 32]
d [SWord 32] -> [SWord 32] -> [SWord 32]
`addRoundKey` [SWord 32]
key
  where d :: [SWord 32]
d = (Int -> SWord 32) -> [Int] -> [SWord 32]
forall a b. (a -> b) -> [a] -> [b]
map (Bool -> Int -> SWord 32
f Bool
isFinal) [Int
0..Int
3]
        a :: [[GF28]]
a = (SWord 32 -> [GF28]) -> [SWord 32] -> [[GF28]]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> [GF28]
forall a. ByteConverter a => a -> [GF28]
toBytes [SWord 32]
s
        f :: Bool -> Int -> SWord 32
f Bool
True Int
j = [GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes [ GF28 -> GF28
sbox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
0) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
0)
                             , GF28 -> GF28
sbox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
1)
                             , GF28 -> GF28
sbox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
2)
                             , GF28 -> GF28
sbox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
3) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
3)
                             ]
        f Bool
False Int
j = SWord 32
e0 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e1 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e2 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e3
              where e0 :: SWord 32
e0 = GF28 -> SWord 32
t0 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
0) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
0)
                    e1 :: SWord 32
e1 = GF28 -> SWord 32
t1 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
1)
                    e2 :: SWord 32
e2 = GF28 -> SWord 32
t2 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
2)
                    e3 :: SWord 32
e3 = GF28 -> SWord 32
t3 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
3) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
3)

-- | One decryption round. Similar to the encryption round, the first argument
-- indicates whether this is the final round or not.
aesInvRound :: Bool -> State -> Key -> State
aesInvRound :: Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
aesInvRound Bool
isFinal [SWord 32]
s [SWord 32]
key = [SWord 32]
d [SWord 32] -> [SWord 32] -> [SWord 32]
`addRoundKey` [SWord 32]
key
  where d :: [SWord 32]
d = (Int -> SWord 32) -> [Int] -> [SWord 32]
forall a b. (a -> b) -> [a] -> [b]
map (Bool -> Int -> SWord 32
f Bool
isFinal) [Int
0..Int
3]
        a :: [[GF28]]
a = (SWord 32 -> [GF28]) -> [SWord 32] -> [[GF28]]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> [GF28]
forall a. ByteConverter a => a -> [GF28]
toBytes [SWord 32]
s
        f :: Bool -> Int -> SWord 32
f Bool
True Int
j = [GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes [ GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
0) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
0)
                             , GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
3) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
1)
                             , GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
2)
                             , GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
3)
                             ]
        f Bool
False Int
j = SWord 32
e0 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e1 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e2 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e3
              where e0 :: SWord 32
e0 = GF28 -> SWord 32
u0 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
0) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
0)
                    e1 :: SWord 32
e1 = GF28 -> SWord 32
u1 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
3) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
1)
                    e2 :: SWord 32
e2 = GF28 -> SWord 32
u2 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
2)
                    e3 :: SWord 32
e3 = GF28 -> SWord 32
u3 ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
3)

-----------------------------------------------------------------------------
-- * AES API
-----------------------------------------------------------------------------

-- | Key schedule. Given a 128, 192, or 256 bit key, expand it to get key-schedules
-- for encryption and decryption. The key is given as a sequence of 32-bit words.
-- (4 elements for 128-bits, 6 for 192, and 8 for 256.) Compare this function to 'aesInvKeySchedule'
-- which can calculate the key-expansion for decryption on the fly, as opposed to calculating
-- the forward key-expansion first.
aesKeySchedule :: Key -> (KS, KS)
aesKeySchedule :: [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
key
  | Int
nk Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
4, Int
6, Int
8]
  = (KS
encKS, KS
decKS)
  | Bool
True
  = [Char] -> (KS, KS)
forall a. HasCallStack => [Char] -> a
error [Char]
"aesKeySchedule: Invalid key size"
  where nk :: Int
nk = [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
key
        nr :: Int
nr = Int
nk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
6
        encKS :: KS
encKS@([SWord 32]
f, [[SWord 32]]
m, [SWord 32]
l) = ([[SWord 32]] -> [SWord 32]
forall a. HasCallStack => [a] -> a
head [[SWord 32]]
rKeys, Int -> [[SWord 32]] -> [[SWord 32]]
forall a. Int -> [a] -> [a]
take (Int
nrInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ([[SWord 32]] -> [[SWord 32]]
forall a. HasCallStack => [a] -> [a]
tail [[SWord 32]]
rKeys), [[SWord 32]]
rKeys [[SWord 32]] -> Int -> [SWord 32]
forall a. HasCallStack => [a] -> Int -> a
!! Int
nr)
        decKS :: KS
decKS = ([SWord 32]
l, ([SWord 32] -> [SWord 32]) -> [[SWord 32]] -> [[SWord 32]]
forall a b. (a -> b) -> [a] -> [b]
map [SWord 32] -> [SWord 32]
invMixColumns ([[SWord 32]] -> [[SWord 32]]
forall a. [a] -> [a]
reverse [[SWord 32]]
m), [SWord 32]
f)
        rKeys :: [[SWord 32]]
rKeys = Int -> [SWord 32] -> [[SWord 32]]
keyExpansion Int
nk [SWord 32]
key

-- | Block encryption. The first argument is the plain-text, which must have
-- precisely 4 elements, for a total of 128-bits of input. The second
-- argument is the key-schedule to be used, obtained by a call to 'aesKeySchedule'.
-- The output will always have 4 32-bit words, which is the cipher-text.
aesEncrypt :: [SWord 32] -> KS -> [SWord 32]
aesEncrypt :: [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
pt KS
encKS
  | [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
pt Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4
  = (Bool -> [SWord 32] -> [SWord 32] -> [SWord 32])
-> KS -> [SWord 32] -> [SWord 32]
doRounds Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
aesRound KS
encKS [SWord 32]
pt
  | Bool
True
  = [Char] -> [SWord 32]
forall a. HasCallStack => [Char] -> a
error [Char]
"aesEncrypt: Invalid plain-text size"

-- | Block decryption. The arguments are the same as in 'aesEncrypt', except
-- the first argument is the cipher-text and the output is the corresponding
-- plain-text.
aesDecrypt :: [SWord 32] -> KS -> [SWord 32]
aesDecrypt :: [SWord 32] -> KS -> [SWord 32]
aesDecrypt [SWord 32]
ct KS
decKS
  | [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
ct Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4
  = (Bool -> [SWord 32] -> [SWord 32] -> [SWord 32])
-> KS -> [SWord 32] -> [SWord 32]
doRounds Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
aesInvRound KS
decKS [SWord 32]
ct
  | Bool
True
  = [Char] -> [SWord 32]
forall a. HasCallStack => [Char] -> a
error [Char]
"aesDecrypt: Invalid cipher-text size"

-----------------------------------------------------------------------------
-- * On-the-fly decryption
-- ${ontheflyintro}
-----------------------------------------------------------------------------
{- $ontheflyintro
   While regular encryption can be fused with key-generation, the standard method of AES
   decryption has to perform the key-expansion before decryption starts. This can be undesirable
   as it necessarily serializes the action of key-expansion before decryption. An
   alternative is to do on-the-fly decryption: We can expand the key in reverse, and thus
   need not save the key-schedule. One downside of this approach, however, is that we need
   to keep the "unwound" key: That is, instead of the common key used for encryption and
   decryption, we need to hold on to the final value of key-expansion, so it can be run
   in reverse. In this section, we implement on-the-fly decryption using this idea.
-}

-- | Inverse key expansion. Starting from the final round key, unwinds key generation operation
-- to construct keys for the previous rounds. Used in on-the-fly decryption.
invKeyExpansion :: Int -> Key -> [Key]
invKeyExpansion :: Int -> [SWord 32] -> [[SWord 32]]
invKeyExpansion Int
nk [SWord 32]
rkey = ([SWord 32] -> [SWord 32]) -> [[SWord 32]] -> [[SWord 32]]
forall a b. (a -> b) -> [a] -> [b]
map [SWord 32] -> [SWord 32]
forall a. [a] -> [a]
reverse ([SWord 32] -> [[SWord 32]]
forall a. [a] -> [[a]]
chop4 [SWord 32]
keys)
   where keys :: [SWord 32]
         keys :: [SWord 32]
keys = [SWord 32]
rkey [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [Int -> SWord 32 -> SWord 32 -> SWord 32
invNextWord Int
i SWord 32
prev SWord 32
old | Int
i <- [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int
0 .. Int
remaining Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] | SWord 32
prev <- Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
drop Int
1 [SWord 32]
keys | SWord 32
old <- [SWord 32]
keys]

         totalWords :: Int
totalWords = Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
nk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
6 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
         remaining :: Int
remaining  = Int
totalWords Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
nk

         invNextWord :: Int -> SWord 32 -> SWord 32 -> SWord 32
         invNextWord :: Int -> SWord 32 -> SWord 32 -> SWord 32
invNextWord Int
i SWord 32
prev SWord 32
old
           | Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
nk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0           = SWord 32
old SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32 -> GF28 -> SWord 32
subWordRcon (SWord 32
prev SWord 32 -> Int -> SWord 32
forall a. Bits a => a -> Int -> a
`rotateL` Int
8) ([GF28]
roundConstants [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
nk))
           | Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
nk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4 Bool -> Bool -> Bool
&& Int
nk Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
6 = SWord 32
old SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32 -> GF28 -> SWord 32
subWordRcon SWord 32
prev GF28
0
           | Bool
True                      = SWord 32
old SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
prev

         subWordRcon :: SWord 32 -> GF28 -> SWord 32
         subWordRcon :: SWord 32 -> GF28 -> SWord 32
subWordRcon SWord 32
w GF28
rc = [GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes [GF28
a GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
`xor` GF28
rc, GF28
b, GF28
c, GF28
d]
            where [GF28
a, GF28
b, GF28
c, GF28
d] = (GF28 -> GF28) -> [GF28] -> [GF28]
forall a b. (a -> b) -> [a] -> [b]
map GF28 -> GF28
sbox ([GF28] -> [GF28]) -> [GF28] -> [GF28]
forall a b. (a -> b) -> a -> b
$ SWord 32 -> [GF28]
forall a. ByteConverter a => a -> [GF28]
toBytes SWord 32
w

-- | AES inverse key schedule. Starting from the last-round key, construct the sequence of keys
-- that can be used for doing on-the-fly decryption. Compare this function to 'aesKeySchedule' which
-- returns both encryption and decryption schedules: In this case, we don't calculate the encryption
-- sequence, hence we can fuse this function with the decryption operation.
aesInvKeySchedule :: Key -> KS
aesInvKeySchedule :: [SWord 32] -> KS
aesInvKeySchedule [SWord 32]
key
  | Int
nk Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
4, Int
6, Int
8]
  = KS
decKS
  | Bool
True
  = [Char] -> KS
forall a. HasCallStack => [Char] -> a
error [Char]
"aesInvKeySchedule: Invalid key size"
  where nk :: Int
nk = [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
key
        nr :: Int
nr = Int
nk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
6
        decKS :: KS
decKS = ([[SWord 32]] -> [SWord 32]
forall a. HasCallStack => [a] -> a
head [[SWord 32]]
rKeys, Int -> [[SWord 32]] -> [[SWord 32]]
forall a. Int -> [a] -> [a]
take (Int
nrInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ([[SWord 32]] -> [[SWord 32]]
forall a. HasCallStack => [a] -> [a]
tail [[SWord 32]]
rKeys), [[SWord 32]]
rKeys [[SWord 32]] -> Int -> [SWord 32]
forall a. HasCallStack => [a] -> Int -> a
!! Int
nr)
        rKeys :: [[SWord 32]]
rKeys = Int -> [SWord 32] -> [[SWord 32]]
invKeyExpansion Int
nk [SWord 32]
key

-- | Block decryption, starting from the unwound key. That is, start from the final key.
-- Also; we don't use the T-box implementation. Just pure AES inverse cipher.
aesDecryptUnwoundKey :: [SWord 32] -> KS -> [SWord 32]
aesDecryptUnwoundKey :: [SWord 32] -> KS -> [SWord 32]
aesDecryptUnwoundKey [SWord 32]
ct KS
decKS
  | [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
ct Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4
  = (Bool -> [SWord 32] -> [SWord 32] -> [SWord 32])
-> KS -> [SWord 32] -> [SWord 32]
doRounds Bool -> [SWord 32] -> [SWord 32] -> [SWord 32]
forall {a}.
ByteConverter a =>
Bool -> [a] -> [SWord 32] -> [SWord 32]
aesInvRoundRegular KS
decKS [SWord 32]
ct
  | Bool
True
  = [Char] -> [SWord 32]
forall a. HasCallStack => [Char] -> a
error [Char]
"aesDecrypt: Invalid cipher-text size"
  where aesInvRoundRegular :: Bool -> [a] -> [SWord 32] -> [SWord 32]
aesInvRoundRegular Bool
isFinal [a]
s [SWord 32]
key = [SWord 32]
u
          where u :: State
                u :: [SWord 32]
u = (Int -> SWord 32) -> [Int] -> [SWord 32]
forall a b. (a -> b) -> [a] -> [b]
map (Bool -> Int -> SWord 32
f Bool
isFinal) [Int
0 .. Int
3]
                  where a :: [[GF28]]
a   = (a -> [GF28]) -> [a] -> [[GF28]]
forall a b. (a -> b) -> [a] -> [b]
map a -> [GF28]
forall a. ByteConverter a => a -> [GF28]
toBytes [a]
s
                        kbs :: [[GF28]]
kbs = (SWord 32 -> [GF28]) -> [SWord 32] -> [[GF28]]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> [GF28]
forall a. ByteConverter a => a -> [GF28]
toBytes [SWord 32]
key
                        f :: Bool -> Int -> SWord 32
f Bool
True Int
j = [GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes [ GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
0) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
0)
                                             , GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
3) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
1)
                                             , GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
2)
                                             , GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
3)
                                             ] SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` ([SWord 32]
key [SWord 32] -> Int -> SWord 32
forall a. HasCallStack => [a] -> Int -> a
!! Int
j)
                        f Bool
False Int
j = SWord 32
e0 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e1 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e2 SWord 32 -> SWord 32 -> SWord 32
forall a. Bits a => a -> a -> a
`xor` SWord 32
e3
                              where e0 :: SWord 32
e0 = GF28 -> SWord 32
otfU0 (GF28 -> SWord 32) -> GF28 -> SWord 32
forall a b. (a -> b) -> a -> b
$ GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
0) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
0) GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
`xor` ([[GF28]]
kbs [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! Int
j [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
0)
                                    e1 :: SWord 32
e1 = GF28 -> SWord 32
otfU1 (GF28 -> SWord 32) -> GF28 -> SWord 32
forall a b. (a -> b) -> a -> b
$ GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
3) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
1) GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
`xor` ([[GF28]]
kbs [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! Int
j [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
1)
                                    e2 :: SWord 32
e2 = GF28 -> SWord 32
otfU2 (GF28 -> SWord 32) -> GF28 -> SWord 32
forall a b. (a -> b) -> a -> b
$ GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
2) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
2) GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
`xor` ([[GF28]]
kbs [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! Int
j [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
2)
                                    e3 :: SWord 32
e3 = GF28 -> SWord 32
otfU3 (GF28 -> SWord 32) -> GF28 -> SWord 32
forall a b. (a -> b) -> a -> b
$ GF28 -> GF28
unSBox ([[GF28]]
a [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! ((Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4) [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
3) GF28 -> GF28 -> GF28
forall a. Bits a => a -> a -> a
`xor` ([[GF28]]
kbs [[GF28]] -> Int -> [GF28]
forall a. HasCallStack => [a] -> Int -> a
!! Int
j [GF28] -> Int -> GF28
forall a. HasCallStack => [a] -> Int -> a
!! Int
3)

                otfU0Func :: GF28 -> [GF28]
otfU0Func GF28
b = [GF28
b GF28 -> GF28 -> GF28
`gf28Mult` GF28
0xE, GF28
b GF28 -> GF28 -> GF28
`gf28Mult` GF28
0x9, GF28
b GF28 -> GF28 -> GF28
`gf28Mult` GF28
0xD, GF28
b GF28 -> GF28 -> GF28
`gf28Mult` GF28
0xB]
                otfU0 :: GF28 -> SWord 32
otfU0 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t0Table SWord 32
0 where t0Table :: [SWord 32]
t0Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
otfU0Func GF28
a)          | GF28
a <- [GF28
0..GF28
255]]
                otfU1 :: GF28 -> SWord 32
otfU1 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t1Table SWord 32
0 where t1Table :: [SWord 32]
t1Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
otfU0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
1) | GF28
a <- [GF28
0..GF28
255]]
                otfU2 :: GF28 -> SWord 32
otfU2 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t2Table SWord 32
0 where t2Table :: [SWord 32]
t2Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
otfU0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
2) | GF28
a <- [GF28
0..GF28
255]]
                otfU3 :: GF28 -> SWord 32
otfU3 = [SWord 32] -> SWord 32 -> GF28 -> SWord 32
forall b.
(Ord b, SymVal b, Num b) =>
[SWord 32] -> SWord 32 -> SBV b -> SWord 32
forall a b.
(Mergeable a, Ord b, SymVal b, Num b) =>
[a] -> a -> SBV b -> a
select [SWord 32]
t3Table SWord 32
0 where t3Table :: [SWord 32]
t3Table = [[GF28] -> SWord 32
forall a. ByteConverter a => [GF28] -> a
fromBytes (GF28 -> [GF28]
otfU0Func GF28
a [GF28] -> Int -> [GF28]
`rotR` Int
3) | GF28
a <- [GF28
0..GF28
255]]

-----------------------------------------------------------------------------
-- * Test vectors
-----------------------------------------------------------------------------

-- | Common plain text for test vectors
commonPT :: [SWord 32]
commonPT :: [SWord 32]
commonPT = [SWord 32
0x00112233, SWord 32
0x44556677, SWord 32
0x8899aabb, SWord 32
0xccddeeff]

-- | Key for 128-bit encryption test
aes128Key :: Key
aes128Key :: [SWord 32]
aes128Key = [SWord 32
0x00010203, SWord 32
0x04050607, SWord 32
0x08090a0b, SWord 32
0x0c0d0e0f]

-- | Key for 192-bit encryption test
aes192Key :: Key
aes192Key :: [SWord 32]
aes192Key = [SWord 32]
aes128Key [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [SWord 32
0x10111213, SWord 32
0x14151617]

-- | Key for 256-bit encryption test
aes256Key :: Key
aes256Key :: [SWord 32]
aes256Key = [SWord 32]
aes192Key [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [SWord 32
0x18191a1b, SWord 32
0x1c1d1e1f]

-- | Expected cipher-text for 128-bit encryption
aes128CT :: [SWord 32]
aes128CT :: [SWord 32]
aes128CT = [SWord 32
0x69c4e0d8, SWord 32
0x6a7b0430, SWord 32
0xd8cdb780, SWord 32
0x70b4c55a]

-- | Expected cipher-text for 192-bit encryption
aes192CT :: [SWord 32]
aes192CT :: [SWord 32]
aes192CT = [SWord 32
0xdda97ca4, SWord 32
0x864cdfe0, SWord 32
0x6eaf70a0, SWord 32
0xec0d7191]

-- | Expected cipher-text for 256-bit encryption
aes256CT :: [SWord 32]
aes256CT :: [SWord 32]
aes256CT = [SWord 32
0x8ea2b7ca, SWord 32
0x516745bf, SWord 32
0xeafc4990, SWord 32
0x4b496089]

-- | Calculate the 128-bit final-round key from on-the-fly decryption key schedule
aes128InvKey :: Key
aes128InvKey :: [SWord 32]
aes128InvKey = [SWord 32] -> [SWord 32]
extractFinalKey [SWord 32]
aes128Key

-- | Calculate the 192-bit final-round key from on-the-fly decryption key schedule
aes192InvKey :: Key
aes192InvKey :: [SWord 32]
aes192InvKey = [SWord 32] -> [SWord 32]
extractFinalKey [SWord 32]
aes192Key

-- | Calculate the 192-bit final-round key from on-the-fly decryption key schedule. Compare this
-- to 'aes192InvKey': Typically we just need the final 6-blocks, but it is advantageous to have
-- the entire last 8-blocks even for 192-bit keys. That is,  e store the final 256-bits of key-expansion
-- for speed purposes for both 192 and 256 bit versions. (But only the final 128 bits for the 128-bit version.)
aes192InvKeyExtended :: Key
aes192InvKeyExtended :: [SWord 32]
aes192InvKeyExtended = [SWord 32] -> [SWord 32]
extractFinalKeyExtended [SWord 32]
aes192Key

-- | Calculate the 256-bit final-round key from on-the-fly decryption key schedule
aes256InvKey :: Key
aes256InvKey :: [SWord 32]
aes256InvKey = [SWord 32] -> [SWord 32]
extractFinalKey [SWord 32]
aes256Key

-- | Extract the final key for on-the-fly decryption. This will extract exactly the number of blocks we need.
extractFinalKey :: [SWord 32] -> [SWord 32]
extractFinalKey :: [SWord 32] -> [SWord 32]
extractFinalKey [SWord 32]
initKey = Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
take Int
nk ([SWord 32] -> [SWord 32]
extractFinalKeyExtended [SWord 32]
initKey)
  where nk :: Int
nk = [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
initKey

-- | Extract the extended key for on-the-fly decryption. This will extract 4-blocks for 128-bit decryption,
-- but 256 bit for both 192 and 256-bit variants
extractFinalKeyExtended :: [SWord 32] -> [SWord 32]
extractFinalKeyExtended :: [SWord 32] -> [SWord 32]
extractFinalKeyExtended [SWord 32]
initKey = Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
take Int
feed (([SWord 32] -> [SWord 32]) -> [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [SWord 32] -> [SWord 32]
forall a. [a] -> [a]
reverse ([SWord 32] -> [[SWord 32]]
forall a. [a] -> [[a]]
chop4 (Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
take Int
feed [SWord 32]
roundKeys)))
  where nk :: Int
nk             = [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
initKey
        feed :: Int
feed | Int
nk Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
4 = Int
4
             | Bool
True    = Int
8

        (([SWord 32]
f, [[SWord 32]]
m, [SWord 32]
l), KS
_) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
initKey
        roundKeys :: [SWord 32]
roundKeys      = [SWord 32]
l [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[SWord 32]] -> [[SWord 32]]
forall a. [a] -> [a]
reverse [[SWord 32]]
m) [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [SWord 32]
f

-----------------------------------------------------------------------------
-- ** 128-bit enc/dec test
-----------------------------------------------------------------------------

-- | 128-bit encryption test, from Appendix C.1 of the AES standard:
--
-- >>> map hex8 t128Enc
-- ["69c4e0d8","6a7b0430","d8cdb780","70b4c55a"]
t128Enc :: [SWord 32]
t128Enc :: [SWord 32]
t128Enc = [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
commonPT KS
ks
  where (KS
ks, KS
_) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
aes128Key

-- | 128-bit decryption test, from Appendix C.1 of the AES standard:
--
-- >>> map hex8 t128Dec
-- ["00112233","44556677","8899aabb","ccddeeff"]
t128Dec :: [SWord 32]
t128Dec :: [SWord 32]
t128Dec = [SWord 32] -> KS -> [SWord 32]
aesDecrypt [SWord 32]
aes128CT KS
ks
  where (KS
_, KS
ks) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
aes128Key

-----------------------------------------------------------------------------
-- ** 192-bit enc/dec test
-----------------------------------------------------------------------------

-- | 192-bit encryption test, from Appendix C.2 of the AES standard:
--
-- >>> map hex8 t192Enc
-- ["dda97ca4","864cdfe0","6eaf70a0","ec0d7191"]
t192Enc :: [SWord 32]
t192Enc :: [SWord 32]
t192Enc = [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
commonPT KS
ks
  where (KS
ks, KS
_) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
aes192Key

-- | 192-bit decryption test, from Appendix C.2 of the AES standard:
--
-- >>> map hex8 t192Dec
-- ["00112233","44556677","8899aabb","ccddeeff"]
--
t192Dec :: [SWord 32]
t192Dec :: [SWord 32]
t192Dec = [SWord 32] -> KS -> [SWord 32]
aesDecrypt [SWord 32]
aes192CT KS
ks
  where (KS
_, KS
ks) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
aes192Key

-----------------------------------------------------------------------------
-- ** 256-bit enc/dec test
-----------------------------------------------------------------------------

-- | 256-bit encryption, from Appendix C.3 of the AES standard:
--
-- >>> map hex8 t256Enc
-- ["8ea2b7ca","516745bf","eafc4990","4b496089"]
t256Enc :: [SWord 32]
t256Enc :: [SWord 32]
t256Enc = [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
commonPT KS
ks
  where (KS
ks, KS
_) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
aes256Key

-- | 256-bit decryption, from Appendix C.3 of the AES standard:
--
-- >>> map hex8 t256Dec
-- ["00112233","44556677","8899aabb","ccddeeff"]
t256Dec :: [SWord 32]
t256Dec :: [SWord 32]
t256Dec = [SWord 32] -> KS -> [SWord 32]
aesDecrypt [SWord 32]
aes256CT KS
ks
  where (KS
_, KS
ks) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
aes256Key

-- | Various tests for round-trip properties. We have:
--
-- >>> runAESTests False
-- GOOD: Key generation AES128
-- GOOD: Key generation AES192
-- GOOD: Key generation AES256
-- GOOD: Encryption     AES128
-- GOOD: Decryption     AES128
-- GOOD: Decryption-OTF AES128
-- GOOD: Encryption     AES192
-- GOOD: Decryption     AES192
-- GOOD: Decryption-OTF AES192
-- GOOD: Encryption     AES256
-- GOOD: Decryption     AES256
-- GOOD: Decryption-OTF AES256
runAESTests :: Bool -> IO ()
runAESTests :: Bool -> IO ()
runAESTests Bool
runQC = do
                 IO ()
testInvKeyExpansion

                 [Char] -> [SWord 32] -> [SWord 32] -> [SWord 32] -> IO ()
check [Char]
"AES128" [SWord 32]
aes128Key [SWord 32]
aes128InvKey [SWord 32]
aes128CT
                 [Char] -> [SWord 32] -> [SWord 32] -> [SWord 32] -> IO ()
check [Char]
"AES192" [SWord 32]
aes192Key [SWord 32]
aes192InvKey [SWord 32]
aes192CT
                 [Char] -> [SWord 32] -> [SWord 32] -> [SWord 32] -> IO ()
check [Char]
"AES256" [SWord 32]
aes256Key [SWord 32]
aes256InvKey [SWord 32]
aes256CT

                 -- Quick-check tests are rather slow. So only run when requested.
                 Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
runQC (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                   [Char] -> IO ()
putStrLn [Char]
"Quick-check AES128 roundtrip" IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ((SWord32, SWord32, SWord32, SWord32)
 -> (SWord32, SWord32, SWord32, SWord32) -> SBool)
-> IO ()
forall prop. Testable prop => prop -> IO ()
quickCheck (SWord32, SWord32, SWord32, SWord32)
-> (SWord32, SWord32, SWord32, SWord32) -> SBool
roundTrip128
                   [Char] -> IO ()
putStrLn [Char]
"Quick-check AES192 roundtrip" IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ((SWord32, SWord32, SWord32, SWord32)
 -> (SWord32, SWord32, SWord32, SWord32, SWord32, SWord32) -> SBool)
-> IO ()
forall prop. Testable prop => prop -> IO ()
quickCheck (SWord32, SWord32, SWord32, SWord32)
-> (SWord32, SWord32, SWord32, SWord32, SWord32, SWord32) -> SBool
roundTrip192
                   [Char] -> IO ()
putStrLn [Char]
"Quick-check AES256 roundtrip" IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ((SWord32, SWord32, SWord32, SWord32)
 -> (SWord32, SWord32, SWord32, SWord32, SWord32, SWord32, SWord32,
     SWord32)
 -> SBool)
-> IO ()
forall prop. Testable prop => prop -> IO ()
quickCheck (SWord32, SWord32, SWord32, SWord32)
-> (SWord32, SWord32, SWord32, SWord32, SWord32, SWord32, SWord32,
    SWord32)
-> SBool
roundTrip256

  where check :: String -> Key -> Key -> [SWord 32] -> IO ()
        check :: [Char] -> [SWord 32] -> [SWord 32] -> [SWord 32] -> IO ()
check [Char]
what [SWord 32]
key [SWord 32]
invKey [SWord 32]
ctExpected = do [Char] -> [SWord 32] -> [SWord 32] -> IO ()
eq ([Char]
"Encryption     " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
what) [SWord 32]
ctExpected [SWord 32]
ctGot
                                              [Char] -> [SWord 32] -> [SWord 32] -> IO ()
eq ([Char]
"Decryption     " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
what) [SWord 32]
commonPT   [SWord 32]
ptGot
                                              [Char] -> [SWord 32] -> [SWord 32] -> IO ()
eq ([Char]
"Decryption-OTF " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
what) [SWord 32]
commonPT   [SWord 32]
ptGotInv
           where (KS
encKS, KS
decKS) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
key
                 ctGot :: [SWord 32]
ctGot          = [SWord 32] -> KS -> [SWord 32]
aesEncrypt           [SWord 32]
commonPT   KS
encKS
                 ptGot :: [SWord 32]
ptGot          = [SWord 32] -> KS -> [SWord 32]
aesDecrypt           [SWord 32]
ctExpected KS
decKS
                 ptGotInv :: [SWord 32]
ptGotInv       = [SWord 32] -> KS -> [SWord 32]
aesDecryptUnwoundKey [SWord 32]
ctExpected ([SWord 32] -> KS
aesInvKeySchedule [SWord 32]
invKey)

                 eq :: [Char] -> [SWord 32] -> [SWord 32] -> IO ()
eq [Char]
tag [SWord 32]
expected [SWord 32]
got
                   | [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
expected Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
got
                   = [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines [ [Char]
"BAD!: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
tag
                                     , [Char]
"Comparing different sized lists:"
                                     , [Char]
"Expected: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [SWord 32] -> [Char]
forall a. Show a => a -> [Char]
show [SWord 32]
expected
                                     , [Char]
"Got     : " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [SWord 32] -> [Char]
forall a. Show a => a -> [Char]
show [SWord 32]
got
                                     ]
                   | (SWord 32 -> Integer) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> Integer
extract [SWord 32]
expected [Integer] -> [Integer] -> Bool
forall a. Eq a => a -> a -> Bool
== (SWord 32 -> Integer) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> Integer
extract [SWord 32]
got
                   = [Char] -> IO ()
putStrLn ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char]
"GOOD: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
tag
                   | Bool
True
                   = [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines [ [Char]
"BAD!: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
tag
                                     , [Char]
"Expected: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [[Char]] -> [Char]
unwords ((SWord 32 -> [Char]) -> [SWord 32] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> [Char]
forall a. (SymVal a, Show a, Integral a) => SBV a -> [Char]
hex8 [SWord 32]
expected)
                                     , [Char]
"Got     : " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [[Char]] -> [Char]
unwords ((SWord 32 -> [Char]) -> [SWord 32] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> [Char]
forall a. (SymVal a, Show a, Integral a) => SBV a -> [Char]
hex8 [SWord 32]
got)
                                     ]
                  where extract :: SWord 32 -> Integer
                        extract :: SWord 32 -> Integer
extract = WordN 32 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordN 32 -> Integer)
-> (SWord 32 -> WordN 32) -> SWord 32 -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (WordN 32) -> WordN 32
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (WordN 32) -> WordN 32)
-> (SWord 32 -> Maybe (WordN 32)) -> SWord 32 -> WordN 32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 32 -> Maybe (WordN 32)
forall a. SymVal a => SBV a -> Maybe a
unliteral

        testInvKeyExpansion :: IO ()
        testInvKeyExpansion :: IO ()
testInvKeyExpansion = do [Char] -> [SWord 32] -> IO ()
goTestInvKey [Char]
"128" [SWord 32]
aes128Key
                                 [Char] -> [SWord 32] -> IO ()
goTestInvKey [Char]
"192" [SWord 32]
aes192Key
                                 [Char] -> [SWord 32] -> IO ()
goTestInvKey [Char]
"256" [SWord 32]
aes256Key
        goTestInvKey :: [Char] -> [SWord 32] -> IO ()
goTestInvKey [Char]
what [SWord 32]
k = do
          let nk :: Int
nk = [SWord 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SWord 32]
k
              nr :: Int
nr = Int
nk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
6

              feed :: Int
feed = case Int
nk of
                       Int
4 -> Int
4
                       Int
_ -> Int
8

              (([SWord 32]
f, [[SWord 32]]
m, [SWord 32]
l), KS
_) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
k
              required :: [SWord 32]
required       = [SWord 32]
l [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[SWord 32]] -> [[SWord 32]]
forall a. [a] -> [a]
reverse [[SWord 32]]
m) [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [SWord 32]
f
              invKeySchedule :: [[SWord 32]]
invKeySchedule = Int -> [[SWord 32]] -> [[SWord 32]]
forall a. Int -> [a] -> [a]
take (Int
nrInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) ([[SWord 32]] -> [[SWord 32]]) -> [[SWord 32]] -> [[SWord 32]]
forall a b. (a -> b) -> a -> b
$ Int -> [SWord 32] -> [[SWord 32]]
invKeyExpansion Int
nk (Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
take Int
nk (([SWord 32] -> [SWord 32]) -> [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [SWord 32] -> [SWord 32]
forall a. [a] -> [a]
reverse ([SWord 32] -> [[SWord 32]]
forall a. [a] -> [[a]]
chop4 (Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
take Int
feed [SWord 32]
required))))
              obtained :: [SWord 32]
obtained       = [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SWord 32]]
invKeySchedule

              expected :: [WordN 32]
expected = (SWord 32 -> WordN 32) -> [SWord 32] -> [WordN 32]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe (WordN 32) -> WordN 32
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (WordN 32) -> WordN 32)
-> (SWord 32 -> Maybe (WordN 32)) -> SWord 32 -> WordN 32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 32 -> Maybe (WordN 32)
forall a. SymVal a => SBV a -> Maybe a
unliteral) [SWord 32]
required
              result :: [WordN 32]
result   = (SWord 32 -> WordN 32) -> [SWord 32] -> [WordN 32]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe (WordN 32) -> WordN 32
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (WordN 32) -> WordN 32)
-> (SWord 32 -> Maybe (WordN 32)) -> SWord 32 -> WordN 32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 32 -> Maybe (WordN 32)
forall a. SymVal a => SBV a -> Maybe a
unliteral) [SWord 32]
obtained

              sh :: a -> [WordN 32] -> [WordN 32] -> [Char]
sh a
i [WordN 32]
a [WordN 32]
b
               | [WordN 32]
a [WordN 32] -> [WordN 32] -> Bool
forall a. Eq a => a -> a -> Bool
== [WordN 32]
b = [Char]
pad [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Show a => a -> [Char]
show a
i [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [WordN 32] -> [Char]
disp [WordN 32]
a
               | Bool
True   = [Char]
pad [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Show a => a -> [Char]
show a
i [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [WordN 32] -> [Char]
disp [WordN 32]
a [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" |vs| " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [WordN 32] -> [Char]
disp [WordN 32]
b
               where pad :: [Char]
pad = if a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
10 then [Char]
" " else [Char]
""

              disp :: [WordN 32] -> [Char]
disp = [[Char]] -> [Char]
unwords ([[Char]] -> [Char])
-> ([WordN 32] -> [[Char]]) -> [WordN 32] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (WordN 32 -> [Char]) -> [WordN 32] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (SWord 32 -> [Char]
forall a. (SymVal a, Show a, Integral a) => SBV a -> [Char]
hex8 (SWord 32 -> [Char])
-> (WordN 32 -> SWord 32) -> WordN 32 -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordN 32 -> SWord 32
forall a. SymVal a => a -> SBV a
literal)

              lexpected :: Int
lexpected = [WordN 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WordN 32]
expected
              lresult :: Int
lresult   = [WordN 32] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WordN 32]
result

          Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
lexpected Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
lresult) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
             [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char]
what [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
": BAD! Mismatching lengths: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Int
lexpected, Int
lresult)

          let debugging :: Bool
debugging = Bool
False

          if [WordN 32]
expected [WordN 32] -> [WordN 32] -> Bool
forall a. Eq a => a -> a -> Bool
== [WordN 32]
result
             then if Bool
debugging
                     then [Char] -> IO ()
putStrLn ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ ([Char]
"Size " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
what [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
": Good") [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: (Int -> [WordN 32] -> [WordN 32] -> [Char])
-> [Int] -> [[WordN 32]] -> [[WordN 32]] -> [[Char]]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Int -> [WordN 32] -> [WordN 32] -> [Char]
forall {a}.
(Show a, Ord a, Num a) =>
a -> [WordN 32] -> [WordN 32] -> [Char]
sh [(Int
0::Int)..] ([WordN 32] -> [[WordN 32]]
forall a. [a] -> [[a]]
chop4 [WordN 32]
expected) ([WordN 32] -> [[WordN 32]]
forall a. [a] -> [[a]]
chop4 [WordN 32]
result)
                     else [Char] -> IO ()
putStrLn ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char]
"GOOD: Key generation AES" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
what
             else [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error    ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [[Char]] -> [Char]
unlines ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ ([Char]
"Size " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
what [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
": BAD!") [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: (Int -> [WordN 32] -> [WordN 32] -> [Char])
-> [Int] -> [[WordN 32]] -> [[WordN 32]] -> [[Char]]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Int -> [WordN 32] -> [WordN 32] -> [Char]
forall {a}.
(Show a, Ord a, Num a) =>
a -> [WordN 32] -> [WordN 32] -> [Char]
sh [(Int
0::Int)..] ([WordN 32] -> [[WordN 32]]
forall a. [a] -> [[a]]
chop4 [WordN 32]
expected) ([WordN 32] -> [[WordN 32]]
forall a. [a] -> [[a]]
chop4 [WordN 32]
result)

        roundTrip128 :: (SWord32, SWord32, SWord32, SWord32)
-> (SWord32, SWord32, SWord32, SWord32) -> SBool
roundTrip128 (SWord32
i0, SWord32
i1, SWord32
i2, SWord32
i3) (SWord32
k0, SWord32
k1, SWord32
k2, SWord32
k3)                 = [SWord32] -> [SWord32] -> SBool
roundTrip [SWord32
i0, SWord32
i1, SWord32
i2, SWord32
i3] [SWord32
k0, SWord32
k1, SWord32
k2, SWord32
k3]
        roundTrip192 :: (SWord32, SWord32, SWord32, SWord32)
-> (SWord32, SWord32, SWord32, SWord32, SWord32, SWord32) -> SBool
roundTrip192 (SWord32
i0, SWord32
i1, SWord32
i2, SWord32
i3) (SWord32
k0, SWord32
k1, SWord32
k2, SWord32
k3, SWord32
k4, SWord32
k5)         = [SWord32] -> [SWord32] -> SBool
roundTrip [SWord32
i0, SWord32
i1, SWord32
i2, SWord32
i3] [SWord32
k0, SWord32
k1, SWord32
k2, SWord32
k3, SWord32
k4, SWord32
k5]
        roundTrip256 :: (SWord32, SWord32, SWord32, SWord32)
-> (SWord32, SWord32, SWord32, SWord32, SWord32, SWord32, SWord32,
    SWord32)
-> SBool
roundTrip256 (SWord32
i0, SWord32
i1, SWord32
i2, SWord32
i3) (SWord32
k0, SWord32
k1, SWord32
k2, SWord32
k3, SWord32
k4, SWord32
k5, SWord32
k6, SWord32
k7) = [SWord32] -> [SWord32] -> SBool
roundTrip [SWord32
i0, SWord32
i1, SWord32
i2, SWord32
i3] [SWord32
k0, SWord32
k1, SWord32
k2, SWord32
k3, SWord32
k4, SWord32
k5, SWord32
k6, SWord32
k7]

        roundTrip :: [SWord32] -> [SWord32] -> SBool
        roundTrip :: [SWord32] -> [SWord32] -> SBool
roundTrip [SWord32]
ptIn [SWord32]
keyIn = [SWord 32]
pt [SWord 32] -> [SWord 32] -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [SWord 32]
pt' SBool -> SBool -> SBool
.&& [SWord 32]
pt [SWord 32] -> [SWord 32] -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [SWord 32]
pt''
           where pt :: [SWord 32]
pt  = (SWord32 -> SWord 32) -> [SWord32] -> [SWord 32]
forall a b. (a -> b) -> [a] -> [b]
map SWord32 -> SWord 32
SWord32 -> ToSized SWord32
forall a. ToSizedBV a => a -> ToSized a
toSized [SWord32]
ptIn
                 key :: [SWord 32]
key = (SWord32 -> SWord 32) -> [SWord32] -> [SWord 32]
forall a b. (a -> b) -> [a] -> [b]
map SWord32 -> SWord 32
SWord32 -> ToSized SWord32
forall a. ToSizedBV a => a -> ToSized a
toSized [SWord32]
keyIn

                 (KS
encKS, KS
decKS) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
key
                 ct :: [SWord 32]
ct   = [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
pt KS
encKS
                 pt' :: [SWord 32]
pt'  = [SWord 32] -> KS -> [SWord 32]
aesDecrypt [SWord 32]
ct KS
decKS
                 pt'' :: [SWord 32]
pt'' = [SWord 32] -> KS -> [SWord 32]
aesDecryptUnwoundKey [SWord 32]
ct ([SWord 32] -> KS
aesInvKeySchedule ([SWord 32] -> [SWord 32]
extractFinalKey [SWord 32]
key))

-----------------------------------------------------------------------------
-- * Verification
-- ${verifIntro}
-----------------------------------------------------------------------------
{- $verifIntro
  While SMT based technologies can prove correct many small properties fairly quickly, it would
  be naive for them to automatically verify that our AES implementation is correct. (By correct,
  we mean decryption followed by encryption yielding the same result.) However, we can state
  this property precisely using SBV, and use quick-check to gain some confidence.
-}

-- | Correctness theorem for 128-bit AES. Ideally, we would run:
--
-- @
--   prove aes128IsCorrect
-- @
--
-- to get a proof automatically. Unfortunately, while SBV will successfully generate the proof
-- obligation for this theorem and ship it to the SMT solver, it would be naive to expect the SMT-solver
-- to finish that proof in any reasonable time with the currently available SMT solving technologies.
-- Instead, we can issue:
--
-- @
--   quickCheck aes128IsCorrect
-- @
--
-- and get some degree of confidence in our code. Similar predicates can be easily constructed for 192, and
-- 256 bit cases as well.
aes128IsCorrect :: (SWord 32, SWord 32, SWord 32, SWord 32)  -- ^ plain-text words
                -> (SWord 32, SWord 32, SWord 32, SWord 32)  -- ^ key-words
                -> SBool                                 -- ^ True if round-trip gives us plain-text back
aes128IsCorrect :: (SWord 32, SWord 32, SWord 32, SWord 32)
-> (SWord 32, SWord 32, SWord 32, SWord 32) -> SBool
aes128IsCorrect (SWord 32
i0, SWord 32
i1, SWord 32
i2, SWord 32
i3) (SWord 32
k0, SWord 32
k1, SWord 32
k2, SWord 32
k3) = [SWord 32]
pt [SWord 32] -> [SWord 32] -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== [SWord 32]
pt'
   where pt :: [SWord 32]
pt  = [SWord 32
i0, SWord 32
i1, SWord 32
i2, SWord 32
i3]
         key :: [SWord 32]
key = [SWord 32
k0, SWord 32
k1, SWord 32
k2, SWord 32
k3]
         (KS
encKS, KS
decKS) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
key
         ct :: [SWord 32]
ct  = [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
pt KS
encKS
         pt' :: [SWord 32]
pt' = [SWord 32] -> KS -> [SWord 32]
aesDecrypt [SWord 32]
ct KS
decKS

-----------------------------------------------------------------------------
-- * Code generation
-- ${codeGenIntro}
-----------------------------------------------------------------------------
{- $codeGenIntro
   We have emphasized that our T-Box implementation in Haskell was guided by clarity and correctness, not
   performance. Indeed, our implementation is hardly the fastest AES implementation in Haskell. However,
   we can use it to automatically generate straight-line C-code that can run fairly fast.

   For the purposes of illustration, we only show here how to generate code for a 128-bit AES block-encrypt
   function, that takes 8 32-bit words as an argument. The first 4 are the 128-bit input, and the final
   four are the 128-bit key. The impact of this is that the generated function would expand the key for
   each block of encryption, a needless task unless we change the key in every block. In a more serious application,
   we would instead generate code for both the 'aesKeySchedule' and the 'aesEncrypt' functions, thus reusing the
   key-schedule over many applications of the encryption call. (Unfortunately doing this is rather cumbersome right
   now, since Haskell does not support fixed-size lists.)
-}

-- | Code generation for 128-bit AES encryption.
--
-- The following sample from the generated code-lines show how T-Boxes are rendered as C arrays:
--
-- @
--   static const SWord32 table1[] = {
--       0xc66363a5UL, 0xf87c7c84UL, 0xee777799UL, 0xf67b7b8dUL,
--       0xfff2f20dUL, 0xd66b6bbdUL, 0xde6f6fb1UL, 0x91c5c554UL,
--       0x60303050UL, 0x02010103UL, 0xce6767a9UL, 0x562b2b7dUL,
--       0xe7fefe19UL, 0xb5d7d762UL, 0x4dababe6UL, 0xec76769aUL,
--       ...
--       }
-- @
--
-- The generated program has 5 tables (one sbox table, and 4-Tboxes), all converted to fast C arrays. Here
-- is a sample of the generated straightline C-code:
--
-- @
--   const SWord8  s1915 = (SWord8) s1912;
--   const SWord8  s1916 = table0[s1915];
--   const SWord16 s1917 = (((SWord16) s1914) << 8) | ((SWord16) s1916);
--   const SWord32 s1918 = (((SWord32) s1911) << 16) | ((SWord32) s1917);
--   const SWord32 s1919 = s1844 ^ s1918;
--   const SWord32 s1920 = s1903 ^ s1919;
-- @
--
-- The GNU C-compiler does a fine job of optimizing this straightline code to generate a fairly efficient C implementation.
cgAES128BlockEncrypt :: IO ()
cgAES128BlockEncrypt :: IO ()
cgAES128BlockEncrypt = Maybe [Char] -> [Char] -> SBVCodeGen () -> IO ()
forall a. Maybe [Char] -> [Char] -> SBVCodeGen a -> IO a
compileToC Maybe [Char]
forall a. Maybe a
Nothing [Char]
"aes128BlockEncrypt" (SBVCodeGen () -> IO ()) -> SBVCodeGen () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        [SWord 32]
pt  <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
4 [Char]
"pt"        -- plain-text as an array of 4 Word32's
        [SWord 32]
key <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
4 [Char]
"key"       -- key as an array of 4 Word32s

        -- Use the test values from Appendix C.1 of the AES standard as the driver values
        [Integer] -> SBVCodeGen ()
cgSetDriverValues ([Integer] -> SBVCodeGen ()) -> [Integer] -> SBVCodeGen ()
forall a b. (a -> b) -> a -> b
$ (SWord 32 -> Integer) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (WordN 32 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordN 32 -> Integer)
-> (SWord 32 -> WordN 32) -> SWord 32 -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (WordN 32) -> WordN 32
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (WordN 32) -> WordN 32)
-> (SWord 32 -> Maybe (WordN 32)) -> SWord 32 -> WordN 32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 32 -> Maybe (WordN 32)
forall a. SymVal a => SBV a -> Maybe a
unliteral) ([SWord 32] -> [Integer]) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [SWord 32]
commonPT [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [SWord 32]
aes128Key

        let (KS
encKs, KS
_) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
key
        [Char] -> [SWord 32] -> SBVCodeGen ()
forall a. SymVal a => [Char] -> [SBV a] -> SBVCodeGen ()
cgOutputArr [Char]
"ct" ([SWord 32] -> SBVCodeGen ()) -> [SWord 32] -> SBVCodeGen ()
forall a b. (a -> b) -> a -> b
$ [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
pt KS
encKs

-----------------------------------------------------------------------------
-- * C-library generation
-- ${libraryIntro}
-----------------------------------------------------------------------------
{- $libraryIntro
   The 'cgAES128BlockEncrypt' example shows how to generate code for 128-bit AES encryption. As the generated
   function performs encryption on a given block, it performs key expansion as necessary. However, this is
   not quite practical: We would like to expand the key only once, and encrypt the stream of plain-text blocks using
   the same expanded key (potentially using some crypto-mode), until we decide to change the key. In this
   section, we show how to use SBV to instead generate a library of functions that can be used in such a scenario.
   The generated library is a typical @.a@ archive, that can be linked using the C-compiler as usual.
-}

-- | Components of the AES implementation that the library is generated from. For each case, we provide
-- the driver values from the AES test-vectors.
aesLibComponents :: Int -> [(String, [Integer], SBVCodeGen ())]
aesLibComponents :: Int -> [([Char], [Integer], SBVCodeGen ())]
aesLibComponents Int
sz = [ ([Char]
"aes" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"KeySchedule",    [Integer]
keyDriverVals,    SBVCodeGen ()
keySchedule)
                      , ([Char]
"aes" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"BlockEncrypt",   [Integer]
encDriverVals,    SBVCodeGen ()
enc)
                      , ([Char]
"aes" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"BlockDecrypt",   [Integer]
decDriverVals,    SBVCodeGen ()
dec)
                      , ([Char]
"aes" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"InvKeySchedule", [Integer]
invKeyDriverVals, SBVCodeGen ()
invKeySchedule)
                      , ([Char]
"aes" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"OTFDecrypt",     [Integer]
invDecDriverVals, SBVCodeGen ()
otfDec)
                      ]
  where badSize :: a
badSize = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"aesLibComponents: Size must be one of 128, 192, or 256; received: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz

        -- key-schedule
        nk :: Int
nk
         | Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
128 = Int
4
         | Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
192 = Int
6
         | Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
256 = Int
8
         | Bool
True      = Int
forall {a}. a
badSize

        -- We get 4*(nr+1) keys, where nr = nk + 6
        nr :: Int
nr = Int
nk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
6
        xk :: Int
xk = Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
nr Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

        ([Integer]
keyDriverVals, [Integer]
invKeyDriverVals, [Integer]
encDriverVals, [Integer]
decDriverVals, [Integer]
invDecDriverVals)
           | Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
128 = ([SWord 32] -> [Integer]
keyDriver [SWord 32]
aes128Key, [SWord 32] -> [Integer]
keyDriver [SWord 32]
aes128InvKey, [SWord 32] -> [SWord 32] -> [Integer]
encDriver [SWord 32]
commonPT [SWord 32]
aes128Key, [SWord 32] -> [SWord 32] -> [Integer]
decDriver [SWord 32]
aes128CT [SWord 32]
aes128Key, [SWord 32] -> [SWord 32] -> [Integer]
invDecDriver [SWord 32]
aes128CT [SWord 32]
aes128InvKey)
           | Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
192 = ([SWord 32] -> [Integer]
keyDriver [SWord 32]
aes192Key, [SWord 32] -> [Integer]
keyDriver [SWord 32]
aes192InvKey, [SWord 32] -> [SWord 32] -> [Integer]
encDriver [SWord 32]
commonPT [SWord 32]
aes192Key, [SWord 32] -> [SWord 32] -> [Integer]
decDriver [SWord 32]
aes192CT [SWord 32]
aes192Key, [SWord 32] -> [SWord 32] -> [Integer]
invDecDriver [SWord 32]
aes192CT [SWord 32]
aes192InvKey)
           | Int
sz Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
256 = ([SWord 32] -> [Integer]
keyDriver [SWord 32]
aes256Key, [SWord 32] -> [Integer]
keyDriver [SWord 32]
aes256InvKey, [SWord 32] -> [SWord 32] -> [Integer]
encDriver [SWord 32]
commonPT [SWord 32]
aes256Key, [SWord 32] -> [SWord 32] -> [Integer]
decDriver [SWord 32]
aes256CT [SWord 32]
aes256Key, [SWord 32] -> [SWord 32] -> [Integer]
invDecDriver [SWord 32]
aes256CT [SWord 32]
aes256InvKey)
           | Bool
True      = ([Integer], [Integer], [Integer], [Integer], [Integer])
forall {a}. a
badSize
           where keyDriver :: [SWord 32] -> [Integer]
keyDriver       [SWord 32]
key = (SWord 32 -> Integer) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> Integer
cvt ([SWord 32] -> [Integer]) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> a -> b
$ ([SWord 32] -> [SWord 32]) -> [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [SWord 32] -> [SWord 32]
forall a. [a] -> [a]
reverse ([SWord 32] -> [[SWord 32]]
forall a. [a] -> [[a]]
chop4 [SWord 32]
key)
                 encDriver :: [SWord 32] -> [SWord 32] -> [Integer]
encDriver    [SWord 32]
pt [SWord 32]
key = (SWord 32 -> Integer) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> Integer
cvt ([SWord 32] -> [Integer]) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [SWord 32]
pt [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ KS -> [SWord 32]
forall {t :: * -> *} {a}. Foldable t => ([a], t [a], [a]) -> [a]
flatten ((KS, KS) -> KS
forall a b. (a, b) -> a
fst ([SWord 32] -> (KS, KS)
aesKeySchedule    [SWord 32]
key))
                 decDriver :: [SWord 32] -> [SWord 32] -> [Integer]
decDriver    [SWord 32]
ct [SWord 32]
key = (SWord 32 -> Integer) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> Integer
cvt ([SWord 32] -> [Integer]) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [SWord 32]
ct [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ KS -> [SWord 32]
forall {t :: * -> *} {a}. Foldable t => ([a], t [a], [a]) -> [a]
flatten ((KS, KS) -> KS
forall a b. (a, b) -> b
snd ([SWord 32] -> (KS, KS)
aesKeySchedule    [SWord 32]
key))
                 invDecDriver :: [SWord 32] -> [SWord 32] -> [Integer]
invDecDriver [SWord 32]
ct [SWord 32]
key = (SWord 32 -> Integer) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map SWord 32 -> Integer
cvt ([SWord 32] -> [Integer]) -> [SWord 32] -> [Integer]
forall a b. (a -> b) -> a -> b
$ [SWord 32]
ct [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ KS -> [SWord 32]
forall {t :: * -> *} {a}. Foldable t => ([a], t [a], [a]) -> [a]
flatten      ([SWord 32] -> KS
aesInvKeySchedule [SWord 32]
key)

                 flatten :: ([a], t [a], [a]) -> [a]
flatten ([a]
f, t [a]
mid, [a]
l) = [a]
f [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ t [a] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat t [a]
mid [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
l
                 cvt :: SWord 32 -> Integer
cvt = WordN 32 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordN 32 -> Integer)
-> (SWord 32 -> WordN 32) -> SWord 32 -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (WordN 32) -> WordN 32
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (WordN 32) -> WordN 32)
-> (SWord 32 -> Maybe (WordN 32)) -> SWord 32 -> WordN 32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SWord 32 -> Maybe (WordN 32)
forall a. SymVal a => SBV a -> Maybe a
unliteral

        keySchedule :: SBVCodeGen ()
keySchedule = do [SWord 32]
key <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
nk [Char]
"key"     -- key
                         let (KS
encKS, KS
decKS) = [SWord 32] -> (KS, KS)
aesKeySchedule [SWord 32]
key
                         [Char] -> [SWord 32] -> SBVCodeGen ()
forall a. SymVal a => [Char] -> [SBV a] -> SBVCodeGen ()
cgOutputArr [Char]
"encKS" (KS -> [SWord 32]
ksToXKey KS
encKS)
                         [Char] -> [SWord 32] -> SBVCodeGen ()
forall a. SymVal a => [Char] -> [SBV a] -> SBVCodeGen ()
cgOutputArr [Char]
"decKS" (KS -> [SWord 32]
ksToXKey KS
decKS)

        invKeySchedule :: SBVCodeGen ()
invKeySchedule = do [SWord 32]
key <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
nk [Char]
"key"     -- key
                            let decKS :: KS
decKS = [SWord 32] -> KS
aesInvKeySchedule (([SWord 32] -> [SWord 32]) -> [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [SWord 32] -> [SWord 32]
forall a. [a] -> [a]
reverse ([SWord 32] -> [[SWord 32]]
forall a. [a] -> [[a]]
chop4 [SWord 32]
key))
                            [Char] -> [SWord 32] -> SBVCodeGen ()
forall a. SymVal a => [Char] -> [SBV a] -> SBVCodeGen ()
cgOutputArr [Char]
"decKS" (KS -> [SWord 32]
ksToXKey KS
decKS)

        -- encryption
        enc :: SBVCodeGen ()
enc = do [SWord 32]
pt   <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
4  [Char]
"pt"    -- plain-text
                 [SWord 32]
xkey <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
xk [Char]
"xkey"  -- expanded key
                 [Char] -> [SWord 32] -> SBVCodeGen ()
forall a. SymVal a => [Char] -> [SBV a] -> SBVCodeGen ()
cgOutputArr [Char]
"ct" ([SWord 32] -> SBVCodeGen ()) -> [SWord 32] -> SBVCodeGen ()
forall a b. (a -> b) -> a -> b
$ [SWord 32] -> KS -> [SWord 32]
aesEncrypt [SWord 32]
pt ([SWord 32] -> KS
xkeyToKS [SWord 32]
xkey)

        -- decryption
        dec :: SBVCodeGen ()
dec = do [SWord 32]
pt   <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
4  [Char]
"ct"    -- cipher-text
                 [SWord 32]
xkey <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
xk [Char]
"xkey"  -- expanded key
                 [Char] -> [SWord 32] -> SBVCodeGen ()
forall a. SymVal a => [Char] -> [SBV a] -> SBVCodeGen ()
cgOutputArr [Char]
"pt" ([SWord 32] -> SBVCodeGen ()) -> [SWord 32] -> SBVCodeGen ()
forall a b. (a -> b) -> a -> b
$ [SWord 32] -> KS -> [SWord 32]
aesDecrypt [SWord 32]
pt ([SWord 32] -> KS
xkeyToKS [SWord 32]
xkey)

        -- on-the-fly decryption
        otfDec :: SBVCodeGen ()
otfDec = do [SWord 32]
ct   <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
4  [Char]
"ct"    -- cipher-text
                    [SWord 32]
xkey <- Int -> [Char] -> SBVCodeGen [SWord 32]
forall a. SymVal a => Int -> [Char] -> SBVCodeGen [SBV a]
cgInputArr Int
xk [Char]
"xkey"  -- expanded key
                    [Char] -> [SWord 32] -> SBVCodeGen ()
forall a. SymVal a => [Char] -> [SBV a] -> SBVCodeGen ()
cgOutputArr [Char]
"pt" ([SWord 32] -> SBVCodeGen ()) -> [SWord 32] -> SBVCodeGen ()
forall a b. (a -> b) -> a -> b
$ [SWord 32] -> KS -> [SWord 32]
aesDecryptUnwoundKey [SWord 32]
ct ([SWord 32] -> KS
xkeyToKS [SWord 32]
xkey)

        -- Transforming back and forth from our KS type to a flat array used by the generated C code
        -- Turn a series of expanded keys to our internal KS type
        xkeyToKS :: [SWord 32] -> KS
        xkeyToKS :: [SWord 32] -> KS
xkeyToKS [SWord 32]
xs = ([SWord 32]
f, [[SWord 32]]
m, [SWord 32]
l)
           where f :: [SWord 32]
f  = Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
take Int
4 [SWord 32]
xs                             -- first round key
                 m :: [[SWord 32]]
m  = [SWord 32] -> [[SWord 32]]
forall a. [a] -> [[a]]
chop4 (Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
take (Int
xk Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8) (Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
drop Int
4 [SWord 32]
xs))     -- middle rounds
                 l :: [SWord 32]
l  = Int -> [SWord 32] -> [SWord 32]
forall a. Int -> [a] -> [a]
drop (Int
xk Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
4) [SWord 32]
xs                      -- last round key

        -- Turn a KS to a series of expanded key words
        ksToXKey :: KS -> [SWord 32]
        ksToXKey :: KS -> [SWord 32]
ksToXKey ([SWord 32]
f, [[SWord 32]]
m, [SWord 32]
l) = [SWord 32]
f [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [[SWord 32]] -> [SWord 32]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SWord 32]]
m [SWord 32] -> [SWord 32] -> [SWord 32]
forall a. [a] -> [a] -> [a]
++ [SWord 32]
l

-- | Generate code for AES functionality; given the key size.
cgAESLibrary :: Int -> Maybe FilePath -> IO ()
cgAESLibrary :: Int -> Maybe [Char] -> IO ()
cgAESLibrary Int
sz Maybe [Char]
mbd
  | Int
sz Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
128, Int
192, Int
256] = IO [()] -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO [()] -> IO ()) -> IO [()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Maybe [Char] -> [Char] -> [([Char], SBVCodeGen ())] -> IO [()]
forall a.
Maybe [Char] -> [Char] -> [([Char], SBVCodeGen a)] -> IO [a]
compileToCLib Maybe [Char]
mbd [Char]
nm [([Char]
fnm, [Integer] -> SBVCodeGen () -> SBVCodeGen ()
forall {b}. [Integer] -> SBVCodeGen b -> SBVCodeGen b
configure [Integer]
dvals SBVCodeGen ()
f) | ([Char]
fnm, [Integer]
dvals, SBVCodeGen ()
f) <- Int -> [([Char], [Integer], SBVCodeGen ())]
aesLibComponents Int
sz]
  | Bool
True                      = [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO ()) -> [Char] -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char]
"cgAESLibrary: Size must be one of 128, 192, or 256, received: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz
  where nm :: [Char]
nm = [Char]
"aes" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
sz [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"Lib"

        configure :: [Integer] -> SBVCodeGen b -> SBVCodeGen b
configure [Integer]
dvals SBVCodeGen b
code = [Integer] -> SBVCodeGen ()
cgSetDriverValues [Integer]
dvals SBVCodeGen () -> SBVCodeGen b -> SBVCodeGen b
forall a b. SBVCodeGen a -> SBVCodeGen b -> SBVCodeGen b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SBVCodeGen b
code

-- | Generate a C library, containing functions for performing 128-bit enc/dec/key-expansion.
-- A note on performance: In a very rough speed test, the generated code was able to do
-- 6.3 million block encryptions per second on a decent MacBook Pro. On the same machine, OpenSSL
-- reports 8.2 million block encryptions per second. So, the generated code is about 25% slower
-- as compared to the highly optimized OpenSSL implementation. (Note that the speed test was done
-- somewhat simplistically, so these numbers should be considered very rough estimates.)
cgAES128Library :: IO ()
cgAES128Library :: IO ()
cgAES128Library = Int -> Maybe [Char] -> IO ()
cgAESLibrary Int
128 Maybe [Char]
forall a. Maybe a
Nothing

--------------------------------------------------------------------------------------------
-- | For doctest purposes only
hex8 :: (SymVal a, Show a, Integral a) => SBV a -> String
hex8 :: forall a. (SymVal a, Show a, Integral a) => SBV a -> [Char]
hex8 SBV a
v = Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Char] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
s) Char
'0' [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
s
  where s :: [Char]
s = (a -> [Char] -> [Char]) -> [Char] -> a -> [Char]
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> [Char] -> [Char]
forall a. Integral a => a -> [Char] -> [Char]
showHex [Char]
"" (a -> [Char]) -> (SBV a -> a) -> SBV a -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> a
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe a -> a) -> (SBV a -> Maybe a) -> SBV a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SBV a -> Maybe a
forall a. SymVal a => SBV a -> Maybe a
unliteral (SBV a -> [Char]) -> SBV a -> [Char]
forall a b. (a -> b) -> a -> b
$ SBV a
v

-- | Chunk in groups of 4. (This function must be in some standard library, where?)
chop4 :: [a] -> [[a]]
chop4 :: forall a. [a] -> [[a]]
chop4 [] = []
chop4 [a]
xs = let ([a]
f, [a]
r) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
4 [a]
xs in [a]
f [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [a] -> [[a]]
forall a. [a] -> [[a]]
chop4 [a]
r

{- HLint ignore aesRound             "Use head" -}
{- HLint ignore aesInvRound          "Use head" -}
{- HLint ignore aesDecryptUnwoundKey "Use head" -}