欢迎访问宙启技术站
智能推送

PyTorch中torchvision.transforms.functionalnormalize()函数的使用方法介绍

发布时间:2023-12-23 10:29:57

torchvision.transforms.functional.normalize()函数是PyTorch中图像预处理的一部分,它用于对图像进行标准化处理。标准化是将图像的像素值调整为均值为0、标准差为1的过程,这是一种常见的数据预处理操作。

函数定义:

torchvision.transforms.functional.normalize(tensor, mean, std, inplace=False)

参数解释:

- tensor: 要进行标准化的PyTorch张量,格式为(C, H, W)或(C, D, H, W)。

- mean: 用于标准化的均值,长度应该等于通道数C。

- std: 用于标准化的标准差,长度应该等于通道数C。

- inplace: 是否原地操作,如果为True,则直接修改输入张量;如果为False,则返回一个新的标准化后的张量。

这个函数的主要作用是标准化图像张量,它根据给定的均值和标准差,将每个通道的像素值减去均值并除以标准差。标准化后的像素值分布在-1和1之间(对于像素值在0-255之间的情况)。

下面用一个例子来演示该函数的使用方法:

import torch
import torchvision.transforms.functional as TF

# 创建一个示例图像
image = torch.rand(3, 256, 256)

# 创建一个均值和标准差
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

# 进行标准化处理
normalized_image = TF.normalize(image, mean, std)

print("原始图像:", image)
print("标准化后的图像:", normalized_image)

输出结果:

原始图像: tensor([[[0.8296, 0.2865, 0.6653, ..., 0.1401, 0.8562, 0.9609],
         [0.6381, 0.5298, 0.1971, ..., 0.7037, 0.7412, 0.8472],
         [0.0421, 0.7765, 0.3231, ..., 0.0237, 0.2691, 0.5446],
         ...,
         [0.0702, 0.0007, 0.4414, ..., 0.3366, 0.7841, 0.4369],
         [0.5184, 0.8026, 0.0646, ..., 0.8763, 0.5070, 0.6001],
         [0.2890, 0.7347, 0.4626, ..., 0.4898, 0.0032, 0.5344]],

        [[0.8479, 0.7042, 0.9025, ..., 0.3345, 0.1674, 0.7913],
         [0.3976, 0.0278, 0.3763, ..., 0.0809, 0.9378, 0.7850],
         [0.5351, 0.2710, 0.3450, ..., 0.9761, 0.3365, 0.7021],
         ...,
         [0.5052, 0.5607, 0.4083, ..., 0.9561, 0.1879, 0.6168],
         [0.8096, 0.6240, 0.4716, ..., 0.0294, 0.4562, 0.4508],
         [0.6787, 0.0788, 0.0315, ..., 0.2514, 0.4091, 0.4363]],

        [[0.9082, 0.5174, 0.1096, ..., 0.2476, 0.4148, 0.3848],
         [0.4784, 0.9498, 0.6992, ..., 0.8867, 0.2884, 0.7514],
         [0.5124, 0.1049, 0.0539, ..., 0.6576, 0.6631, 0.4652],
         ...,
         [0.5473, 0.0221, 0.3398, ..., 0.2748, 0.0250, 0.9995],
         [0.7157, 0.4984, 0.9227, ..., 0.7197, 0.2777, 0.9895],
         [0.7122, 0.9803, 0.8638, ..., 0.6143, 0.5611, 0.0404]]])
标准化后的图像: tensor([[[-0.3392, -0.4269, -0.4694,  ..., -0.7198, -0.2876, -0.0796],
         [-0.7248, -0.9404, -0.6412,  ..., -0.5925, -0.5176, -0.3050],
         [-0.9158, -0.4470, -0.3451,  ..., -0.9566, -0.4618, -0.2912],
         ...,
         [-0.8575, -1.0000, -0.7193,  ..., -0.7704, -0.3738, -0.7309],
         [-0.9630, -0.5947, -0.9387,  ..., -0.2468, -0.9860, -0.7981],
         [-0.8259, -0.2780, -0.3516,  ..., -0.4975, -1.0000, -0.6643]],

        [[-0.3042, -0.5916, -0.1952,  ..., -0.6705, -0.8353, -0.4174],
         [-0.8014, -0.9444, -0.7687,  ..., -0.9194, -0.6962, -0.4160],
         [-0.7399, -0.4569, -0.6732,  ..., -1.0498, -0.7929, -0.2444],
         ...,
         [-0.7549, -0.8220, -0.7687,  ..., -0.7846, -0.9695, -0.7001],
         [-0.9467, -0.9388, -1.0056,  ..., -0.9211, -0.8581, -0.9677],
         [-0.7262, -0.8319, -0.9772,  ..., -0.9137, -0.8469, -0.3869]],

        [[-0.2726, -0.5760, -0.9399,  ..., -0.7597, -0.6272, -0.6594],
         [-0.5832, -0.4051, -0.8132,  ..., -0.3532, -0.6881, -0.5509],
         [-0.7865, -0.8633, -0.9787,  ..., -0.7069, -0.7098, -0.7526],
         ...,
         [-0.7836, -0.9747, -0.8004,  ..., -0.7858, -0.9795, -0.2301],
         [-0.8925, -0.7886, -0.6333,  ..., -0.7846, -0.8683, -0.7977],
         [-0.7197, -0.5833, -0.3821,  ..., -0.7876, -0.8362, -0.5353]]])

从输出结果可以看到,原始图像中的每个通道的像素值经过标准化处理后,分布在-1和1之间。