欢迎访问宙启技术站
智能推送

使用Python编写的datasets.download_and_convert_cifar10函数:下载和转换CIFAR-10数据集

发布时间:2023-12-19 04:23:15

以下是一个使用Python编写的数据集下载和转换函数datasets.download_and_convert_cifar10,用于下载和转换CIFAR-10数据集。该函数会将原始数据集下载到本地,并将其转换成可以直接在机器学习模型中使用的格式。

import os
import tarfile
import urllib.request
import numpy as np
import pickle
import tensorflow as tf

def download_and_convert_cifar10(dataset_dir):
    """
    下载并转换CIFAR-10数据集
    
    参数:
        - dataset_dir:数据集保存的文件夹路径
    
    返回:
        - 无
        
    """
    if not os.path.exists(dataset_dir):
        os.makedirs(dataset_dir)

    download_url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = download_url.split('/')[-1]
    filepath = os.path.join(dataset_dir, filename)

    if not os.path.exists(filepath):
        print("开始下载并解压CIFAR-10数据集...")
        urllib.request.urlretrieve(download_url, filepath)
        tar = tarfile.open(filepath)
        tar.extractall(dataset_dir)
        tar.close()
        print("数据集下载完成!")
    
    print("开始转换数据集...")
    train_images, train_labels = _read_data_file(os.path.join(dataset_dir, "cifar-10-batches-py/data_batch_"))
    test_images, test_labels = _read_data_file(os.path.join(dataset_dir, "cifar-10-batches-py/test_batch"))
    
    train_images = np.reshape(train_images, (-1, 3, 32, 32))
    train_images = np.transpose(train_images, [0, 2, 3, 1])
    test_images = np.reshape(test_images, (-1, 3, 32, 32))
    test_images = np.transpose(test_images, [0, 2, 3, 1])
    
    train_labels = np.array(train_labels)
    test_labels = np.array(test_labels)
    
    # 存储处理后的数据
    output_path = os.path.join(dataset_dir, "cifar-10")
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    np.save(os.path.join(output_path, "train_images.npy"), train_images)
    np.save(os.path.join(output_path, "train_labels.npy"), train_labels)
    np.save(os.path.join(output_path, "test_images.npy"), test_images)
    np.save(os.path.join(output_path, "test_labels.npy"), test_labels)
    print("数据集转换完成!")

def _read_data_file(filename):
    with open(filename, 'rb') as f:
        data_dict = pickle.load(f, encoding='bytes')
        images = data_dict[b'data']
        labels = data_dict[b'labels']
    return images, labels

使用例子:

dataset_dir = "./cifar10"
download_and_convert_cifar10(dataset_dir)

以上函数download_and_convert_cifar10首先会检查本地是否已存在CIFAR-10数据集文件夹,如果不存在则创建文件夹。然后,它会从指定的URL下载tar.gz文件,并将其解压到指定的文件夹中。接下来,它将读取解压后的数据文件,并对图像数据进行转换和整理。最后,它将处理后的数据保存为.npy文件。

使用例子中,首先指定数据集保存的文件夹路径dataset_dir,然后调用download_and_convert_cifar10函数即可开始下载和转换CIFAR-10数据集。

注意:该函数的实现假定了CIFAR-10数据集的结构,并使用Python标准库的tarfilepickle模块进行解压和读取数据文件。