{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    UndecidableInstances
  #-}

{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}

module Data.Random.Distribution.Bernoulli where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform

import Data.Ratio
import Data.Complex
import Data.Int
import Data.Word

-- |Generate a Bernoulli variate with the given probability.  For @Bool@ results,
-- @bernoulli p@ will return True (p*100)% of the time and False otherwise.
-- For numerical types, True is replaced by 1 and False by 0.
bernoulli :: Distribution (Bernoulli b) a => b -> RVar a
bernoulli :: b -> RVar a
bernoulli b
p = Bernoulli b a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (b -> Bernoulli b a
forall b a. b -> Bernoulli b a
Bernoulli b
p)

-- |Generate a Bernoulli process with the given probability.  For @Bool@ results,
-- @bernoulli p@ will return True (p*100)% of the time and False otherwise.
-- For numerical types, True is replaced by 1 and False by 0.
bernoulliT :: Distribution (Bernoulli b) a => b -> RVarT m a
bernoulliT :: b -> RVarT m a
bernoulliT b
p = Bernoulli b a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (b -> Bernoulli b a
forall b a. b -> Bernoulli b a
Bernoulli b
p)

-- |A random variable whose value is 'True' the given fraction of the time
-- and 'False' the rest.
boolBernoulli :: (Fractional a, Ord a, Distribution StdUniform a) => a -> RVarT m Bool
boolBernoulli :: a -> RVarT m Bool
boolBernoulli a
p = do
    a
x <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
    Bool -> RVarT m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
p)

boolBernoulliCDF :: (Real a) => a -> Bool -> Double
boolBernoulliCDF :: a -> Bool -> Double
boolBernoulliCDF a
_ Bool
True  = Double
1
boolBernoulliCDF a
p Bool
False = (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
p)

-- | @generalBernoulli t f p@ generates a random variable whose value is @t@
-- with probability @p@ and @f@ with probability @1-p@.
generalBernoulli :: Distribution (Bernoulli b) Bool => a -> a -> b -> RVarT m a
generalBernoulli :: a -> a -> b -> RVarT m a
generalBernoulli a
f a
t b
p = do
    Bool
x <- b -> RVarT m Bool
forall b a (m :: * -> *).
Distribution (Bernoulli b) a =>
b -> RVarT m a
bernoulliT b
p
    a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (if Bool
x then a
t else a
f)

