李宏毅机器学习课程HW03代码解释

article/2025/9/24 20:58:09

作业3任务是将图片进行分类

从官网上下载数据到data文件里面。此外,将代码分为三个模块,分别是dataset,model以及main。

 一、dataset模块

此模块作用是读取图片数据。

重要函数:os.path.join(path,x)    将path和x路径组合在一起

#导入库
import os
from PIL import Image
from torch.utils.data import Dataset#定义读入数据集
class FoodDataset(Dataset):def __init__(self, path, tfm, files=None):super(FoodDataset).__init__()self.path = path  # 图片文件路径# 找到文件夹里的每一个图片文件,os.path.join()是将多个路径组合在一起,path是文件夹路径,x是文件夹下每一个图片的局部路径self.files = sorted([os.path.join(path, x) for x in os.listdir(path) if x.endswith(".jpg")])  if files != None:self.files = filesprint(f"One {path} sample", self.files[0])self.transform = tfm  # 转换图片大小,在main函数中有定义#  返回文件大小def __len__(self):return len(self.files)def __getitem__(self, idx):fname = self.files[idx]im = Image.open(fname)  # 打开图片文件im = self.transform(im)# im = self.data[idx]try:label = int(fname.split("\\")[-1].split("_")[0])  # 找到每一个图片的类别,为了后面计算分类的准确度except:label = -1  # test has no label 倘若没有找到类别,则返回-1return im, label

二、model 函数

model函数的作用是建立模型,在此设置5次CNN卷积后输出512组4*4大小的图片数据,再通过线性层后输出[64,11]大小的二维数据。

from torch import nnclass Classifier(nn.Module):def __init__(self):#  继承父类super(Classifier, self).__init__()# input 維度 [3, 128, 128]self.cnn = nn.Sequential(#  3:输入通道数,64:输出通道数,3:卷积核大小,1:步长,1:填充大小nn.Conv2d(3, 64, 3, 1, 1),  # [64, 128, 128]nn.BatchNorm2d(64),  # 传入数字需和输出通道数相同nn.ReLU(),  # 激活函数nn.MaxPool2d(2, 2, 0),  # [64, 64, 64]#池化层改变图片的宽、高,128/2=64nn.Conv2d(64, 128, 3, 1, 1),  # [128, 64, 64]nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(2, 2, 0),  # [128, 32, 32]nn.Conv2d(128, 256, 3, 1, 1),  # [256, 32, 32]nn.BatchNorm2d(256),nn.ReLU(),nn.MaxPool2d(2, 2, 0),  # [256, 16, 16]nn.Conv2d(256, 512, 3, 1, 1),  # [512, 16, 16]nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2, 2, 0),  # [512, 8, 8]nn.Conv2d(512, 512, 3, 1, 1),  # [512, 8, 8]nn.BatchNorm2d(512),nn.ReLU(),nn.MaxPool2d(2, 2, 0),  # [512, 4, 4]  512组特征)self.fc = nn.Sequential(nn.Linear(512 * 4 * 4, 1024),  # 输入512 * 4 * 4,输出1024大小nn.ReLU(),nn.Linear(1024, 512),nn.ReLU(),nn.Linear(512, 11)  # 11:按要求需要分成11个类)def forward(self, x):out = self.cnn(x)out = out.view(out.size()[0], -1)  # [64,512,4,4]-->[64,512*4*4]return self.fc(out) #[64,11]

三、main函数

导入库,包括model和dataset两部分

import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
from dataset import FoodDataset
from model import Classifier

准备测试集,所有图片要求是相同大小,此处是128*128,也可以制定其他大小 。之后转为tensor格式

test_tfm = transforms.Compose([transforms.Resize((128, 128)),  # 转换测试集的图片大小,设置为相同大小transforms.ToTensor(),
])

准备训练集,训练集还可以做其他变换,例如放大、缩小、对称变换等,此处只是单纯更改了图片大小。

train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)transforms.Resize((128, 128)),训练集图片还可以做其他变化,例如图片对称翻转等等transforms.ToTensor(),
])

各种定义

batch_size = 64  # 每组64张图片
_dataset_dir = "./data/food11"  # 数据路径
train_set = FoodDataset(os.path.join(_dataset_dir,"training"), tfm=train_tfm)  # 读取文件
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)  # 打开文件并根据batch分配
valid_set = FoodDataset(os.path.join(_dataset_dir,"validation"), tfm=test_tfm)  # 读取文件
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)  # 打开文件并根据batch分配n_epochs = 3
patience = 300  # 如果300次仍没有土生,则提前终止
model = Classifier().to(device)
criterion = nn.CrossEntropyLoss()  # 交叉熵  
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003, weight_decay=1e-5)  #指定优化器,也可以用其他优化器,传入模型参数和学习率
stale = 0
best_acc = 0
_exp_name = "sample"

