FlyAI小课堂:【文本分类-中文】textCNN

article/2025/9/19 10:48:08

摘要: 随着深度学习的快速发展,人们创建了一整套神经网络结构来解决各种各样的任务和问题。在英文分类基础上,中文文本分类的处理也同样重要...

人工智能学习离不开实践的验证,推荐大家可以多在FlyAI-AI竞赛服务平台多参加训练和竞赛,以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。

目录

  1. 概述
  2. 数据集合
  3. 代码
  4. 结果展示

一、概述

在英文分类的基础上,再看看中文分类的,是一种10分类问题(体育,科技,游戏,财经,房产,家居等)的处理。

二、数据集合

数据集为新闻,总共有四个数据文件,在/data/cnews目录下,包括内容如下图所示测试集,训练集和验证集,和单词表(最后的单词表cnews.vocab.txt可以不要,因为训练可以自动产生)。数据格式:前面为类别,后面为描述内容。

训练数据地址:链接: https://pan.baidu.com/s/1ZHh98RrjQpG5Tm-yq73vBQ 提取码:2r04

其中训练集的格式:

vocab.txt的格式:每个字一行,其中前面加上PAD。

三、代码

3.1 数据采集cnews_loader.py

# coding: utf-8import sysfrom collections import Counterimport numpy as npimport tensorflow.contrib.keras as krif sys.version_info[0] > 2:is_py3 = Trueelse:reload(sys)sys.setdefaultencoding("utf-8")is_py3 = Falsedef native_word(word, encoding='utf-8'):"""如果在python2下面使用python3训练的模型,可考虑调用此函数转化一下字符编码"""if not is_py3:return word.encode(encoding)else:return worddef native_content(content):if not is_py3:return content.decode('utf-8')else:return contentdef open_file(filename, mode='r'):"""常用文件操作,可在python2和python3间切换.mode: 'r' or 'w' for read or write"""if is_py3:return open(filename, mode, encoding='utf-8', errors='ignore')else:return open(filename, mode)def read_file(filename):"""读取文件数据"""contents, labels = [], []with open_file(filename) as f:for line in f:try:label, content = line.strip().split('\t')if content:contents.append(list(native_content(content)))labels.append(native_content(label))except:passreturn contents, labelsdef build_vocab(train_dir, vocab_dir, vocab_size=5000):"""根据训练集构建词汇表,存储"""data_train, _ = read_file(train_dir)all_data = []for content in data_train:all_data.extend(content)counter = Counter(all_data)count_pairs = counter.most_common(vocab_size - 1)words, _ = list(zip(*count_pairs))# 添加一个 <PAD> 来将所有文本pad为同一长度words = ['<PAD>'] + list(words)open_file(vocab_dir, mode='w').write('\n'.join(words) + '\n')def read_vocab(vocab_dir):"""读取词汇表"""# words = open_file(vocab_dir).read().strip().split('\n')with open_file(vocab_dir) as fp:# 如果是py2 则每个值都转化为unicodewords = [native_content(_.strip()) for _ in fp.readlines()]word_to_id = dict(zip(words, range(len(words))))return words, word_to_iddef read_category():"""读取分类目录,固定"""categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']categories = [native_content(x) for x in categories]cat_to_id = dict(zip(categories, range(len(categories))))return categories, cat_to_iddef to_words(content, words):"""将id表示的内容转换为文字"""return ''.join(words[x] for x in content)def process_file(filename, word_to_id, cat_to_id, max_length=600):"""将文件转换为id表示"""contents, labels = read_file(filename)data_id, label_id = [], []for i in range(len(contents)):data_id.append([word_to_id[x] for x in contents[i] if x in word_to_id])label_id.append(cat_to_id[labels[i]])# 使用keras提供的pad_sequences来将文本pad为固定长度x_pad = kr.preprocessing.sequence.pad_sequences(data_id, max_length)y_pad = kr.utils.to_categorical(label_id, num_classes=len(cat_to_id))  # 将标签转换为one-hot表示return x_pad, y_paddef batch_iter(x, y, batch_size=64):"""生成批次数据"""data_len = len(x)num_batch = int((data_len - 1) / batch_size) + 1indices = np.random.permutation(np.arange(data_len))x_shuffle = x[indices]y_shuffle = y[indices]for i in range(num_batch):start_id = i * batch_sizeend_id = min((i + 1) * batch_size, data_len)yield x_shuffle[start_id:end_id], y_shuffle[start_id:end_id]

