-- | Compression stage. Padding - extension - compression.
--
-- See <http://en.wikipedia.org/wiki/Sha1> for the real thing.
--
-- Forward and reverse of a single round is implemented, to perform the full
-- thing just foldl over [0..79].
module Data.Digest.XSHA1.Compress where
import qualified Data.Vector as V
import Data.Vector (Vector, (!))
import Data.Bits
import qualified Test.QuickCheck as Q
import Numeric.Taint.Word32


type Regs = (N,N,N,N,N)

consts :: Regs
-- ^ Initial values for registers a..e.
--
-- Also get added to the registers at the end.
consts = (N 0x67452301, N 0xefcdab89, N 0x98badcfe, N 0x10325476, N 0xc3d2e1f0)


-- | Round function f, with round number and three of the registers.
f :: Int -> N -> N -> N -> N
f r b c d | r<20 = b .&. c .|. complement b .&. d 
          | r<40 = b `xor` c `xor` d
          | r<60 = b .&. c .|. c .&. d .|. d .&. b
          | r<80 = b `xor` c `xor` d

-- | Round constant k.
k :: Int -> N
k r | r<20 = N 0x5a827999
    | r<40 = N 0x6ed9eba1
    | r<60 = N 0x8f1bbcdc
    | r<80 = N 0xca62c1d6

-- | One round of the XSHA1 compression function.
iter :: Vector N -> Regs -> Int -> Regs
iter xs (a,b,c,d,e) r = (t, a, rotate b 30, c, d) where
    t = e + f r b c d + (xs!r) + (k r) + rotate a 5

-- | Inverse of the round function.
reti :: Vector N -> Regs -> Int -> Regs
reti xs (a',b',c',d',e') r = (a,b,c,d,e) where
    e = a' - f r b c d - (xs!r) - k r - rotate a 5
    a = b'
    b = rotate c' (-30)
    c = d'
    d = e'
    

-- | Does the last step, a weird addition thrown in for obfuscation.
finalize :: Regs -> Regs
finalize (a,b,c,d,e) =
    let (aa,bb,cc,dd,ee) = consts
    in (a+aa,b+bb,c+cc,d+dd,e+ee)

-- | Inverse of "finalize".
unfinalize :: Regs -> Regs
unfinalize (a,b,c,d,e) =
    let (aa,bb,cc,dd,ee) = consts
    in (a-aa,b-bb,c-cc,d-dd,e-ee)

-- | Just checking that it actually is an inverse...
--
-- Hm, this check kind of sucks - it didn't catch a negative sign bug. Welp.
instance Q.Arbitrary N where
    arbitrary = do
        int <- Q.choose (0, 0xFFFFFFFF) :: Q.Gen Int
        return . N . read . show $ int
test = Q.quickCheck (\t-> do
        r <- Q.choose (0, 79)
        xs <- V.fromList `fmap` Q.vector 80 :: Q.Gen (Vector N)
        return $ (reti xs (iter xs t r) r) == t)