generalBernoulliCDF :: CDF (Bernoulli b) Bool => (a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF :: (a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF a -> a -> Bool
gte a
f a
t b
p a
x
    | a
f a -> a -> Bool
`gte` a
t = [Char] -> Double
forall a. HasCallStack => [Char] -> a
error [Char]
"generalBernoulliCDF: f >= t"
    | a
x a -> a -> Bool
`gte` a
t = Bernoulli b Bool -> Bool -> Double
forall (d :: * -> *) t. CDF d t => d t -> t -> Double
cdf (b -> Bernoulli b Bool
forall b a. b -> Bernoulli b a
Bernoulli b
p) Bool
True
    | a
x a -> a -> Bool
`gte` a
f = Bernoulli b Bool -> Bool -> Double
forall (d :: * -> *) t. CDF d t => d t -> t -> Double
cdf (b -> Bernoulli b Bool
forall b a. b -> Bernoulli b a
Bernoulli b
p) Bool
False
    | Bool
otherwise = Double
0

newtype Bernoulli b a = Bernoulli b

instance (Fractional b, Ord b, Distribution StdUniform b)
       => Distribution (Bernoulli b) Bool
    where
        rvarT :: Bernoulli b Bool -> RVarT n Bool
rvarT (Bernoulli b
p) = b -> RVarT n Bool
forall a (m :: * -> *).
(Fractional a, Ord a, Distribution StdUniform a) =>
a -> RVarT m Bool
boolBernoulli b
p
instance (Distribution (Bernoulli b) Bool, Real b)
       => CDF (Bernoulli b) Bool
    where
        cdf :: Bernoulli b Bool -> Bool -> Double
cdf  (Bernoulli b
p) = b -> Bool -> Double
forall a. Real a => a -> Bool -> Double
boolBernoulliCDF b
p

instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Integer where
    rvarT :: Bernoulli b Integer -> RVarT n Integer
rvarT (Bernoulli b
p) = Integer -> Integer -> b -> RVarT n Integer
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Integer
0 Integer
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Integer where
    cdf :: Bernoulli b Integer -> Integer -> Double
cdf   (Bernoulli b
p) = (Integer -> Integer -> Bool)
-> Integer -> Integer -> b -> Integer -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Integer
0 Integer
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int where
    rvarT :: Bernoulli b Int -> RVarT n Int
rvarT (Bernoulli b
p) = Int -> Int -> b -> RVarT n Int
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int
0 Int
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int where
    cdf :: Bernoulli b Int -> Int -> Double
cdf   (Bernoulli b
p) = (Int -> Int -> Bool) -> Int -> Int -> b -> Int -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Int
0 Int
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int8 where
    rvarT :: Bernoulli b Int8 -> RVarT n Int8
rvarT (Bernoulli b
p) = Int8 -> Int8 -> b -> RVarT n Int8
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int8
0 Int8
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int8 where
    cdf :: Bernoulli b Int8 -> Int8 -> Double
cdf   (Bernoulli b
p) = (Int8 -> Int8 -> Bool) -> Int8 -> Int8 -> b -> Int8 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Int8 -> Int8 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Int8
0 Int8
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int16 where
    rvarT :: Bernoulli b Int16 -> RVarT n Int16
rvarT (Bernoulli b
p) = Int16 -> Int16 -> b -> RVarT n Int16
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int16
0 Int16
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int16 where
    cdf :: Bernoulli b Int16 -> Int16 -> Double
cdf   (Bernoulli b
p) = (Int16 -> Int16 -> Bool) -> Int16 -> Int16 -> b -> Int16 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Int16
0 Int16
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int32 where
    rvarT :: Bernoulli b Int32 -> RVarT n Int32
rvarT (Bernoulli b
p) = Int32 -> Int32 -> b -> RVarT n Int32
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int32
0 Int32
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int32 where
    cdf :: Bernoulli b Int32 -> Int32 -> Double
cdf   (Bernoulli b
p) = (Int32 -> Int32 -> Bool) -> Int32 -> Int32 -> b -> Int32 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Int32
0 Int32
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Int64 where
    rvarT :: Bernoulli b Int64 -> RVarT n Int64
rvarT (Bernoulli b
p) = Int64 -> Int64 -> b -> RVarT n Int64
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Int64
0 Int64
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Int64 where
    cdf :: Bernoulli b Int64 -> Int64 -> Double
cdf   (Bernoulli b
p) = (Int64 -> Int64 -> Bool) -> Int64 -> Int64 -> b -> Int64 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Int64
0 Int64
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word where
    rvarT :: Bernoulli b Word -> RVarT n Word
rvarT (Bernoulli b
p) = Word -> Word -> b -> RVarT n Word
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word
0 Word
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word where
    cdf :: Bernoulli b Word -> Word -> Double
cdf   (Bernoulli b
p) = (Word -> Word -> Bool) -> Word -> Word -> b -> Word -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Word
0 Word
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word8 where
    rvarT :: Bernoulli b Word8 -> RVarT n Word8
rvarT (Bernoulli b
p) = Word8 -> Word8 -> b -> RVarT n Word8
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word8
0 Word8
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word8 where
    cdf :: Bernoulli b Word8 -> Word8 -> Double
cdf   (Bernoulli b
p) = (Word8 -> Word8 -> Bool) -> Word8 -> Word8 -> b -> Word8 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Word8
0 Word8
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word16 where
    rvarT :: Bernoulli b Word16 -> RVarT n Word16
rvarT (Bernoulli b
p) = Word16 -> Word16 -> b -> RVarT n Word16
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word16
0 Word16
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word16 where
    cdf :: Bernoulli b Word16 -> Word16 -> Double
cdf   (Bernoulli b
p) = (Word16 -> Word16 -> Bool)
-> Word16 -> Word16 -> b -> Word16 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Word16
0 Word16
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word32 where
    rvarT :: Bernoulli b Word32 -> RVarT n Word32
rvarT (Bernoulli b
p) = Word32 -> Word32 -> b -> RVarT n Word32
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word32
0 Word32
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word32 where
    cdf :: Bernoulli b Word32 -> Word32 -> Double
cdf   (Bernoulli b
p) = (Word32 -> Word32 -> Bool)
-> Word32 -> Word32 -> b -> Word32 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Word32 -> Word32 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Word32
0 Word32
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Word64 where
    rvarT :: Bernoulli b Word64 -> RVarT n Word64
rvarT (Bernoulli b
p) = Word64 -> Word64 -> b -> RVarT n Word64
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Word64
0 Word64
1 b
p
instance CDF (Bernoulli b) Bool          => CDF (Bernoulli b) Word64 where
    cdf :: Bernoulli b Word64 -> Word64 -> Double
cdf   (Bernoulli b
p) = (Word64 -> Word64 -> Bool)
-> Word64 -> Word64 -> b -> Word64 -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Word64
0 Word64
1 b
p

instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Float where
    rvarT :: Bernoulli b Float -> RVarT n Float
rvarT (Bernoulli b
p) = Float -> Float -> b -> RVarT n Float
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Float
0 Float
1 b
p
instance CDF (Bernoulli b) Bool => CDF (Bernoulli b) Float where
    cdf :: Bernoulli b Float -> Float -> Double
cdf   (Bernoulli b
p) = (Float -> Float -> Bool) -> Float -> Float -> b -> Float -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Float
0 Float
1 b
p
instance Distribution (Bernoulli b) Bool => Distribution (Bernoulli b) Double where
    rvarT :: Bernoulli b Double -> RVarT n Double
rvarT (Bernoulli b
p) = Double -> Double -> b -> RVarT n Double
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Double
0 Double
1 b
p
instance CDF (Bernoulli b) Bool => CDF (Bernoulli b) Double where
    cdf :: Bernoulli b Double -> Double -> Double
cdf   (Bernoulli b
p) = (Double -> Double -> Bool)
-> Double -> Double -> b -> Double -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Double
0 Double
1 b
p

instance (Distribution (Bernoulli b) Bool, Integral a)
       => Distribution (Bernoulli b) (Ratio a)
       where
           rvarT :: Bernoulli b (Ratio a) -> RVarT n (Ratio a)
rvarT (Bernoulli b
p) = Ratio a -> Ratio a -> b -> RVarT n (Ratio a)
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Ratio a
0 Ratio a
1 b
p
instance (CDF (Bernoulli b) Bool, Integral a)
       => CDF (Bernoulli b) (Ratio a)
       where
           cdf :: Bernoulli b (Ratio a) -> Ratio a -> Double
cdf  (Bernoulli b
p) = (Ratio a -> Ratio a -> Bool)
-> Ratio a -> Ratio a -> b -> Ratio a -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF Ratio a -> Ratio a -> Bool
forall a. Ord a => a -> a -> Bool
(>=) Ratio a
0 Ratio a
1 b
p
instance (Distribution (Bernoulli b) Bool, RealFloat a)
       => Distribution (Bernoulli b) (Complex a)
       where
           rvarT :: Bernoulli b (Complex a) -> RVarT n (Complex a)
rvarT (Bernoulli b
p) = Complex a -> Complex a -> b -> RVarT n (Complex a)
forall b a (m :: * -> *).
Distribution (Bernoulli b) Bool =>
a -> a -> b -> RVarT m a
generalBernoulli Complex a
0 Complex a
1 b
p
instance (CDF (Bernoulli b) Bool, RealFloat a)
       => CDF (Bernoulli b) (Complex a)
       where
           cdf :: Bernoulli b (Complex a) -> Complex a -> Double
cdf  (Bernoulli b
p) = (Complex a -> Complex a -> Bool)
-> Complex a -> Complex a -> b -> Complex a -> Double
forall b a.
CDF (Bernoulli b) Bool =>
(a -> a -> Bool) -> a -> a -> b -> a -> Double
generalBernoulliCDF (\Complex a
x Complex a
y -> Complex a -> a
forall a. Complex a -> a
realPart Complex a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= Complex a -> a
forall a. Complex a -> a
realPart Complex a
y) Complex a
0 Complex a
1 b
p