-- | Advanced Encryption System (specification can be found in FIPS-197)
module Codec.Encryption.AESAux(
	aes128Encrypt,
	aes192Encrypt,
	aes256Encrypt,
	aes128Decrypt,
	aes192Decrypt,
	aes256Decrypt,
) where

import Data.Bits
import Data.Int(Int)
import Data.Word(Word32)

import Codec.Utils(Octet)

aes128Encrypt :: [Octet] -- ^ key (16 octets)
	      -> [Octet] -- ^ msg (16 octets)
	      -> [Octet] -- ^ enciphered msg (16 octets)
aes128Encrypt = aesEncrypt 10 4

aes192Encrypt :: [Octet] -- ^ key (24 octets)
	      -> [Octet] -- ^ msg (16 octets)
	      -> [Octet] -- ^ enciphered msg (16 octets)
aes192Encrypt = aesEncrypt 12 6

aes256Encrypt :: [Octet] -- ^ key (32 octets)
	      -> [Octet] -- ^ msg (16 octets)
	      -> [Octet] -- ^ enciphered msg (16 octets)
aes256Encrypt = aesEncrypt 14 8

aes128Decrypt :: [Octet] -- ^ key (16 octets)
	      -> [Octet] -- ^ enciphered msg (16 octets)
	      -> [Octet] -- ^ deciphered msg (16 octets)
aes128Decrypt = aesDecrypt 10 4

aes192Decrypt :: [Octet] -- ^ key (24 octets)
	      -> [Octet] -- ^ enciphered msg (16 octets)
	      -> [Octet] -- ^ deciphered msg (16 octets)
aes192Decrypt = aesDecrypt 12 6

aes256Decrypt :: [Octet] -- ^ key (32 octets)
	      -> [Octet] -- ^ enciphered msg (16 octets)
	      -> [Octet] -- ^ deciphered msg (16 octets)
aes256Decrypt = aesDecrypt 14 8

aesEncrypt :: Int     -- ^ nr 
	   -> Int     -- ^ nk
	   -> [Octet] -- ^ key
	   -> [Octet] -- ^ msg
	   -> [Octet] -- ^ enciphered msg
aesEncrypt nr nk key 
		[i00, i10, i20, i30,
		 i01, i11, i21, i31,
		 i02, i12, i22, i32,
		 i03, i13, i23, i33] =
		[fo o00, fo o10, fo o20, fo o30,
		 fo o01, fo o11, fo o21, fo o31,
		 fo o02, fo o12, fo o22, fo o32,
		 fo o03, fo o13, fo o23, fo o33]
	where State (o00, o01, o02, o03)	 
		    (o10, o11, o12, o13)	 
		    (o20, o21, o22, o23)	 
		    (o30, o31, o32, o33) = transform (
		    	State(fi i00, fi i01, fi i02, fi i03)
		    	     (fi i10, fi i11, fi i12, fi i13)
		    	     (fi i20, fi i21, fi i22, fi i23)
		    	     (fi i30, fi i31, fi i32, fi i33))
	      fi = (fromIntegral :: Octet -> Word32)		     
	      fo = (fromIntegral :: Word32 -> Octet)		     
	      (kt0:kts) = genAddRoundKey (generateKeys nr nk key)
	      transform = foldr (.) kt0 (reverse rest)
	      mss = replicate (nr - 1) (mixColumns . shiftRows . subBytes)
	      rest = zipWith (.) kts (mss ++ [shiftRows . subBytes ])
	      

aesDecrypt :: Int     -- ^ nr 
	   -> Int     -- ^ nk
	   -> [Octet] -- ^ key
	   -> [Octet] -- ^ enciphered msg
	   -> [Octet] -- ^ deciphered msg
aesDecrypt nr nk key 
		[i00, i10, i20, i30,
		 i01, i11, i21, i31,
		 i02, i12, i22, i32,
		 i03, i13, i23, i33] =
		[fo o00, fo o10, fo o20, fo o30,
		 fo o01, fo o11, fo o21, fo o31,
		 fo o02, fo o12, fo o22, fo o32,
		 fo o03, fo o13, fo o23, fo o33]
	where State (o00, o01, o02, o03)	 
		    (o10, o11, o12, o13)	 
		    (o20, o21, o22, o23)	 
		    (o30, o31, o32, o33) = transform (
		    	State(fi i00, fi i01, fi i02, fi i03)
		    	     (fi i10, fi i11, fi i12, fi i13)
		    	     (fi i20, fi i21, fi i22, fi i23)
		    	     (fi i30, fi i31, fi i32, fi i33))
	      fi = (fromIntegral :: Octet -> Word32)		     
	      fo = (fromIntegral :: Word32 -> Octet)		     
	      (kt0:kts) = reverse (genAddRoundKey (generateKeys nr nk key))
	      transform = foldr (.) kt0 (reverse rest)
	      ssm = replicate (nr - 1) 
	      			(subBytesRev . shiftRowsRev . mixColumnsRev)
	      rest = zipWith (.) kts ([subBytesRev . shiftRowsRev] ++ ssm)
	      


data State =  State !(Word32, Word32, Word32, Word32)
		    !(Word32, Word32, Word32, Word32)
		    !(Word32, Word32, Word32, Word32)
		    !(Word32, Word32, Word32, Word32)