3.2 模型搭建cnn_model.py

定义训练的参数,TextCNN()模型

# coding: utf-8import tensorflow as tfclass TCNNConfig(object):"""CNN配置参数"""embedding_dim = 64  # 词向量维度seq_length = 600  # 序列长度num_classes = 10  # 类别数num_filters = 256  # 卷积核数目kernel_size = 5  # 卷积核尺寸vocab_size = 5000  # 词汇表达小hidden_dim = 128  # 全连接层神经元dropout_keep_prob = 0.5  # dropout保留比例learning_rate = 1e-3  # 学习率batch_size = 64  # 每批训练大小num_epochs = 10  # 总迭代轮次print_per_batch = 100  # 每多少轮输出一次结果save_per_batch = 10  # 每多少轮存入tensorboardclass TextCNN(object):"""文本分类,CNN模型"""def __init__(self, config):self.config = config# 三个待输入的数据self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')self.cnn()def cnn(self):"""CNN模型"""# 词向量映射with tf.device('/cpu:0'):embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)with tf.name_scope("cnn"):# CNN layerconv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')# global max pooling layergmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')with tf.name_scope("score"):# 全连接层,后面接dropout以及relu激活fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')fc = tf.contrib.layers.dropout(fc, self.keep_prob)fc = tf.nn.relu(fc)# 分类器self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别with tf.name_scope("optimize"):# 损失函数,交叉熵cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)self.loss = tf.reduce_mean(cross_entropy)# 优化器self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)with tf.name_scope("accuracy"):# 准确率correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

3.3 运行代码run_cnn.py

