1.可视化train,test的loss acc
1.1 案例:交通指示牌识别案例-history数组
代码地址:E:\项目例程\猫狗分类\迁移学习\猫狗_resnet18_2 \猫狗分类_迁移学习可视化
导入库
from collections import defaultdict
训练函数中构建一个默认value为list的字典
history = defaultdict(list) # 构建一个默认value为list的字典
训练函数中保存train_loss,train_acc,test_loss,test_acc结果
history['train_acc'].append(train_accuracy)
history['train_loss'].append(train_loss)
history['val_acc'].append(val_accuracy)
history['val_loss'].append(val_loss)
训练函数返回
return model, history
训练模型调用时
# 调用训练函数训练
model_conv, history = train_model(
model_conv,
criterion,
optimizer_conv,
exp_lr_scheduler,
num_epochs=30
)
绘制函数 两张图,每个图两个曲线,写法固定
注意:若运行不出图,则加plt.show()
# 绘制 loss, acc 写法固定:两张表
def plot_training_history(history):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
ax1.plot(history['train_loss'], label='train loss')
ax1.plot(history['val_loss'], label='val loss')
ax1.set_ylim([-0.05, 1.05])
ax1.legend()
ax1.set_ylabel('Loss')
ax1.set_xlabel('Epoch')
ax2.plot(history['train_acc'], label='train acc')
ax2.plot(history['val_acc'], label='val acc')
ax2.set_ylim([-0.05, 1.05])
ax2.legend()
ax2.set_ylabel('Accuracy')
ax2.set_xlabel('Epoch')
fig.suptitle('Training History')
plt.show()
plot_training_history(history)
结果曲线展示:
1.2 一张图两条曲线loss 写法总结
案例:gan-手写数字 画G_loss,D_loss
代码位置 :E:\项目例程\GNN\手写数字\3_可视化
步骤:
step1 定义数组
G_losses = []
D_losses = []
1
2
step2 添加数据
G_losses.append(sum_loss.item())
D_losses.append(d_loss.item())
1
2
step3 画图
x=[i for i in range(len(G_losses))]
figure = plt.figure(figsize=(20, 8), dpi=80)
plt.plot(x,G_losses,label='G_losses')
plt.plot(x,D_losses,label='D_losses')
plt.xlabel("iterations",fontsize=15)
plt.ylabel("loss",fontsize=15)
plt.legend()
plt.grid()
plt.show()
效果
代码地址:E:\项目例程\猫狗分类\迁移学习\猫狗_resnet18_2 \猫狗分类_迁移学习可视化
# 测试结果可视化函数
def visualize_model(model):
model.eval()
with torch.no_grad():
inputs, labels = next(iter(dataloaders['val']))
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
preds = outputs.argmax(1)
plt.figure(figsize=(9, 9))
for i in range(inputs.size(0)):
plt.subplot(4,4,i+1) #根据batch_size修改
plt.axis('off')
plt.title(f'pred: {class_names[preds[i]]}|true: {class_names[labels[i]]}')
im = no_normalize(inputs[i].cpu())
plt.imshow(im)
plt.savefig('train.jpg')
plt.show()
- 调用函数
# 测试结果可视化
visualize_model(model_conv)效果展示:
3 loss曲线生成及保存
代码参考:E:\项目例程\GNN\DCGAN\02DCGAN_oxford17
DCGAN生成鲜花figure_save_path="./figures/"
plt.figure(1,figsize=(8, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(self.G_loss_list[::10], label="G")
plt.plot(self.D_loss_list[::10], label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.axhline(y=0, label="0", c="g") # asymptote
plt.legend()
plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'loss.jpg', bbox_inches='tight')4GAN -真实图片与生成图片
代码参考:E:\项目例程\GNN\DCGAN\02DCGAN_oxford17
(1)真实图片与生成图片对比
plt.figure(4,figsize=(8, 4))
# Plot the real images
plt.subplot(1, 2, 1)
plt.axis("off")
plt.title("Real Images")
real = next(iter(self.dataloader)) # real[0]image,real[1]label
plt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))
# Load the Best Generative Model
# self.G.load_state_dict(
# torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))
self.G.eval()
# Generate the Fake Images
with torch.no_grad():
fake = self.G(self.fixed_noise).cpu()
# Plot the fake images
plt.subplot(1, 2, 2)
plt.axis("off")
plt.title("Fake Images")
fake = utils.make_grid(fake[:self.num_showimage] * 0.5 + 0.5, nrow=10)
plt.imshow(fake.permute(1, 2, 0))
# Save the comparation result
plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'result.jpg', bbox_inches='tight')
plt.show()(2)生成图片的gif
import matplotlib.pyplot as plt
import matplotlib.animation as animation
fig = plt.figure(3,figsize=(5, 5))
plt.axis("off")
ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in self.img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
#HTML(ani.to_jshtml())
# ani.to_html5_video()
ani.save(self.figure_save_path + str(num_epochs) + 'epochs_' + 'generation.gif')
————————————————
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/zhe470719/article/details/120610934
文章
10.5W+人气
19粉丝
1关注
©Copyrights 2016-2022 杭州易知微科技有限公司 浙ICP备2021017017号-3 浙公网安备33011002011932号
互联网信息服务业务 合字B2-20220090