module Network.QUIC.Packet.Token (
    CryptoToken(..)
  , isRetryToken
  , generateToken
  , generateRetryToken
  , encryptToken
  , decryptToken
  ) where

import qualified Crypto.Token as CT
import Data.UnixTime
import Foreign.C.Types
import Foreign.Ptr
import Foreign.Storable
import Network.ByteOrder

import Network.QUIC.Imports
import Network.QUIC.Types

----------------------------------------------------------------

data CryptoToken = CryptoToken {
    CryptoToken -> Version
tokenQUICVersion :: Version
  , CryptoToken -> TimeMicrosecond
tokenCreatedTime :: TimeMicrosecond
  , CryptoToken -> Maybe (CID, CID, CID)
tokenCIDs        :: Maybe (CID, CID, CID) -- local, remote, orig local
  }

isRetryToken :: CryptoToken -> Bool
isRetryToken :: CryptoToken -> Bool
isRetryToken CryptoToken
token = Maybe (CID, CID, CID) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (CID, CID, CID) -> Bool) -> Maybe (CID, CID, CID) -> Bool
forall a b. (a -> b) -> a -> b
$ CryptoToken -> Maybe (CID, CID, CID)
tokenCIDs CryptoToken
token

----------------------------------------------------------------

generateToken :: Version -> IO CryptoToken
generateToken :: Version -> IO CryptoToken
generateToken Version
ver = do
    TimeMicrosecond
t <- IO TimeMicrosecond
getTimeMicrosecond
    CryptoToken -> IO CryptoToken
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptoToken -> IO CryptoToken) -> CryptoToken -> IO CryptoToken
forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
t Maybe (CID, CID, CID)
forall a. Maybe a
Nothing

generateRetryToken :: Version -> CID -> CID -> CID -> IO CryptoToken
generateRetryToken :: Version -> CID -> CID -> CID -> IO CryptoToken
generateRetryToken Version
ver CID
l CID
r CID
o = do
    TimeMicrosecond
t <- IO TimeMicrosecond
getTimeMicrosecond
    CryptoToken -> IO CryptoToken
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptoToken -> IO CryptoToken) -> CryptoToken -> IO CryptoToken
forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
t (Maybe (CID, CID, CID) -> CryptoToken)
-> Maybe (CID, CID, CID) -> CryptoToken
forall a b. (a -> b) -> a -> b
$ (CID, CID, CID) -> Maybe (CID, CID, CID)
forall a. a -> Maybe a
Just (CID
l,CID
r,CID
o)

----------------------------------------------------------------

encryptToken :: CT.TokenManager -> CryptoToken -> IO Token
encryptToken :: TokenManager -> CryptoToken -> IO Token
encryptToken = TokenManager -> CryptoToken -> IO Token
forall a ba.
(Storable a, ByteArray ba) =>
TokenManager -> a -> IO ba
CT.encryptToken

decryptToken :: CT.TokenManager -> Token -> IO (Maybe CryptoToken)
decryptToken :: TokenManager -> Token -> IO (Maybe CryptoToken)
decryptToken = TokenManager -> Token -> IO (Maybe CryptoToken)
forall a ba.
(Storable a, ByteArray ba) =>
TokenManager -> ba -> IO (Maybe a)
CT.decryptToken

----------------------------------------------------------------

cryptoTokenSize :: Int
cryptoTokenSize :: Int
cryptoTokenSize = Int
76 -- 4 + 8 + 1 + (1 + 20) * 3

-- length includes its field
instance Storable CryptoToken where
    sizeOf :: CryptoToken -> Int
sizeOf    ~CryptoToken
_ = Int
cryptoTokenSize
    alignment :: CryptoToken -> Int
alignment ~CryptoToken
_ = Int
4
    peek :: Ptr CryptoToken -> IO CryptoToken
peek Ptr CryptoToken
ptr = do
        ReadBuffer
rbuf <- Buffer -> Int -> IO ReadBuffer
newReadBuffer (Ptr CryptoToken -> Buffer
forall a b. Ptr a -> Ptr b
castPtr Ptr CryptoToken
ptr) Int
cryptoTokenSize
        Version
