TF2.0模型训练

article/2025/10/29 18:18:36

TF2.0模型训练

    • 概述
    • 数据集介绍
    • 1、通过fit方法训练模型
      • 准备数据
      • 创建模型
      • 编译模型
      • 训练模型
    • 2、通过fit_generator方法训练模型
      • 构建生成器
      • 创建模型
      • 编译模型
      • 训练模型
    • 3、自定义训练
      • 准备数据
      • 创建模型
      • 定义损失函数及优化器
      • 训练模型
    • 下一篇
      • TF2.0模型保存

概述

这是TF2.0入门笔记【TF2.0模型创建、TF2.0模型训练、TF2.0模型保存】中第二篇【TF2.0模型训练】,本篇将介绍模型的训练

  • 这里我会介绍用以下三种方法去演示模型训练(仅用图像分类举例)。
    • 1、通过fit方法训练模型
    • 2、通过fit_generator方法训练模型
    • 3、自定义训练

数据集介绍

该数据集为tf_flowers,数据集为五种花朵数据集,分别为雏菊(daisy),郁金香(tulips),向日葵(sunflowers),玫瑰(roses),蒲公英(dandelion)。

import pathlib
from tensorflow.keras.utils import get_filedata_root = get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',fname='flower_photos', untar=True, cache_dir='./', cache_subdir='datasets')data_path = pathlib.Path(data_root)print("data_path:",data_path)
for item in data_path.iterdir():print(item)

运行输出:

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 1s 0us/step
data_path: datasets/flower_photos
datasets/flower_photos/daisy
datasets/flower_photos/tulips
datasets/flower_photos/sunflowers
datasets/flower_photos/roses
datasets/flower_photos/LICENSE.txt
datasets/flower_photos/dandelion

1、通过fit方法训练模型

第一种:通过fit方法训练模型
步骤:
1、准备数据
2、创建模型
3、编译模型
4、训练模型

准备数据

获取所有花朵图片的路径

import randomall_image_paths = list(data_path.glob('*/*'))#获取子目录下所有文件
all_image_paths = [str(path) for path in all_image_paths]#把<class 'pathlib.WindowsPath'>转换成str类型
random.shuffle(all_image_paths)#打乱顺序
print(all_image_paths[0])
#输出:
#datasets\flower_photos\roses\3422228549_f147d6e642.jpg

获取所有花朵的标签

label_names = []
for item in data_path.glob('*/'):#获取目录下所有文件if item.is_dir():#判断是否是文件夹label_names.append(item.name)label_names.sort()#整理一下label_name_index = dict((name, index) for index, name in enumerate(label_names))
print(label_names)
print(label_name_index)
#输出:
#['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
#{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}#获取文件的目录,得到的目录根据字典映射出标签
all_image_labels = [label_name_index[pathlib.Path(path).parent.name]for path in all_image_paths]
print(all_image_labels[0])
#输出:
#2

定义一些变量

input_shape=(192,192,3)
classes    =len(label_names)
batch_size =64
epochs     =10
steps_per_epoch=len(all_image_paths)//batch_size

现在我们有了所有图片的路径all_image_paths,以及标签all_image_labels
现在我们写一个函数load_preprocess_image去加载并处理图片,make_image_label_datasets这个函数用来将图片和标签整合到一起

import tensorflow as tfdef load_preprocess_image(image_paths):image = tf.io.read_file(image_paths)            #img_stringimage = tf.image.decode_jpeg(image, channels=3) #img_tensorimage = tf.image.resize(image, [192,192])       #img_resizeimage = image/255.0                             #img_normal    return imagedef make_image_label_datasets(image_paths, image_labels):return load_preprocess_image(image_paths), image_labels

制作数据集

datasets = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
image_label_datasets = datasets.map(make_image_label_datasets)

取出两个图片及标签进行可视化

import matplotlib.pyplot as plt
import numpy as npplt.figure(figsize=(6,6))
n=0
for img,leb in image_label_datasets.take(2):n=n+1image=np.array(img.numpy()*255.0).astype("uint8")plt.subplot(1,2,n)plt.title('lebel:'+str(leb.numpy()))plt.imshow(image)plt.show()

创建模型

考虑到是入门教学,这里不进行迁移学习,我们来创建一个类似VGG系列的模型,这里的创建方法用到的是上一节所说的方法二

