module FPE.FF1 (encrypt, decrypt, BlockCipher, Crypter, Tweak) where
import Data.Bits
import Control.Arrow
import Control.Monad
import Data.Tuple (swap)
import Math.NumberTheory.Logarithms
import Data.Vector.Generic (Vector)
import qualified Data.Vector.Generic as V
import qualified Data.ByteString as S
import qualified Data.ByteString.Lazy as L
type BlockCipher = S.ByteString -> S.ByteString
type Tweak = S.ByteString
type Crypter v a = BlockCipher -> Int -> Tweak -> v a -> v a
bytesFor :: Int -> Int -> Int
bytesFor radix len =
integerLog2 ((fromIntegral radix ^ len) - 1) `div` 8 + 1
xorBytes :: S.ByteString -> S.ByteString -> S.ByteString
xorBytes a b = S.pack $ S.zipWith xor a b
vecToNum :: (Vector v a, Integral a) => Int -> v a -> Integer
vecToNum radix = V.foldl go 0 where
go val c = val * fromIntegral radix + fromIntegral c
numToVec :: (Vector v a, Integral a) => Int -> Int -> Integer -> v a
numToVec radix len num = V.reverse $ V.fromListN len $
map (fromIntegral . (`mod` radix_)) $ iterate (`div` radix_) num
where radix_ = fromIntegral radix
bytesToNum :: Integral a => S.ByteString -> a
bytesToNum = S.foldl (\val c -> val * 256 + fromIntegral c) 0
{-# SPECIALIZE bytesToNum :: S.ByteString -> Integer #-}
numToBytes :: Integral a => Int -> a -> S.ByteString
numToBytes len num = S.reverse $ S.pack $ map fromIntegral $
take len $ iterate (`div` 256) num
prf :: BlockCipher -> S.ByteString -> S.ByteString
prf cipher = loop (S.replicate 16 0) where
loop y src = if S.null rest then y' else loop y' rest where
(x, rest) = S.splitAt 16 src
y' = cipher $ x `xorBytes` y
extend :: BlockCipher -> Int -> S.ByteString -> S.ByteString
extend cipher len blk = L.toStrict $ L.take (fromIntegral len) $ L.fromChunks $
blk : [ cipher $ blk `xorBytes` numToBytes @Int 16 i | i <- [1..] ]
crypt :: (Vector v a, Integral a) => Bool -> Crypter v a
crypt isEncrypt cipher radix tweak msg =
numToVec radix u finalA V.++ numToVec radix v finalB where
t = S.length tweak
n = V.length msg; u = n `div` 2; v = n - u
b = bytesFor radix v
d = 4*((b-1)`div`4) + 8
rpow = (fromIntegral radix ^)
bP = S.concat [
S.pack [1, 2, 1], numToBytes 3 radix,
S.pack [10, fromIntegral u], numToBytes 4 n, numToBytes 4 t]
pfxQ = tweak <> S.pack (replicate ((-t-b-1)`mod`16) 0)
(numA0, numB0) = join (***) (vecToNum radix) $ V.splitAt u msg
loop (numA, numB) i = (numB, numC) where
y = bytesToNum $ extend cipher d $ prf cipher $ S.concat [
bP, pfxQ, S.singleton i, numToBytes b numB]
op = if isEncrypt then (+ y) else subtract y
numC = op numA `mod` (if even i then rpow u else rpow v)
wrap = if isEncrypt then id else swap
(finalA, finalB) = wrap $
foldl loop (wrap (numA0, numB0)) $
if isEncrypt then [0..9] else [9,8..0]
encrypt, decrypt :: (Vector v a, Integral a) => Crypter v a
encrypt = crypt True
decrypt = crypt False