module Network.CommSec.Package
(
OutContext(..)
, InContext(..)
, CommSecError(..)
, SequenceMode(..)
, newInContext, newOutContext
, decode
, encode
, decodePtr
, encodePtr
, encBytes, decBytes
, peekBE32
, pokeBE32
, peekBE
, pokeBE
) where
import Prelude hiding (seq)
import qualified Crypto.Cipher.AES128.Internal as AES
import Crypto.Cipher.AES128.Internal (AESKey)
import Crypto.Cipher.AES128 ()
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import Crypto.Classes (buildKey)
import Data.ByteString (ByteString)
import Data.Bits
import Data.Maybe (fromMaybe)
import Data.Word
import Data.List
import Foreign.Ptr
import Foreign.ForeignPtr
import Foreign.Storable
import Foreign.Marshal.Alloc (allocaBytes)
import Foreign.Marshal.Utils (copyBytes)
import System.IO.Unsafe
import Data.Data
import Data.Typeable
import Control.Exception
import Network.CommSec.Types
import Network.CommSec.BitWindow
gTagLen,gCtrSize :: Int
gTagLen = 16
gCtrSize = 8
data OutContext =
Out { aesCtr :: !Word64
, saltOut :: !Word32
, outKey :: AESKey
}
data InContext
= In { bitWindow :: !BitWindow
, saltIn :: !Word32
, inKey :: AESKey
}
| InStrict
{ seqVal :: !Word64
, saltIn :: !Word32
, inKey :: AESKey
}
| InSequential
{ seqVal :: !Word64
, saltIn :: !Word32
, inKey :: AESKey
}
newOutContext :: ByteString -> OutContext
newOutContext bs
| B.length bs < 20 = error $ "Not enough entropy: " ++ show (B.length bs)
| otherwise =
let aesCtr = 1
saltOut = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr
outKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltOut) bs
in Out {..}
newInContext :: ByteString -> SequenceMode -> InContext
newInContext bs md
| B.length bs < 20 = error $ "Not enough entropy: " ++ show (B.length bs)
| otherwise =
let bitWindow = zeroWindow
seqVal = 0
saltIn = unsafePerformIO $ B.unsafeUseAsCString bs $ peekBE32 . castPtr
inKey = fromMaybe (error "Could not build a key") $ buildKey $ B.drop (sizeOf saltIn) bs
in case md of
AllowOutOfOrder -> In {..}
StrictOrdering -> InStrict {..}
Sequential -> InSequential {..}
encryptGCM :: AESKey
-> Word64
-> Word32
-> ByteString
-> ByteString
encryptGCM key ctr salt pt = unsafePerformIO $ do
B.unsafeUseAsCString pt $ \ptPtr -> do
B.create (encBytes (B.length pt)) $ \ctPtr -> do
encryptGCMPtr key ctr salt (castPtr ptPtr) (B.length pt) (castPtr ctPtr)
encryptGCMPtr :: AESKey
-> Word64
-> Word32
-> Ptr Word8
-> Int
-> Ptr Word8
-> IO ()
encryptGCMPtr key ctr salt ptPtr ptLen ctPtr = do
let ivLen = sizeOf ctr + sizeOf salt
tagLen = gTagLen
allocaBytes ivLen $ \ptrIV -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
pokeBE ctPtr ctr
let tagPtr = ctPtr' `plusPtr` ptLen
ctPtr' = ctPtr `plusPtr` sizeOf ctr
AES.encryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptPtr) ptLen (castPtr ctPtr') tagPtr
decryptGCMPtr :: AESKey
-> Word64
-> Word32
-> Ptr Word8
-> Int
-> Ptr Word8
-> Int
-> Ptr Word8
-> IO (Either CommSecError ())
decryptGCMPtr key ctr salt ctPtr ctLen tagPtr tagLen ptPtr
| tagLen /= gTagLen = return $ Left InvalidICV
| otherwise = do
let ivLen = sizeOf ctr + sizeOf salt
paddedLen = ctLen
allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ctPtr) paddedLen (castPtr ptPtr) ctagPtr
w1 <- peekBE ctagPtr
w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1)
y1 <- peekBE (castPtr tagPtr)
y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1)
if (w1 /= y1 || w2 /= y2)
then return (Left InvalidICV)
else return (Right ())
decryptGCM :: AESKey
-> Word64
-> Word32
-> ByteString
-> ByteString
-> Either CommSecError ByteString
decryptGCM key ctr salt ct tag
| B.length tag < gTagLen = Left InvalidICV
| otherwise = unsafePerformIO $ do
let ivLen = sizeOf ctr + sizeOf salt
tagLen = gTagLen
paddedLen = B.length ct
allocaBytes ivLen $ \ptrIV -> allocaBytes tagLen $ \ctagPtr -> do
pokeBE32 ptrIV salt
pokeBE (ptrIV `plusPtr` sizeOf salt) ctr
B.unsafeUseAsCString tag $ \tagPtr -> do
B.unsafeUseAsCString ct $ \ptrCT -> do
pt <- B.create paddedLen $ \ptrPT -> do
AES.decryptGCM key ptrIV ivLen nullPtr 0 (castPtr ptrCT) (B.length ct) (castPtr ptrPT) ctagPtr
w1 <- peekBE ctagPtr
w2 <- peekBE (ctagPtr `plusPtr` sizeOf w1)
y1 <- peekBE (castPtr tagPtr)
y2 <- peekBE (castPtr tagPtr `plusPtr` sizeOf y1)
if (w1 /= y1 || w2 /= y2)
then return (Left InvalidICV)
else return (Right pt)
encode :: OutContext -> ByteString -> (ByteString, OutContext)
encode ctx@(Out {..}) pt
| aesCtr == maxBound = throw OldContext
| otherwise =
let !iv_ct_tag = encryptGCM outKey aesCtr saltOut pt
in (iv_ct_tag, ctx { aesCtr = 1 + aesCtr })
encBytes :: Int -> Int
encBytes lenMsg =
let tagLen = gTagLen
ctrLen = gCtrSize
in ctrLen + lenMsg + tagLen
decBytes :: Int -> Int
decBytes lenPkg =
let tagLen = gTagLen
ctrLen = gCtrSize
in lenPkg tagLen ctrLen
encodePtr :: OutContext -> Ptr Word8 -> Ptr Word8 -> Int -> IO OutContext
encodePtr ctx@(Out {..}) ptPtr pkgPtr ptLen
| aesCtr == maxBound = throw OldContext
| otherwise = do
encryptGCMPtr outKey aesCtr saltOut ptPtr ptLen pkgPtr
return (ctx { aesCtr = 1 + aesCtr })
decodePtr :: InContext -> Ptr Word8 -> Ptr Word8 -> Int -> IO (Either CommSecError (Int,InContext))
decodePtr ctx pkgPtr msgPtr pkgLen = do
cnt <- peekBE pkgPtr
let !ctPtr = pkgPtr `plusPtr` sizeOf cnt
!ctLen = pkgLen tagLen sizeOf cnt
!tagPtr = pkgPtr `plusPtr` (pkgLen tagLen)
tagLen = gTagLen
r <- decryptGCMPtr (inKey ctx) cnt (saltIn ctx) ctPtr ctLen tagPtr tagLen msgPtr
case r of
Left err -> return (Left err)
Right () -> fmap (ctLen,) `fmap` helper ctx cnt
where
helper :: InContext -> Word64
-> IO (Either CommSecError InContext)
helper (InStrict {..}) cnt
| cnt > seqVal = return $ Right (InStrict cnt saltIn inKey)
| otherwise = return (Left DuplicateSeq)
helper (InSequential {..}) cnt
| cnt == seqVal + 1 = return $ Right (InSequential cnt saltIn inKey)
| otherwise = return (Left DuplicateSeq)
helper (In {..}) cnt = do
case updateBitWindow bitWindow cnt of
Left e -> return (Left e)
Right newMask -> return $ Right (In newMask saltIn inKey)
decode :: InContext -> ByteString -> Either CommSecError (ByteString, InContext)
decode ctx pkg = unsafePerformIO $ do
let ptLen = decBytes (B.length pkg)
pt <- B.mallocByteString ptLen
r <- withForeignPtr pt $ \ptPtr -> do
B.unsafeUseAsCString pkg $ \pkgPtr -> do
decodePtr ctx (castPtr pkgPtr) (castPtr ptPtr) (B.length pkg)
case r of
Left e -> return (Left e)
Right (_,c) -> return (Right (B.fromForeignPtr pt 0 ptLen,c))
peekBE :: Ptr Word8 -> IO Word64
peekBE p = do
let op n = fromIntegral `fmap` peekElemOff p n
as <- mapM op [0..7]
return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as)
pokeBE :: Ptr Word8 -> Word64 -> IO ()
pokeBE p w = do
let op n = pokeElemOff p n (fromIntegral (w `shiftR` (56(8*n) :: Int)))
mapM_ op [0..7]
pokeBE32 :: Ptr Word8 -> Word32 -> IO ()
pokeBE32 p w = do
let op n = pokeElemOff p n (fromIntegral (w `shiftR` (24 (8*n) :: Int)))
mapM_ op [0..3]
peekBE32 :: Ptr Word8 -> IO Word32
peekBE32 p = do
let op n = fromIntegral `fmap` peekElemOff p n
as <- mapM op [0..3]
return (foldl1' (\r a -> (r `shiftL` 8) .|. a) as)