{-# LANGUAGE BangPatterns, MagicHash #-}
-- |
-- Module      : Crypto.Cipher.AES
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good

module Crypto.Cipher.AES
	( Key
	, IV
	-- * Basic encryption and decryption
	, encrypt
	, decrypt
	-- * CBC encryption and decryption
	, encryptCBC
	, decryptCBC
	-- * key building mechanism
	, initKey128
	, initKey192
	, initKey256
	-- * Wrappers for "crypto-api" instances
	, AES128
	, AES192
	, AES256
	) where

import Data.Word
import Data.Vector.Unboxed (Vector)
import qualified Data.Vector.Unboxed as V
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Internal as B

import GHC.Prim (indexWord8OffAddr#, indexWord32OffAddr#, word2Int#, Addr#, remInt#)
import GHC.Word
import GHC.Types

import Foreign.Ptr
import Foreign.Storable

import Data.Tagged (Tagged(..))
import Crypto.Classes (BlockCipher(..))
import Data.Serialize (Serialize(..), getByteString, putByteString)

import Control.Monad (forM_)
import Control.Monad.Primitive

import Data.Primitive.ByteArray

import System.Endian (littleEndian)

newtype AES128 = A128 { unA128 :: Key }
newtype AES192 = A192 { unA192 :: Key }
newtype AES256 = A256 { unA256 :: Key }

instance BlockCipher AES128 where
	blockSize    = Tagged 128
	encryptBlock = encrypt . unA128
	decryptBlock = decrypt . unA128
	buildKey b   = either (const Nothing) (Just . A128) $ initKey128 b
	keyLength    = Tagged 128

instance BlockCipher AES192 where
	blockSize    = Tagged 128
	encryptBlock = encrypt . unA192
	decryptBlock = decrypt . unA192
	buildKey b   = either (const Nothing) (Just . A192) $ initKey192 b
	keyLength    = Tagged 192

instance BlockCipher AES256 where
	blockSize    = Tagged 128
	encryptBlock = encrypt . unA256
	decryptBlock = decrypt . unA256
	buildKey b   = either (const Nothing) (Just . A256) $ initKey256 b
	keyLength    = Tagged 256

serializeKey :: Key -> ByteString
serializeKey (Key v)
	| V.length v == 176 = B.pack $ map (V.unsafeIndex v) [0..15]
	| V.length v == 208 = B.pack $ map (V.unsafeIndex v) [0..23]
	| otherwise         = B.pack $ map (V.unsafeIndex v) [0..31]

instance Serialize AES128 where
	put = putByteString . serializeKey . unA128
	get = do
		raw <- getByteString (128 `div` 8)
		case buildKey raw of
			Nothing -> fail "Invalid raw key material."
			Just k  -> return k

instance Serialize AES192 where
	put = putByteString . serializeKey . unA192
	get = do
		raw <- getByteString (192 `div` 8)
		case buildKey raw of
			Nothing -> fail "Invalid raw key material."
			Just k  -> return k

instance Serialize AES256 where
	put = putByteString . serializeKey . unA256
	get = do
		raw <- getByteString (256 `div` 8)
		case buildKey raw of
			Nothing -> fail "Invalid raw key material."
			Just k  -> return k

data Key = Key (Vector Word8)
	deriving (Show,Eq)

type IV = B.ByteString

type AESState = MutableByteArray RealWorld

{- | encrypt using CBC mode
 - IV need to be 16 bytes and the data to encrypt a multiple of 16 bytes -}
encryptCBC :: Key -> IV -> B.ByteString -> B.ByteString
encryptCBC key iv b
	| B.length iv /= 16        = error "invalid IV length"
	| B.length b `mod` 16 /= 0 = error "invalid data length"
	| otherwise                = B.concat $ encryptIter iv (makeChunks b)
			encryptIter _   []     = []
			encryptIter iv' (x:xs) =
				let r = coreEncrypt key $ B.pack $ B.zipWith xor iv' x in
				r : encryptIter r xs

{- | encrypt using simple EBC mode -}
encrypt :: Key -> B.ByteString -> B.ByteString
encrypt key b
	| B.length b `mod` 16 /= 0 = error "invalid data length"
	| otherwise                = B.concat $ doChunks (coreEncrypt key) b

{- | decrypt using CBC mode
 - IV need to be 16 bytes and the data to decrypt a multiple of 16 bytes -}
decryptCBC :: Key -> IV -> B.ByteString -> B.ByteString
decryptCBC key iv b
	| B.length iv /= 16        = error "invalid IV length"
	| B.length b `mod` 16 /= 0 = error "invalid data length"
	| otherwise                = B.concat $ decryptIter iv (makeChunks b)
			decryptIter _   []     = []
			decryptIter iv' (x:xs) =
				let r = B.pack $ B.zipWith xor iv' $ coreDecrypt key x in
				r : decryptIter x xs

{- | decrypt using simple EBC mode -}
decrypt :: Key -> B.ByteString -> B.ByteString
decrypt key b
	| B.length b `mod` 16 /= 0 = error "invalid data length"
	| otherwise                = B.concat $ doChunks (coreDecrypt key) b

doChunks :: (B.ByteString -> B.ByteString) -> B.ByteString -> [B.ByteString]
doChunks f b =
	let (x, rest) = B.splitAt 16 b in
	if B.length rest >= 16
		then f x : doChunks f rest
		else [ f x ]

makeChunks :: B.ByteString -> [B.ByteString]
makeChunks = doChunks id

newAESState :: IO AESState
newAESState = newAlignedPinnedByteArray 16 16

coreEncrypt :: Key -> ByteString -> ByteString
coreEncrypt key input = B.unsafeCreate (B.length input) $ \ptr -> do
	st <- newAESState
	swapBlock input st
	aesMain (getNbr key) key st
	swapBlockInv st ptr

coreDecrypt :: Key -> ByteString -> ByteString
coreDecrypt key input = B.unsafeCreate (B.length input) $ \ptr -> do
	st <- newAESState
	swapBlock input st
	aesMainInv (getNbr key) key st
	swapBlockInv st ptr

getNbr :: Key -> Int
getNbr (Key v)
	| V.length v == 176 = 10
	| V.length v == 208 = 12
	| otherwise         = 14

initKey128, initKey192, initKey256 :: ByteString -> Either String Key

initKey128 = initKey 16
initKey192 = initKey 24
initKey256 = initKey 32

initKey :: Int -> ByteString -> Either String Key
initKey sz b
	| B.length b == sz = Right $ coreExpandKey (V.generate sz $ B.unsafeIndex b)
	| otherwise        = Left "wrong key size"

aesMain :: Int -> Key -> AESState -> IO ()
aesMain nbr key blk = do
	addRoundKey key 0 blk
	forM_ [1..nbr-1] $ \i -> do
		shiftRows blk >> mixColumns blk >> addRoundKey key i blk
	shiftRows blk >> addRoundKey key nbr blk

aesMainInv :: Int -> Key -> AESState -> IO ()
aesMainInv nbr key blk = do
	addRoundKey key nbr blk
	forM_ (reverse [1..nbr-1]) $ \i -> do
		shiftRowsInv blk >> addRoundKey key i blk >> mixColumnsInv blk
	shiftRowsInv blk >> addRoundKey key 0 blk

{-# INLINE swapIndex #-}
swapIndex :: Int -> Int
swapIndex 0 = 0
swapIndex 1 = 4
swapIndex 2 = 8
swapIndex 3 = 12
swapIndex 4 = 1
swapIndex 5 = 5
swapIndex 6 = 9
swapIndex 7 = 13
swapIndex 8 = 2
swapIndex 9 = 6
swapIndex 10 = 10
swapIndex 11 = 14
swapIndex 12 = 3
swapIndex 13 = 7
swapIndex 14 = 11
swapIndex 15 = 15
swapIndex _  = 0

coreExpandKey :: Vector Word8 -> Key
coreExpandKey vkey
	| V.length vkey == 16 = Key (V.concat (ek0 : ekN16))
	| V.length vkey == 24 = Key (V.concat (ek0 : ekN24))
	| V.length vkey == 32 = Key (V.concat (ek0 : ekN32))
	| otherwise           = Key (V.empty)
		ek0 = vkey
		ekN16 = reverse $ snd $ foldl (generateFold generate16) (ek0, []) [1..10]

		ekN24 =
			let (lk, acc) = foldl (generateFold generate24) (ek0, []) [1..7] in
			let nk = generate16 lk 8 in
			reverse (nk : acc)

		ekN32 =
			let (lk, acc) = foldl (generateFold generate32) (ek0, []) [1..6] in
			let nk = generate16 lk 7 in
			reverse (nk : acc)

		generateFold gen (prevk, accK) it = let nk = gen prevk it in (nk, nk : accK)

		generate16 prevk it =
			let len = V.length prevk in
			let v0 = cR0 it (V.unsafeIndex prevk $ len - 4)
			                (V.unsafeIndex prevk $ len - 3)
			                (V.unsafeIndex prevk $ len - 2)
			                (V.unsafeIndex prevk $ len - 1) in
			let eg0@(e0,e1,e2,e3)     = xorVector prevk 0 v0   in
			let eg1@(e4,e5,e6,e7)     = xorVector prevk 4 eg0  in
			let eg2@(e8,e9,e10,e11)   = xorVector prevk 8 eg1  in
			let     (e12,e13,e14,e15) = xorVector prevk 12 eg2 in
			V.fromList [e0,e1,e2,e3,e4,e5,e6,e7,e8,e9,e10,e11,e12,e13,e14,e15]

		generate24 prevk it =
			let len = V.length prevk in
			let v0 = cR0 it (V.unsafeIndex prevk $ len - 4)
			                (V.unsafeIndex prevk $ len - 3)
			                (V.unsafeIndex prevk $ len - 2)
			                (V.unsafeIndex prevk $ len - 1) in
			let eg0@(e0,e1,e2,e3)     = xorVector prevk 0 v0   in
			let eg1@(e4,e5,e6,e7)     = xorVector prevk 4 eg0  in
			let eg2@(e8,e9,e10,e11)   = xorVector prevk 8 eg1  in
			let eg3@(e12,e13,e14,e15) = xorVector prevk 12 eg2 in
			let eg4@(e16,e17,e18,e19) = xorVector prevk 16 eg3 in
			let     (e20,e21,e22,e23) = xorVector prevk 20 eg4 in
			V.fromList [e0,e1,e2,e3,e4,e5,e6,e7,e8,e9,e10,e11,e12,e13,e14,e15,e16,e17,e18,e19,e20,e21,e22,e23]

		generate32 prevk it =
			let len = V.length prevk in
			let v0 = cR0 it (V.unsafeIndex prevk $ len - 4)
			                (V.unsafeIndex prevk $ len - 3)
			                (V.unsafeIndex prevk $ len - 2)
			                (V.unsafeIndex prevk $ len - 1) in
			let eg0@(e0,e1,e2,e3)     = xorVector prevk 0 v0   in
			let eg1@(e4,e5,e6,e7)     = xorVector prevk 4 eg0  in
			let eg2@(e8,e9,e10,e11)   = xorVector prevk 8 eg1  in
			let eg3@(e12,e13,e14,e15) = xorVector prevk 12 eg2 in
			let eg4@(e16,e17,e18,e19) = xorSboxVector prevk 16 eg3 in
			let eg5@(e20,e21,e22,e23) = xorVector prevk 20 eg4 in
			let eg6@(e24,e25,e26,e27) = xorVector prevk 24 eg5 in
			let     (e28,e29,e30,e31) = xorVector prevk 28 eg6 in
			V.fromList [e0,e1,e2,e3,e4,e5,e6,e7,e8,e9,e10,e11,e12,e13,e14,e15,e16

		xorVector k i (t0,t1,t2,t3) =
			( V.unsafeIndex k (i+0) `xor` t0
			, V.unsafeIndex k (i+1) `xor` t1
			, V.unsafeIndex k (i+2) `xor` t2
			, V.unsafeIndex k (i+3) `xor` t3

		xorSboxVector k i (t0,t1,t2,t3) =
			( V.unsafeIndex k (i+0) `xor` sbox t0
			, V.unsafeIndex k (i+1) `xor` sbox t1
			, V.unsafeIndex k (i+2) `xor` sbox t2
			, V.unsafeIndex k (i+3) `xor` sbox t3

		cR0 it r0 r1 r2 r3 =
			(sbox r1 `xor` rcon it, sbox r2, sbox r3, sbox r0)

{-# INLINE shiftRows #-}
shiftRows :: AESState -> IO ()
shiftRows blk = do
	r32 blk 0 >>= w32 blk 0 . msbox32
	r32 blk 1 >>= \t1 -> w32 blk 1 $ rotateR (msbox32 t1) 8
	r32 blk 2 >>= \t2 -> w32 blk 2 $ rotateR (msbox32 t2) 16
	r32 blk 3 >>= \t3 -> w32 blk 3 $ rotateR (msbox32 t3) 24

{-# INLINE addRoundKey #-}
addRoundKey :: Key -> Int -> AESState -> IO ()
addRoundKey (Key key) i blk = forM_ [0..15] $ \n -> do
	r8 blk n >>= \v1 -> w8 blk n $ v1 `xor` V.unsafeIndex key (16 * i + swapIndex n)

{-# INLINE mixColumns #-}
mixColumns :: AESState -> IO ()
mixColumns state = pr 0 >> pr 1 >> pr 2 >> pr 3
		{-# INLINE pr #-}
		pr i = do
			cpy0 <- r8 state (0 * 4 + i)
			cpy1 <- r8 state (1 * 4 + i)
			cpy2 <- r8 state (2 * 4 + i)
			cpy3 <- r8 state (3 * 4 + i)

			let state0 = gm2 cpy0 `xor` gm1 cpy3 `xor` gm1 cpy2 `xor` gm3 cpy1
			let state4 = gm2 cpy1 `xor` gm1 cpy0 `xor` gm1 cpy3 `xor` gm3 cpy2
			let state8 = gm2 cpy2 `xor` gm1 cpy1 `xor` gm1 cpy0 `xor` gm3 cpy3
			let state12 = gm2 cpy3 `xor` gm1 cpy2 `xor` gm1 cpy1 `xor` gm3 cpy0

			w8 state (0 * 4 + i) state0
			w8 state (1 * 4 + i) state4
			w8 state (2 * 4 + i) state8
			w8 state (3 * 4 + i) state12
		{-# INLINE gm1 #-}
		gm1 a = a

{-# INLINE shiftRowsInv #-}
shiftRowsInv :: AESState -> IO ()
shiftRowsInv blk = do
	r32 blk 0 >>= w32 blk 0 . mrsbox32
	r32 blk 1 >>= \t1 -> w32 blk 1 $ mrsbox32 $ rotateL t1 8
	r32 blk 2 >>= \t2 -> w32 blk 2 $ mrsbox32 $ rotateL t2 16
	r32 blk 3 >>= \t3 -> w32 blk 3 $ mrsbox32 $ rotateL t3 24

{-# INLINE mixColumnsInv #-}
mixColumnsInv :: AESState -> IO ()
mixColumnsInv state = pr 0 >> pr 1 >> pr 2 >> pr 3
		{-# INLINE pr #-}
		pr i = do
			cpy0 <- r8 state (0 * 4 + i)
			cpy1 <- r8 state (1 * 4 + i)
			cpy2 <- r8 state (2 * 4 + i)
			cpy3 <- r8 state (3 * 4 + i)

			let state0  = gm14 cpy0 `xor` gm9 cpy3 `xor` gm13 cpy2 `xor` gm11 cpy1
			let state4  = gm14 cpy1 `xor` gm9 cpy0 `xor` gm13 cpy3 `xor` gm11 cpy2
			let state8  = gm14 cpy2 `xor` gm9 cpy1 `xor` gm13 cpy0 `xor` gm11 cpy3
			let state12 = gm14 cpy3 `xor` gm9 cpy2 `xor` gm13 cpy1 `xor` gm11 cpy0

			w8 state (0 * 4 + i) state0
			w8 state (1 * 4 + i) state4
			w8 state (2 * 4 + i) state8
			w8 state (3 * 4 + i) state12

{-# INLINE r8 #-}
r8 :: AESState -> Int -> IO Word8
r8 = readByteArray

{-# INLINE w8 #-}
w8 :: AESState -> Int -> Word8 -> IO ()
w8 = writeByteArray

{-# INLINE r32 #-}
r32 :: AESState -> Int -> IO Word32
r32 = readByteArray

{-# INLINE w32 #-}
w32 :: AESState -> Int -> Word32 -> IO ()
w32 = writeByteArray

msbox32 :: Word32 -> Word32
msbox32 w = sbox4 a .|. sbox3 b .|. sbox2 c .|. sbox1 d
		a = fromIntegral (w `shiftR` 24)
		b = fromIntegral (w `shiftR` 16)
		c = fromIntegral (w `shiftR` 8)
		d = fromIntegral w

mrsbox32 :: Word32 -> Word32
mrsbox32 w = fromIntegral (rsbox a) `shiftL` 24 .|.
             fromIntegral (rsbox b) `shiftL` 16 .|.
             fromIntegral (rsbox c) `shiftL` 8  .|.
             fromIntegral (rsbox d)
		a = fromIntegral (w `shiftR` 24)
		b = fromIntegral (w `shiftR` 16)
		c = fromIntegral (w `shiftR` 8)
		d = fromIntegral w

{-# INLINE swapBlock #-}
swapBlock :: ByteString -> AESState -> IO ()
swapBlock b blk = do -- V.generate 16 (\i -> B.unsafeIndex b $ swapIndex i)
	forM_ [0..15] $ \i -> w8 blk i $ B.unsafeIndex b $ swapIndex i

{-# INLINE swapBlockInv #-}
swapBlockInv :: AESState -> Ptr Word8 -> IO ()
swapBlockInv blk ptr = forM_ [0..15] $ \i -> r8 blk (swapIndex i) >>= pokeByteOff ptr i

{-# INLINE sbox #-}
sbox :: Word8 -> Word8
sbox (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =

{-# INLINE sbox1 #-}
{-# INLINE sbox2 #-}
{-# INLINE sbox3 #-}
{-# INLINE sbox4 #-}
sbox1, sbox2, sbox3, sbox4 :: Word8 -> Word32
sbox1 (W8# w) = W32# (indexWord32OffAddr# table (word2Int# w))
	where !(Table table) = sbox1Tab
sbox2 (W8# w) = W32# (indexWord32OffAddr# table (word2Int# w))
	where !(Table table) = sbox2Tab
sbox3 (W8# w) = W32# (indexWord32OffAddr# table (word2Int# w))
	where !(Table table) = sbox3Tab
sbox4 (W8# w) = W32# (indexWord32OffAddr# table (word2Int# w))
	where !(Table table) = sbox4Tab

sbox1Tab, sbox2Tab, sbox3Tab, sbox4Tab :: Table
sbox1Tab = if littleEndian then sbox_x000 else sbox_000x
sbox2Tab = if littleEndian then sbox_0x00 else sbox_00x0
sbox3Tab = if littleEndian then sbox_00x0 else sbox_0x00
sbox4Tab = if littleEndian then sbox_000x else sbox_x000

data Table = Table !Addr#

sbox_000x, sbox_00x0, sbox_0x00, sbox_x000 :: Table
sbox_000x = Table

sbox_00x0 = Table

sbox_0x00 = Table

sbox_x000 = Table

{-# INLINE rsbox #-}
rsbox :: Word8 -> Word8
rsbox (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =

{-# INLINE rcon #-}
rcon :: Int -> Word8
rcon (I# i) = W8# (indexWord8OffAddr# table (i `remInt#` 51#))
	where !table =

{-# INLINE gm2 #-}
{-# INLINE gm3 #-}
{-# INLINE gm9 #-}
{-# INLINE gm11 #-}
{-# INLINE gm13 #-}
{-# INLINE gm14 #-}
gm2, gm3, gm9, gm11, gm13, gm14 :: Word8 -> Word8
gm2 (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =

gm3 (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =

gm9 (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =

gm11 (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =

gm13 (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =

gm14 (W8# w) = W8# (indexWord8OffAddr# table (word2Int# w))
	where !table =