-- CSE (Cryptographic Service Engine) emulation implementation
module Codec.Automotive.CSE (
  M1, unM1, makeM1, extractM1,
  M2, unM2, makeM2,
  M3, unM3, makeM3,
  M4, unM4, makeM4, extractM4, makeM4', extractM4',
  M5, unM5, makeM5,

  K1, K1', makeK1,
  K2, K2', makeK2,
  K3, K3', makeK3,
  K4, K4', makeK4,

  UID, unUID, makeUID,

  Derived, unDerived,
  kdf, keyUpdateEncC, keyUpdateMacC,

  DerivedCipher, derivedCipher,

  KeyAuthUse, Auth, NotAuth,
  makeKeyAuthUse, unKeyAuthUse,

  UpdateC, Enc, Mac,
  ) where

import Control.Monad (MonadPlus, guard)
import Data.Monoid ((<>), mconcat, Endo (..))
import Data.Bits (shiftL, shiftR, (.&.), (.|.))
import Data.Word (Word8, Word32)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.Serialize.Get (runGet, getWord64be)
import Data.Serialize.Put (runPut, putWord64be)
import Numeric (showHex)

import qualified Data.ByteArray as B
import Crypto.Cipher.Types (cipherInit, ecbEncrypt, cbcEncrypt, nullIV, ecbDecrypt)
import Crypto.Cipher.AES (AES128)
import Crypto.Error (eitherCryptoError)

import Backport.Crypto.MAC.CMAC (CMAC(..), cmac)
import Backport.Crypto.ConstructHash.MiyaguchiPreneel (MiyaguchiPreneel(..), mp)


hdump :: ByteString -> String
hdump = (`appEndo` "") . mconcat . map (Endo . showW8) . BS.unpack
  where
    showW8 w
      | w < 16     =  ('0' :) . showHex w
      | otherwise  =  showHex w


data Enc
data Mac

newtype UpdateC c =
  UpdateC ByteString
  deriving Eq

instance Show (UpdateC c) where
  show (UpdateC c) = unwords ["UpdateC", hdump c]

keyUpdateEncC :: UpdateC Enc
keyUpdateEncC =
  UpdateC . runPut $
  putWord64be 0x0101534845008000 >>
  putWord64be 0x00000000000000B0
  ---  0x010153484500800000000000000000B0

keyUpdateMacC :: UpdateC Mac
keyUpdateMacC =
  UpdateC . runPut $
  putWord64be 0x0102534845008000 >>
  putWord64be 0x00000000000000B0
  --  0x010253484500800000000000000000B0

data Auth
data NotAuth

newtype KeyAuthUse k =
  KeyAuthUse ByteString
  deriving Eq

makeKeyAuthUse :: MonadPlus m => ByteString -> m (KeyAuthUse k)
makeKeyAuthUse k = do
  guard $ BS.length k == 16
  return $ KeyAuthUse k

unKeyAuthUse :: KeyAuthUse k -> ByteString
unKeyAuthUse (KeyAuthUse bs) = bs

newtype Derived k c =
  Derived ByteString
  deriving Eq

instance Show (Derived k c) where
  show (Derived k) = unwords ["DerivedCipher", hdump k]

unDerived :: Derived k c -> ByteString
unDerived (Derived k) = k

kdf :: KeyAuthUse k -> UpdateC c -> Derived k c
kdf (KeyAuthUse k) (UpdateC c) = Derived . B.convert $ chashGetBytes (mp $ k <> c :: MiyaguchiPreneel AES128)

kdfEnc :: KeyAuthUse k -> Derived k Enc
kdfEnc = (`kdf` keyUpdateEncC)

kdfMac :: KeyAuthUse k -> Derived k Mac
kdfMac = (`kdf` keyUpdateMacC)

newtype DerivedCipher k c = DerivedCipher AES128

derivedCipher :: Derived k c -> DerivedCipher k c
derivedCipher (Derived k) =
  DerivedCipher
  . either (error . ("Codec.Automotive.CSE.derivedCipher: internal error: " ++) . show) id
  -- assume refined length (16 byte) of miyaguchi-preneel AES128 result
  . eitherCryptoError $ cipherInit k

type K1' = Derived Auth Enc
type K1  = DerivedCipher Auth Enc

makeK1 :: KeyAuthUse Auth -- ^ AuthKey Data
       -> K1'             -- ^ Result Hash value
makeK1 = kdfEnc

type K2' = Derived Auth Mac
type K2  = DerivedCipher Auth Mac

makeK2 :: KeyAuthUse Auth -- ^ AuthKey Data
       -> K2'             -- ^ Result Hash value
makeK2 = kdfMac

type K3' = Derived NotAuth Enc
type K3  = DerivedCipher NotAuth Enc