from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Input
def my_model(input_shape, classes):inputs=Input(input_shape)# Block 1x = Conv2D(64,  (3, 3), activation='relu', padding='same')(inputs)x = MaxPooling2D((2, 2), strides=(2, 2))(x)# Block 2x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)# Block 3x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)x = Conv2D(256, (3, 3), activation='relu', padding='same')(x)x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)x = Flatten()(x)x = Dense(512, activation='relu')(x)x = Dense(256, activation='relu')(x)x = Dense(classes, activation='softmax')(x)model = Model(inputs, x)return model
model = my_model(input_shape, classes)

编译模型

优化器用的是Adam优化器,因为标签是0,1,2,…而不是one-hot 编码[1, 0, 0,…], [0, 1 0,…], [0, 0, 1,…]。所以损失函数用sparse_categorical_crossentropy而不是categorical_crossentropy

from tensorflow.keras.optimizers import Adam
opt=Adam()
model.compile(optimizer=opt,loss='sparse_categorical_crossentropy',metrics=["accuracy"])

训练模型

image_label_datasets = image_label_datasets.shuffle(buffer_size=len(all_image_paths))
image_label_datasets = image_label_datasets.repeat()
image_label_datasets = image_label_datasets.batch(batch_size)
# 当模型在训练的时候,`prefetch` 使数据集在后台取得 batch,也就是流水线进行。
image_label_datasets = image_label_datasets.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)model.fit(image_label_datasets, epochs=epochs, steps_per_epoch=steps_per_epoch)

运行输出:

Train for 57 steps
Epoch 1/10
57/57 [==============================] - 32s 569ms/step - loss: 1.4404 - accuracy: 0.4046
Epoch 2/10
57/57 [==============================] - 21s 377ms/step - loss: 1.0551 - accuracy: 0.5762
Epoch 3/10
57/57 [==============================] - 21s 368ms/step - loss: 0.9082 - accuracy: 0.6417
Epoch 4/10
57/57 [==============================] - 21s 363ms/step - loss: 0.7993 - accuracy: 0.6853
Epoch 5/10
57/57 [==============================] - 21s 360ms/step - loss: 0.6667 - accuracy: 0.7410
Epoch 6/10
57/57 [==============================] - 19s 337ms/step - loss: 0.4645 - accuracy: 0.8331
Epoch 7/10
57/57 [==============================] - 17s 299ms/step - loss: 0.3154 - accuracy: 0.8890
Epoch 8/10
57/57 [==============================] - 15s 257ms/step - loss: 0.2015 - accuracy: 0.9328
Epoch 9/10
57/57 [==============================] - 14s 254ms/step - loss: 0.1692 - accuracy: 0.9487
Epoch 10/10
57/57 [==============================] - 14s 253ms/step - loss: 0.1205 - accuracy: 0.9638
<tensorflow.python.keras.callbacks.History at 0x7fb7d4467f28>

2、通过fit_generator方法训练模型

第二种:通过fit_generator方法训练模型
通过实践方法一,如果你是新手的话,你一定感受到了制作数据是一件很麻烦的事。
接下来我们将介绍使用ImageDataGenerator类及其flow_from_directory方法进行便捷地读取数据进行训练
步骤:
1、构建生成器
2、创建模型
3、编译模型
4、训练模型

构建生成器

ImageDataGenerator类及其flow_from_directory方法是有很多参数可用的,详情可以点击去官网看手册(如果你是新手的话,看手册会很频繁哦)

from tensorflow.keras.preprocessing.image import ImageDataGenerator
import osdef make_Gen(data_path):train_dataNums = 0train_gen  = ImageDataGenerator(rescale=1/255.0)#只用归一化for root, dirs, files in os.walk(data_path):for file in files:train_dataNums += 1return train_gen, train_dataNumsdata_path='datasets/flower_photos'
train_gen, train_dataNums = make_Gen(data_path)
train_generator = train_gen.flow_from_directory(directory   = data_path,target_size = (192,192),batch_size  = batch_size,class_mode  = 'categorical')#class_mode选'categorical',这时它会自动帮我们处理图片和标签print(train_generator.class_indices)
#输出:
#Found 3670 images belonging to 5 classes.
#{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

创建模型

使用方法一创建的模型

model = my_model((192,192,3), 5)

编译模型

损失函数用categorical_crossentropy

from tensorflow.keras.optimizers import Adam
opt=Adam()
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

训练模型

