module Crypto.Hash.Keccak where import Data.Bits import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import Data.Word type State = [[Word64]] emptyState :: State emptyState = replicate 5 (replicate 5 0) -- truncated when w is smaller than 64 roundConstants :: [Word64] roundConstants = [ 0x0000000000000001, 0x0000000000008082, 0x800000000000808A , 0x8000000080008000, 0x000000000000808B, 0x0000000080000001 , 0x8000000080008081, 0x8000000000008009, 0x000000000000008A , 0x0000000000000088, 0x0000000080008009, 0x000000008000000A , 0x000000008000808B, 0x800000000000008B, 0x8000000000008089 , 0x8000000000008003, 0x8000000000008002, 0x8000000000000080 , 0x000000000000800A, 0x800000008000000A, 0x8000000080008081 , 0x8000000000008080, 0x0000000080000001, 0x8000000080008008 ] rotationConstants :: [[Int]] rotationConstants = [ [ 0, 36, 3, 41, 18 ] , [ 1, 44, 10, 45, 2 ] , [ 62, 6, 43, 15, 61 ] , [ 28, 55, 25, 21, 56 ] , [ 27, 20, 39, 8, 14 ] ] paddingKeccak :: BS.ByteString -> [Word8] paddingKeccak = multiratePadding 0x1 paddingSha3 :: BS.ByteString -> [Word8] paddingSha3 = multiratePadding 0x6 multiratePadding :: Word -> BS.ByteString -> [Word8] multiratePadding pad input = BS.unpack . BS.append input $ if padlen == 1 then BS.pack [0x81] else BS.pack $ 0x01 : replicate (padlen - 2) 0x00 ++ [0x80] where bitRateBytes = 136 -- TODO: modulo bitRateBytes? usedBytes = BS.length input padlen = bitRateBytes - mod usedBytes bitRateBytes -- r (bitrate) = 1088 -- c (capacity) = 512 keccak256 :: BS.ByteString -> BS.ByteString keccak256 = squeeze 32 . absorb . toBlocks 136 . paddingKeccak -- Sized inputs to this? toBlocks :: Int -> [Word8] -> [[Word64]] toBlocks _ [] = [] toBlocks sizeInBytes input = let (a, b) = splitAt sizeInBytes input in toLanes a : toBlocks sizeInBytes b where toLanes :: [Word8] -> [Word64] toLanes [] = [] toLanes octets = let (a, b) = splitAt 8 octets in toLane a : toLanes b toLane :: [Word8] -> Word64 toLane octets = foldl1 xor $ zipWith (\offset octet -> shiftL (fromIntegral octet) (offset * 8)) [0..7] octets -- for each block Pi in P -- S[x,y] = S[x,y] xor Pi[x+5*y], for (x,y) such that x+5*y < r/w -- S = Keccak-f[r+c](S) -- TODO support `input` larger than single block absorb :: [[Word64]] -> State absorb = foldl absorbBlock emptyState absorbBlock :: State -> [Word64] -> State absorbBlock state input = keccakF state' where r = 1088 w = 64 state' = [ [ if x + 5 * y < div r w then ((state !! x) !! y) `xor` (input !! (x + 5 * y)) else (state !! x) !! y | y <- [0..4] ] | x <- [0..4] ] -- # Squeezing phase -- Z = empty string -- while output is requested -- Z = Z || S[x,y], for (x,y) such that x+5*y < r/w -- S = Keccak-f[r+c](S) -- TODO handle longer outputs squeeze :: Int -> State -> BS.ByteString squeeze len = BS.pack . take len . stateToBytes where comma = 44 stateToBytes :: State -> [Word8] stateToBytes state = concat [ laneToBytes (state !! x !! y) | y <- [0..4] , x <- [0..4] ] laneToBytes :: Word64 -> [Word8] laneToBytes word = fmap (\x -> fromIntegral (shiftR word (x * 8) .&. 0xFF)) [0..7] keccakF :: State -> State keccakF state = foldl (\s r -> iota r . chi . rhoPi $ theta s) state [0 .. (rounds - 1)] where rounds = 24 -- # θ step -- C[x] = A[x,0] xor A[x,1] xor A[x,2] xor A[x,3] xor A[x,4], for x in 0…4 -- D[x] = C[x-1] xor rot(C[x+1],1), for x in 0…4 -- A[x,y] = A[x,y] xor D[x], for (x,y) in (0…4,0…4) theta :: State -> State theta state = [ [ ((state !! x) !! y) `xor` (d !! x) | y <- [0..4] ] | x <- [0..4] ] where c = [ foldl1 xor [ (state !! x) !! y | y <- [0..4] ] | x <- [0..4] ] d = [ c !! ((x - 1) `mod` 5) `xor` rotateL (c !! ((x + 1) `mod` 5)) 1 | x <- [0..4] ] -- # ρ and π steps -- B[y,2*x+3*y] = rot(A[x,y], r[x,y]), for (x,y) in (0…4,0…4) rhoPi :: State -> [[Word64]] rhoPi state = fmap (fmap rotFunc) [ [ ((x + 3 * y) `mod` 5, x) | y <- [0..4] ] | x <- [0..4] ] where rotFunc (x, y) = rotateL ((state !! x) !! y) ((rotationConstants !! x) !! y) -- # χ step -- A[x,y] = B[x,y] xor ((not B[x+1,y]) and B[x+2,y]), for (x,y) in (0…4,0…4) chi :: [[Word64]] -> State chi b = [ [ ((b !! x) !! y) `xor` (complement ((b !! ((x + 1) `mod` 5)) !! y) .&. ((b !! ((x + 2) `mod` 5)) !! y)) | y <- [0..4] ] | x <- [0..4] ] -- # ι step -- A[0,0] = A[0,0] xor RC -- TODO Data.List.Lens iota :: Int -> State -> State iota round ((first : rest) : restRows) = (xor (roundConstants !! round) first : rest) : restRows