#!/usr/bin/python
# -*- coding: utf-8 -*-
from __future__ import print_function
import os
import sys
import time
from datetime import timedelta
import numpy as np
import tensorflow as tf
from sklearn import metrics
from cnn_model import TCNNConfig, TextCNN
from  cnews_loader import read_vocab, read_category, batch_iter, process_file, build_vocabbase_dir = '../data/cnews'
train_dir = os.path.join(base_dir, 'cnews.train.txt')
test_dir = os.path.join(base_dir, 'cnews.test.txt')
val_dir = os.path.join(base_dir, 'cnews.val.txt')
vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')
save_dir = 'checkpoints/textcnn'
save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径def get_time_dif(start_time):"""获取已使用时间"""end_time = time.time()time_dif = end_time - start_timereturn timedelta(seconds=int(round(time_dif)))def feed_data(x_batch, y_batch, keep_prob):feed_dict = {model.input_x: x_batch,model.input_y: y_batch,model.keep_prob: keep_prob}return feed_dictdef evaluate(sess, x_, y_):"""评估在某一数据上的准确率和损失"""data_len = len(x_)batch_eval = batch_iter(x_, y_, 128)total_loss = 0.0total_acc = 0.0for x_batch, y_batch in batch_eval:batch_len = len(x_batch)feed_dict = feed_data(x_batch, y_batch, 1.0)loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)total_loss += loss * batch_lentotal_acc += acc * batch_lenreturn total_loss / data_len, total_acc / data_lendef train():print("Configuring TensorBoard and Saver...")# 配置 Tensorboard,重新训练时,请将tensorboard文件夹删除,不然图会覆盖tensorboard_dir = '../tensorboard/textcnn'if not os.path.exists(tensorboard_dir):os.makedirs(tensorboard_dir)tf.summary.scalar("loss", model.loss)tf.summary.scalar("accuracy", model.acc)merged_summary = tf.summary.merge_all()writer = tf.summary.FileWriter(tensorboard_dir)# 配置 Saversaver = tf.train.Saver()if not os.path.exists(save_dir):os.makedirs(save_dir)print("Loading training and validation data...")# 载入训练集与验证集start_time = time.time()x_train, y_train = process_file(train_dir, word_to_id, cat_to_id, config.seq_length)x_val, y_val = process_file(val_dir, word_to_id, cat_to_id, config.seq_length)time_dif = get_time_dif(start_time)print("Time usage:", time_dif)# 创建sessionsession = tf.Session()session.run(tf.global_variables_initializer())writer.add_graph(session.graph)print('Training and evaluating...')start_time = time.time()total_batch = 0  # 总批次best_acc_val = 0.0  # 最佳验证集准确率last_improved = 0  # 记录上一次提升批次require_improvement = 1000  # 如果超过1000轮未提升,提前结束训练flag = Falsefor epoch in range(config.num_epochs):print('Epoch:', epoch + 1)batch_train = batch_iter(x_train, y_train, config.batch_size)for x_batch, y_batch in batch_train:feed_dict = feed_data(x_batch, y_batch, config.dropout_keep_prob)#print("x_batch is {}".format(x_batch.shape))if total_batch % config.save_per_batch == 0:# 每多少轮次将训练结果写入tensorboard scalars = session.run(merged_summary, feed_dict=feed_dict)writer.add_summary(s, total_batch)if total_batch % config.print_per_batch == 0:# 每多少轮次输出在训练集和验证集上的性能feed_dict[model.keep_prob] = 1.0loss_train, acc_train = session.run([model.loss, model.acc], feed_dict=feed_dict)loss_val, acc_val = evaluate(session, x_val, y_val)  # todoif acc_val > best_acc_val:# 保存最好结果best_acc_val = acc_vallast_improved = total_batchsaver.save(sess=session, save_path=save_path)improved_str = '*'else:improved_str = ''time_dif = get_time_dif(start_time)msg = 'Iter: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \+ ' Val Loss: {3:>6.2}, Val Acc: {4:>7.2%}, Time: {5} {6}'print(msg.format(total_batch, loss_train, acc_train, loss_val, acc_val, time_dif, improved_str))session.run(model.optim, feed_dict=feed_dict)  # 运行优化total_batch += 1if total_batch - last_improved > require_improvement:# 验证集正确率长期不提升,提前结束训练print("No optimization for a long time, auto-stopping...")flag = Truebreak  # 跳出循环if flag:  # 同上breakdef test():print("Loading test data...")start_time = time.time()x_test, y_test = process_file(test_dir, word_to_id, cat_to_id, config.seq_length)session = tf.Session()session.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess=session, save_path=save_path)  # 读取保存的模型print('Testing...')loss_test, acc_test = evaluate(session, x_test, y_test)msg = 'Test Loss: {0:>6.2}, Test Acc: {1:>7.2%}'print(msg.format(loss_test, acc_test))batch_size = 128data_len = len(x_test)num_batch = int((data_len - 1) / batch_size) + 1y_test_cls = np.argmax(y_test, 1)y_pred_cls = np.zeros(shape=len(x_test), dtype=np.int32)  # 保存预测结果for i in range(num_batch):  # 逐批次处理start_id = i * batch_sizeend_id = min((i + 1) * batch_size, data_len)feed_dict = {model.input_x: x_test[start_id:end_id],model.keep_prob: 1.0}y_pred_cls[start_id:end_id] = session.run(model.y_pred_cls, feed_dict=feed_dict)# 评估print("Precision, Recall and F1-Score...")print(metrics.classification_report(y_test_cls, y_pred_cls, target_names=categories))# 混淆矩阵print("Confusion Matrix...")cm = metrics.confusion_matrix(y_test_cls, y_pred_cls)print(cm)time_dif = get_time_dif(start_time)print("Time usage:" , time_dif)if __name__ == '__main__':config = TCNNConfig()if not os.path.exists(vocab_dir):  # 如果不存在词汇表,重建,这里存在,因此不用重建build_vocab(train_dir, vocab_dir, config.vocab_size)categories, cat_to_id = read_category()words, word_to_id = read_vocab(vocab_dir)config.vocab_size = len(words)model = TextCNN(config)option='train'if option == 'train':train()else:test()

3.4 预测predict.py

# coding: utf-8from __future__ import print_functionimport osimport tensorflow as tfimport tensorflow.contrib.keras as krfrom cnn_model import TCNNConfig, TextCNNfrom cnews_loader import read_category, read_vocabtry:bool(type(unicode))except NameError:unicode = strbase_dir = '../data/cnews'vocab_dir = os.path.join(base_dir, 'cnews.vocab.txt')save_dir = '../checkpoints/textcnn'save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径class CnnModel:def __init__(self):self.config = TCNNConfig()self.categories, self.cat_to_id = read_category()self.words, self.word_to_id = read_vocab(vocab_dir)self.config.vocab_size = len(self.words)self.model = TextCNN(self.config)self.session = tf.Session()self.session.run(tf.global_variables_initializer())saver = tf.train.Saver()saver.restore(sess=self.session, save_path=save_path)  # 读取保存的模型def predict(self, message):# 支持不论在python2还是python3下训练的模型都可以在2或者3的环境下运行content = unicode(message)data = [self.word_to_id[x] for x in content if x in self.word_to_id]feed_dict = {self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),self.model.keep_prob: 1.0}y_pred_cls = self.session.run(self.model.y_pred_cls, feed_dict=feed_dict)return self.categories[y_pred_cls[0]]if __name__ == '__main__':cnn_model = CnnModel()test_demo = ['三星ST550以全新的拍摄方式超越了以往任何一款数码相机','热火vs骑士前瞻:皇帝回乡二番战 东部次席唾手可得新浪体育讯北京时间3月30日7:00']for i in test_demo:print(cnn_model.predict(i))