sbox :: Word32 -> Word32
sboxRev :: Word32 -> Word32

sbox 0x00 = 0x63
sbox 0x01 = 0x7C
sbox 0x02 = 0x77
sbox 0x03 = 0x7B
sbox 0x04 = 0xF2
sbox 0x05 = 0x6B
sbox 0x06 = 0x6F
sbox 0x07 = 0xC5

sbox 0x08 = 0x30
sbox 0x09 = 0x01
sbox 0x0a = 0x67
sbox 0x0b = 0x2B
sbox 0x0c = 0xFE
sbox 0x0d = 0xD7
sbox 0x0e = 0xAB
sbox 0x0f = 0x76

sbox 0x10 = 0xCA
sbox 0x11 = 0x82
sbox 0x12 = 0xC9
sbox 0x13 = 0x7D
sbox 0x14 = 0xFA
sbox 0x15 = 0x59
sbox 0x16 = 0x47
sbox 0x17 = 0xF0

sbox 0x18 = 0xAD
sbox 0x19 = 0xD4
sbox 0x1a = 0xA2
sbox 0x1b = 0xAF
sbox 0x1c = 0x9C
sbox 0x1d = 0xA4
sbox 0x1e = 0x72
sbox 0x1f = 0xC0

sbox 0x20 = 0xB7
sbox 0x21 = 0xFD
sbox 0x22 = 0x93
sbox 0x23 = 0x26
sbox 0x24 = 0x36
sbox 0x25 = 0x3F
sbox 0x26 = 0xF7
sbox 0x27 = 0xCC

sbox 0x28 = 0x34
sbox 0x29 = 0xA5
sbox 0x2a = 0xE5
sbox 0x2b = 0xF1
sbox 0x2c = 0x71
sbox 0x2d = 0xD8
sbox 0x2e = 0x31
sbox 0x2f = 0x15

sbox 0x30 = 0x04
sbox 0x31 = 0xC7
sbox 0x32 = 0x23
sbox 0x33 = 0xC3
sbox 0x34 = 0x18
sbox 0x35 = 0x96
sbox 0x36 = 0x05
sbox 0x37 = 0x9A

sbox 0x38 = 0x07
sbox 0x39 = 0x12
sbox 0x3a = 0x80
sbox 0x3b = 0xE2
sbox 0x3c = 0xEB
sbox 0x3d = 0x27
sbox 0x3e = 0xB2
sbox 0x3f = 0x75

sbox 0x40 = 0x09
sbox 0x41 = 0x83
sbox 0x42 = 0x2C
sbox 0x43 = 0x1A
sbox 0x44 = 0x1B
sbox 0x45 = 0x6E
sbox 0x46 = 0x5A
sbox 0x47 = 0xA0

sbox 0x48 = 0x52
sbox 0x49 = 0x3B
sbox 0x4a = 0xD6
sbox 0x4b = 0xB3
sbox 0x4c = 0x29
sbox 0x4d = 0xE3
sbox 0x4e = 0x2F
sbox 0x4f = 0x84

sbox 0x50 = 0x53
sbox 0x51 = 0xD1
sbox 0x52 = 0x00
sbox 0x53 = 0xED
sbox 0x54 = 0x20
sbox 0x55 = 0xFC
sbox 0x56 = 0xB1
sbox 0x57 = 0x5B

sbox 0x58 = 0x6A
sbox 0x59 = 0xCB
sbox 0x5a = 0xBE
sbox 0x5b = 0x39
sbox 0x5c = 0x4A
sbox 0x5d = 0x4C
sbox 0x5e = 0x58
sbox 0x5f = 0xCF

sbox 0x60 = 0xD0
sbox 0x61 = 0xEF
sbox 0x62 = 0xAA
sbox 0x63 = 0xFB
sbox 0x64 = 0x43
sbox 0x65 = 0x4D
sbox 0x66 = 0x33
sbox 0x67 = 0x85

sbox 0x68 = 0x45
sbox 0x69 = 0xF9
sbox 0x6a = 0x02
sbox 0x6b = 0x7F
sbox 0x6c = 0x50
sbox 0x6d = 0x3C
sbox 0x6e = 0x9F
sbox 0x6f = 0xA8

sbox 0x70 = 0x51
sbox 0x71 = 0xA3
sbox 0x72 = 0x40
sbox 0x73 = 0x8F
sbox 0x74 = 0x92
sbox 0x75 = 0x9D
sbox 0x76 = 0x38
sbox 0x77 = 0xF5

sbox 0x78 = 0xBC
sbox 0x79 = 0xB6
sbox 0x7a = 0xDA
sbox 0x7b = 0x21
sbox 0x7c = 0x10
sbox 0x7d = 0xFF
sbox 0x7e = 0xF3
sbox 0x7f = 0xD2

sbox 0x80 = 0xCD
sbox 0x81 = 0x0C
sbox 0x82 = 0x13
sbox 0x83 = 0xEC
sbox 0x84 = 0x5F
sbox 0x85 = 0x97
sbox 0x86 = 0x44
sbox 0x87 = 0x17

