欢迎访问宙启技术站
智能推送

介绍PyTorch中torch.distributions的常见概率分布模型

发布时间:2023-12-18 06:11:51

PyTorch是一个广泛应用于深度学习的开源框架。PyTorch中的torch.distributions模块提供了一组常见的概率分布模型,用于生成随机变量,并进行概率计算和采样。下面将介绍一些常见的概率分布模型并附上使用例子。

1. Bernoulli分布:

Bernoulli分布用于描述二元随机变量,取值为0或1,其概率质量函数为:

p(x|p) = p^x * (1-p)^(1-x)

其中,p是成功的概率。

使用例子:

   import torch
   from torch.distributions import Bernoulli

   p = torch.tensor([0.3, 0.7])  # 成功的概率
   dist = Bernoulli(p)
   x = dist.sample()  # 从Bernoulli分布中采样
   prob = dist.log_prob(x)  # 计算采样点的概率密度
   

2. Normal分布:

Normal分布(也称为高斯分布)是一种连续随机变量的分布,其概率密度函数为:

p(x|μ, σ) = (1/sqrt(2*pi*σ^2)) * exp(-(x-μ)^2 / (2σ^2))

其中,μ是均值,σ是标准差。

使用例子:

   import torch
   from torch.distributions import Normal

   mu = torch.tensor([0.0, 1.0])  # 均值
   sigma = torch.tensor([1.0, 2.0])  # 标准差
   dist = Normal(mu, sigma)
   x = dist.sample()  # 从Normal分布中采样
   prob = dist.log_prob(x)  # 计算采样点的概率密度
   

3. Categorical分布:

Categorical分布用于描述离散型随机变量,其概率质量函数为:

p(x|p) = ∏(p_i^x_i)

其中,p是每个类别的概率。

使用例子:

   import torch
   from torch.distributions import Categorical

   p = torch.tensor([0.2, 0.3, 0.5])  # 类别的概率
   dist = Categorical(p)
   x = dist.sample()  # 从Categorical分布中采样
   prob = dist.log_prob(x)  # 计算采样点的概率密度
   

4. Exponential分布:

Exponential分布是一种连续随机变量的分布,其概率密度函数为:

p(x|λ) = λ * exp(-λx)

其中,λ是速度参数。

使用例子:

   import torch
   from torch.distributions import Exponential

   rate = torch.tensor([0.5, 1.0])  # 速度参数
   dist = Exponential(rate)
   x = dist.sample()  # 从Exponential分布中采样
   prob = dist.log_prob(x)  # 计算采样点的概率密度
   

这里只介绍了部分常见的概率分布模型,PyTorch中的torch.distributions模块还提供了其他许多分布,如Uniform、Beta、Gamma等。使用torch.distributions模块可以方便地进行概率计算和随机采样,有助于深度学习中涉及概率的任务,如生成模型、贝叶斯推断等。