对以下代码进行修改使其有更多的评判标准如正确率、F1、混淆矩阵等等给出修改后的代码# 调用resnet50模型paddlevisionset_image_backendcv2model = paddlevisionmodelsresnet50pretrained=True num_classes=12# 定义数据迭代器train_dataloader = DataLoadertrain_data
调用resnet50模型
paddle.vision.set_image_backend('cv2') model = paddle.vision.models.resnet50(pretrained=True, num_classes=12)
定义数据迭代器
train_dataloader = DataLoader(train_data, batch_size=256, shuffle=True, drop_last=False) valid_dataloader = DataLoader(valid_data, batch_size=256, shuffle=False, drop_last=False)
定义优化器
opt = paddle.optimizer.Adam(learning_rate=1e-4, parameters=model.parameters(), weight_decay=paddle.regularizer.L2Decay(1e-4))
定义损失函数
loss_fn = paddle.nn.CrossEntropyLoss()
设置gpu环境
paddle.set_device('gpu:0')
整体训练流程
for epoch_id in range(15): model.train() train_loss = [] for batch_id, data in enumerate(train_dataloader()): # 读取数据 features, labels = data features = paddle.to_tensor(features) labels = paddle.to_tensor(labels)
# 前向传播
predicts = model(features)
# 损失计算
loss = loss_fn(predicts, labels)
train_loss.append(loss.numpy())
# 反向传播
avg_loss = paddle.mean(loss)
avg_loss.backward()
# 更新
opt.step()
# 清零梯度
opt.clear_grad()
# 打印损失
if batch_id % 2 == 0:
print('epoch_id:{}, batch_id:{}, loss:{}'.format(epoch_id, batch_id, avg_loss.numpy()))
# 测试模型
model.eval()
i = 0
acc = 0
y_true = []
y_pred = []
for data in valid_dataloader():
image, label = data
image = paddle.to_tensor(image)
label = paddle.to_tensor(label)
# 预测
pre = model(image)
pre = paddle.argmax(pre, axis=-1)
# 统计指标
y_true.extend(label.numpy())
y_pred.extend(pre.numpy())
i += len(label)
acc += (pre == label).sum().numpy()
# 计算指标
accuracy = acc / i
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
confusion_matrix = confusion_matrix(y_true, y_pred)
print('epoch_id: {}, train_loss: {}, accuracy: {}, precision: {}, recall: {}, f1: {}'.format(epoch_id, np.mean(train_loss), accuracy, precision, recall, f1))
print('confusion matrix:\n', confusion_matrix)
# 保存最优模型
if accuracy > best_accuracy:
best_accuracy = accuracy
paddle.save(model.state_dict(), 'best_model.pdparams'
原文地址: https://cveoy.top/t/topic/hljo 著作权归作者所有。请勿转载和采集!