sbox 0x88 = 0xC4
sbox 0x89 = 0xA7
sbox 0x8a = 0x7E
sbox 0x8b = 0x3D
sbox 0x8c = 0x64
sbox 0x8d = 0x5D
sbox 0x8e = 0x19
sbox 0x8f = 0x73

sbox 0x90 = 0x60
sbox 0x91 = 0x81
sbox 0x92 = 0x4F
sbox 0x93 = 0xDC
sbox 0x94 = 0x22
sbox 0x95 = 0x2A
sbox 0x96 = 0x90
sbox 0x97 = 0x88

sbox 0x98 = 0x46
sbox 0x99 = 0xEE
sbox 0x9a = 0xB8
sbox 0x9b = 0x14
sbox 0x9c = 0xDE
sbox 0x9d = 0x5E
sbox 0x9e = 0x0B
sbox 0x9f = 0xDB

sbox 0xa0 = 0xE0
sbox 0xa1 = 0x32
sbox 0xa2 = 0x3A
sbox 0xa3 = 0x0A
sbox 0xa4 = 0x49
sbox 0xa5 = 0x06
sbox 0xa6 = 0x24
sbox 0xa7 = 0x5C

sbox 0xa8 = 0xC2
sbox 0xa9 = 0xD3
sbox 0xaa = 0xAC
sbox 0xab = 0x62
sbox 0xac = 0x91
sbox 0xad = 0x95
sbox 0xae = 0xE4
sbox 0xaf = 0x79

sbox 0xb0 = 0xE7
sbox 0xb1 = 0xC8
sbox 0xb2 = 0x37
sbox 0xb3 = 0x6D
sbox 0xb4 = 0x8D
sbox 0xb5 = 0xD5
sbox 0xb6 = 0x4E
sbox 0xb7 = 0xA9

sbox 0xb8 = 0x6C
sbox 0xb9 = 0x56
sbox 0xba = 0xF4
sbox 0xbb = 0xEA
sbox 0xbc = 0x65
sbox 0xbd = 0x7A
sbox 0xbe = 0xAE
sbox 0xbf = 0x08

sbox 0xc0 = 0xBA
sbox 0xc1 = 0x78
sbox 0xc2 = 0x25
sbox 0xc3 = 0x2E
sbox 0xc4 = 0x1C
sbox 0xc5 = 0xA6
sbox 0xc6 = 0xB4
sbox 0xc7 = 0xC6

sbox 0xc8 = 0xE8
sbox 0xc9 = 0xDD
sbox 0xca = 0x74
sbox 0xcb = 0x1F
sbox 0xcc = 0x4B
sbox 0xcd = 0xBD
sbox 0xce = 0x8B
sbox 0xcf = 0x8A

sbox 0xd0 = 0x70
sbox 0xd1 = 0x3E
sbox 0xd2 = 0xB5
sbox 0xd3 = 0x66
sbox 0xd4 = 0x48
sbox 0xd5 = 0x03
sbox 0xd6 = 0xF6
sbox 0xd7 = 0x0E

sbox 0xd8 = 0x61
sbox 0xd9 = 0x35
sbox 0xda = 0x57
sbox 0xdb = 0xB9
sbox 0xdc = 0x86
sbox 0xdd = 0xC1
sbox 0xde = 0x1D
sbox 0xdf = 0x9E

sbox 0xe0 = 0xE1
sbox 0xe1 = 0xF8
sbox 0xe2 = 0x98
sbox 0xe3 = 0x11
sbox 0xe4 = 0x69
sbox 0xe5 = 0xD9
sbox 0xe6 = 0x8E
sbox 0xe7 = 0x94

sbox 0xe8 = 0x9B
sbox 0xe9 = 0x1E
sbox 0xea = 0x87
sbox 0xeb = 0xE9
sbox 0xec = 0xCE
sbox 0xed = 0x55
sbox 0xee = 0x28
sbox 0xef = 0xDF

sbox 0xf0 = 0x8C
sbox 0xf1 = 0xA1
sbox 0xf2 = 0x89
sbox 0xf3 = 0x0D
sbox 0xf4 = 0xBF
sbox 0xf5 = 0xE6
sbox 0xf6 = 0x42
sbox 0xf7 = 0x68

sbox 0xf8 = 0x41
sbox 0xf9 = 0x99
sbox 0xfa = 0x2D
sbox 0xfb = 0x0F
sbox 0xfc = 0xB0
sbox 0xfd = 0x54
sbox 0xfe = 0xBB
sbox 0xff = 0x16

{----}

sboxRev 0x63 = 0x00
sboxRev 0x7C = 0x01
sboxRev 0x77 = 0x02
sboxRev 0x7B = 0x03
sboxRev 0xF2 = 0x04
sboxRev 0x6B = 0x05
sboxRev 0x6F = 0x06
sboxRev 0xC5 = 0x07

sboxRev 0x30 = 0x08
sboxRev 0x01 = 0x09
sboxRev 0x67 = 0x0a
sboxRev 0x2B = 0x0b
sboxRev 0xFE = 0x0c
sboxRev 0xD7 = 0x0d
sboxRev 0xAB = 0x0e
sboxRev 0x76 = 0x0f

