{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE Trustworthy #-}

module Crypto.RandomMonad (RndT, RndST, RndIO, Rnd, RndState, getRandomM, getRandom2M, runRndT, newRandomElementST, getRandomElement, randomElementsLength, replaceSeedM, addSeedM, getRandomByteStringM, RandomElementsListST, RndStateList(..), BitStringToRandomExceptions(..)) where

import Control.Exception (Exception, throw)
import Control.Monad.Identity (Identity)
import Control.Monad.Primitive (PrimMonad, PrimState, primitive)
import Control.Monad.ST (ST)
import Control.Monad.Trans.Class (MonadTrans, lift)
import Control.Monad.Trans.State.Lazy (StateT, put, state, runStateT)
import Control.Parallel.Strategies (parMap, rpar, using, rseq)
import Data.Bits (shiftL, (.|.), xor)
import Data.Typeable (Typeable)
import Data.STRef (STRef, newSTRef, readSTRef, writeSTRef)

import qualified Data.BitString as BS
import qualified Data.ByteString.Lazy as ByS
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as VM

data BitStringToRandomExceptions = OutOfElementsException deriving (Show, Typeable)
data RandomElementsListST s a = RandomElementsListST (STRef s (V.Vector a))
instance Exception BitStringToRandomExceptions

bitsNeeded :: Integer -> Integer
bitsNeeded x = (+) 1 $ floor $ logBase 2 (fromIntegral x)

convertBitStringToInteger = BS.foldl' convert' 0
  where
  convert' :: Integer -> Bool -> Integer
  convert' prev cur = (shiftL prev 1) .|. (case cur of True -> 1 ; False -> 0)

multipleBitstringsSplitAt i (RndStateListSequencial [x]) = let (takers, droppers) = BS.splitAt i x in ([takers], RndStateListSequencial [droppers])
multipleBitstringsSplitAt i (RndStateListSequencial x) = join' (split' x) [] []
 where
 split' = map $ BS.splitAt i
 join' [] takers droppers = (takers, RndStateListSequencial droppers)
 join' (x:xs) takers droppers = let (newTake, newDrop) = x in join' xs (newTake:takers) (newDrop:droppers)
multipleBitstringsSplitAt i (RndStateListParallel [x]) = let (takers, droppers) = BS.splitAt i x in ([takers], RndStateListParallel [droppers])
multipleBitstringsSplitAt i (RndStateListParallel x) = join' (split' x) [] []
 where
 split' = parMap rpar (\bs -> let (take,drop) = BS.splitAt i bs in (take `using` rseq, drop))
 join' [] takers droppers = (takers, RndStateListParallel droppers)
 join' (x:xs) takers droppers = let (newTake, newDrop) = x in join' xs (newTake:takers) (newDrop:droppers)

multipleBitstringsAssertLength _ [] = False
multipleBitstringsAssertLength len x = len' x
  where
  len' [] = True
  len' (x:xs) = if (BS.length x) == len
                   then len' xs
                   else False


getRandom :: Integer -> RndStateList -> (Integer, RndStateList)
getRandom 0 x = (0, x)
getRandom max string = if has_error
                          then error "There was an error acquiring random data"
                          else if random <= max
                                  then (random, unused)
                                  else getRandom max unused
  where
    bitsNeeded' = bitsNeeded max
    has_error = not $ multipleBitstringsAssertLength (fromInteger bitsNeeded') used
    random = foldl (\i cur -> xor i $ convertBitStringToInteger cur) 0 used
    (used, unused) = multipleBitstringsSplitAt (fromIntegral bitsNeeded') string

getRandom2 :: Integer -> Integer -> RndStateList -> (Integer, RndStateList)
getRandom2 a b string = getRandom2' (getRandom (max' - min') string)
  where
  min' = min a b
  max' = max a b
  getRandom2' (random, unused) = (random + min', unused)

getRandomByteString :: Integer -> RndStateList -> (ByS.ByteString, RndStateList)
getRandomByteString 0 x = (ByS.pack [], x)
getRandomByteString len x = let (byte, newState) = getRandom 255 x ; (allBytes, lastState) = getRandomByteString (len - 1) newState in (ByS.cons (fromIntegral byte) allBytes, lastState)

newRandomElementST :: VM.Unbox a => [a] -> ST s (RandomElementsListST s a)
newRandomElementST acc = (newSTRef $ V.fromList acc) >>= \ref -> return $ RandomElementsListST ref

getRandomElement :: VM.Unbox a => (RandomElementsListST s a) -> RndST s a
getRandomElement (RandomElementsListST ref) = do
  vec <- lift $ readSTRef ref
  vec' <- lift $ V.unsafeThaw vec
  let n = toInteger $ VM.length vec'
  j <- if n > 0 then getRandomM $ n - 1
                else throw OutOfElementsException
  let j' = fromInteger j
  aa <- lift $ VM.read vec' 0
  ab <- lift $ VM.read vec' j'
  lift $ VM.write vec' j' aa
  vec'' <- lift $ V.unsafeFreeze vec'
  lift $ writeSTRef ref $ V.unsafeTail vec''
  return ab

randomElementsLength :: (VM.Unbox a) => RandomElementsListST s a -> RndST s Int
randomElementsLength (RandomElementsListST ref) = do
  vec <- lift $ readSTRef ref
  return $ V.length vec

type RndStatePrimitive = [BS.BitString]
data RndStateList = RndStateListSequencial RndStatePrimitive
                  | RndStateListParallel RndStatePrimitive

type RndState = RndStateList
newtype RndT m a = RndT
  { unRndT :: StateT RndState m a }
  deriving (Functor, Applicative, Monad, MonadTrans)

instance PrimMonad m => PrimMonad (RndT m) where
  type PrimState (RndT m) = PrimState m
  primitive = lift . primitive
  {-# INLINE primitive #-}

type RndST s a = RndT (ST s) a
type RndIO a = RndT IO a
type Rnd a = RndT Identity a

replaceSeedM :: Monad m => RndState -> RndT m ()
replaceSeedM s = RndT $ put s

addSeedM :: Monad m => RndStatePrimitive -> RndT m ()
addSeedM s = RndT $ state $ addSeedM s
  where
  addSeedM x (RndStateListSequencial y) = ((),RndStateListSequencial (x ++ y))
  addSeedM x (RndStateListParallel y) = ((),RndStateListParallel (x ++ y))


getRandomM :: Monad m => Integer -> RndT m Integer
getRandomM x = RndT $ state $ getRandom x

getRandom2M :: Monad m => Integer -> Integer -> RndT m Integer
getRandom2M x y = RndT $ state $ getRandom2 x y

getRandomByteStringM :: Monad m => Integer -> RndT m ByS.ByteString
getRandomByteStringM x = RndT $ state $ getRandomByteString x

runRndT :: RndState -> RndT m a -> m (a, RndState)
runRndT rnd m = runStateT (unRndT m) rnd