module Crypto.Random
( CryptoRandomGen(..)
, genInteger
, GenError (..)
, newGenIO
) where
import System.Crypto.Random (getEntropy)
import Crypto.Types
import Control.Monad (liftM)
import qualified Data.ByteString as B
import Data.Tagged
import Data.Bits (xor, setBit, shiftR, shiftL, (.&.))
import Data.List (foldl')
data GenError =
GenErrorOther String
| RequestedTooManyBytes
| RangeInvalid
| NeedReseed
| NotEnoughEntropy
deriving (Eq, Ord, Show)
#if !MIN_VERSION_base(4,3,0)
instance Monad (Either GenError) where
return = Right
(Left x) >>= _ = Left x
(Right x) >>= f = f x
#endif
class CryptoRandomGen g where
newGen :: B.ByteString -> Either GenError g
genSeedLength :: Tagged g ByteLength
genBytes :: g -> ByteLength -> Either GenError (B.ByteString, g)
genBytesWithEntropy :: g -> ByteLength -> B.ByteString -> Either GenError (B.ByteString, g)
genBytesWithEntropy g len entropy =
let res = genBytes g len
in case res of
Left err -> Left err
Right (bs,g') ->
let entropy' = B.append entropy (B.replicate (len B.length entropy) 0)
in Right (zwp' entropy' bs, g')
reseed :: g -> B.ByteString -> Either GenError g
newGenIO :: CryptoRandomGen g => IO g
newGenIO = do
let r = undefined
l = genSeedLength `for` r
res <- liftM newGen (getEntropy l)
case res of
Left _ -> newGenIO
Right g -> return (g `asTypeOf` r)
for :: Tagged a b -> a -> b
for t _ = unTagged t
genInteger :: CryptoRandomGen g => g -> (Integer, Integer) -> Either GenError (Integer, g)
genInteger g (low,high)
| high < low = genInteger g (high,low)
| high == low = Right (high, g)
| otherwise = go g
where
mask = foldl' setBit 0 [0 .. fromIntegral nrBits 1]
nrBits = base2Log range
range = high low
nrBytes = (nrBits + 7) `div` 8
go gen =
let offset = genBytes gen (fromIntegral nrBytes)
in case offset of
Left err -> Left err
Right (bs,g') ->
if nrBytes > fromIntegral (maxBound :: Int)
then Left RangeInvalid
else let res = low + (bs2i bs .&. mask)
in if res > high then go g' else Right (res, g')
base2Log :: Integer -> Integer
base2Log i
| i >= setBit 0 64 = 64 + base2Log (i `shiftR` 64)
| i >= setBit 0 32 = 32 + base2Log (i `shiftR` 32)
| i >= setBit 0 16 = 16 + base2Log (i `shiftR` 16)
| i >= setBit 0 8 = 8 + base2Log (i `shiftR` 8)
| i >= setBit 0 0 = 1 + base2Log (i `shiftR` 1)
| otherwise = 0
bs2i :: B.ByteString -> Integer
bs2i bs = B.foldl' (\i b -> (i `shiftL` 8) + fromIntegral b) 0 bs
zwp' a = B.pack . B.zipWith xor a