sboxRev 0xCA = 0x10
sboxRev 0x82 = 0x11
sboxRev 0xC9 = 0x12
sboxRev 0x7D = 0x13
sboxRev 0xFA = 0x14
sboxRev 0x59 = 0x15
sboxRev 0x47 = 0x16
sboxRev 0xF0 = 0x17

sboxRev 0xAD = 0x18
sboxRev 0xD4 = 0x19
sboxRev 0xA2 = 0x1a
sboxRev 0xAF = 0x1b
sboxRev 0x9C = 0x1c
sboxRev 0xA4 = 0x1d
sboxRev 0x72 = 0x1e
sboxRev 0xC0 = 0x1f

sboxRev 0xB7 = 0x20
sboxRev 0xFD = 0x21
sboxRev 0x93 = 0x22
sboxRev 0x26 = 0x23
sboxRev 0x36 = 0x24
sboxRev 0x3F = 0x25
sboxRev 0xF7 = 0x26
sboxRev 0xCC = 0x27

sboxRev 0x34 = 0x28
sboxRev 0xA5 = 0x29
sboxRev 0xE5 = 0x2a
sboxRev 0xF1 = 0x2b
sboxRev 0x71 = 0x2c
sboxRev 0xD8 = 0x2d
sboxRev 0x31 = 0x2e
sboxRev 0x15 = 0x2f

sboxRev 0x04 = 0x30
sboxRev 0xC7 = 0x31
sboxRev 0x23 = 0x32
sboxRev 0xC3 = 0x33
sboxRev 0x18 = 0x34
sboxRev 0x96 = 0x35
sboxRev 0x05 = 0x36
sboxRev 0x9A = 0x37

sboxRev 0x07 = 0x38
sboxRev 0x12 = 0x39
sboxRev 0x80 = 0x3a
sboxRev 0xE2 = 0x3b
sboxRev 0xEB = 0x3c
sboxRev 0x27 = 0x3d
sboxRev 0xB2 = 0x3e
sboxRev 0x75 = 0x3f

sboxRev 0x09 = 0x40
sboxRev 0x83 = 0x41
sboxRev 0x2C = 0x42
sboxRev 0x1A = 0x43
sboxRev 0x1B = 0x44
sboxRev 0x6E = 0x45
sboxRev 0x5A = 0x46
sboxRev 0xA0 = 0x47

sboxRev 0x52 = 0x48
sboxRev 0x3B = 0x49
sboxRev 0xD6 = 0x4a
sboxRev 0xB3 = 0x4b
sboxRev 0x29 = 0x4c
sboxRev 0xE3 = 0x4d
sboxRev 0x2F = 0x4e
sboxRev 0x84 = 0x4f

sboxRev 0x53 = 0x50
sboxRev 0xD1 = 0x51
sboxRev 0x00 = 0x52
sboxRev 0xED = 0x53
sboxRev 0x20 = 0x54
sboxRev 0xFC = 0x55
sboxRev 0xB1 = 0x56
sboxRev 0x5B = 0x57

sboxRev 0x6A = 0x58
sboxRev 0xCB = 0x59
sboxRev 0xBE = 0x5a
sboxRev 0x39 = 0x5b
sboxRev 0x4A = 0x5c
sboxRev 0x4C = 0x5d
sboxRev 0x58 = 0x5e
sboxRev 0xCF = 0x5f

sboxRev 0xD0 = 0x60
sboxRev 0xEF = 0x61
sboxRev 0xAA = 0x62
sboxRev 0xFB = 0x63
sboxRev 0x43 = 0x64
sboxRev 0x4D = 0x65
sboxRev 0x33 = 0x66
sboxRev 0x85 = 0x67

sboxRev 0x45 = 0x68
sboxRev 0xF9 = 0x69
sboxRev 0x02 = 0x6a
sboxRev 0x7F = 0x6b
sboxRev 0x50 = 0x6c
sboxRev 0x3C = 0x6d
sboxRev 0x9F = 0x6e
sboxRev 0xA8 = 0x6f

sboxRev 0x51 = 0x70
sboxRev 0xA3 = 0x71
sboxRev 0x40 = 0x72
sboxRev 0x8F = 0x73
sboxRev 0x92 = 0x74
sboxRev 0x9D = 0x75
sboxRev 0x38 = 0x76
sboxRev 0xF5 = 0x77

sboxRev 0xBC = 0x78
sboxRev 0xB6 = 0x79
sboxRev 0xDA = 0x7a
sboxRev 0x21 = 0x7b
sboxRev 0x10 = 0x7c
sboxRev 0xFF = 0x7d
sboxRev 0xF3 = 0x7e
sboxRev 0xD2 = 0x7f

sboxRev 0xCD = 0x80
sboxRev 0x0C = 0x81
sboxRev 0x13 = 0x82
sboxRev 0xEC = 0x83
sboxRev 0x5F = 0x84
sboxRev 0x97 = 0x85
sboxRev 0x44 = 0x86
sboxRev 0x17 = 0x87

sboxRev 0xC4 = 0x88
sboxRev 0xA7 = 0x89
sboxRev 0x7E = 0x8a
sboxRev 0x3D = 0x8b
sboxRev 0x64 = 0x8c
sboxRev 0x5D = 0x8d
sboxRev 0x19 = 0x8e
sboxRev 0x73 = 0x8f