model.fit_generator(train_generator,steps_per_epoch =train_dataNums//batch_size,epochs=epochs)

运行输出:
根据结果可以看到,这种方法速度比较慢,相同轮数下收敛也没这么快,原因可以自己思考一下哦。

Epoch 1/10
57/57 [==============================] - 27s 470ms/step - loss: 1.7245 - accuracy: 0.3236
Epoch 2/10
57/57 [==============================] - 24s 428ms/step - loss: 1.3447 - accuracy: 0.4010
Epoch 3/10
57/57 [==============================] - 24s 421ms/step - loss: 1.2717 - accuracy: 0.4323
Epoch 4/10
57/57 [==============================] - 24s 425ms/step - loss: 1.2436 - accuracy: 0.4507
Epoch 5/10
57/57 [==============================] - 24s 418ms/step - loss: 1.1845 - accuracy: 0.4907
Epoch 6/10
57/57 [==============================] - 24s 422ms/step - loss: 1.0594 - accuracy: 0.5657
Epoch 7/10
57/57 [==============================] - 24s 427ms/step - loss: 0.8960 - accuracy: 0.6521
Epoch 8/10
57/57 [==============================] - 24s 417ms/step - loss: 0.6565 - accuracy: 0.7570
Epoch 9/10
57/57 [==============================] - 24s 419ms/step - loss: 0.4401 - accuracy: 0.8464
Epoch 10/10
57/57 [==============================] - 24s 418ms/step - loss: 0.2753 - accuracy: 0.9121
<tensorflow.python.keras.callbacks.History at 0x7fb77d69ae10>

3、自定义训练

第三种:自定义训练
有时候,一些繁杂的任务,或者说你想根据自己的想法进行更多的选择以及自定义,这时你就可以进行自定义训练

准备数据

见方法一:准备数据

创建模型

见方法一:创建模型

定义损失函数及优化器

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy#模型最后的输出是通过了softmax激活函数,所以这里from_logits=False
my_loss=SparseCategoricalCrossentropy(from_logits=False)
my_opt =Adam()
def loss(real, pred):loss=my_loss(real, pred)return loss

@tf.function带上这句可以加速,train_per_step该函数求每一个steploss梯度更新变量
这里用到了tensorflow.GradientTape类,GradientTape会监控可训练变量,详情可查看文档,也可以查看我的这篇文章TF2.0 GradientTape()类讲解

@tf.function
def train_per_step(inputs, targets):with tf.GradientTape() as tape:predicts=model(inputs)#求lossloss_value = loss(real=targets,pred=predicts)#根据损失求梯度gradients=tape.gradient(loss_value, model.trainable_variables)  #把梯度和变量进行绑定grads_and_vars=zip(gradients, model.trainable_variables)  #进行梯度更新my_opt.apply_gradients(grads_and_vars)return loss_value

训练模型

打乱数据集,设定batch_size

epochs     = 10
batch_size = 64
#打乱并设定batch_size
image_label_datasets = image_label_datasets.shuffle(buffer_size=len(all_image_paths))
image_label_datasets = image_label_datasets.batch(batch_size, drop_remainder=True)

开始训练
这里用到了tensorflow.keras.metrics.Mean类以及tensorflow.keras.metrics.SparseCategoricalAccuracy类,它们都有三个Methods(reset_states, result, update_state),详情可以看手册。

import time train_loss_results = []#保存loss值
train_accuracy_results = []#保存accuracy值for epoch in range(epochs):start = time.time()#注意这两行代码是在epochs的for循环里面,#每次循环之后会进行重置(重新赋值),所以不用加reset_states()方法epoch_loss_avg = tf.keras.metrics.Mean()epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()for image, label in image_label_datasets:batch_loss = train_per_step(image, label)#求平均,只要不调用reset_states()方法,之前的值是会累计下来的epoch_loss_avg(batch_loss)epoch_accuracy(label, model(image))#保存loss、accuracy值,可用于可视化train_loss_results.append(epoch_loss_avg.result())train_accuracy_results.append(epoch_accuracy.result())#每一个epoch后打印Loss、Accuracy以及花费的时间print("Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(epoch,epoch_loss_avg.result(),epoch_accuracy.result()))print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

运行输出:

Epoch 000: Loss: 1.128, Accuracy: 54.441%
Time taken for 1 epoch 18.131535291671753 secEpoch 001: Loss: 1.002, Accuracy: 62.582%
Time taken for 1 epoch 18.448741674423218 secEpoch 002: Loss: 0.895, Accuracy: 66.859%
Time taken for 1 epoch 18.31860089302063 secEpoch 003: Loss: 0.761, Accuracy: 74.397%
Time taken for 1 epoch 17.966360569000244 secEpoch 004: Loss: 0.585, Accuracy: 81.168%
Time taken for 1 epoch 17.9322772026062 secEpoch 005: Loss: 0.410, Accuracy: 89.200%
Time taken for 1 epoch 18.117868900299072 secEpoch 006: Loss: 0.269, Accuracy: 94.545%
Time taken for 1 epoch 17.976419687271118 secEpoch 007: Loss: 0.139, Accuracy: 97.478%
Time taken for 1 epoch 17.916046380996704 secEpoch 008: Loss: 0.094, Accuracy: 98.629%
Time taken for 1 epoch 18.384119987487793 secEpoch 009: Loss: 0.154, Accuracy: 98.438%
Time taken for 1 epoch 17.962616682052612 sec

下一篇

TF2.0模型保存


http://chatgpt.dhexx.cn/article/SjG78PPP.shtml

相关文章

TensorFlow 2.0 —— 模型训练

目录 1、Keras版本模型训练1.1 构造模型&#xff08;顺序模型、函数式模型、子类模型&#xff09;1.2 模型训练&#xff1a;model.fit()1.3 模型验证&#xff1a;model.evaluate()1.4 模型预测&#xff1a;model.predict()1.5 使用样本加权和类别加权1.6 回调函数1.6.1 EarlySt…

如何在jupyter上运行Java代码(适用LINUX)

如何在jupyter上运行Java代码 1.下载必须软件 下载JDK且JDK版本必须 ≥ 9 ≥9 ≥9从github上下载ijava 附 &#xff1a; ijava下载链接.装有jupyter&#xff0c;我在LINUX上是直接装的anaconda 安装过程 将下载的ijava压缩包解压出来&#xff0c;并在此路径用该命令 : sudo…

Java单元测试介绍

文章目录 单元测试单元测试基本介绍单元测试快速入门单元测试常用注解 单元测试 单元测试基本介绍 单元测试: 单元测试就是针对最小的功能单元编写测试代码&#xff0c;Java程序最小的功能单元是方法&#xff0c;因此&#xff0c;单元测试就是针对Java方法的测试&#xff0c;…

Jupyter 配置 Java环境,写Java代码,测试成功

本次简单诉说下怎么通过jupyter安装iJava&#xff0c;写Java代码。 安装Java的不说了 我使用的是Java15 然后去&#xff1a;https://github.com/SpencerPark/IJava/releases 下载zip&#xff0c;不要下载其他的 得到就是一个py文件 下面就是一个 python install.py 我这里就…

java调用python执行脚本,附代码

最近有个功能需要java调用python脚本实现一些功能&#xff0c;前期需要做好的准备&#xff1a;配置好python环境&#xff0c;如下&#xff1a; 以下展示的为两种&#xff0c;一种为生成图片&#xff0c;另一种为生成字符串。 package com.msdw.tms.common.utils.py;import ja…

Selenium Java自动化测试环境搭建

IDE用的是Eclipse。 步骤1&#xff1a;因为是基于Java&#xff0c;所以首先要下载与安装JDK&#xff08;Java Development Kit&#xff09; 下载&#xff1a; 点击这里下载JDK 安装&#xff1a;按照默认安装一路点next就可以了。 验证&#xff1a;安装完成后&#xff0c;在命…

java单元测试(Junit)

相关代码下载链接&#xff1a; http://download.csdn.net/detail/stevenhu_223/4884357 在有些时候&#xff0c;我们需要对我们自己编写的代码进行单元测试&#xff08;好处是&#xff0c;减少后期维护的精力和费用&#xff09;&#xff0c;这是一些最基本的模块测试。当然&…

Java单元测试工具:JUnit4(一)——概述及简单例子

&#xff08;一&#xff09;JUnit概述及一个简单例子 看了慕课网的JUnit视频教程&#xff1a; http://www.imooc.com/learn/356&#xff0c;总结笔记。 这篇笔记记录JUnit的概述&#xff0c;以及一个快速入门的例子。 1.概述 1.1 什么是JUnit ①JUnit是用于编写可复用测试集的…

Linux下执行Python脚本

1.Linux Python环境 Linux系统一般集成Python&#xff0c;如果没有安装&#xff0c;可以手动安装&#xff0c;联网状态下可直接安装。Fedora下使用yum install&#xff0c;Ubuntu下使用apt-get install&#xff0c;前提都是root权限。安装完毕&#xff0c;可将Python加入环境变…

python pytest脚本执行工具

pytest脚本执行工具 支持获取当前路径下所有.py脚本 添加多个脚本&#xff0c;一起执行 import tkinter as tk from tkinter import filedialog import subprocess import os from datetime import datetimedef select_script():script_path filedialog.askopenfilename(fil…

linux上运行python(简单版)

linux上运行python&#xff08;简单版&#xff09; 一、前提准备1.centOS72.挂载yum源[http://t.csdn.cn/Isf0i](http://t.csdn.cn/Isf0i) 二、安装python3三、运行程序 一、前提准备 1.centOS7 2.挂载yum源http://t.csdn.cn/Isf0i 在终端进行安装python3 二、安装python3 …

linux怎么运行python脚本?

linux运行python脚本的方法&#xff1a; 1、命令行执行&#xff1a; 建立一个test.py文档&#xff0c;在其中书写python代码。之后&#xff0c;在命令行执行&#xff1a;python test.py 说明&#xff1a;其中python可以写成python的绝对路径。使用which python进行查询。 注…

java实现远程执行Linux下的shell脚本

java实现远程执行Linux下的shell脚本 背景导入Jar包第一步&#xff1a;远程连接第二步&#xff1a;开启Session第三步&#xff1a;新建测试脚本文件结果报错 背景 最近有个项目&#xff0c;需要在Linux下的服务器内写了一部分Python脚本&#xff0c;业务处理却是在Java内&…

Java运行Python脚本

前段时间遇到了在JavaWeb项目中嵌入运行Python脚本的功能的需求。想到的方案有两种&#xff0c;一种是使用Java技术&#xff08;Jython或Runtime.exec&#xff09;运行Python脚本&#xff0c;另一种是搭建一个Python工程对外提供相应http或webservice接口。两种方案我都有实现&…

Java项目分层

MVC模式 在实际的开发中有一种项目的程序组织架构方案叫做MVC模式&#xff0c;按照程序 的功能将他们分成三个层&#xff0c;如下图&#xff1a;Modle层&#xff08;模型层&#xff09;、View层&#xff08;显示层&#xff09;、Controller层&#xff08;控制层&#xff09;。…

java项目收获总结_java开发项目收获心得

1 java开发项目收获心得 it行业现在的发展如日中天,很多人都纷纷走进这个行业,而java作为跨平台的编程语言更是受欢迎。java其实相对其他语言来说的确很有优势,但是也有点缺陷,但是以后发展到什么程度,谁都不知道。那么下面小编给大家说说java开发项目收获心得,希望能对你…

java查看jar包依赖_java项目开发中如何查找到项目依赖的jar包?

不管是java普通工程,还是java web项目,甚至是android项目,依赖包的管理有2种: 1.直接依赖jar包 这种方式简单直白,项目下载后在正确的ide或者稍微做转换就可以运行起来。比如java web工程的WEB-INF/lib下 只要按这个步骤Java Build Path=>Add Libraty=>Web App Libr…

Java小白必看:开发一个编程项目的完整流程(附100套Java编程项目源码+视频)

我相信很多Java新手都会遇到这样一个问题&#xff1a;跟着教材敲代码&#xff0c;很容易&#xff1b;但是让他完整的实现一个应用项目&#xff0c;却不会&#xff1b;不知道从哪里开始&#xff0c;不知道实现一个项目的完整流程是怎样的&#xff0c;看似很简单的一个问题&#…

分享67套基于Java开发的Java毕业设计实战项目(含源码+毕业论文)【新星计划】

【新星计划】分享67套基于Java开发的Java毕业设计实战项目(含源码毕业论文) 基于Java开发的Java毕业设计实战项目 本文中的所有主题都来自互联网。如果您侵犯您的权利&#xff0c;请及时联系Blogger&#xff0c;博主将及时处理。 投诉邮箱&#xff1a;1919101926qq.com (没事…

分享一些我的学习方法

赖勇浩&#xff08;http://laiyonghao.com &#xff09; 经常听到和看到一些前辈提起搞编程这一行最大的痛苦在于知识的更新太频繁&#xff0c;如同逆水行舟&#xff0c;不进则退&#xff0c;稍一松懈&#xff0c;就跟不上潮流。的确如此&#xff0c;既然身在 IT 界&#xff0c…