四、结果展示

   

 

相关代码可见:https://github.com/yifanhunter/textClassifier_chinese


 

有关于‘文本分类’的竞赛项目,大家可移步官网进行查看和参赛!

更多精彩内容请访问FlyAI-AI竞赛服务平台;为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台;每周免费提供项目开源算法样例,支持算法能力变现以及快速的迭代算法模型。

挑战者,都在FlyAI!!!


http://chatgpt.dhexx.cn/article/hQhLmDVQ.shtml

相关文章

FlyAI人工智能社区参赛指南—用户体验版

Fly-AI竞赛服务平台 flyai.com 在开始学习之前推荐大家可以多在 FlyAI竞赛服务平台多参加训练和竞赛&#xff0c;以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例&#xff0c;支持算法能力变现以…

FlyAI资讯:GPT-3的威力,助力AI领域

摘要&#xff1a;大概一个月之前&#xff0c;史上最巨无霸NLP模型GPT-3问世。当时它向世界展示的能力是&#xff0c;“不仅会写短文&#xff0c;而且写出来的作文挺逼真的&#xff0c;几乎可以骗过人类&#xff0c;可以说几乎通过了图灵测试。”可能是因为它的前一代模型GPT-2也…

FlyAI赛题预告:时间序列初学者指南

近期&#xff0c;FlyAI服务竞赛平台将上线时间序列相关赛题&#xff0c;看到此文的童鞋们可以多多关注下&#xff1b;推荐大家可以多在FlyAI竞赛服务平台多参加训练和竞赛&#xff0c;以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。…

FlyAI资讯:人工智能的前世今生

摘要: 现代电子产品和设备在诸如通信 、娱乐 、安全和医疗保健等许多方面改善了我们的生活质量 &#xff0c;这主要是因为现代微电子技术的发展极大地改变了人们的日常工作和互动方式。在过去几十年中&#xff0c;摩尔定律一直是通过不断缩小芯 … 人工智能学习离不开实践的验…

FlyAI小课堂:Tensorflow基操

人工智能学习离不开实践的验证&#xff0c;推荐大家可以多在FlyAI-AI竞赛服务平台多参加训练和竞赛&#xff0c;以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例&#xff0c;支持算法能力变现以及快速的…

FlyAI资讯:台积电布局新存储技术

摘要&#xff1a;近年来&#xff0c;在人工智能&#xff08;AI&#xff09;、5G等推动下&#xff0c;以MRAM&#xff08;磁阻式随机存取存储器&#xff09;、铁电随机存取存储器 (FRAM)、相变随机存取存储器&#xff08;PRAM&#xff09;&#xff0c;以及可变电阻式随机存取存储…

强烈推荐!FlyAI机器学习数据竞赛启动,丰厚奖金等你来拿

FlyAI 数据竞赛平台 FlyAI 是隶属于北京智能工场科技有限公司旗下&#xff0c;为AI开发者 &#xff08;深度学习&#xff09;提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例&#xff0c;样例所使用开发框架涉及TensorFlow、Keras、PyTorch. 支持…

FlyAI竞赛服务平台赛题上新——手写英文字体识别(名企内推)

人工智能学习离不开实践的验证&#xff0c;推荐大家可以多在FlyAI-AI竞赛服务平台多参加训练和竞赛&#xff0c;以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例&#xff0c;支持算法能力变现以及快速的…

万元大奖,FlyAI算法新赛事,心理卡牌目标检测

本文转自竞赛社区FlyAI最新上线算法赛。 “心理卡牌目标检测算法赛”由测测星座与趣鸭心理联合发起的线上竞赛项目。2020年新冠肺炎疫情打破了我们平静的生活&#xff0c;影响了我们的身体健康和精神健康。在国家发布的《“健康中国2030”规划纲要》中也提到&#xff0c;要加大…

百万现金奖励,FlyAI实时竞赛等你来战

