欢迎访问宙启技术站
智能推送

在Python中使用Wandbwatch()函数来记录您的数据科学实验

发布时间:2024-01-10 16:51:59

在Python中使用Wandb.watch()函数来记录数据科学实验,可以方便地追踪和记录实验的指标和结果。以下是一个使用Wandb.watch()的例子:

import wandb
import torch
import torch.nn as nn
import torch.optim as optim

# 初始化Wandb
wandb.init(project='my-project', entity='my-entity')

# 加载数据集
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# 定义模型
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 启用Wandb.watch()
wandb.watch(model, criterion, log='gradients', log_freq=100)

# 训练模型
for epoch in range(10):
    running_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    for i, (inputs, labels) in enumerate(train_loader):
        # 将输入和标签数据加载到GPU上(如果可用)
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # 重置梯度
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(inputs)
        
        # 计算损失
        loss = criterion(outputs, labels)
        
        # 反向传播和优化
        loss.backward()
        optimizer.step()
        
        # 记录损失
        running_loss += loss.item()
        
        # 计算准确率
        _, predicted = torch.max(outputs.data, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)
        
        # 每100个batch打印一次损失和准确率
        if (i+1) % 100 == 0:
            print('[Epoch: %d, Batch: %5d] Loss: %.3f, Accuracy: %.2f %%'
                  % (epoch+1, i+1, running_loss/100, total_correct/total_samples * 100))
            
            # 使用Wandb记录损失和准确率
            wandb.log({'loss': running_loss/100, 'accuracy': total_correct/total_samples})
            
            running_loss = 0.0
            total_correct = 0
            total_samples = 0
            
# 结束Wandb记录
wandb.finish()

在这个例子中,我们使用了Wandb库来跟踪和记录训练过程中的损失和准确率。首先,我们初始化了Wandb,并指定了项目名称和实体名称。然后,我们加载数据集、定义模型、损失函数和优化器。

然后,我们使用Wandb.watch()函数来启用模型的梯度记录。我们将模型、损失函数和记录类型(这里是梯度)作为参数传递给Wandb.watch()函数。

在训练模型的循环中,我们将输入数据加载到GPU上(如果可用),定义优化器的梯度为零,执行前向传播和反向传播,记录损失和准确率,并在每100个batch打印损失和准确率。

在每次记录损失和准确率时,我们使用Wandb.log()函数将损失和准确率作为字典参数传递给它,并将它们记录下来。

最后,我们使用wandb.finish()函数结束Wandb记录。

使用Wandb.watch()函数可以轻松地记录数据科学实验的各种指标和结果,方便后续分析和可视化。