配置使用GPU或者CPU,最好使用GPU

device = "cuda" if torch.cuda.is_available() else "cpu"

开始训练

for epoch in range(n_epochs):# 开始训练model.train()# These are used to record information in training.train_loss = []train_accs = []for batch in tqdm(train_loader):imgs, labels = batch# 选择cpu还是gpulogits = model(imgs.to(device))# 计算交叉熵loss = criterion(logits, labels.to(device))# 清零,否则每迭代一次就会加上前面的数据optimizer.zero_grad()# 反向传播loss.backward()grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)# 更新参数optimizer.step()acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()# Record the loss and accuracy.train_loss.append(loss.item())train_accs.append(acc)train_loss = sum(train_loss) / len(train_loss)train_acc = sum(train_accs) / len(train_accs)print(f"[ Train | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")

开始验证

model.eval()valid_loss = []valid_accs = []for batch in tqdm(valid_loader):# A batch consists of image data and corresponding labels.imgs, labels = batchwith torch.no_grad():logits = model(imgs.to(device))loss = criterion(logits, labels.to(device))acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()valid_loss.append(loss.item())valid_accs.append(acc)# breakvalid_loss = sum(valid_loss) / len(valid_loss)valid_acc = sum(valid_accs) / len(valid_accs)print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")if valid_acc > best_acc:with open(f"./{_exp_name}_log.txt", "a"):print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} -> best")else:with open(f"./{_exp_name}_log.txt", "a"):print(f"[ Valid | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")if valid_acc > best_acc:print(f"Best model found at epoch {epoch}, saving model")torch.save(model.state_dict(), f"{_exp_name}_best.ckpt")  # only save best to prevent output memory exceed errorbest_acc = valid_accstale = 0else:stale += 1if stale > patience:print(f"No improvment {patience} consecutive epochs, early stopping")break

笔者初学机器学习,还有很多不懂的地方,如有错误恳请各位读者不吝告知,笔者不胜感激。


http://chatgpt.dhexx.cn/article/RFFZmxYB.shtml

相关文章

接入友盟厂商push通道遇到的坑

目录 调试友盟Push问题的检查清单 客户端、服务端数据协议 客户端接入方式 小米厂商通道 华为厂商通道 魅族厂商通道 VIVO厂商通道 OPPO厂商通道 支持桌面角标的厂商 吐槽一下集成友盟厂商通道时发现的问题 调试友盟Push问题的检查清单 过滤UmengPushAgent开头的日志…

Android集成友盟消息推送SDK

消息推送SDK快速集成: 申请AppKey ——> 接入Push SDK ——> 基础接口引入 ——> 消息推送测试 ——> 接入完成 1.申请AppKey 2.接入Push SDK 1)、加入依赖 //友盟push相关依赖(必须)implementationcom.umeng.umsdk:push:6.1.0impleme…

机器学习之手写决策树以及sklearn中的决策树及其可视化

文章目录 决策树理论部分基本算法划分选择信息熵 信息增益信息增益率基尼系数基尼指数 决策树代码实现参考 决策树理论部分 决策树的思路很简单,就是从数据集中挑选一个特征,然后进行分类。 基本算法 从伪代码中可以看出,分三种情况考虑&…

android使用友盟推送注册失败获取不到token accs bindapp error!

使用友盟推送注册失败获取不到token public void onFailure(String s, String s1)的值分别是“-9”和“accs bindapp error!”或者s的值为-11.都是同一个问题 就是主工程(除友盟PushSDK 其他的module均看成为主工程)so目录与PushSDK下的so目录不一致…

同时集成阿里云旺与友盟推送,初始化失败s:-11,s1:accs bindapp error!的解决办法

在应用中需要同时集成聊天和推送功能,聊天选用阿里的sdk(百川云旺),推送选用友盟的pushSDK。 这时候悲剧就出现了,注册友盟的时候报错。 I/com.umeng.message.PushAgent: register-->onFailure-->s:-11,s1:accs …

关于友盟s=-11;s1=accs bindapp error!的解决处理

项目使用了友盟推送之后,在部分手机上出现accs bindapp error,错误码-11的问题,为什么会出现这个问题呢,网上查找了很久,友盟给出的解释是so文件不正确。 具体链接:http://bbs.umeng.com/thread-23018-1-1…

友盟register failed: -11 accs bindapp error!

