-- CSE (Cryptographic Service Engine) emulation implementation
module Codec.Automotive.CSE.Internal (
  M1, unM1, makeM1, extractM1, refineM1,
  M2, unM2, makeM2, extractM2, refineM2,
  M3, unM3, makeM3,
  M4, unM4, makeM4, extractM4, makeM4', extractM4',
  M5, unM5, makeM5,

  K1, K1', makeK1,
  K2, K2', makeK2,
  K3, K3', makeK3,
  K4, K4', makeK4,

  UID, unUID, makeUID,

  Derived, unDerived,
  kdf, keyUpdateEncC, keyUpdateMacC,

  DerivedCipher, derivedCipher,

  KeyAuthUse, Auth, NotAuth,
  makeKeyAuthUse, unKeyAuthUse,

  UpdateC, Enc, Mac,

  unsafeMakeDerived, unsafeMakeDerivedCipher,
  ) where

import Control.Monad (MonadPlus, guard, liftM)
import Data.Monoid ((<>), mconcat, Endo (..))
import Data.Bits (shiftL, shiftR, (.&.), (.|.))
import Data.Word (Word8, Word32)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short as Short
import Data.Serialize.Get (runGet, getWord64be)
import Data.Serialize.Put (runPut, putWord64be)
import Numeric (showHex)

import qualified Data.ByteArray as B
import Crypto.Cipher.Types (cipherInit, ecbEncrypt, cbcEncrypt, nullIV, ecbDecrypt, cbcDecrypt)
import Crypto.Cipher.AES (AES128)
import Crypto.Error (eitherCryptoError)

import Backport.Crypto.MAC.CMAC (CMAC(..), cmac)
import Backport.Crypto.ConstructHash.MiyaguchiPreneel (MiyaguchiPreneel(..), mp)


dump8 :: [Word8] -> String
dump8 = (`appEndo` "") . mconcat . map (Endo . showW8)
  where
    showW8 w
      | w < 16     =  ('0' :) . showHex w
      | otherwise  =  showHex w

hdump :: ByteString -> String
hdump = dump8 . BS.unpack

hdump' :: ShortByteString -> String
hdump' = dump8 . Short.unpack


data Enc
data Mac

newtype UpdateC c =
  UpdateC ByteString
  deriving Eq

instance Show (UpdateC c) where
  show (UpdateC c) = unwords ["UpdateC", hdump c]

keyUpdateEncC :: UpdateC Enc
keyUpdateEncC =
  UpdateC . runPut $
  putWord64be 0x0101534845008000 >>
  putWord64be 0x00000000000000B0
  ---  0x010153484500800000000000000000B0

keyUpdateMacC :: UpdateC Mac
keyUpdateMacC =
  UpdateC . runPut $
  putWord64be 0x0102534845008000 >>
  putWord64be 0x00000000000000B0
  --  0x010253484500800000000000000000B0

data Auth
data NotAuth

newtype KeyAuthUse k =
  KeyAuthUse ShortByteString
  deriving Eq

makeKeyAuthUse :: MonadPlus m => ByteString -> m (KeyAuthUse k)
makeKeyAuthUse k = do
  guard $ BS.length k == 16
  return . KeyAuthUse $ Short.toShort k

unKeyAuthUse :: KeyAuthUse k -> ByteString
unKeyAuthUse (KeyAuthUse bs) = Short.fromShort bs

newtype Derived k c =
  Derived ShortByteString
  deriving Eq

