{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    UndecidableInstances
  #-}

{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}

module Data.Random.Distribution.Bernoulli where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform

import Data.Ratio
import Data.Complex
import Data.Int
import Data.Word

-- |Generate a Bernoulli variate with the given probability.  For @Bool@ results,
-- @bernoulli p@ will return True (p*100)% of the time and False otherwise.
-- For numerical types, True is replaced by 1 and False by 0.
bernoulli :: Distribution (Bernoulli b) a => b -> RVar a
bernoulli :: forall b a. Distribution (Bernoulli b) a => b -> RVar a
bernoulli b
p = forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (forall b a. b -> Bernoulli b a
Bernoulli b
p)

-- |Generate a Bernoulli process with the given probability.  For @Bool@ results,
-- @bernoulli p@ will return True (p*100)% of the time and False otherwise.
-- For numerical types, True is replaced by 1 and False by 0.
bernoulliT :: Distribution (Bernoulli b) a => b -> RVarT m a
bernoulliT :: forall b a (m :: * -> *).
Distribution (Bernoulli b) a =>
b -> RVarT m a
bernoulliT b
p = forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (forall b a. b -> Bernoulli b a
Bernoulli b
p)

-- |A random variable whose value is 'True' the given fraction of the time
-- and 'False' the rest.
boolBernoulli :: (Fractional a, Ord a, Distribution StdUniform a) => a -> RVarT m Bool
boolBernoulli :: forall a (m :: * -> *).
(Fractional a, Ord a, Distribution StdUniform a) =>
a -> RVarT m Bool
boolBernoulli a
p = do
    a
x <- forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
x forall a. Ord a => a -> a -> Bool
<= a
p)

boolBernoulliCDF :: (Real a) => a -> Bool -> Double
boolBernoulliCDF :: forall a. Real a => a -> Bool -> Double
boolBernoulliCDF a
_ Bool
True  = Double
1
boolBernoulliCDF a
p Bool
False = (Double
1 forall a. Num a => a -> a -> a
- forall a b. (Real a, Fractional b) => a -> b
realToFrac a
p)

-- | @generalBernoulli t f p@ generates a random variable whose value is @t@
-- with probability @p@ and @f@ with probability @1-p@.
generalBernoulli :: Distribution (Bernoulli b) Bool => a -> a -> b -> RVarT m a
generalBernoulli :: forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli a
f a
t b
p = do
    Bool
x <- forall b a (m :: * -> *).
Distribution (Bernoulli b) a =>
b -> RVarT m a
bernoulliT b
p
    forall (m :: * -> *) a. Monad m => a -> m a
return (if Bool
x then a
t else a
f)