下载官方Demo后,替换自己的id包名后出现 register failed: -11 accs bindapp error! 经过一番搜索之后,都是说这二种原因 1、检查appkey和secret key是否配置正确,如果正确无误,请看步骤2。2、so文件配置有误,需重新配置: Pus…

阿里无线11.11 | 手机淘宝移动端接入网关基础架构演进之路

移动网络优化是超级App永恒的话题,对于无线电商来说更为重要,网络请求体验跟用户的购买行为息息相关,手机淘宝从过去的HTTP API网关,到2014年升级支持SPDY,2015年双十一自研高性能、全双工、安全的ACCS(阿里…

VS2015 realease模式下调试

一、将项目属性设置为Release,生成--->配置管理器: 二、按AltF7,弹出属性页进行设置:

AndroidStudio如何打包生成realease版本的arr包,并上传到Nexus搭建的maven仓库,供项目远程依赖(二)

AndroidStudio如何打包生成realease版本的arr包,并上传到Nexus搭建的maven仓库,供项目远程依赖(二) AndroidStudio如何打包生成realease版本的arr包,并上传到Nexus搭建的maven仓库,供项目远程依赖&#xff…

AndroidStudio如何打包生成realease版本的arr包,并上传到Nexus搭建的maven仓库,供项目远程依赖(一)

AndroidStudio如何打包生成realease版本的arr包,并上传到Nexus搭建的maven仓库,供项目远程依赖(一) 背景: 公司之前在eclipse上做开发,写了很多library库供项目依赖使用,现在转AS上了,并用Nexu…

QT debug 功能正常 realease和windeplayqt工具打包部分功能无法使用或者不正常

目录 说明开发环境错误说明结论 说明 在项目的开发中,一般程序员都是使用debug版本进行程序的编写和调试,习惯好一些的程序员可能会天天用realease跑一遍自己写的程序是否正常,但是很多程序员可能都不会这么做,直到程序功能完成时…

Python OpenCV10:OpenCV 视频基本操作

1. 读视频 1.1 获取视频对象 要在 OpenCV 中获取视频,需要创建一个 VideoCapture 对象并指定要读取的视频文件。 cv.VideoCapture(filepath) 参数: filepath 视频文件路径 返回值: cap 读取视频的对象 1.2 获取视频属性 cap.get(propId) 获…

Renderers

渲染器 (Renderers) 在将 TemplateResponse 实例返回给客户端之前,必须渲染它。渲染过程采用模板和上下文的中间表示,并将其转换为可以提供给客户端的最终字节流。—— Django 文档 REST framework 包含许多内置的渲染器 (Renderer) 类,允许…

python调用opencv实现视频读写

文章目录 一、从文件中读取视频并播放1.1 基本API讲解1.2 python实现 二、保存视频2.1 基本API讲解2.1 python实现范例 一、从文件中读取视频并播放 1.1 基本API讲解 在OpenCV中我们要获取一个视频,需要创建一个VideoCapture对象,指定你要读取的视频文…

记一次有趣的debug,VS编译器上Debug和Realease的差异

之前自己写过一个imageread的函数,用了好久一直没问题。最近两天,同事让我realease一个项目给他,其中就包含了我自己写的imageread函数。 我的函数就长这样,不包含公司的code,不算泄密哈。 在realse之前,我…

C++语言基础篇

✅作者简介:CSDN内容合伙人,全栈领域新星创作者,阿里云专家博主,华为云云享专家博主,掘金后端评审团成员 💕前言: 学长出的这一系列专栏适合有⼀点 C 基础&#xff0c…

PCL12.1 Realease 附加依赖项

PCL12.1 Realease 附加依赖项 libboost_atomic-vc142-mt-g-x64-1_78.lib libboost_bzip2-vc142-mt-g-x64-1_78.lib libboost_chrono-vc142-mt-g-x64-1_78.lib libboost_container-vc142-mt-g-x64-1_78.lib libboost_context-vc142-mt-g-x64-1_78.lib libboost_contract-vc142-…

Vue强制刷新页面重新加载数据方法

业务场景 在管理后台执行完增删改查的操作之后,需要重新加载页面刷新数据以便页面数据的更新 实现原理 就是通过控制router-view 的显示与隐藏,来重渲染路由区域,重而达到页面刷新的效果,show -> flase -> show 具体代码…

Linux 重新加载 nginx 配置命令

1. 查找 nginx 位置 whereis nginx2. 进入 nginx 目录 cd /usr/local/nginx/sbin3. 检查 nginx 配置文件是否正确 ./nginx -t 4. 重新加载配置文件 ./nginx -s reload