module Control.Monad.Random (
    module System.Random,
    module Control.Monad.Random.Class,
    evalRandT,
    runRandT,
    evalRand,
    runRand,
    evalRandIO,
    fromList,
    uniform,
    Rand, RandT, 
    
    liftRand,
    liftRandT
    
    
    ) where
import           Control.Applicative
import           Control.Arrow
import           Control.Monad                ()
import           Control.Monad.Cont
import           Control.Monad.Error
import           Control.Monad.Identity
import           Control.Monad.Random.Class
import           Control.Monad.Reader
import qualified Control.Monad.RWS.Lazy       as RWSL
import qualified Control.Monad.RWS.Strict     as RWSS
import           Control.Monad.State
import qualified Control.Monad.State.Lazy     as SL
import qualified Control.Monad.State.Strict   as SS
import           Control.Monad.Trans          ()
import           Control.Monad.Trans.Except
import           Control.Monad.Trans.Identity
import           Control.Monad.Trans.Maybe
import           Control.Monad.Writer.Class
import qualified Control.Monad.Writer.Lazy    as WL
import qualified Control.Monad.Writer.Strict  as WS
import           Data.Monoid                  (Monoid)
import           System.Random
newtype RandT g m a = RandT (StateT g m a)
    deriving (Functor, Monad, MonadPlus, MonadTrans, MonadIO, MonadFix, MonadReader r, MonadWriter w)
instance (Functor m, Monad m) => Applicative (RandT g m) where
  pure = return
  (<*>) = ap
instance (Functor m, MonadPlus m) => Alternative (RandT g m) where
  empty = mzero
  (<|>) = mplus
liftRandT :: (g -> m (a, g)) 
             -> RandT g m a
liftRandT = RandT . StateT
liftRand :: (g -> (a, g)) 
            -> Rand g a
liftRand = RandT . state
instance (Monad m, RandomGen g) => MonadRandom (RandT g m) where
    getRandom = RandT . state $ random
    getRandoms = RandT . state $ first randoms . split
    getRandomR (x,y) = RandT . state $ randomR (x,y)
    getRandomRs (x,y) = RandT . state $
                            first (randomRs (x,y)) . split
instance (Monad m, RandomGen g) => MonadSplit g (RandT g m) where
    getSplit = RandT . state $ split
evalRandT :: (Monad m) => RandT g m a -> g -> m a
evalRandT (RandT x) g = evalStateT x g
runRandT  :: RandT g m a -> g -> m (a, g)
runRandT (RandT x) g = runStateT x g
type Rand g = RandT g Identity
evalRand :: Rand g a -> g -> a
evalRand x g = runIdentity (evalRandT x g)
runRand :: Rand g a -> g -> (a, g)
runRand x g = runIdentity (runRandT x g)
evalRandIO :: Rand StdGen a -> IO a
evalRandIO x = fmap (evalRand x) newStdGen
fromList :: (MonadRandom m) => [(a,Rational)] -> m a
fromList [] = error "MonadRandom.fromList called with empty list"
fromList [(x,_)] = return x
fromList xs = do
  
  
  let s = (fromRational (sum (map snd xs))) :: Double 
      cs = scanl1 (\(_,q) (y,s') -> (y, s'+q)) xs       
  p <- liftM toRational $ getRandomR (0.0,s)
  return . fst . head $ dropWhile (\(_,q) -> q < p) cs
uniform :: (MonadRandom m) => [a] -> m a
uniform = fromList . fmap (flip (,) 1)
instance (MonadRandom m) => MonadRandom (IdentityT m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m) => MonadRandom (SL.StateT s m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m) => MonadRandom (SS.StateT s m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m, Monoid w) => MonadRandom (WL.WriterT w m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m, Monoid w) => MonadRandom (WS.WriterT w m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m) => MonadRandom (ReaderT r m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m, Monoid w) => MonadRandom (RWSL.RWST r w s m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m, Monoid w) => MonadRandom (RWSS.RWST r w s m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m) => MonadRandom (ExceptT e m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (Error e, MonadRandom m) => MonadRandom (ErrorT e m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadRandom m) => MonadRandom (MaybeT m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance MonadRandom m => MonadRandom (ContT r m) where
    getRandom = lift getRandom
    getRandomR = lift . getRandomR
    getRandoms = lift getRandoms
    getRandomRs = lift . getRandomRs
instance (MonadSplit g m) => MonadSplit g (IdentityT m) where
    getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (SL.StateT s m) where
    getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (SS.StateT s m) where
    getSplit = lift getSplit
instance (MonadSplit g m, Monoid w) => MonadSplit g (WL.WriterT w m) where
    getSplit = lift getSplit
instance (MonadSplit g m, Monoid w) => MonadSplit g (WS.WriterT w m) where
    getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ReaderT r m) where
    getSplit = lift getSplit
instance (MonadSplit g m, Monoid w) => MonadSplit g (RWSL.RWST r w s m) where
    getSplit = lift getSplit
instance (MonadSplit g m, Monoid w) => MonadSplit g (RWSS.RWST r w s m) where
    getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ExceptT e m) where
    getSplit = lift getSplit
instance (Error e, MonadSplit g m) => MonadSplit g (ErrorT e m) where
    getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (MaybeT m) where
    getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ContT r m) where
    getSplit = lift getSplit
instance (MonadState s m) => MonadState s (RandT g m) where
    get = lift get
    put = lift . put
instance MonadRandom IO where
    getRandom = randomIO
    getRandomR = randomRIO
    getRandoms = fmap randoms newStdGen
    getRandomRs b = fmap (randomRs b) newStdGen
instance MonadSplit StdGen IO where
    getSplit = newStdGen