module Torch.Distributions.Categorical
( Categorical (..),
fromProbs,
fromLogits,
)
where
import qualified Torch.DType as D
import qualified Torch.Distributions.Constraints as Constraints
import Torch.Distributions.Distribution
import qualified Torch.Functional as F
import qualified Torch.Functional.Internal as I
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
data Categorical = Categorical
{ Categorical -> Tensor
probs :: D.Tensor,
Categorical -> Tensor
logits :: D.Tensor
}
deriving (Int -> Categorical -> ShowS
[Categorical] -> ShowS
Categorical -> String
(Int -> Categorical -> ShowS)
-> (Categorical -> String)
-> ([Categorical] -> ShowS)
-> Show Categorical
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Categorical -> ShowS
showsPrec :: Int -> Categorical -> ShowS
$cshow :: Categorical -> String
show :: Categorical -> String
$cshowList :: [Categorical] -> ShowS
showList :: [Categorical] -> ShowS
Show)
instance Distribution Categorical where
batchShape :: Categorical -> [Int]
batchShape Categorical
d =
if Tensor -> Int
D.numel (Categorical -> Tensor
probs Categorical
d) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
then [Int] -> [Int]
forall a. HasCallStack => [a] -> [a]
init (Tensor -> [Int]
D.shape (Tensor -> [Int]) -> Tensor -> [Int]
forall a b. (a -> b) -> a -> b
$ Categorical -> Tensor
probs Categorical
d)
else []
eventShape :: Categorical -> [Int]
eventShape Categorical
_d = []
expand :: Categorical -> [Int] -> Categorical
expand Categorical
d [Int]
batchShape' = Tensor -> Categorical
fromProbs (Tensor -> Categorical) -> Tensor -> Categorical
forall a b. (a -> b) -> a -> b
$ Tensor -> Bool -> [Int] -> Tensor
F.expand (Categorical -> Tensor
probs Categorical
d) Bool
False (Categorical -> [Int]
paramShape Categorical
d)
where
paramShape :: Categorical -> [Int]
paramShape Categorical
d' = [Int]
batchShape' [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> [Categorical -> Int
numEvents Categorical
d']
support :: Categorical -> Tensor -> Tensor
support Categorical
d = Int -> Int -> Tensor -> Tensor
Constraints.integerInterval Int
0 (Int -> Tensor -> Tensor) -> Int -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (Categorical -> Int
numEvents Categorical
d) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
mean :: Categorical -> Tensor
mean Categorical
d = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
F.divScalar (Float
0.0 :: Float) ([Int] -> TensorOptions -> Tensor
D.ones (Categorical -> [Int] -> [Int]
forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d []) TensorOptions
D.float_opts)
variance :: Categorical -> Tensor
variance Categorical
d = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
F.divScalar (Float
0.0 :: Float) ([Int] -> TensorOptions -> Tensor
D.ones (Categorical -> [Int] -> [Int]
forall a. Distribution a => a -> [Int] -> [Int]
extendedShape Categorical
d []) TensorOptions
D.float_opts)
sample :: Categorical -> [Int] -> IO Tensor
sample Categorical
d [Int]
sampleShape = do
let probs2d :: Tensor
probs2d = [Int] -> Tensor -> Tensor
D.reshape [-Int
1, (Categorical -> Int
numEvents Categorical
d)] (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Categorical -> Tensor
probs Categorical
d
samples2d <- Tensor -> Tensor
F.transpose2D (Tensor -> Tensor) -> IO Tensor -> IO Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Int -> Bool -> IO Tensor
D.multinomialIO Tensor
probs2d ([Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sampleShape) Bool
True
return $ D.reshape (extendedShape d sampleShape) samples2d
logProb :: Categorical -> Tensor -> Tensor
logProb Categorical
d Tensor
value =
let value' :: Tensor
value' = Tensor -> Int -> Tensor
I.unsqueeze (DType -> Tensor -> Tensor
F.toDType DType
D.Int64 Tensor
value) (-Int
1 :: Int)
value'' :: Tensor
value'' = Int -> Int -> Tensor -> Tensor
D.select (-Int
1) Int
0 Tensor
value'
in Int -> Tensor -> Tensor
F.squeezeDim (-Int
1) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor -> Int -> Tensor -> Bool -> Tensor
I.gather (Categorical -> Tensor
logits Categorical
d) (-Int
1 :: Int) Tensor
value'' Bool
False
entropy :: Categorical -> Tensor
entropy Categorical
d = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
F.mulScalar (-Float
1.0 :: Float) (Dim -> KeepDim -> DType -> Tensor -> Tensor
F.sumDim (Int -> Dim
F.Dim (Int -> Dim) -> Int -> Dim
forall a b. (a -> b) -> a -> b
$ -Int
1) KeepDim
F.RemoveDim (Tensor -> DType
D.dtype Tensor
pLogP) Tensor
pLogP)
where
pLogP :: Tensor
pLogP = Categorical -> Tensor
logits Categorical
d Tensor -> Tensor -> Tensor
`F.mul` Categorical -> Tensor
probs Categorical
d
enumerateSupport :: Categorical -> Bool -> Tensor
enumerateSupport Categorical
d Bool
doExpand =
(if Bool
doExpand then \Tensor
t -> Tensor -> Bool -> [Int] -> Tensor
F.expand Tensor
t Bool
False ([-Int
1] [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Categorical -> [Int]
forall a. Distribution a => a -> [Int]
batchShape Categorical
d) else Tensor -> Tensor
forall a. a -> a
id) Tensor
values
where
values :: Tensor
values = [Int] -> Tensor -> Tensor
D.reshape ([-Int
1] [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<> Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ Categorical -> [Int]
forall a. Distribution a => a -> [Int]
batchShape Categorical
d) Int
1) (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ [Float] -> Tensor
forall a. TensorLike a => a -> Tensor
D.asTensor [Float
0.0, Float
1.0 :: Float]
numEvents :: Categorical -> Int
numEvents :: Categorical -> Int
numEvents (Categorical Tensor
ps Tensor
_logits) = Int -> Tensor -> Int
D.size (-Int
1) Tensor
ps
fromProbs :: D.Tensor -> Categorical
fromProbs :: Tensor -> Categorical
fromProbs Tensor
ps = Tensor -> Tensor -> Categorical
Categorical Tensor
ps (Tensor -> Categorical) -> Tensor -> Categorical
forall a b. (a -> b) -> a -> b
$ Bool -> Tensor -> Tensor
probsToLogits Bool
False Tensor
ps
fromLogits :: D.Tensor -> Categorical
fromLogits :: Tensor -> Categorical
fromLogits Tensor
logits' = Tensor -> Tensor -> Categorical
Categorical (Bool -> Tensor -> Tensor
logitsToProbs Bool
False Tensor
logits') Tensor
logits'