module Network.CommSec
(
OutContext(..)
, InContext(..)
, newInContext, newOutContext
, recv
, send
) where
import Prelude hiding (seq)
import qualified Crypto.Cipher.AES128.Internal as AES
import Crypto.Cipher.AES128.Internal (AESKey)
import Crypto.Cipher.AES128 ()
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Crypto.Classes (buildKey)
import Data.ByteString (ByteString)
import Data.Bits
import Data.Maybe (fromMaybe)
import Data.Word
import Data.List
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Storable
import Foreign.Marshal.Alloc (allocaBytes)
import System.IO.Unsafe
data OutContext =
Out { aesCtr :: !Word64
, saltOut :: !Word32
, outKey :: AESKey
}
data InContext =
In { base :: !Word64
, mask :: !Word64
, saltIn :: !Word32
, inKey :: AESKey
}
newOutContext :: ByteString -> OutContext
newOutContext bs
| B.length bs < 24 = error "Not enough entropy"
| otherwise =
let aesCtr = 0
saltOut = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr
outKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltOut) bs
in Out {..}
newInContext :: ByteString -> InContext
newInContext bs
| B.length bs < 24 = error "Not enough entropy"
| otherwise =
let base = 0
mask = 0
saltIn = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr
inKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltIn) bs
in In {..}
encryptGCM :: AESKey
-> Word64
-> Word32
-> ByteString
-> ByteString
encryptGCM key ctr salt pt = unsafePerformIO $ do
let ivLen = sizeOf ctr + sizeOf salt
tagLen = 16
paddedLen = B.length pt
allocaBytes ivLen $ \ptrIV -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
B.unsafeUseAsCString pt $ \ptrPT -> do
B.create (paddedLen + sizeOf ctr + tagLen) $ \ctPtr -> do
pokeBE ctPtr ctr
let tagPtr = ctPtr' `plusPtr` paddedLen
ctPtr' = ctPtr `plusPtr` sizeOf ctr
AES.encryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrPT) (B.length pt) (castPtr ctPtr') tagPtr
decryptGCM :: AESKey
-> Word64
-> Word32
-> ByteString
-> ByteString
-> Either String ByteString
decryptGCM key ctr salt ct tag
| B.length tag < 16 = Left "Tag too small"
| otherwise = unsafePerformIO $ do
let ivLen = sizeOf ctr + sizeOf salt
tagLen = 16
paddedLen = B.length ct
allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
B.unsafeUseAsCString tag $ \tagPtr -> do
B.unsafeUseAsCString ct $ \ptrCT -> do
pt <- B.create paddedLen $ \ptrPT -> do
AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrCT) (B.length ct) (castPtr ptrPT) ctagPtr
w1 <- peekBE ctagPtr
w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1)
y1 <- peekBE (castPtr tagPtr)
y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1)
if (w1 /= y1 || w2 /= y2)
then return (Left $ "Tags do not match: " ++ show (w1,w2,y1,y2))
else return (Right pt)
send :: OutContext -> ByteString -> (ByteString, OutContext)
send ctx@(Out {..}) pt =
let !iv_ct_tag = encryptGCM outKey aesCtr saltOut (pad pt)
in (iv_ct_tag, ctx { aesCtr = 1 + aesCtr })
recv :: InContext -> ByteString -> Either String (ByteString, InContext)
recv (In {..}) pkg
| base >= maxBound 64 = Left "This cipher context has been used too long."
| otherwise =
let cnt = unsafePerformIO $ B.unsafeUseAsCString pkg (peekBE . castPtr)
cntLen = sizeOf cnt
tagLen = 16
tag = B.drop (B.length pkg tagLen) pkg
ct = let st = (B.drop cntLen pkg) in B.take (B.length st tagLen) st
ptpd = decryptGCM inKey cnt saltIn ct tag
in case updateBaseMask base mask cnt of
Nothing -> Left "Dup!"
Just (base',mask') ->
case ptpd of
Left err -> Left err
Right ptPad ->
case unpad ptPad of
Nothing -> Left "Bad padding"
Just pt -> Right (pt, In base' mask' saltIn inKey)
pad :: ByteString -> ByteString
pad bs =
let pd = B.replicate pdLen pdValue
len = 16
r = len (B.length bs `rem` len)
pdLen = if r == 0 then len else r
pdValue = fromIntegral pdLen
in B.concat [bs,pd]
unpad :: ByteString -> Maybe ByteString
unpad bs
| len > 0 = Just $ B.take (len fromIntegral (B.last bs)) bs
| otherwise = Nothing
where
len = B.length bs
updateBaseMask :: Word64 -> Word64 -> Word64 -> Maybe (Word64,Word64)
updateBaseMask !base !mask !seq
| base <= seq && base >= seqBase =
let pos = fromIntegral $ seq base
in if testBit mask pos
then Nothing
else Just (base, setBit mask pos)
| base < seqBase = updateBaseMask seq (mask `shiftR` fromIntegral (seq base)) seq
| base > seq = Nothing
where
!seqBase | seq < 64 = 0
| otherwise = seq63
peekBE :: Ptr Word8 -> IO Word64
peekBE p = do
let op n = fromIntegral `fmap` peekElemOff p n
as <- mapM op [0..7]
return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as)
pokeBE :: Ptr Word8 -> Word64 -> IO ()
pokeBE p w = do
let op n = pokeElemOff p n (fromIntegral (w `shiftR` (56(8*n) :: Int)))
mapM_ op [0..7]
pokeBE32 :: Ptr Word8 -> Word32 -> IO ()
pokeBE32 p w = do
let op n = pokeElemOff p n (fromIntegral (w `shiftR` (24 (8*n) :: Int)))
mapM_ op [0..3]
peekBE32 :: Ptr Word8 -> IO Word32
peekBE32 p = do
let op n = fromIntegral `fmap` peekElemOff p n
as <- mapM op [0..3]
return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as)