{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Token (
Config,
interval,
tokenLifetime,
defaultConfig,
TokenManager,
spawnTokenManager,
killTokenManager,
encryptToken,
decryptToken,
) where
import Control.Concurrent
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types (AEADMode (..), AuthTag (..))
import qualified Crypto.Cipher.Types as C
import Crypto.Error (maybeCryptoError, throwCryptoError)
import Crypto.Random (getRandomBytes)
import Data.Array.IO
import Data.Bits (xor)
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BS
import qualified Data.IORef as I
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import Network.ByteOrder
type Index = Word16
type Counter = Word64
data Config = Config
{ Config -> BufferSize
interval :: Int
, Config -> BufferSize
tokenLifetime :: Int
}
deriving (Config -> Config -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c== :: Config -> Config -> Bool
Eq, BufferSize -> Config -> ShowS
[Config] -> ShowS
Config -> String
forall a.
(BufferSize -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Config] -> ShowS
$cshowList :: [Config] -> ShowS
show :: Config -> String
$cshow :: Config -> String
showsPrec :: BufferSize -> Config -> ShowS
$cshowsPrec :: BufferSize -> Config -> ShowS
Show)
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
Config
{ interval :: BufferSize
interval = BufferSize
1800
, tokenLifetime :: BufferSize
tokenLifetime = BufferSize
86400
}
data TokenManager = TokenManager
{ :: Header
, TokenManager -> IO (Secret, Index)
getEncryptSecret :: IO (Secret, Index)
, TokenManager -> Index -> IO Secret
getDecryptSecret :: Index -> IO Secret
, TokenManager -> ThreadId
threadId :: ThreadId
}
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager :: Config -> IO TokenManager
spawnTokenManager Config{BufferSize
tokenLifetime :: BufferSize
interval :: BufferSize
tokenLifetime :: Config -> BufferSize
interval :: Config -> BufferSize
..} = do
Secret
emp <- IO Secret
emptySecret
let lim :: Index
lim = forall a b. (Integral a, Num b) => a -> b
fromIntegral (BufferSize
tokenLifetime forall a. Integral a => a -> a -> a
`div` BufferSize
interval)
IOArray Index Secret
arr <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Index
0, Index
lim forall a. Num a => a -> a -> a
- Index
1) Secret
emp
Secret
ent <- IO Secret
generateSecret
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray IOArray Index Secret
arr Index
0 Secret
ent
IORef Index
ref <- forall a. a -> IO (IORef a)
I.newIORef Index
0
ThreadId
tid <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall {b}. IOArray Index Secret -> IORef Index -> IO b
loop IOArray Index Secret
arr IORef Index
ref
Header
msk <- IO Header
newHeaderMask
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Header
-> IO (Secret, Index)
-> (Index -> IO Secret)
-> ThreadId
-> TokenManager
TokenManager Header
msk (IOArray Index Secret -> IORef Index -> IO (Secret, Index)
readCurrentSecret IOArray Index Secret
arr IORef Index
ref) (IOArray Index Secret -> Index -> IO Secret
readSecret IOArray Index Secret
arr) ThreadId
tid
where
loop :: IOArray Index Secret -> IORef Index -> IO b
loop IOArray Index Secret
arr IORef Index
ref = do
BufferSize -> IO ()
threadDelay (BufferSize
interval forall a. Num a => a -> a -> a
* BufferSize
1000000)
IOArray Index Secret -> IORef Index -> IO ()
update IOArray Index Secret
arr IORef Index
ref
IOArray Index Secret -> IORef Index -> IO b
loop IOArray Index Secret
arr IORef Index
ref
update :: IOArray Index Secret -> I.IORef Index -> IO ()
update :: IOArray Index Secret -> IORef Index -> IO ()
update IOArray Index Secret
arr IORef Index
ref = do
Index
idx0 <- forall a. IORef a -> IO a
I.readIORef IORef Index
ref
(Index
_, Index
n) <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds IOArray Index Secret
arr
let idx :: Index
idx = (Index
idx0 forall a. Num a => a -> a -> a
+ Index
1) forall a. Integral a => a -> a -> a
`mod` (Index
n forall a. Num a => a -> a -> a
+ Index
1)
Secret
sec <- IO Secret
generateSecret
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray IOArray Index Secret
arr Index
idx Secret
sec
forall a. IORef a -> a -> IO ()
I.writeIORef IORef Index
ref Index
idx
killTokenManager :: TokenManager -> IO ()
killTokenManager :: TokenManager -> IO ()
killTokenManager TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
threadId :: ThreadId
getDecryptSecret :: Index -> IO Secret
getEncryptSecret :: IO (Secret, Index)
headerMask :: Header
threadId :: TokenManager -> ThreadId
getDecryptSecret :: TokenManager -> Index -> IO Secret
getEncryptSecret :: TokenManager -> IO (Secret, Index)
headerMask :: TokenManager -> Header
..} = ThreadId -> IO ()
killThread ThreadId
threadId
readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret :: IOArray Index Secret -> Index -> IO Secret
readSecret IOArray Index Secret
secrets Index
idx0 = do
(Index
_, Index
n) <- forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> m (i, i)
getBounds IOArray Index Secret
secrets
let idx :: Index
idx = Index
idx0 forall a. Integral a => a -> a -> a
`mod` (Index
n forall a. Num a => a -> a -> a
+ Index
1)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray IOArray Index Secret
secrets Index
idx
readCurrentSecret :: IOArray Index Secret -> I.IORef Index -> IO (Secret, Index)
readCurrentSecret :: IOArray Index Secret -> IORef Index -> IO (Secret, Index)
readCurrentSecret IOArray Index Secret
arr IORef Index
ref = do
Index
idx <- forall a. IORef a -> IO a
I.readIORef IORef Index
ref
Secret
sec <- IOArray Index Secret -> Index -> IO Secret
readSecret IOArray Index Secret
arr Index
idx
forall (m :: * -> *) a. Monad m => a -> m a
return (Secret
sec, Index
idx)
data Secret = Secret
{ Secret -> ByteString
secretIV :: ByteString
, Secret -> ByteString
secretKey :: ByteString
, Secret -> IORef Counter
secretCounter :: I.IORef Counter
}
emptySecret :: IO Secret
emptySecret :: IO Secret
emptySecret = ByteString -> ByteString -> IORef Counter -> Secret
Secret ByteString
BS.empty ByteString
BS.empty forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
I.newIORef Counter
0
generateSecret :: IO Secret
generateSecret :: IO Secret
generateSecret =
ByteString -> ByteString -> IORef Counter -> Secret
Secret
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ByteString
genIV
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO ByteString
genKey
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
I.newIORef Counter
0
genKey :: IO ByteString
genKey :: IO ByteString
genKey = forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes BufferSize
keyLength
genIV :: IO ByteString
genIV :: IO ByteString
genIV = forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes BufferSize
ivLength
ivLength :: Int
ivLength :: BufferSize
ivLength = BufferSize
8
keyLength :: Int
keyLength :: BufferSize
keyLength = BufferSize
32
indexLength :: Int
indexLength :: BufferSize
indexLength = BufferSize
2
counterLength :: Int
counterLength :: BufferSize
counterLength = BufferSize
8
tagLength :: Int
tagLength :: BufferSize
tagLength = BufferSize
16
data =
{ :: Index
, :: Counter
}
encodeHeader :: Header -> IO ByteString
Header{Index
Counter
headerCounter :: Counter
headerIndex :: Index
headerCounter :: Header -> Counter
headerIndex :: Header -> Index
..} = BufferSize -> (WriteBuffer -> IO ()) -> IO ByteString
withWriteBuffer (BufferSize
indexLength forall a. Num a => a -> a -> a
+ BufferSize
counterLength) forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> do
WriteBuffer -> Index -> IO ()
write16 WriteBuffer
wbuf Index
headerIndex
WriteBuffer -> Counter -> IO ()
write64 WriteBuffer
wbuf Counter
headerCounter
decodeHeader :: ByteString -> IO Header
ByteString
bs = forall a. ByteString -> (ReadBuffer -> IO a) -> IO a
withReadBuffer ByteString
bs forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf ->
Index -> Counter -> Header
Header forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Index
read16 ReadBuffer
rbuf forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Readable a => a -> IO Counter
read64 ReadBuffer
rbuf
newHeaderMask :: IO Header
= do
ByteString
bin <- forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
BufferSize -> m byteArray
getRandomBytes (BufferSize
indexLength forall a. Num a => a -> a -> a
+ BufferSize
counterLength) :: IO ByteString
ByteString -> IO Header
decodeHeader ByteString
bin
xorHeader :: Header -> Header -> Header
Header
x Header
y =
Header
{ headerIndex :: Index
headerIndex = Header -> Index
headerIndex Header
x forall a. Bits a => a -> a -> a
`xor` Header -> Index
headerIndex Header
y
, headerCounter :: Counter
headerCounter = Header -> Counter
headerCounter Header
x forall a. Bits a => a -> a -> a
`xor` Header -> Counter
headerCounter Header
y
}
addHeader :: TokenManager -> Index -> Counter -> ByteString -> IO ByteString
TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
threadId :: ThreadId
getDecryptSecret :: Index -> IO Secret
getEncryptSecret :: IO (Secret, Index)
headerMask :: Header
threadId :: TokenManager -> ThreadId
getDecryptSecret :: TokenManager -> Index -> IO Secret
getEncryptSecret :: TokenManager -> IO (Secret, Index)
headerMask :: TokenManager -> Header
..} Index
idx Counter
counter ByteString
cipher = do
let hdr :: Header
hdr = Index -> Counter -> Header
Header Index
idx Counter
counter
mskhdr :: Header
mskhdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
hdr
ByteString
hdrbin <- Header -> IO ByteString
encodeHeader Header
mskhdr
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
hdrbin ByteString -> ByteString -> ByteString
`BS.append` ByteString
cipher)
delHeader
:: TokenManager -> ByteString -> IO (Maybe (Index, Counter, ByteString))
TokenManager{IO (Secret, Index)
ThreadId
Header
Index -> IO Secret
threadId :: ThreadId
getDecryptSecret :: Index -> IO Secret
getEncryptSecret :: IO (Secret, Index)
headerMask :: Header
threadId :: TokenManager -> ThreadId
getDecryptSecret :: TokenManager -> Index -> IO Secret
getEncryptSecret :: TokenManager -> IO (Secret, Index)
headerMask :: TokenManager -> Header
..} ByteString
token
| ByteString -> BufferSize
BS.length ByteString
token forall a. Ord a => a -> a -> Bool
< BufferSize
minlen = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
| Bool
otherwise = do
let (ByteString
hdrbin, ByteString
cipher) = BufferSize -> ByteString -> (ByteString, ByteString)
BS.splitAt BufferSize
minlen ByteString
token
Header
mskhdr <- ByteString -> IO Header
decodeHeader ByteString
hdrbin
let hdr :: Header
hdr = Header
headerMask Header -> Header -> Header
`xorHeader` Header
mskhdr
idx :: Index
idx = Header -> Index
headerIndex Header
hdr
counter :: Counter
counter = Header -> Counter
headerCounter Header
hdr
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Index
idx, Counter
counter, ByteString
cipher)
where
minlen :: BufferSize
minlen = BufferSize
indexLength forall a. Num a => a -> a -> a
+ BufferSize
counterLength
encryptToken
:: TokenManager
-> ByteString
-> IO ByteString
encryptToken :: TokenManager -> ByteString -> IO ByteString
encryptToken TokenManager
mgr ByteString
x = do
(Secret
secret, Index
idx) <- TokenManager -> IO (Secret, Index)
getEncryptSecret TokenManager
mgr
(Counter
counter, ByteString
cipher) <- Secret -> ByteString -> IO (Counter, ByteString)
encrypt Secret
secret ByteString
x
TokenManager -> Index -> Counter -> ByteString -> IO ByteString
addHeader TokenManager
mgr Index
idx Counter
counter ByteString
cipher
encrypt
:: Secret
-> ByteString
-> IO (Counter, ByteString)
encrypt :: Secret -> ByteString -> IO (Counter, ByteString)
encrypt Secret
secret ByteString
plain = do
Counter
counter <- forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' (Secret -> IORef Counter
secretCounter Secret
secret) (\Counter
i -> (Counter
i forall a. Num a => a -> a -> a
+ Counter
1, Counter
i))
ByteString
nonce <- Counter -> ByteString -> IO ByteString
makeNonce Counter
counter forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
secretIV Secret
secret
let cipher :: ByteString
cipher = ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain (Secret -> ByteString
secretKey Secret
secret) ByteString
nonce
forall (m :: * -> *) a. Monad m => a -> m a
return (Counter
counter, ByteString
cipher)
decryptToken
:: TokenManager
-> ByteString
-> IO (Maybe ByteString)
decryptToken :: TokenManager -> ByteString -> IO (Maybe ByteString)
decryptToken TokenManager
mgr ByteString
token = do
Maybe (Index, Counter, ByteString)
mx <- TokenManager
-> ByteString -> IO (Maybe (Index, Counter, ByteString))
delHeader TokenManager
mgr ByteString
token
case Maybe (Index, Counter, ByteString)
mx of
Maybe (Index, Counter, ByteString)
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
Just (Index
idx, Counter
counter, ByteString
cipher) -> do
Secret
secret <- TokenManager -> Index -> IO Secret
getDecryptSecret TokenManager
mgr Index
idx
Secret -> Counter -> ByteString -> IO (Maybe ByteString)
decrypt Secret
secret Counter
counter ByteString
cipher
decrypt
:: Secret
-> Counter
-> ByteString
-> IO (Maybe ByteString)
decrypt :: Secret -> Counter -> ByteString -> IO (Maybe ByteString)
decrypt Secret
secret Counter
counter ByteString
cipher = do
ByteString
nonce <- Counter -> ByteString -> IO ByteString
makeNonce Counter
counter forall a b. (a -> b) -> a -> b
$ Secret -> ByteString
secretIV Secret
secret
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> ByteString -> Maybe ByteString
aes256gcmDecrypt ByteString
cipher (Secret -> ByteString
secretKey Secret
secret) ByteString
nonce
makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce :: Counter -> ByteString -> IO ByteString
makeNonce Counter
counter ByteString
iv = do
ByteString
cv <- BufferSize -> (Ptr Word8 -> IO ()) -> IO ByteString
BS.create BufferSize
ivLength forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr) Counter
counter
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString
iv forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
`BA.xor` ByteString
cv
constantAdditionalData :: ByteString
constantAdditionalData :: ByteString
constantAdditionalData = ByteString
BS.empty
aes256gcmEncrypt
:: ByteString
-> ByteString
-> ByteString
-> ByteString
aes256gcmEncrypt :: ByteString -> ByteString -> ByteString -> ByteString
aes256gcmEncrypt ByteString
plain ByteString
key ByteString
nonce = ByteString
cipher ByteString -> ByteString -> ByteString
`BS.append` (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Bytes
tag)
where
conn :: AES256
conn = forall a. CryptoFailable a -> a
throwCryptoError (forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key) :: AES256
aeadIni :: AEAD AES256
aeadIni = forall a. CryptoFailable a -> a
throwCryptoError forall a b. (a -> b) -> a -> b
$ forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
C.aeadInit AEADMode
AEAD_GCM AES256
conn ByteString
nonce
(AuthTag Bytes
tag, ByteString
cipher) = forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> BufferSize -> (AuthTag, ba)
C.aeadSimpleEncrypt AEAD AES256
aeadIni ByteString
constantAdditionalData ByteString
plain BufferSize
tagLength
aes256gcmDecrypt
:: ByteString
-> ByteString
-> ByteString
-> Maybe ByteString
aes256gcmDecrypt :: ByteString -> ByteString -> ByteString -> Maybe ByteString
aes256gcmDecrypt ByteString
ctexttag ByteString
key ByteString
nonce = do
AES256
aes <- forall a. CryptoFailable a -> Maybe a
maybeCryptoError forall a b. (a -> b) -> a -> b
$ forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
C.cipherInit ByteString
key :: Maybe AES256
AEAD AES256
aead <- forall a. CryptoFailable a -> Maybe a
maybeCryptoError forall a b. (a -> b) -> a -> b
$ forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
C.aeadInit AEADMode
AEAD_GCM AES256
aes ByteString
nonce
let (ByteString
ctext, ByteString
tag) = BufferSize -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> BufferSize
BS.length ByteString
ctexttag forall a. Num a => a -> a -> a
- BufferSize
tagLength) ByteString
ctexttag
authtag :: AuthTag
authtag = Bytes -> AuthTag
AuthTag forall a b. (a -> b) -> a -> b
$ forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
tag
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> AuthTag -> Maybe ba
C.aeadSimpleDecrypt AEAD AES256
aead ByteString
constantAdditionalData ByteString
ctext AuthTag
authtag