sboxRev 0x60 = 0x90
sboxRev 0x81 = 0x91
sboxRev 0x4F = 0x92
sboxRev 0xDC = 0x93
sboxRev 0x22 = 0x94
sboxRev 0x2A = 0x95
sboxRev 0x90 = 0x96
sboxRev 0x88 = 0x97

sboxRev 0x46 = 0x98
sboxRev 0xEE = 0x99
sboxRev 0xB8 = 0x9a
sboxRev 0x14 = 0x9b
sboxRev 0xDE = 0x9c
sboxRev 0x5E = 0x9d
sboxRev 0x0B = 0x9e
sboxRev 0xDB = 0x9f

sboxRev 0xE0 = 0xa0
sboxRev 0x32 = 0xa1
sboxRev 0x3A = 0xa2
sboxRev 0x0A = 0xa3
sboxRev 0x49 = 0xa4
sboxRev 0x06 = 0xa5
sboxRev 0x24 = 0xa6
sboxRev 0x5C = 0xa7

sboxRev 0xC2 = 0xa8
sboxRev 0xD3 = 0xa9
sboxRev 0xAC = 0xaa
sboxRev 0x62 = 0xab
sboxRev 0x91 = 0xac
sboxRev 0x95 = 0xad
sboxRev 0xE4 = 0xae
sboxRev 0x79 = 0xaf

sboxRev 0xE7 = 0xb0
sboxRev 0xC8 = 0xb1
sboxRev 0x37 = 0xb2
sboxRev 0x6D = 0xb3
sboxRev 0x8D = 0xb4
sboxRev 0xD5 = 0xb5
sboxRev 0x4E = 0xb6
sboxRev 0xA9 = 0xb7

sboxRev 0x6C = 0xb8
sboxRev 0x56 = 0xb9
sboxRev 0xF4 = 0xba
sboxRev 0xEA = 0xbb
sboxRev 0x65 = 0xbc
sboxRev 0x7A = 0xbd
sboxRev 0xAE = 0xbe
sboxRev 0x08 = 0xbf

sboxRev 0xBA = 0xc0
sboxRev 0x78 = 0xc1
sboxRev 0x25 = 0xc2
sboxRev 0x2E = 0xc3
sboxRev 0x1C = 0xc4
sboxRev 0xA6 = 0xc5
sboxRev 0xB4 = 0xc6
sboxRev 0xC6 = 0xc7

sboxRev 0xE8 = 0xc8
sboxRev 0xDD = 0xc9
sboxRev 0x74 = 0xca
sboxRev 0x1F = 0xcb
sboxRev 0x4B = 0xcc
sboxRev 0xBD = 0xcd
sboxRev 0x8B = 0xce
sboxRev 0x8A = 0xcf

sboxRev 0x70 = 0xd0
sboxRev 0x3E = 0xd1
sboxRev 0xB5 = 0xd2
sboxRev 0x66 = 0xd3
sboxRev 0x48 = 0xd4
sboxRev 0x03 = 0xd5
sboxRev 0xF6 = 0xd6
sboxRev 0x0E = 0xd7

sboxRev 0x61 = 0xd8
sboxRev 0x35 = 0xd9
sboxRev 0x57 = 0xda
sboxRev 0xB9 = 0xdb
sboxRev 0x86 = 0xdc
sboxRev 0xC1 = 0xdd
sboxRev 0x1D = 0xde
sboxRev 0x9E = 0xdf

sboxRev 0xE1 = 0xe0
sboxRev 0xF8 = 0xe1
sboxRev 0x98 = 0xe2
sboxRev 0x11 = 0xe3
sboxRev 0x69 = 0xe4
sboxRev 0xD9 = 0xe5
sboxRev 0x8E = 0xe6
sboxRev 0x94 = 0xe7

sboxRev 0x9B = 0xe8
sboxRev 0x1E = 0xe9
sboxRev 0x87 = 0xea
sboxRev 0xE9 = 0xeb
sboxRev 0xCE = 0xec
sboxRev 0x55 = 0xed
sboxRev 0x28 = 0xee
sboxRev 0xDF = 0xef

sboxRev 0x8C = 0xf0
sboxRev 0xA1 = 0xf1
sboxRev 0x89 = 0xf2
sboxRev 0x0D = 0xf3
sboxRev 0xBF = 0xf4
sboxRev 0xE6 = 0xf5
sboxRev 0x42 = 0xf6
sboxRev 0x68 = 0xf7

sboxRev 0x41 = 0xf8
sboxRev 0x99 = 0xf9
sboxRev 0x2D = 0xfa
sboxRev 0x0F = 0xfb
sboxRev 0xB0 = 0xfc
sboxRev 0x54 = 0xfd
sboxRev 0xBB = 0xfe
sboxRev 0x16 = 0xff

xtime :: Word32 -> Word32
xtime x = b
	where	a = x `shiftL` 1
		b = if a .&. (0x0100) == 0 then a else a `xor` 0x11b

xtimeX2 :: Word32 -> Word32
--xtimeX2 = xtime . xtime
xtimeX2 x = c
	where	a = x `shiftL` 2
		b = if a .&. (0x0200) == 0 then a else a `xor` 0x236
		c = if b .&. (0x0100) == 0 then b else b `xor` 0x11b

