{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeApplications #-} module Torch.Distributions.Bernoulli ( Bernoulli (..), fromProbs, fromLogits, ) where import qualified Torch.DType as D import qualified Torch.Distributions.Constraints as Constraints import Torch.Distributions.Distribution import qualified Torch.Functional as F import qualified Torch.Functional.Internal as I import Torch.Scalar import qualified Torch.Tensor as D import qualified Torch.TensorFactories as D import Torch.TensorOptions import Torch.Typed.Functional (reductionVal) data Bernoulli = Bernoulli { Bernoulli -> Tensor probs :: D.Tensor, Bernoulli -> Tensor logits :: D.Tensor } deriving (Int -> Bernoulli -> ShowS [Bernoulli] -> ShowS Bernoulli -> String (Int -> Bernoulli -> ShowS) -> (Bernoulli -> String) -> ([Bernoulli] -> ShowS) -> Show Bernoulli forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a $cshowsPrec :: Int -> Bernoulli -> ShowS showsPrec :: Int -> Bernoulli -> ShowS $cshow :: Bernoulli -> String show :: Bernoulli -> String $cshowList :: [Bernoulli] -> ShowS showList :: [Bernoulli] -> ShowS Show) instance Distribution Bernoulli where batchShape :: Bernoulli -> [Int] batchShape Bernoulli d = [] eventShape :: Bernoulli -> [Int] eventShape Bernoulli _d = [] expand :: Bernoulli -> [Int] -> Bernoulli expand Bernoulli d = Tensor -> Bernoulli fromProbs (Tensor -> Bernoulli) -> ([Int] -> Tensor) -> [Int] -> Bernoulli forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> Bool -> [Int] -> Tensor F.expand (Bernoulli -> Tensor probs Bernoulli d) Bool False support :: Bernoulli -> Tensor -> Tensor support Bernoulli d = Tensor -> Tensor Constraints.boolean mean :: Bernoulli -> Tensor mean = Bernoulli -> Tensor probs variance :: Bernoulli -> Tensor variance Bernoulli d = Tensor p Tensor -> Tensor -> Tensor `F.mul` (Tensor -> Tensor D.onesLike Tensor p Tensor -> Tensor -> Tensor `F.sub` Tensor p) where p :: Tensor p = Bernoulli -> Tensor probs Bernoulli d sample :: Bernoulli -> [Int] -> IO Tensor sample Bernoulli d = Tensor -> IO Tensor D.bernoulliIO' (Tensor -> IO Tensor) -> ([Int] -> Tensor) -> [Int] -> IO Tensor forall b c a. (b -> c) -> (a -> b) -> a -> c . Tensor -> Bool -> [Int] -> Tensor F.expand (Bernoulli -> Tensor probs Bernoulli d) Bool False ([Int] -> Tensor) -> ([Int] -> [Int]) -> [Int] -> Tensor forall b c a. (b -> c) -> (a -> b) -> a -> c . Bernoulli -> [Int] -> [Int] forall a. Distribution a => a -> [Int] -> [Int] extendedShape Bernoulli d logProb :: Bernoulli -> Tensor -> Tensor logProb Bernoulli d Tensor value = Int -> Tensor -> Tensor forall a. Scalar a => a -> Tensor -> Tensor F.mulScalar (-Int 1 :: Int) (Tensor -> Tensor -> Tensor bce' (Bernoulli -> Tensor logits Bernoulli d) Tensor value) entropy :: Bernoulli -> Tensor entropy Bernoulli d = Tensor -> Tensor -> Tensor bce' (Bernoulli -> Tensor logits Bernoulli d) (Tensor -> Tensor) -> Tensor -> Tensor forall a b. (a -> b) -> a -> b $ Bernoulli -> Tensor probs Bernoulli d enumerateSupport :: Bernoulli -> Bool -> Tensor enumerateSupport Bernoulli d Bool doExpand = (if Bool doExpand then \Tensor t -> Tensor -> Bool -> [Int] -> Tensor F.expand Tensor t Bool False ([-Int 1] [Int] -> [Int] -> [Int] forall a. Semigroup a => a -> a -> a <> Bernoulli -> [Int] forall a. Distribution a => a -> [Int] batchShape Bernoulli d) else Tensor -> Tensor forall a. a -> a id) Tensor values where values :: Tensor values = [Int] -> Tensor -> Tensor D.reshape ([-Int 1] [Int] -> [Int] -> [Int] forall a. Semigroup a => a -> a -> a <> Int -> Int -> [Int] forall a. Int -> a -> [a] replicate ([Int] -> Int forall a. [a] -> Int forall (t :: * -> *) a. Foldable t => t a -> Int length ([Int] -> Int) -> [Int] -> Int forall a b. (a -> b) -> a -> b $ Bernoulli -> [Int] forall a. Distribution a => a -> [Int] batchShape Bernoulli d) Int 1) (Tensor -> Tensor) -> Tensor -> Tensor forall a b. (a -> b) -> a -> b $ [Float] -> Tensor forall a. TensorLike a => a -> Tensor D.asTensor [Float 0.0, Float 1.0 :: Float] bce' :: D.Tensor -> D.Tensor -> D.Tensor bce' :: Tensor -> Tensor -> Tensor bce' Tensor logits Tensor probs = Tensor -> Tensor -> Tensor -> Tensor -> Int -> Tensor I.binary_cross_entropy_with_logits Tensor logits Tensor probs (Tensor -> Tensor D.onesLike Tensor logits) ([Int] -> TensorOptions -> Tensor D.ones [Int -> Tensor -> Int D.size (-Int 1) Tensor logits] TensorOptions D.float_opts) (Int -> Tensor) -> Int -> Tensor forall a b. (a -> b) -> a -> b $ forall {k} (reduction :: k). KnownReduction reduction => Int forall (reduction :: Reduction). KnownReduction reduction => Int reductionVal @(F.ReduceNone) fromProbs :: D.Tensor -> Bernoulli fromProbs :: Tensor -> Bernoulli fromProbs Tensor probs = Tensor -> Tensor -> Bernoulli Bernoulli Tensor probs (Tensor -> Bernoulli) -> Tensor -> Bernoulli forall a b. (a -> b) -> a -> b $ Bool -> Tensor -> Tensor probsToLogits Bool False Tensor probs fromLogits :: D.Tensor -> Bernoulli fromLogits :: Tensor -> Bernoulli fromLogits Tensor logits = Tensor -> Tensor -> Bernoulli Bernoulli (Bool -> Tensor -> Tensor probsToLogits Bool False Tensor logits) Tensor logits