module Data.HMAC(
   
   hmac, hmac_sha1, hmac_md5,
   
   HashMethod(HashMethod, digest, input_blocksize),
   ) where
import Data.Digest.SHA1 as SHA1
import Data.Digest.MD5 as MD5
import Data.Word (Word32)
import Data.Bits (shiftR, xor, bitSize, Bits)
import Codec.Utils (Octet)
data HashMethod =
    HashMethod { 
                 digest :: [Octet] -> [Octet],
                
                 input_blocksize :: Int}
sha1_hm = HashMethod (w160_to_w8s . SHA1.hash) 512
md5_hm = HashMethod MD5.hash 512
hmac_sha1 :: [Octet] 
          -> [Octet] 
          -> [Octet] 
hmac_sha1 = hmac sha1_hm
hmac_md5 :: [Octet] 
         -> [Octet] 
         -> [Octet] 
hmac_md5 = hmac md5_hm
w160_to_w8s :: Word160 -> [Octet]
w160_to_w8s w = concat $ map w32_to_w8s (w160_to_w32s w)
w160_to_w32s :: Word160 -> [Word32]
w160_to_w32s (Word160 a b c d e) = a : b : c : d : e : []
w32_to_w8s :: Word32 -> [Octet]
w32_to_w8s a = (fromIntegral (shiftR a 24)) :
               (fromIntegral (shiftR a 16)) :
               (fromIntegral (shiftR a 8)) :
               (fromIntegral a) : []
hmac :: HashMethod 
        -> [Octet] 
        -> [Octet] 
        -> [Octet] 
hmac h uk m = hash (opad ++ (hash (ipad ++ m)))
    where hash = digest h
          (opad, ipad) = process_pads key
                           (make_start_pad bs opad_pattern)
                           (make_start_pad bs ipad_pattern)
          bs = input_blocksize h
          key = key_from_user h uk
key_from_user :: HashMethod -> [Octet] -> [Octet]
key_from_user h uk =
    case (compare (bitcount uk) (input_blocksize h)) of
      GT -> fill_key ((digest h) uk)
      LT -> fill_key uk
      EQ -> uk
    where fill_key kd =
              kd ++ (take (((input_blocksize h)  (bitcount kd)) `div` 8)
                     (repeat 0x0))
process_pads :: [Octet] 
             -> [Octet] 
             -> [Octet] 
             -> ([Octet], [Octet]) 
process_pads ks os is =
    unzip $ zipWith3 (\k o i -> (k `xor` o, k `xor` i)) ks os is
make_start_pad :: Int -> Octet -> [Octet]
make_start_pad size pad = take (size `div` (bitSize pad)) $ repeat pad
opad_pattern = 0x5c :: Octet
ipad_pattern = 0x36 :: Octet
bitcount :: [Octet] -> Int
bitcount k = (length k) * (bitSize (head k))