提问 发文

深度学习:可视化-结果loss acc可视化及测试数据显示

微微菌

| 2024-03-12 10:28 281 0 0

1.可视化train,test的loss acc
1.1 案例:交通指示牌识别案例-history数组
代码地址:E:\项目例程\猫狗分类\迁移学习\猫狗_resnet18_2 \猫狗分类_迁移学习可视化

导入库
from collections import defaultdict

训练函数中构建一个默认value为list的字典
history = defaultdict(list) # 构建一个默认value为list的字典

训练函数中保存train_loss,train_acc,test_loss,test_acc结果
history['train_acc'].append(train_accuracy)
history['train_loss'].append(train_loss)
history['val_acc'].append(val_accuracy)
history['val_loss'].append(val_loss)

训练函数返回
return model, history

训练模型调用时
# 调用训练函数训练
model_conv, history = train_model(
model_conv,
criterion,
optimizer_conv,
exp_lr_scheduler,
num_epochs=30
)

绘制函数 两张图,每个图两个曲线,写法固定
注意:若运行不出图,则加plt.show()
# 绘制 loss, acc 写法固定:两张表
def plot_training_history(history):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
ax1.plot(history['train_loss'], label='train loss')
ax1.plot(history['val_loss'], label='val loss')

ax1.set_ylim([-0.05, 1.05])
ax1.legend()
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')

ax2.plot(history['train_acc'], label='train acc')
ax2.plot(history['val_acc'], label='val acc')

ax2.set_ylim([-0.05, 1.05])
ax2.legend()
ax2.set_ylabel('Accuracy')
ax2.set_xlabel('Epoch')

fig.suptitle('Training History')
plt.show()

plot_training_history(history)


结果曲线展示:
猫狗分类迁移学习

1.2 一张图两条曲线loss 写法总结
案例:gan-手写数字 画G_loss,D_loss
代码位置 :E:\项目例程\GNN\手写数字\3_可视化
步骤:

step1 定义数组
G_losses = []
D_losses = []
1
2
step2 添加数据
G_losses.append(sum_loss.item())
D_losses.append(d_loss.item())
1
2
step3 画图
x=[i for i in range(len(G_losses))]
figure = plt.figure(figsize=(20, 8), dpi=80)
plt.plot(x,G_losses,label='G_losses')
plt.plot(x,D_losses,label='D_losses')
plt.xlabel("iterations",fontsize=15)
plt.ylabel("loss",fontsize=15)
plt.legend()
plt.grid()
plt.show()

效果

在这里插入图片描述

2.可视化测试结果

代码地址:E:\项目例程\猫狗分类\迁移学习\猫狗_resnet18_2 \猫狗分类_迁移学习可视化

  1. 定义结果可视化函数
    注意:plt.subplot(4,4,i+1) 应根据batch_size修改,本案例batch_size=16.故为4x4

# 测试结果可视化函数
def visualize_model(model):
model.eval()
with torch.no_grad():
inputs, labels = next(iter(dataloaders['val']))
inputs = inputs.to(device)
labels = labels.to(device)

outputs = model(inputs)
preds = outputs.argmax(1)

plt.figure(figsize=(9, 9))
for i in range(inputs.size(0)):
plt.subplot(4,4,i+1) #根据batch_size修改
plt.axis('off')
plt.title(f'pred: {class_names[preds[i]]}|true: {class_names[labels[i]]}')
im = no_normalize(inputs[i].cpu())
plt.imshow(im)
plt.savefig('train.jpg')
plt.show()

  1. 调用函数
  2. # 测试结果可视化
    visualize_model(model_conv)

  3. 效果展示:在这里插入图片描述

  4. 3 loss曲线生成及保存

    代码参考:E:\项目例程\GNN\DCGAN\02DCGAN_oxford17
    DCGAN生成鲜花

  5. figure_save_path="./figures/"
    plt.figure(1,figsize=(8, 4))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(self.G_loss_list[::10], label="G")
    plt.plot(self.D_loss_list[::10], label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.axhline(y=0, label="0", c="g") # asymptote
    plt.legend()
    plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'loss.jpg', bbox_inches='tight')

  6. 4GAN -真实图片与生成图片

    代码参考:E:\项目例程\GNN\DCGAN\02DCGAN_oxford17

    (1)真实图片与生成图片对比

  7. plt.figure(4,figsize=(8, 4))
    # Plot the real images
    plt.subplot(1, 2, 1)
    plt.axis("off")
    plt.title("Real Images")
    real = next(iter(self.dataloader)) # real[0]image,real[1]label
    plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))

    # Load the Best Generative Model
    # self.G.load_state_dict(
    # torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))
    self.G.eval()
    # Generate the Fake Images
    with torch.no_grad():
    fake = self.G(self.fixed_noise).cpu()
    # Plot the fake images
    plt.subplot(1, 2, 2)
    plt.axis("off")
    plt.title("Fake Images")
    fake = utils.make_grid(fake[:self.num_showimage] * 0.5 + 0.5, nrow=10)
    plt.imshow(fake.permute(1, 2, 0))

    # Save the comparation result
    plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'result.jpg', bbox_inches='tight')
    plt.show()

  8. (2)生成图片的gif

  9. import matplotlib.pyplot as plt
    import matplotlib.animation as animation

    fig = plt.figure(3,figsize=(5, 5))
    plt.axis("off")
    ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in self.img_list]
    ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
    #HTML(ani.to_jshtml())
    # ani.to_html5_video()
    ani.save(self.figure_save_path + str(num_epochs) + 'epochs_' + 'generation.gif')

————————————————

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/zhe470719/article/details/120610934

收藏 0
分享
分享方式
微信

评论

游客

全部 0条评论

10603

文章

10.5W+

人气

19

粉丝

1

关注

官方媒体

轻松设计高效搭建,减少3倍设计改稿与开发运维工作量

开始免费试用 预约演示

扫一扫关注公众号 扫一扫联系客服

©Copyrights 2016-2022 杭州易知微科技有限公司 浙ICP备2021017017号-3 浙公网安备33011002011932号

互联网信息服务业务 合字B2-20220090

400-8505-905 复制
免费试用
微信社区
易知微-数据可视化
微信扫一扫入群