xtimeX3 :: Word32 -> Word32
--xtimeX3 = xtime . xtime . xtime
xtimeX3 x = d
	where	a = x `shiftL` 3
		b = if a .&. (0x0400) == 0 then a else a `xor` 0x46c
		c = if b .&. (0x0200) == 0 then b else b `xor` 0x236
		d = if c .&. (0x0100) == 0 then c else c `xor` 0x11b

xtime03 :: Word32 -> Word32
xtime03 x = x `xor` (xtime x)

xtime0e :: Word32 -> Word32 
xtime0e x = xtime (x `xor` (xtime (x `xor` (xtime x))))

xtime09 :: Word32 -> Word32 
xtime09 x = x `xor` (xtimeX3 x)

xtime0d :: Word32 -> Word32 
xtime0d x = x `xor` (xtimeX2 (x `xor` (xtime x)))

xtime0b :: Word32 -> Word32 
xtime0b x = x `xor` (xtime (x `xor` (xtimeX2 x)))

generateKey :: Int -> Int -> Word32 -> Word32 -> Word32
generateKey nk i wIminus1 wIminusNk = 
		(temp' `xor` wIminusNk)
	where temp' = 
		if (i `mod` nk) == 0 then (subword(rotword temp)) `xor` rcon
		else if (nk > 6) && ((i `mod` nk) == 4) then subword temp
		else temp
	      temp = wIminus1	
	      subword :: Word32 -> Word32
	      subword w = (a `shiftL` 24) .|. (b `shiftL` 16) .|.
	      		  (c `shiftL` 8) .|. d
			  where a = sbox ((w `shiftR` 24) .&. 0xff)
			        b = sbox ((w `shiftR` 16) .&. 0xff)
			        c = sbox ((w `shiftR`  8) .&. 0xff)
			        d = sbox ( w             .&. 0xff)
	      rotword :: Word32 -> Word32
	      rotword w = w `rotateL` 8 			
	      rcon :: Word32
	      rcon = ((fromIntegral rconMSB)::Word32) `shiftL` 24
	      rconMSB = (iterate xtime 0x01) !! ((i `div` nk) - 1)
	
wordify :: [Octet] -> [Word32]
wordify [] = []
wordify octets = firstWord:otherWords
	where
		(firstWord, otherOctets) = getWord32 octets
		otherWords = wordify otherOctets

generateKeys :: Int -> Int -> [Octet] -> [Word32]
generateKeys nr nk mainKey = 
		-- assert ((nk * 4) == length mainKey) $
		(take (4 * (nr + 1)) xs)
	where   
		xs = (wordify mainKey) ++ (zipWith3 (generateKey nk)
						(drop nk [0,1..])
						(drop (nk - 1) xs)
						xs
					)

subBytes :: State -> State
subBytes (State (s00, s01, s02, s03)
		(s10, s11, s12, s13)
	 	(s20, s21, s22, s23)
	 	(s30, s31, s32, s33)) = 
	  State (sbox s00, sbox s01, sbox s02, sbox s03)
 		(sbox s10, sbox s11, sbox s12, sbox s13)
		(sbox s20, sbox s21, sbox s22, sbox s23)
		(sbox s30, sbox s31, sbox s32, sbox s33)

subBytesRev :: State -> State
subBytesRev (State (s00, s01, s02, s03)
		   (s10, s11, s12, s13)
	 	   (s20, s21, s22, s23)
	 	   (s30, s31, s32, s33)) = 
	  State (sboxRev s00, sboxRev s01, sboxRev s02, sboxRev s03)
  		(sboxRev s10, sboxRev s11, sboxRev s12, sboxRev s13)
  		(sboxRev s20, sboxRev s21, sboxRev s22, sboxRev s23)
  		(sboxRev s30, sboxRev s31, sboxRev s32, sboxRev s33)

shiftRows :: State -> State		
shiftRows (State (s00, s01, s02, s03)
		 (s10, s11, s12, s13)
	 	 (s20, s21, s22, s23)
	 	 (s30, s31, s32, s33)) = 
	  State (s00, s01, s02, s03)
 		(s11, s12, s13, s10)
		(s22, s23, s20, s21)
		(s33, s30, s31, s32)

shiftRowsRev :: State -> State		
shiftRowsRev (State (s00, s01, s02, s03)
		    (s10, s11, s12, s13)
	 	    (s20, s21, s22, s23)
	 	    (s30, s31, s32, s33)) = 
	  State (s00, s01, s02, s03)
 		(s13, s10, s11, s12)
		(s22, s23, s20, s21)
		(s31, s32, s33, s30)

mixColumn:: (Word32, Word32, Word32, Word32) -> (Word32, Word32, Word32, Word32)
mixColumn (s0,s1,s2,s3) = 
	((xtime   s0) `xor` (xtime03 s1) `xor`          s2  `xor`          s3 ,
	          s0  `xor` (xtime   s1) `xor` (xtime03 s2) `xor`          s3 ,
	          s0  `xor`          s1  `xor` (xtime   s2) `xor` (xtime03 s3),
	 (xtime03 s0) `xor`          s1  `xor`          s2  `xor` (xtime   s3))

mixColumns :: State -> State		
mixColumns (State (s00, s01, s02, s03)
		  (s10, s11, s12, s13)
	 	  (s20, s21, s22, s23)
	 	  (s30, s31, s32, s33)) = 
	  State (r00, r01, r02, r03)
 		(r10, r11, r12, r13)
		(r20, r21, r22, r23)
		(r30, r31, r32, r33)
	where (r00, r10, r20, r30) = mixColumn (s00, s10, s20, s30)
	      (r01, r11, r21, r31) = mixColumn (s01, s11, s21, s31)
	      (r02, r12, r22, r32) = mixColumn (s02, s12, s22, s32)
	      (r03, r13, r23, r33) = mixColumn (s03, s13, s23, s33)

mixColumnRev :: (Word32, Word32, Word32, Word32) 
		-> (Word32, Word32, Word32, Word32)
mixColumnRev (s0,s1,s2,s3) = 
	((xtime0e s0) `xor` (xtime0b s1) `xor` (xtime0d s2) `xor` (xtime09 s3),
	 (xtime09 s0) `xor` (xtime0e s1) `xor` (xtime0b s2) `xor` (xtime0d s3),
	 (xtime0d s0) `xor` (xtime09 s1) `xor` (xtime0e s2) `xor` (xtime0b s3),
	 (xtime0b s0) `xor` (xtime0d s1) `xor` (xtime09 s2) `xor` (xtime0e s3))

mixColumnsRev :: State -> State		
mixColumnsRev (State (s00, s01, s02, s03)
		  (s10, s11, s12, s13)
	 	  (s20, s21, s22, s23)
	 	  (s30, s31, s32, s33)) = 
	  State (r00, r01, r02, r03)
 		(r10, r11, r12, r13)
		(r20, r21, r22, r23)
		(r30, r31, r32, r33)
	where (r00, r10, r20, r30) = mixColumnRev (s00, s10, s20, s30)
	      (r01, r11, r21, r31) = mixColumnRev (s01, s11, s21, s31)
	      (r02, r12, r22, r32) = mixColumnRev (s02, s12, s22, s32)
	      (r03, r13, r23, r33) = mixColumnRev (s03, s13, s23, s33)

addRoundKey :: State -> State -> State
addRoundKey (State (k00, k01, k02, k03)
		   (k10, k11, k12, k13)
		   (k20, k21, k22, k23)
	 	   (k30, k31, k32, k33))   
	    (State (s00, s01, s02, s03)
		   (s10, s11, s12, s13)
		   (s20, s21, s22, s23)
	 	   (s30, s31, s32, s33)) = 
	  State (s00 `xor` k00, s01 `xor` k01, s02 `xor` k02, s03 `xor` k03)
 		(s10 `xor` k10, s11 `xor` k11, s12 `xor` k12, s13 `xor` k13)
		(s20 `xor` k20, s21 `xor` k21, s22 `xor` k22, s23 `xor` k23)
		(s30 `xor` k30, s31 `xor` k31, s32 `xor` k32, s33 `xor` k33)

genAddRoundKey :: [Word32] -> [State -> State]
genAddRoundKey [] = []
genAddRoundKey (a:b:c:d:ks) = (addRoundKey k):(genAddRoundKey ks)
	where k = State (fromIntegral s00, fromIntegral s01,
			 fromIntegral s02, fromIntegral s03)
			(fromIntegral s10, fromIntegral s11,
			 fromIntegral s12, fromIntegral s13)
			(fromIntegral s20, fromIntegral s21,
			 fromIntegral s22, fromIntegral s23)
			(fromIntegral s30, fromIntegral s31,
			 fromIntegral s32, fromIntegral s33) 
	      [s00, s10, s20, s30] = putWord32 a 		
	      [s01, s11, s21, s31] = putWord32 b 		
	      [s02, s12, s22, s32] = putWord32 c 		
	      [s03, s13, s23, s33] = putWord32 d 		

getWord32 :: [Octet] -> (Word32, [Octet])
getWord32 (a:b:c:d:xs) = (x, xs)
	where 
		x = ((fromIntegral a) `shiftL` 24) .|. 
		    ((fromIntegral b) `shiftL` 16) .|. 
		    ((fromIntegral c) `shiftL`  8) .|. 
		    ((fromIntegral d)            ) 

putWord32 :: Word32 -> [Octet]
--a bit slower putWord32 x = map fromIntegral [a,b,c,d]
putWord32 x = [fromIntegral a, fromIntegral b, fromIntegral c, fromIntegral d]
	where
		a = (x `shiftR` 24) 
		b = (x `shiftR` 16) .&. 255
		c = (x `shiftR`  8) .&. 255
		d = (x            ) .&. 255

{-
testGenerateKeys128 :: Test
testGenerateKeys128 = 
	let key = [0x2b, 0x7e, 0x15, 0x16,
		   0x28, 0xae, 0xd2, 0xa6,
		   0xab, 0xf7, 0x15, 0x88,
		   0x09, 0xcf, 0x4f, 0x3c]
	    expected = [0x2b7e1516, 0x28aed2a6, 0xabf71588, 0x09cf4f3c,
	    		0xa0fafe17, 0x88542cb1, 0x23a33939, 0x2a6c7605,
			0xf2c295f2, 0x7a96b943, 0x5935807a, 0x7359f67f,
			0x3d80477d, 0x4716fe3e, 0x1e237e44, 0x6d7a883b,
			0xef44a541, 0xa8525b7f, 0xb671253b, 0xdb0bad00,
			0xd4d1c6f8, 0x7c839d87, 0xcaf2b8bc, 0x11f915bc,
			0x6d88a37a, 0x110b3efd, 0xdbf98641, 0xca0093fd,
			0x4e54f70e, 0x5f5fc9f3, 0x84a64fb2, 0x4ea6dc4f,
			0xead27321, 0xb58dbad2, 0x312bf560, 0x7f8d292f,
			0xac7766f3, 0x19fadc21, 0x28d12941, 0x575c006e,
			0xd014f9a8, 0xc9ee2589, 0xe13f0cc8, 0xb6630ca6
			]
	in TestCase (do
		assertEqual "" expected (generateKeys 10 4 key)
	)


testGenerateKeys192 :: Test
testGenerateKeys192 = 
	let key = [0x8e, 0x73, 0xb0, 0xf7,
		   0xda, 0x0e, 0x64, 0x52,
		   0xc8, 0x10, 0xf3, 0x2b,
		   0x80, 0x90, 0x79, 0xe5,
		   0x62, 0xf8, 0xea, 0xd2,
		   0x52, 0x2c, 0x6b, 0x7b]
	    expected = [0x8e73b0f7, 0xda0e6452, 0xc810f32b, 0x809079e5,
	    		0x62f8ead2, 0x522c6b7b, 0xfe0c91f7, 0x2402f5a5,
			0xec12068e, 0x6c827f6b, 0x0e7a95b9, 0x5c56fec2,
			0x4db7b4bd, 0x69b54118, 0x85a74796, 0xe92538fd,
			0xe75fad44, 0xbb095386, 0x485af057, 0x21efb14f
			]
	in TestCase (do
		assertEqual "" expected
			(take (length expected) (generateKeys 12 6 key))
	)

testGenerateKeys256 :: Test
testGenerateKeys256 = 
	let key = [0x60, 0x3d, 0xeb, 0x10,
		   0x15, 0xca, 0x71, 0xbe,
		   0x2b, 0x73, 0xae, 0xf0,
		   0x85, 0x7d, 0x77, 0x81,
		   0x1f, 0x35, 0x2c, 0x07,
		   0x3b, 0x61, 0x08, 0xd7,
		   0x2d, 0x98, 0x10, 0xa3,
		   0x09, 0x14, 0xdf, 0xf4]
	    expected = [0x603deb10, 0x15ca71be, 0x2b73aef0, 0x857d7781,
	    		0x1f352c07, 0x3b6108d7, 0x2d9810a3, 0x0914dff4,
			0x9ba35411, 0x8e6925af, 0xa51a8b5f, 0x2067fcde,
			0xa8b09c1a, 0x93d194cd, 0xbe49846e, 0xb75d5b9a,
			0xd59aecb8, 0x5bf3c917, 0xfee94248, 0xde8ebe96
			]
	in TestCase (do
		assertEqual "" expected 
			(take (length expected) (generateKeys 14 8 key))
	)

testAes128 :: Test
testAes128 = 
	let key = [0x2b, 0x7e, 0x15, 0x16,
		   0x28, 0xae, 0xd2, 0xa6,
		   0xab, 0xf7, 0x15, 0x88,
		   0x09, 0xcf, 0x4f, 0x3c]
	    input = [0x32, 0x43, 0xf6, 0xa8, 
	    	     0x88, 0x5a, 0x30, 0x8d,
		     0x31, 0x31, 0x98, 0xa2,
		     0xe0, 0x37, 0x07, 0x34]
	    output = [0x39, 0x25, 0x84, 0x1d,
	    	      0x02, 0xdc, 0x09, 0xfb,
		      0xdc, 0x11, 0x85, 0x97,
		      0x19, 0x6a, 0x0b, 0x32]
	in TestCase (do
		assertEqual "encrypt test" output (aes128Encrypt key input)
		assertEqual "encrypt/decrypt test" input 
				(aes128Decrypt key (aes128Encrypt key input))
	)

testAesRandom :: Test
testAesRandom = 
	TestCase (do
		key128 <- getRandomOctets 16 
		key192 <- getRandomOctets 24 
		key256 <- getRandomOctets 32 
		msg <- getRandomOctets 16 
		assertEqual "aes128" msg 
			(aes128Decrypt key128 (aes128Encrypt key128 msg))
		assertEqual "aes192" msg 
			(aes192Decrypt key192 (aes192Encrypt key192 msg))
		assertEqual "aes256" msg 
			(aes256Decrypt key256 (aes256Encrypt key256 msg))
	)

-- | HUnit tests
tests :: Test
tests = TestList [
		TestLabel "testGenerateKeys128" testGenerateKeys128,
		TestLabel "testGenerateKeys192" testGenerateKeys192,
		TestLabel "testGenerateKeys256" testGenerateKeys256,
		TestLabel "testAes128" testAes128,
		TestLabel "testAesRandom" testAesRandom
	]
-}