makeK3 :: KeyAuthUse NotAuth -- ^ Key Data
       -> K3'                -- ^ Result Hash value
makeK3 = kdfEnc

type K4' = Derived NotAuth Mac
type K4  = DerivedCipher NotAuth Mac

makeK4 :: KeyAuthUse NotAuth -- ^ Key Data
       -> K4'                -- ^ Result Hash value
makeK4 = kdfMac


newtype UID = UID ByteString deriving Eq

instance Show UID where
  show (UID s) = unwords ["UID", hdump s]

unUID :: UID -> ByteString
unUID (UID u) = u

makeUID :: MonadPlus m => ByteString -> m UID
makeUID s = do
  guard $ BS.length s == 15
  return $ UID s

newtype M1 = M1 ByteString deriving Eq

instance Show M1 where
  show (M1 s) = unwords ["M1", hdump s]

makeM1 :: UID        -- ^ UID          - 15 octet
       -> Word8      -- ^ Key ID       -  4 bit
       -> Word8      -- ^ Auth key ID  -  4 bit
       -> M1
makeM1 (UID uid) kid akid = M1 $ uid <> BS.singleton (kid `shiftL` 4 .|. akid)

unM1 :: M1 -> ByteString
unM1 (M1 m1) = m1

extractM1 :: M1 -> (UID, Word8, Word8)
extractM1 (M1 m1) = (UID uid, lw `shiftR` 4, lw .&. 0x0F)
  where
    (uid, x) = BS.splitAt 15 m1
    lw = head $ BS.unpack x
    -- assume refined M1

newtype M2 = M2 ByteString deriving Eq

instance Show M2 where
  show (M2 s) = unwords ["M2", hdump s]

makeM2 :: K1                 -- ^ K1 value
       -> Word32             -- ^ Counter   - 28 bit
       -> Word8              -- ^ Key Flag  -  6 bit
       -> KeyAuthUse NotAuth -- ^ Key Data for AES128
       -> M2
makeM2 (DerivedCipher k1) counter flags (KeyAuthUse keyData) =
    M2 $ cbcEncrypt k1 nullIV plain
  where
    plain = (runPut $ do
                putWord64be $
                  fromIntegral counter `shiftL` 36 .|.
                  fromIntegral flags `shiftL` 30
                  ---  fromIntegral (flags `shiftR` 1) `shiftL` 31  ---  SHE standard
                putWord64be 0)
            <> keyData

unM2 :: M2 -> ByteString
unM2 (M2 m2) = m2

newtype M3 = M3 ByteString deriving Eq

instance Show M3 where
  show (M3 s) = unwords ["M3", hdump s]

makeM3 :: K2
       -> M1
       -> M2
       -> M3
makeM3 (DerivedCipher k2) (M1 m1) (M2 m2) = M3 . B.convert . cmacGetBytes . cmac k2 $ m1 <> m2

unM3 :: M3 -> ByteString
unM3 (M3 m3) = m3

newtype M4 = M4 ByteString deriving Eq

instance Show M4 where
  show (M4 s) = unwords ["M4", hdump s]

makeM4' :: K3
        -> M1
        -> Word32
        -> M4
makeM4' (DerivedCipher k3) (M1 m1) counter =
    M4 $ m1 <> ecbEncrypt k3 p2
  where
    p2 = runPut $ do
      putWord64be $
        fromIntegral counter `shiftL` 36 .|.
        1                    `shiftL` 35
      putWord64be 0

makeM4 :: K3
       -> UID
       -> Word8
       -> Word8
       -> Word32
       -> M4
makeM4 k3 uid kid akid counter =
  makeM4' k3 (makeM1 uid kid akid) counter

unM4 :: M4 -> ByteString
unM4 (M4 m4) = m4

extractM4' :: K3 -> M4 -> (M1, Word32)
extractM4' (DerivedCipher k3) (M4 m4) = (M1 m1, fromIntegral $ w64 `shiftR` 36)
  where
    (m1, m4') = BS.splitAt 16 m4
    w64 = either (error . ("Codec.Automotive.CSE.extractM4: internal error: " ++)) id
          . runGet getWord64be $ ecbDecrypt k3 m4'

extractM4 :: K3 -> M4 -> ((UID, Word8, Word8), Word32)
extractM4 k3 m4 = (extractM1 m1, counter)
  where (m1, counter) = extractM4' k3 m4

newtype M5 = M5 ByteString deriving Eq

instance Show M5 where
  show (M5 s) = unwords ["M5", hdump s]

makeM5 :: K4
       -> M4
       -> M5
makeM5 (DerivedCipher k4) (M4 m4) = M5 . B.convert . cmacGetBytes $ cmac k4 m4

unM5 :: M5 -> ByteString
unM5 (M5 m4) = m4