finetune

article/2025/9/24 20:01:06

finetune的含义是获取预训练好的网络的部分结构和权重,与自己新增的网络部分一起训练。下面介绍几种finetune的方法。

完整代码:https://github.com/toyow/learn_tensorflow/tree/master/finetune

一,如何恢复预训练的网络

方法一:

思路:恢复原图所有的网络结构(op)以及权重,获取中间层的tensor,自己只需要编写新的网络结构,然后把中间层的tensor作为新网络结构的输入。

存在的问题:

1.这种方法是把原图所有结构载入到新图中,也就是说不需要的那部分也被载入了,浪费资源。

2.在执行优化器操作的时候,如果不锁定共有的结构(layer2=tf.stop_gradient(layer2,name='layer2_stop')),会导致重名提示报错,因为原结构已经有一个优化器操作了,你再优化一下就重名了。

核心代码:
1.把原网络加载到新图里
def train():
#恢复原网络的op tensorwith tf.Graph().as_default() as g:saver=tf.train.import_meta_graph('./my_ckpt_save_dir/wdy_model-15.meta')#把原网络载入到图g中
2.获取原图中间层tensor作为新网络的输入x_input=g.get_tensor_by_name('input/x:0')#恢复原op的tensory_input = g.get_tensor_by_name('input/y:0')layer2=g.get_tensor_by_name('layer2/layer2:0')#layer2=tf.stop_gradient(layer2,name='layer2_stop')#layer2及其以前的op均不进行反向传播softmax_linear=inference(layer2)#继续前向传播cost=loss(y_input,softmax_linear)train_op=tf.train.AdamOptimizer(0.001,name='Adma2').minimize(cost)#重名,所以改名3.恢复所有权重saver.restore(sess,save_path=tf.train.latest_checkpoint('./my_ckpt_save_dir/'))

方法二:

思路:重新定义网络结构,保持共有部分与原来同名。在恢复权重时,只恢复共有部分。

1.自定义网络结构
def inference(x):with tf.variable_scope('layer1') as scope:weights=weights_variabel('weights',[784,256],0.04)bias=bias_variabel('bias',[256],tf.constant_initializer(0.0))layer1=tf.nn.relu(tf.add(tf.matmul(x,weights),bias),name=scope.name)with tf.variable_scope('layer2') as scope:weights=weights_variabel('weights',[256,128],0.02)bias=bias_variabel('bias',[128],tf.constant_initializer(0.0))layer2=tf.nn.relu(tf.add(tf.matmul(layer1,weights),bias),name=scope.name)# layer2=tf.stop_gradient(layer2,name='layer2_stop')#layer2及其以前的op均不进行反向传播with tf.variable_scope('layer3') as scope:weights=weights_variabel('weights',[128,64],0.001)bias=bias_variabel('bias',[64],tf.constant_initializer(0.0))layer3=tf.nn.relu(tf.add(tf.matmul(layer2,weights),bias),name=scope.name)with tf.variable_scope('softmax_linear_1') as scope:weights = weights_variabel('weights', [64, 10], 0.0001)bias = bias_variabel('bias', [10], tf.constant_initializer(0.0))softmax_linear = tf.add(tf.matmul(layer3, weights), bias,name=scope.name)return softmax_linear2.恢复指定的权重variables_to_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)[:4]#这里获取权重列表,只选择自己需要的部分saver = tf.train.Saver(variables_to_restore)with tf.Session(graph=g) as sess:#恢复权重saver.restore(sess,save_path=tf.train.latest_checkpoint('./my_ckpt_save_dir/'))#这个时候就是只恢复需要的权重了

二,如何获取锁层部分的变量名称,如何避免名称不匹配的问题。

   锁住了也可以显示所有变量。params_1=slim.get_model_variables()#放心大胆地获取纯净的参数变量,包括batchnormparams_2 = slim.get_variables_to_restore()  # 包含优化函数里面定义的动量等等变量,exclude       只能写全名。params_2 = [val for val in params_2 if 'Logits' not in val.name]#剔除含有这个字符的变量锁住了(trianable=False)就不显示。params_3 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)  params_4 = tf.trainable_variables()# 不包含优化器参数

解决方法,利用slim.get_variables_to_restore(),紧跟在原网络结构后面。之后再写自己定义的操作。

 

三,如何给不同层设置不同的学习率

思路:minizie()函数实际由compute_gradients()和apply_gradients()两个步骤完成。

compute_gradients()返回的是(gradent,varibel)元组对的列表,把这个列表varibel对应的gradent乘以学习率,再把新列表传入apply_gradients()就搞定了。