instance Show (Derived k c) where
  show (Derived k) = unwords ["DerivedCipher", hdump' k]

-- | Specifying derived key directly.
unsafeMakeDerived :: MonadPlus m => ByteString -> m (Derived k c)
unsafeMakeDerived k = do
  guard $ BS.length k == 16
  return . Derived $ Short.toShort k

unDerived :: Derived k c -> ByteString
unDerived (Derived k) = Short.fromShort k

kdf :: KeyAuthUse k -> UpdateC c -> Derived k c
kdf k (UpdateC c) =
  Derived . Short.toShort . B.convert
  $ chashGetBytes (mp $ unKeyAuthUse k <> c :: MiyaguchiPreneel AES128)

kdfEnc :: KeyAuthUse k -> Derived k Enc
kdfEnc = (`kdf` keyUpdateEncC)

kdfMac :: KeyAuthUse k -> Derived k Mac
kdfMac = (`kdf` keyUpdateMacC)

newtype DerivedCipher k c = DerivedCipher AES128

derivedCipher :: Derived k c -> DerivedCipher k c
derivedCipher k =
  DerivedCipher
  . either (error . ("Codec.Automotive.CSE.derivedCipher: internal error: " ++) . show) id
  -- assume refined length (16 byte) of miyaguchi-preneel AES128 result
  . eitherCryptoError . cipherInit $ unDerived k

unsafeMakeDerivedCipher :: MonadPlus m => ByteString -> m (DerivedCipher k c)
unsafeMakeDerivedCipher = liftM derivedCipher . unsafeMakeDerived

type K1' = Derived Auth Enc
type K1  = DerivedCipher Auth Enc

makeK1 :: KeyAuthUse Auth -- ^ AuthKey Data
       -> K1'             -- ^ Result Hash value
makeK1 = kdfEnc

type K2' = Derived Auth Mac
type K2  = DerivedCipher Auth Mac

makeK2 :: KeyAuthUse Auth -- ^ AuthKey Data
       -> K2'             -- ^ Result Hash value
makeK2 = kdfMac

type K3' = Derived NotAuth Enc
type K3  = DerivedCipher NotAuth Enc

makeK3 :: KeyAuthUse NotAuth -- ^ Key Data
       -> K3'                -- ^ Result Hash value
makeK3 = kdfEnc

type K4' = Derived NotAuth Mac
type K4  = DerivedCipher NotAuth Mac

makeK4 :: KeyAuthUse NotAuth -- ^ Key Data
       -> K4'                -- ^ Result Hash value
makeK4 = kdfMac


newtype UID = UID ShortByteString deriving (Eq, Ord)

instance Show UID where
  show (UID s) = unwords ["UID", hdump' s]

unUID :: UID -> ByteString
unUID (UID u) = Short.fromShort u

makeUID :: MonadPlus m => ByteString -> m UID
makeUID s = do
  guard $ BS.length s == 15
  return . UID $ Short.toShort s

newtype M1 = M1 ByteString deriving Eq

instance Show M1 where
  show (M1 s) = unwords ["M1", hdump s]

unM1 :: M1 -> ByteString
unM1 (M1 m1) = m1

makeM1 :: UID        -- ^ UID          - 15 octet
       -> Word8      -- ^ Key ID       -  4 bit
       -> Word8      -- ^ Auth key ID  -  4 bit
       -> M1
makeM1 uid kid akid = M1 $ unUID uid <> BS.singleton (kid `shiftL` 4 .|. akid)

extractM1 :: M1 -> (UID, Word8, Word8)
extractM1 (M1 m1) = (UID $ Short.toShort uid, lw `shiftR` 4, lw .&. 0x0F)
  where
    (uid, x) = BS.splitAt 15 m1
    lw = head $ BS.unpack x
    -- assume refined M1

refineM1 :: MonadPlus m => ByteString -> m (M1, (UID, Word8, Word8))
refineM1 s = do
  let m1 = M1 s
      ps@(uid, kid, akid) = extractM1 m1
  guard $ makeM1 uid kid akid == m1
  return (m1, ps)

newtype M2 = M2 ByteString deriving Eq

instance Show M2 where
  show (M2 s) = unwords ["M2", hdump s]

unM2 :: M2 -> ByteString
unM2 (M2 m2) = m2

makeM2 :: K1                 -- ^ K1 value
       -> Word32             -- ^ Counter   - 28 bit
       -> Word8              -- ^ Key Flag  -  6 bit
       -> KeyAuthUse NotAuth -- ^ Key Data for AES128
       -> M2
makeM2 (DerivedCipher k1) counter flags keyData =
    M2 $ cbcEncrypt k1 nullIV plain
  where
    plain = (runPut $ do
                putWord64be $
                  fromIntegral counter `shiftL` 36 .|.
                  fromIntegral flags `shiftL` 30
                  ---  fromIntegral (flags `shiftR` 1) `shiftL` 31  ---  SHE standard
                putWord64be 0)
            <> unKeyAuthUse keyData

extractM2 :: K1
          -> M2
          -> (Word32, Word8, KeyAuthUse NotAuth)
extractM2 (DerivedCipher k1) (M2 m2) =
    (fromIntegral $ b64 `shiftR` 36,
     fromIntegral $ b64 `shiftR` 30 .&. 0x3f,
     key)
  where
    (dbits, rawKey) = BS.splitAt 16 $ cbcDecrypt k1 nullIV m2
    b64 = either (error . ("extractM2: " ++)) id $ runGet getWord64be dbits
    key =  maybe (error "extractM2: wrong M2 length?") id $ makeKeyAuthUse rawKey

refineM2 :: MonadPlus m
         => K1
         -> ByteString
         -> m (M2, (Word32, Word8, KeyAuthUse NotAuth))
refineM2 k1 s = do
  guard $ BS.length s == 32
  let m2 = M2 s
      ps@(counter, flags, key) = extractM2 k1 m2
  guard $ makeM2 k1 counter flags key == m2
  return (m2, ps)

newtype M3 = M3 ByteString deriving Eq

instance Show M3 where
  show (M3 s) = unwords ["M3", hdump s]

unM3 :: M3 -> ByteString
unM3 (M3 m3) = m3

makeM3 :: K2
       -> M1
       -> M2
       -> M3
makeM3 (DerivedCipher k2) (M1 m1) (M2 m2) = M3 . B.convert . cmacGetBytes . cmac k2 $ m1 <> m2

newtype M4 = M4 ByteString deriving Eq

instance Show M4 where
  show (M4 s) = unwords ["M4", hdump s]

unM4 :: M4 -> ByteString
unM4 (M4 m4) = m4

makeM4' :: K3
        -> M1
        -> Word32
        -> M4
makeM4' (DerivedCipher k3) (M1 m1) counter =
    M4 $ m1 <> ecbEncrypt k3 p2
  where
    p2 = runPut $ do
      putWord64be $
        fromIntegral counter `shiftL` 36 .|.
        1                    `shiftL` 35
      putWord64be 0

makeM4 :: K3
       -> UID
       -> Word8
       -> Word8
       -> Word32
       -> M4
makeM4 k3 uid kid akid counter =
  makeM4' k3 (makeM1 uid kid akid) counter

extractM4' :: K3 -> M4 -> (M1, Word32)
extractM4' (DerivedCipher k3) (M4 m4) = (M1 m1, fromIntegral $ w64 `shiftR` 36)
  where
    (m1, m4') = BS.splitAt 16 m4
    w64 = either (error . ("Codec.Automotive.CSE.extractM4: internal error: " ++)) id
          . runGet getWord64be $ ecbDecrypt k3 m4'

extractM4 :: K3 -> M4 -> ((UID, Word8, Word8), Word32)
extractM4 k3 m4 = (extractM1 m1, counter)
  where (m1, counter) = extractM4' k3 m4

newtype M5 = M5 ByteString deriving Eq

instance Show M5 where
  show (M5 s) = unwords ["M5", hdump s]

unM5 :: M5 -> ByteString
unM5 (M5 m4) = m4

makeM5 :: K4
       -> M4
       -> M5
makeM5 (DerivedCipher k4) (M4 m4) = M5 . B.convert . cmacGetBytes $ cmac k4 m4