module Web.ClientSession
(
getKey
, getDefaultKey
, encrypt
, decrypt
, IsByteString (..)
) where
import Codec.Encryption.AES (AESKey)
import qualified Data.ByteString as BS
import qualified Data.ByteString.UTF8 as BSU
import Data.LargeWord (Word256)
import Codec.Utils (listFromOctets, listToOctets)
import Data.Word (Word8)
import System.Random (getStdGen, randoms, Random, randomR, random)
import qualified Data.ByteString as BS
import qualified Codec.Encryption.AES as AES
import qualified Codec.Binary.Base64Url as Base64
import qualified Data.Digest.MD5 as MD5
class IsByteString a where
toByteString :: a -> BS.ByteString
fromByteString :: BS.ByteString -> a
instance IsByteString BS.ByteString where
toByteString = id
fromByteString = id
instance IsByteString String where
toByteString = BSU.fromString
fromByteString = BSU.toString
getDefaultKey :: IO Word256
getDefaultKey = getKey "client_session_key.aes"
getKey :: FilePath
-> IO Word256
getKey keyFile = catch loadKeyFromFile $ const generateNewKey where
loadKeyFromFile :: IO Word256
loadKeyFromFile = do
contents <- BS.readFile keyFile
if BS.length contents < 32
then fail "Key too small"
else return $ head $ listFromOctets $ BS.unpack contents
generateNewKey :: IO Word256
generateNewKey = do
stdGen <- getStdGen
let word8s = take 32 $ randoms stdGen
let newKey = head $ listFromOctets word8s
BS.writeFile keyFile $ BS.pack word8s
return newKey
instance Random Word8 where
randomR (a,b) g =
let (x, g') = randomR (toInteger a, toInteger b) g
in (toEnum $ fromEnum $ mod x 256, g')
random = randomR (minBound,maxBound)
encrypt :: (IsByteString b, AES.AESKey k)
=> k
-> b
-> String
encrypt k x =
let unpacked = BS.unpack $ toByteString x
in Base64.encode . listToOctets . map (AES.encrypt k) .
listFromOctets $ MD5.hash unpacked ++ unpacked
liftMaybe :: Monad m => Maybe a -> m a
liftMaybe Nothing = fail "Nothing"
liftMaybe (Just x) = return x
decrypt :: (AES.AESKey k, Monad m, IsByteString b)
=> k
-> String
-> m b
decrypt k x = do
decoded <- liftMaybe $ Base64.decode x
let decrypted = listToOctets $ map (AES.decrypt k)
$ listFromOctets decoded
let (hash, rest) = splitAt 16 decrypted
if hash == MD5.hash rest
then return $ fromByteString $ BS.pack rest
else fail "Invalid"