generalBernoulliCDF :: CDF (Bernoulli b) Bool => (a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF :: forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF a -> a -> Bool
gte a
f a
t b
p a
x
    | a
f a -> a -> Bool
`gte` a
t = forall a. HasCallStack => [Char] -> a
error [Char]
"generalBernoulliCDF: f >= t"
    | a
x a -> a -> Bool
`gte` a
t = forall (d :: * -> *) t. CDF d t => d t -> t -> Double
cdf (forall b a. b -> Bernoulli b a
Bernoulli b
p) Bool
True
    | a
x a -> a -> Bool
`gte` a
f = forall (d :: * -> *) t. CDF d t => d t -> t -> Double
cdf (forall b a. b -> Bernoulli b a
Bernoulli b
p) Bool
False
    | Bool
otherwise = Double
0

newtype Bernoulli b a = Bernoulli b

instance (Fractional b, Ord b, Distribution StdUniform b)
       => Distribution (Bernoulli b) Bool
    where
        rvarT :: forall (n :: * -> *). Bernoulli b Bool -> RVarT n Bool
rvarT (Bernoulli b
p) = forall a (m :: * -> *).
(Fractional a, Ord a, Distribution StdUniform a) =>
a -> RVarT m Bool
boolBernoulli b
p
instance (Distribution (Bernoulli b) Bool, Real b)
       => CDF (Bernoulli b) Bool
    where
        cdf :: Bernoulli b Bool -> Bool -> Double
cdf  (Bernoulli b
p) = forall a. Real a => a -> Bool -> Double
boolBernoulliCDF b
p

instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Integer where
    rvarT :: forall (n :: * -> *). Bernoulli b Integer -> RVarT n Integer
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Integer
0 Integer
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Integer where
    cdf :: Bernoulli b Integer -> Integer -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Integer
0 Integer
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int where
    rvarT :: forall (n :: * -> *). Bernoulli b Int -> RVarT n Int
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int
0 Int
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int where
    cdf :: Bernoulli b Int -> Int -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Int
0 Int
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int8 where
    rvarT :: forall (n :: * -> *). Bernoulli b Int8 -> RVarT n Int8
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int8
0 Int8
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int8 where
    cdf :: Bernoulli b Int8 -> Int8 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Int8
0 Int8
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int16 where
    rvarT :: forall (n :: * -> *). Bernoulli b Int16 -> RVarT n Int16
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int16
0 Int16
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int16 where
    cdf :: Bernoulli b Int16 -> Int16 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Int16
0 Int16
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int32 where
    rvarT :: forall (n :: * -> *). Bernoulli b Int32 -> RVarT n Int32
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int32
0 Int32
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int32 where
    cdf :: Bernoulli b Int32 -> Int32 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Int32
0 Int32
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int64 where
    rvarT :: forall (n :: * -> *). Bernoulli b Int64 -> RVarT n Int64
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int64
0 Int64
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int64 where
    cdf :: Bernoulli b Int64 -> Int64 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Int64
0 Int64
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word where
    rvarT :: forall (n :: * -> *). Bernoulli b Word -> RVarT n Word
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word
0 Word
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word where
    cdf :: Bernoulli b Word -> Word -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Word
0 Word
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word8 where
    rvarT :: forall (n :: * -> *). Bernoulli b Word8 -> RVarT n Word8
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word8
0 Word8
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word8 where
    cdf :: Bernoulli b Word8 -> Word8 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Word8
0 Word8
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word16 where
    rvarT :: forall (n :: * -> *). Bernoulli b Word16 -> RVarT n Word16
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word16
0 Word16
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word16 where
    cdf :: Bernoulli b Word16 -> Word16 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Word16
0 Word16
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word32 where
    rvarT :: forall (n :: * -> *). Bernoulli b Word32 -> RVarT n Word32
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word32
0 Word32
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word32 where
    cdf :: Bernoulli b Word32 -> Word32 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Word32
0 Word32
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word64 where
    rvarT :: forall (n :: * -> *). Bernoulli b Word64 -> RVarT n Word64
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word64
0 Word64
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word64 where
    cdf :: Bernoulli b Word64 -> Word64 -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Word64
0 Word64
1 b
p

instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Float where
    rvarT :: forall (n :: * -> *). Bernoulli b Float -> RVarT n Float
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Float
0 Float
1 b
p
instance CDF (Bernoulli b) Bool => CDF (Bernoulli b) Float where
    cdf :: Bernoulli b Float -> Float -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Float
0 Float
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Double where
    rvarT :: forall (n :: * -> *). Bernoulli b Double -> RVarT n Double
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Double
0 Double
1 b
p
instance CDF (Bernoulli b) Bool => CDF (Bernoulli b) Double where
    cdf :: Bernoulli b Double -> Double -> Double
cdf   (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Double
0 Double
1 b
p

instance (Distribution (Bernoulli b) Bool, Integral a)
       => Distribution (Bernoulli b) (Ratio a)
       where
           rvarT :: forall (n :: * -> *). Bernoulli b (Ratio a) -> RVarT n (Ratio a)
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Ratio a
0 Ratio a
1 b
p
instance (CDF (Bernoulli b) Bool, Integral a)
       => CDF (Bernoulli b) (Ratio a)
       where
           cdf :: Bernoulli b (Ratio a) -> Ratio a -> Double
cdf  (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF forall a. Ord a => a -> a -> Bool
(>=) Ratio a
0 Ratio a
1 b
p
instance (Distribution (Bernoulli b) Bool, RealFloat a)
       => Distribution (Bernoulli b) (Complex a)
       where
           rvarT :: forall (n :: * -> *).
Bernoulli b (Complex a) -> RVarT n (Complex a)
rvarT (Bernoulli b
p) = forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Complex a
0 Complex a
1 b
p
instance (CDF (Bernoulli b) Bool, RealFloat a)
       => CDF (Bernoulli b) (Complex a)
       where
           cdf :: Bernoulli b (Complex a) -> Complex a -> Double
cdf  (Bernoulli b
p) = forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF (\Complex a
x Complex a
y -> forall a. Complex a -> a
realPart Complex a
x forall a. Ord a => a -> a -> Bool
>= forall a. Complex a -> a
realPart Complex a
y) Complex a
0 Complex a
1 b
p