ver  <- Word32 -> Version
Version (Word32 -> Version) -> IO Word32 -> IO Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word32
forall a. Readable a => a -> IO Word32
read32 ReadBuffer
rbuf
        CTime
s <- Int64 -> CTime
CTime (Int64 -> CTime) -> (Word64 -> Int64) -> Word64 -> CTime
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> CTime) -> IO Word64 -> IO CTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadBuffer -> IO Word64
forall a. Readable a => a -> IO Word64
read64 ReadBuffer
rbuf
        let tim :: TimeMicrosecond
tim = CTime -> Int32 -> TimeMicrosecond
UnixTime CTime
s Int32
0
        Word8
typ <- ReadBuffer -> IO Word8
forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
        case Word8
typ of
          Word8
0 -> CryptoToken -> IO CryptoToken
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptoToken -> IO CryptoToken) -> CryptoToken -> IO CryptoToken
forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
tim Maybe (CID, CID, CID)
forall a. Maybe a
Nothing
          Word8
_ -> do
              CID
l <- ReadBuffer -> IO CID
forall a. Readable a => a -> IO CID
pick ReadBuffer
rbuf
              CID
r <- ReadBuffer -> IO CID
forall a. Readable a => a -> IO CID
pick ReadBuffer
rbuf
              CID
o <- ReadBuffer -> IO CID
forall a. Readable a => a -> IO CID
pick ReadBuffer
rbuf
              CryptoToken -> IO CryptoToken
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptoToken -> IO CryptoToken) -> CryptoToken -> IO CryptoToken
forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
tim (Maybe (CID, CID, CID) -> CryptoToken)
-> Maybe (CID, CID, CID) -> CryptoToken
forall a b. (a -> b) -> a -> b
$ (CID, CID, CID) -> Maybe (CID, CID, CID)
forall a. a -> Maybe a
Just (CID
l,CID
r,CID
o)
      where
        pick :: a -> IO CID
pick a
rbuf = do
            Int
xlen0 <- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Int) -> IO Word8 -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> IO Word8
forall a. Readable a => a -> IO Word8
read8 a
rbuf
            let xlen :: Int
xlen = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
xlen0 Int
20
            CID
x <- ShortByteString -> CID
makeCID (ShortByteString -> CID) -> IO ShortByteString -> IO CID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> Int -> IO ShortByteString
forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString a
rbuf Int
xlen
            a -> Int -> IO ()
forall a. Readable a => a -> Int -> IO ()
ff a
rbuf (Int
20 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
xlen)
            CID -> IO CID
forall (m :: * -> *) a. Monad m => a -> m a
return CID
x
    poke :: Ptr CryptoToken -> CryptoToken -> IO ()
poke Ptr CryptoToken
ptr (CryptoToken (Version Word32
ver) TimeMicrosecond
tim Maybe (CID, CID, CID)
mcids) = do
        WriteBuffer
wbuf <- Buffer -> Int -> IO WriteBuffer
newWriteBuffer (Ptr CryptoToken -> Buffer
forall a b. Ptr a -> Ptr b
castPtr Ptr CryptoToken
ptr) Int
cryptoTokenSize
        WriteBuffer -> Word32 -> IO ()
write32 WriteBuffer
wbuf Word32
ver
        let CTime Int64
s = TimeMicrosecond -> CTime
utSeconds TimeMicrosecond
tim
        WriteBuffer -> Word64 -> IO ()
write64 WriteBuffer
wbuf (Word64 -> IO ()) -> Word64 -> IO ()
forall a b. (a -> b) -> a -> b
$ Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
s
        case Maybe (CID, CID, CID)
mcids of
          Maybe (CID, CID, CID)
Nothing      -> WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf Word8
0
          Just (CID
l,CID
r,CID
o) -> do
              WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf Word8
1
              WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
l
              WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
r
              WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
o
      where
        bury :: WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
x = do
            let (ShortByteString
xcid, Word8
xlen) = CID -> (ShortByteString, Word8)
unpackCID CID
x
            WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf Word8
xlen
            WriteBuffer -> ShortByteString -> IO ()
copyShortByteString WriteBuffer
wbuf ShortByteString
xcid
            WriteBuffer -> Int -> IO ()
forall a. Readable a => a -> Int -> IO ()
ff WriteBuffer
wbuf (Int
20 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
xlen)