点击我爱计算机视觉标星&#xff0c;更快获取CVML新技术 这是一种新的算法竞赛方式&#xff01; FlyAI是一个为算法工程师提供&#xff08;深度学习&#xff09;项目竞赛并支持GPU离线训练的网站。目前每周更新两个以上现金奖励的竞赛项目。项目涉及领域包括图像识别/分类/检测…

使用FlyAI进行科学数据竞赛

转自百度经验 https://jingyan.baidu.com/article/acf728fd9c3e6af8e510a392.html 工具/原料 python3 Windows命令行&#xff08;mac/Linux都可运行&#xff0c;此经验以Windows为例&#xff09; 百度经验:jingyan.baidu.com 注册账号 1 使用浏览器打开 www.flyai.com,并点击注…

FlyAI资讯:收藏!深度学习必读10篇经典算法论文总结!

前言 目录 前言 1998年&#xff1a;LeNet 2012年&#xff1a;AlexNet 2014年&#xff1a;VGG 2014年&#xff1a;GoogLeNet 2015年&#xff1a;Batch Normalization 2015年&#xff1a;ResNet 2016年&#xff1a;Xception 2017年&#xff1a;MobileNet 2017年&#…

FlyAi实战之MNIST手写数字识别练习赛(准确率99.55%)

欢迎关注WX公众号&#xff1a;【程序员管小亮】 文章目录 欢迎关注WX公众号&#xff1a;【程序员管小亮】一、介绍二、代码实现1_数据加载2_归一化3_定义网络结构4_设置优化器和退火函数5_数据增强6_拟合数据7_训练轮数和批大小8_准确率和损失 三、总结 一、介绍 最近发现了一…

FlyAI图像识别类竞赛:什么蘑菇?

人工智能学习离不开实践的验证&#xff0c;推荐大家可以多在FlyAI-AI竞赛服务平台多参加训练和竞赛&#xff0c;以此来提升自己的能力。FlyAI是为AI开发者提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例&#xff0c;支持算法能力变现以及快速的…

flyai下载预训练的keras模型

进入FlyAI预训练模型地址找到需要的keras模型&#xff0c;相应链接后确定 3.得到复制后的内容 # 必须使用该方法下载模型&#xff0c;然后加载 from flyai.utils import remote_helper path remote_helper.get_remote_date("https://www.flyai.com/m/v0.8|NASNet-mobile…

FlyAI实验室使用教程【完整版】

FlyAI使用教程 文章目录 FlyAI使用教程1、FlyAI是什么&#xff1f;2、账号注册3、文件上传4、代码提交5、怎么训练 1、FlyAI是什么&#xff1f; 想知道FlyAI如何使用&#xff0c;首先你要知道FlyAI是个什么平台&#xff0c;真的蛮良心的一个平台&#xff0c;地址是&#xff1a…

FlyAI资讯:强大如GPT-3,1750亿参数也搞不定中国话

摘要&#xff1a;2019 年&#xff0c;GPT-2 凭借将近 30 亿条参数的规模拿下来“最强 NLP 模型”的称号&#xff1b;2020 年&#xff0c;1750 亿条参数的 GPT-3 震撼发布&#xff0c;并同时在工业界与学术界掀起了各种争论。随着时间的推移&#xff0c;争论的焦点也发生了变化&…

FlyAI-遥感影像场景分类预测经验总结

文章目录 数据介绍经验1. 准确率92.55%&#xff1a;SENet、PyTorch&#xff08;1&#xff09;数据预处理&#xff1a;加权采样&#xff08;2&#xff09;数据增强&#xff1a;采用随机裁剪&#xff0c;随机旋转&#xff0c;随机翻转&#xff0c;随机擦除&#xff08;3&#xff…

FlyAI平台竞赛入门记录

FlyAI算法竞赛平台官方介绍如下&#xff1a; FlyAI 是隶属于北京智能工场科技有限公司旗下&#xff0c;为AI开发者 &#xff08;深度学习&#xff09;提供数据竞赛并支持GPU离线训练的一站式服务平台。每周免费提供项目开源算法样例&#xff0c;样例所使用开发框架涉及TensorFl…

电子邮箱免费注册,比较好用的电子邮箱怎么注册?如何申请?

电子邮箱免费的很多&#xff0c;我们常用的163、TOM、QQ等&#xff0c;如果公司用&#xff0c;就要用企业电子邮箱了。申请企业电子邮箱&#xff0c;注册3年用6年&#xff0c;注册5年用10年&#xff0c;这是在网上看到TOM企业邮箱的优惠&#xff0c;以下是企业邮箱总结。 TOM企…