介绍PyTorch中torch.distributions的常见概率分布模型
发布时间:2023-12-18 06:11:51
PyTorch是一个广泛应用于深度学习的开源框架。PyTorch中的torch.distributions模块提供了一组常见的概率分布模型,用于生成随机变量,并进行概率计算和采样。下面将介绍一些常见的概率分布模型并附上使用例子。
1. Bernoulli分布:
Bernoulli分布用于描述二元随机变量,取值为0或1,其概率质量函数为:
p(x|p) = p^x * (1-p)^(1-x)
其中,p是成功的概率。
使用例子:
import torch from torch.distributions import Bernoulli p = torch.tensor([0.3, 0.7]) # 成功的概率 dist = Bernoulli(p) x = dist.sample() # 从Bernoulli分布中采样 prob = dist.log_prob(x) # 计算采样点的概率密度
2. Normal分布:
Normal分布(也称为高斯分布)是一种连续随机变量的分布,其概率密度函数为:
p(x|μ, σ) = (1/sqrt(2*pi*σ^2)) * exp(-(x-μ)^2 / (2σ^2))
其中,μ是均值,σ是标准差。
使用例子:
import torch from torch.distributions import Normal mu = torch.tensor([0.0, 1.0]) # 均值 sigma = torch.tensor([1.0, 2.0]) # 标准差 dist = Normal(mu, sigma) x = dist.sample() # 从Normal分布中采样 prob = dist.log_prob(x) # 计算采样点的概率密度
3. Categorical分布:
Categorical分布用于描述离散型随机变量,其概率质量函数为:
p(x|p) = ∏(p_i^x_i)
其中,p是每个类别的概率。
使用例子:
import torch from torch.distributions import Categorical p = torch.tensor([0.2, 0.3, 0.5]) # 类别的概率 dist = Categorical(p) x = dist.sample() # 从Categorical分布中采样 prob = dist.log_prob(x) # 计算采样点的概率密度
4. Exponential分布:
Exponential分布是一种连续随机变量的分布,其概率密度函数为:
p(x|λ) = λ * exp(-λx)
其中,λ是速度参数。
使用例子:
import torch from torch.distributions import Exponential rate = torch.tensor([0.5, 1.0]) # 速度参数 dist = Exponential(rate) x = dist.sample() # 从Exponential分布中采样 prob = dist.log_prob(x) # 计算采样点的概率密度
这里只介绍了部分常见的概率分布模型,PyTorch中的torch.distributions模块还提供了其他许多分布,如Uniform、Beta、Gamma等。使用torch.distributions模块可以方便地进行概率计算和随机采样,有助于深度学习中涉及概率的任务,如生成模型、贝叶斯推断等。