核心代码:

softmax_linear=inference(x_input)#继续前向传播
cost=loss(y_input,softmax_linear)
train_op=tf.train.AdamOptimizer()
grads=train_op.compute_gradients(cost)#返回的是(gradent,varibel)元组对的列表
variables_low_LR = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)[:4]#获取低学习率的变量列表
low_rate=0.0001
high_rate=0.001
new_grads_varible=[]#新的列表
for grad in grads:#对属于低学习率的变量的梯度,乘以一个低学习率if grad[1] in variables_low_LR:new_grads_varible.append((low_rate*grad[0],grad[1]))else:new_grads_varible.append((high_rate * grad[0], grad[1]))
apply_gradient_op = train_op.apply_gradients(new_grads_varible)
sess.run(apply_gradient_op,feed_dict={x_input:x_train_batch,y_input:y_train_batch})

三,关于PB文件

一,保存:

ckpt类型文件,是把结构(mate)与权重(checkpoint)分开保存,恢复的时候也是可以单独恢复。而PB文件是把结构与权重保存进了一个文件里。其中权重被固化成了常量,无法进行再次训练了。

可以看到,我指定保存最后一个tensor。只保存了之前的结构和权重,甚至y都没保存。

核心代码:

graph = convert_variables_to_constants(sess,sess.graph_def,['softmax_linear/softmax_linear'])
tf.train.write_graph(graph,'.','graph.pb',as_text=False)

二,恢复

恢复的思路跟ckpt恢复网络结构类似,不过因为只保存了我指定tensor之前的结构,所以自然也只能恢复保存了的网络结构。

with tf.Graph().as_default() as g:x_place = tf.placeholder(tf.float32, shape=[None, 784], name='x')y_place = tf.placeholder(tf.float32, shape=[None, 10], name='y')with open('./graph.pb','rb') as f:graph_def = tf.GraphDef()graph_def.ParseFromString(f.read())tf.import_graph_def(graph_def, name='')
#恢复tensorgraph_op = tf.import_graph_def(graph_def,name='',input_map={'input/x:0':x_place},return_elements=['layer2/layer2:0','layer1/weights:0'])

或者可以用

# x_place = g.get_tensor_by_name('input/x:0')#y_place = g.get_tensor_by_name('input/y:0')#layer2 = g.get_tensor_by_name('layer2/layer2:0')

 


http://chatgpt.dhexx.cn/article/IlREkYwV.shtml

相关文章

11.2 模型finetune

一、Transform Learning 与 Model Finetune 二、pytorch中的Finetune 一、Transfer Learning 与 Model Finetune 1. 什么是Transfer Learning? 迁移学习是机器学习的一个分支,主要研究源域的知识如何应用到目标域当中。迁移学习是一个很大的概念。 怎么理解源域…

飞桨深度学习学院零基础深度学习7日入门-CV疫情特辑学习笔记(四)DAY03 车牌识别

本课分为理论和实战两个部分 理论:卷积神经网络 1.思考全连接神经网络的问题 一般来收机器学习模型实践分为三个步骤,(1)建立模型 (2)选择损失函数 (3)参数调整学习 1.1 模型结构不…

unity sdk(android)-友盟推送SDK接入

注意:一开始想接友盟Unity的SDk,但是导入后缺少各种jar,所以最后还是接了android的,demo文档齐全 官方文档:开发者中心 按照官方文档对接即可, 接入流程 1、项目中com.android.tools.build:gradle配置&…

友盟推送学习

一、首次使用U_Push 1、首先注册友盟账号,进入工作台,选择产品U_Push。 2、创建应用 3、在自己的项目中自动集成SDK 开发环境要求: Android Studio 3.0以上 Android minSdkVersion: 14 Cradle: 4.4以上 在根目录build.gradle中添加mav…

Android 学习之如何集成友盟推送

我是利用Android studio 新建一个空的Android项目。 步骤一 导入第三方库 1.切换Android项目状态为Project状态 2.在main文件下新建 jniLibs文件夹(用来导入PushSDK项目下lib文件中的so文件) 3.在libs文件夹下添加友盟PuskSDK中的 jar 文件&#xff…

用PaddlePaddle(飞浆)实现车牌识别

项目描述:本次实践是一个多分类任务,需要将照片中的每个字符分别进行识别,完成车牌的识别 实践平台:百度AI实训平台-AI Studio、PaddlePaddle1.8.0 动态图 数据集介绍(自己去网上下载车牌识别数据集) 数据…

