对以下代码进行修改使其有更多的评判标准如正确率、F1、混淆矩阵等等给出修改后的代码# 调用resnet50模型paddlevisionset_image_backendcv2model = paddlevisionmodelsresnet50pretrained=True num_classes=12# 定义数据迭代器train_dataloader = DataLoadertrain_data

调用resnet50模型

paddle.vision.set_image_backend('cv2') model = paddle.vision.models.resnet50(pretrained=True, num_classes=12)

定义数据迭代器

train_dataloader = DataLoader(train_data, batch_size=256, shuffle=True, drop_last=False) valid_dataloader = DataLoader(valid_data, batch_size=256, shuffle=False, drop_last=False)

定义优化器

opt = paddle.optimizer.Adam(learning_rate=1e-4, parameters=model.parameters(), weight_decay=paddle.regularizer.L2Decay(1e-4))

定义损失函数

loss_fn = paddle.nn.CrossEntropyLoss()

设置gpu环境

paddle.set_device('gpu:0')

整体训练流程

for epoch_id in range(15): model.train() train_loss = [] for batch_id, data in enumerate(train_dataloader()): # 读取数据 features, labels = data features = paddle.to_tensor(features) labels = paddle.to_tensor(labels)

    # 前向传播
    predicts = model(features)

    # 损失计算
    loss = loss_fn(predicts, labels)
    train_loss.append(loss.numpy())

    # 反向传播
    avg_loss = paddle.mean(loss)
    avg_loss.backward()

    # 更新
    opt.step()

    # 清零梯度
    opt.clear_grad()

    # 打印损失
    if batch_id % 2 == 0:
        print('epoch_id:{}, batch_id:{}, loss:{}'.format(epoch_id, batch_id, avg_loss.numpy()))

# 测试模型
model.eval()
i = 0
acc = 0
y_true = []
y_pred = []
for data in valid_dataloader():
    image, label = data
    image = paddle.to_tensor(image)
    label = paddle.to_tensor(label)

    # 预测
    pre = model(image)
    pre = paddle.argmax(pre, axis=-1)

    # 统计指标
    y_true.extend(label.numpy())
    y_pred.extend(pre.numpy())
    i += len(label)
    acc += (pre == label).sum().numpy()

# 计算指标
accuracy = acc / i
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
confusion_matrix = confusion_matrix(y_true, y_pred)

print('epoch_id: {}, train_loss: {}, accuracy: {}, precision: {}, recall: {}, f1: {}'.format(epoch_id, np.mean(train_loss), accuracy, precision, recall, f1))
print('confusion matrix:\n', confusion_matrix)

# 保存最优模型
if accuracy > best_accuracy:
    best_accuracy = accuracy
    paddle.save(model.state_dict(), 'best_model.pdparams'

标签: 财经


原文地址: https://cveoy.top/t/topic/hljo 著作权归作者所有。请勿转载和采集!