深度学习(五) CNN卷积神经网络

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 CNN卷积神经网络 前言一、CNN是什么?二、为什么要使用CNN?三、CNN的结构1.图片的结构2.卷积层1.感受野(Receptive Field)2.卷积…

CNN网络实现手写数字(MNIST)识别 代码分析

CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import numpy as np //第三方库,用于进行科学计算 import torc…

Android删除chartty证书,C/C++知识点之android应用安全分析

本文主要向大家介绍了C/C知识点之android应用安全分析,通过具体的内容向大家展示,希望对大家学习C/C知识点有所帮助。 应用名 :OKEx(OKEx-android.apk) 包名 :com.okinc.okex MD5 :1ffbd328d13e91b661592cdf58516bd2 版…

代码编写过程 - 正确率折线图

获取绘图函数 首先,看到需要画acc和loss图。先去参考现成的,于是打开猫12分类,找到生成折线图的地方。 发现框内的两个函数绘制了折线图。既然是作为函数出现,说明已经有一定的封装,考虑能不能把整个函数搬走用。 由…

李宏毅机器学习课程HW03代码解释

作业3任务是将图片进行分类 从官网上下载数据到data文件里面。此外,将代码分为三个模块,分别是dataset,model以及main。 一、dataset模块 此模块作用是读取图片数据。 重要函数:os.path.join(path,x) 将path和x路径组合在一起 #导入库…

接入友盟厂商push通道遇到的坑

目录 调试友盟Push问题的检查清单 客户端、服务端数据协议 客户端接入方式 小米厂商通道 华为厂商通道 魅族厂商通道 VIVO厂商通道 OPPO厂商通道 支持桌面角标的厂商 吐槽一下集成友盟厂商通道时发现的问题 调试友盟Push问题的检查清单 过滤UmengPushAgent开头的日志…

Android集成友盟消息推送SDK

消息推送SDK快速集成: 申请AppKey ——> 接入Push SDK ——> 基础接口引入 ——> 消息推送测试 ——> 接入完成 1.申请AppKey 2.接入Push SDK 1)、加入依赖 //友盟push相关依赖(必须)implementationcom.umeng.umsdk:push:6.1.0impleme…

机器学习之手写决策树以及sklearn中的决策树及其可视化

文章目录 决策树理论部分基本算法划分选择信息熵 信息增益信息增益率基尼系数基尼指数 决策树代码实现参考 决策树理论部分 决策树的思路很简单,就是从数据集中挑选一个特征,然后进行分类。 基本算法 从伪代码中可以看出,分三种情况考虑&…

android使用友盟推送注册失败获取不到token accs bindapp error!

使用友盟推送注册失败获取不到token public void onFailure(String s, String s1)的值分别是“-9”和“accs bindapp error!”或者s的值为-11.都是同一个问题 就是主工程(除友盟PushSDK 其他的module均看成为主工程)so目录与PushSDK下的so目录不一致…

同时集成阿里云旺与友盟推送,初始化失败s:-11,s1:accs bindapp error!的解决办法

在应用中需要同时集成聊天和推送功能,聊天选用阿里的sdk(百川云旺),推送选用友盟的pushSDK。 这时候悲剧就出现了,注册友盟的时候报错。 I/com.umeng.message.PushAgent: register-->onFailure-->s:-11,s1:accs …

关于友盟s=-11;s1=accs bindapp error!的解决处理

项目使用了友盟推送之后,在部分手机上出现accs bindapp error,错误码-11的问题,为什么会出现这个问题呢,网上查找了很久,友盟给出的解释是so文件不正确。 具体链接:http://bbs.umeng.com/thread-23018-1-1…

友盟register failed: -11 accs bindapp error!

下载官方Demo后,替换自己的id包名后出现 register failed: -11 accs bindapp error! 经过一番搜索之后,都是说这二种原因 1、检查appkey和secret key是否配置正确,如果正确无误,请看步骤2。2、so文件配置有误,需重新配置: Pus…

阿里无线11.11 | 手机淘宝移动端接入网关基础架构演进之路

移动网络优化是超级App永恒的话题,对于无线电商来说更为重要,网络请求体验跟用户的购买行为息息相关,手机淘宝从过去的HTTP API网关,到2014年升级支持SPDY,2015年双十一自研高性能、全双工、安全的ACCS(阿里…

VS2015 realease模式下调试

一、将项目属性设置为Release,生成--->配置管理器: 二、按AltF7,弹出属性页进行设置: