tensorflow 2.0+ 基于预训练BERT模型的多标签文本分类_tensorflow2 bert文本分类-程序员宅基地

技术标签: tensorflow  nlp  深度学习  人工智能  自然语言处理  bert  

在多标签分类的问题中,模型的训练集由实例组成,每个实例可以被分配多个类别,表示为一组目标标签,最终任务是准确预测测试数据的标签集。例如:

  • 文本可以同时涉及宗教、政治、金融或教育,也可以不属于其中任何一个。
  • 电影按其抽象内容可分为动作片、喜剧片和浪漫片。电影有可能属于多种类型,比如周星驰的《大话西游》,同时属于浪漫片与喜剧片。


多标签和多分类有什么区别?

在多分类中,每个样本被分配到一个且只有一个标签:水果可以是苹果或梨,但不能同时是苹果和梨。让我们考虑三个类别的例子C = [“Sun,”Moon,Cloud“]。在多分类中,每个样本只可以属于其中一个C类;在多标签中,每个样本可以属于一个或多个类。

在这里插入图片描述



数据集

在这篇文章, 我们将使用Kaggle的 Toxic Comment Classification Challenge数据集,该数据集由大量维基百科评论组成,这些评论已经被专业评估者标记为恶意行为。恶意的类型为:

toxic(恶意),severetoxic(穷凶极恶),obscene(猥琐),threat(恐吓),insult(侮辱),identityhate(种族歧视)

例:

“Hi! I am back again! Last warning! Stop undoing my edits or die!”

被标记为[1,0,0,1,0,0]。意思是它同时属于toxic 和threat。



BERT简介


2018 年 10 月,Google 发布了一种名为 BERT 的新语言表示模型, 它代表来自Transformers的双向编码器表示。BERT建立在预训练上下文表示模型—半监督序列学习、生成预训练、ELMo和ULMFit 的基础上。但是,与之前的模型不同,BERT 是第一个深度双向、无监督的语言表示形式。仅使用纯文本语料库(维基百科)进行预训练。

预训练表示可以分为无上下文模型与上下文模型:

  1. 无上下文模型(如 word2vec 或 GloVe)为词汇中的每个单词生成单个单词嵌入表示形式,例如,单词”bank“在“bank account” 和“bank of the river” 中有相同的单词嵌入表示。
  2. 相反,上下文模型生成基于句子中其他单词的每个单词的表示形式。上下文表示可以进一步区分为单向的或双向的,例如,句子“I accessed the bank account”,单向上下文模型将是基于“ I accessed the ”来表示“bank”,而不是后面的“ account账户 ”。然而,BERT同时使用它的前问和后文- “ I accessed the … account ”来表示“bank” - 从深度神经网络的底部开始,使其深度双向。

基于双向 LSTM 的语言模型会训练一个标准的从左到右的语言模型,并训练从右到左(反向)的语言模型。该模型可预测后续单词(如 ELMO 中的单词)中的先前单词,在ELMo中,前向语言模型和后向语言模型都分别是一个LSTM模型,关键的区别在于,LSTM都不会同时考虑前一个和后一个令牌。



为什么 BERT 优于其他双向模型?


直观地说,深度双向模型比从左到右模型或从左到右和从右到左模型的串联更为严格。遗憾的是,标准条件语言模型只能从左到右或从右到左进行训练,因为双向调节将允许每个单词在多层上下文中间接地“看到自己”。

为了解决这个问题,Bert使用“掩蔽”技术(MASKING)在输入中屏蔽一些单词,然后双向调节每个单词以预测被屏蔽的单词。例如:

在这里插入图片描述
在这里插入图片描述


BERT 还学会根据一个非常简单的任务对句子之间的关系进行建模, 该任务可以从任何文本语料库生成: 给定两个句子 A 和 B,B 是语料库中 A 之后的实际下一句,还是一个随机句子?例如:

在这里插入图片描述


多分类的问题我在上一篇文章中已经详细讨论过: tensorflow 2.0+ 基于BERT模型的文本分类 。本文将重点研究BERT在多标签文本分类中的应用。因此,我们只需修改相应代码,使其适合多标签方案。



使用TensorFlow 2.0+ keras API微调BERT

现在,我们需要在所有样本中应用 BERT tokenizer 。我们将token映射到词嵌入。这可以通过encode_plus完成。

def convert_example_to_feature(review):
  
  # combine step for tokenization, WordPiece vector mapping, adding special tokens as well as truncating reviews longer than the max length
    return tokenizer.encode_plus(review, 
                add_special_tokens = True, # add [CLS], [SEP]
                max_length = max_length, # max length of the text that can go to BERT
                pad_to_max_length = True, # add [PAD] tokens
                return_attention_mask = True, # add attention mask to not focus on pad tokens
                truncation=True
              )
# map to the expected input to TFBertForSequenceClassification, see here 
def map_example_to_dict(input_ids, attention_masks, token_type_ids, label):
    return {
    
      "input_ids": input_ids,
      "token_type_ids": token_type_ids,
      "attention_mask": attention_masks,
  }, label

def encode_examples(ds, limit=-1):
    # prepare list, so that we can build up final TensorFlow dataset from slices.
    input_ids_list = []
    token_type_ids_list = []
    attention_mask_list = []
    label_list = []
    if (limit > 0):
        ds = ds.take(limit)
    
    for (i, row) in enumerate(ds.values):
#     for index, row in ds.iterrows():
#         review = row["text"]
#         label = row["y"]
        review = row[1]
        label = list(row[2:])
        bert_input = convert_example_to_feature(review)
  
        input_ids_list.append(bert_input['input_ids'])
        token_type_ids_list.append(bert_input['token_type_ids'])
        attention_mask_list.append(bert_input['attention_mask'])
        label_list.append(label)
    return tf.data.Dataset.from_tensor_slices((input_ids_list, attention_mask_list, token_type_ids_list, label_list)).map(map_example_to_dict)


我们可以使用以下函数对数据集进行编码:

# train dataset
ds_train_encoded = encode_examples(train_data).shuffle(10000).batch(batch_size)
# val dataset
ds_val_encoded = encode_examples(val_data).batch(batch_size)
# test dataset
ds_test_encoded = encode_examples(test_data).batch(batch_size)

创建模型

from transformers import TFBertPreTrainedModel,TFBertMainLayer
import tensorflow as tf
from transformers.modeling_tf_utils import (
    TFQuestionAnsweringLoss,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)

class TFBertForMultilabelClassification(TFBertPreTrainedModel):

    def __init__(self, config, *inputs, **kwargs):
        super(TFBertForMultilabelClassification, self).__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels
        self.bert = TFBertMainLayer(config, name='bert')
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(config.num_labels,
                                                kernel_initializer=get_initializer(config.initializer_range),
                                                name='classifier',
                                                activation='sigmoid')#--------------------- sigmoid激活函数

    def call(self, inputs, **kwargs):
        outputs = self.bert(inputs, **kwargs)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False))
        logits = self.classifier(pooled_output)
        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        return outputs  # logits, (hidden_states), (attentions)
        

编译与训练模型

# model initialization
model = TFBertForMultilabelClassification.from_pretrained(model_path, num_labels=6)#------------6个标签
# optimizer Adam recommended
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,epsilon=1e-08, clipnorm=1)
# we do not have one-hot vectors, we can use sparce categorical cross entropy and accuracy
loss = tf.keras.losses.BinaryCrossentropy()#-----------------------------------binary_crossentropy 损失函数
metric = tf.keras.metrics.CategoricalAccuracy()
model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

# fit model
bert_history = model.fit(ds_train_encoded, epochs=number_of_epochs, validation_data=ds_val_encoded)

计算每一个标签AUC

def measure_auc(label,pred):
  auc = [roc_auc_score(label[:,i],pred[:,i]) for i in list(range(6))]
  return pd.DataFrame({
    "label_name":["toxic","severe_toxic","obscene","threat","insult","identity_hate"],"auc":auc})

pred=model.predict(ds_val_encoded)[0]#------------------------------------------------predict dataset
df_auc = measure_auc(val_data.iloc[:,2:].astype(np.float32).values,pred)
print("val set mean column auc:",df_auc)

以下是2个epochs的训练结果:

Epoch 1/2
4488/4488 [==============================] - 3922s 874ms/step - loss: 0.0500 - categorical_accuracy: 0.9701 - val_loss: 0.0388 - val_categorical_accuracy: 0.9938
Epoch 2/2
4488/4488 [==============================] - 3927s 875ms/step - loss: 0.0333 - categorical_accuracy: 0.9796 - val_loss: 0.0408 - val_categorical_accuracy: 0.9918

val set mean column auc:       label_name       auc
0          toxic  0.986974
1   severe_toxic  0.991380
2        obscene  0.992404
3         threat  0.993322
4         insult  0.988814
5  identity_hate  0.992388

可以看到,训练集正确率99.38%,验证集正确率99.18%,还有下面每一个标签的auc值

0 label_name auc
1 toxic 0.987
2 severe_toxic 0.991
3 obscene 0.992
4 threat 0.993
5 insult 0.989
6 identity_hate 0.992

由于类别严重不平衡,auc值(ROC曲线)并不能完全衡量预测效果,可以用precision-recall curve进行评估,详细请参考Precision-Recall



代码与数据


数据

链接:https://pan.baidu.com/s/17BHBSXdtJOUBG402tmWWBw
提取码:kces

bert模型

https://huggingface.co/models : bert-base-uncased > List all files in model

代码

https://github.com/NZbryan/NLP_bert/blob/master/tf2.0_bert_emb_en_MultiLabel.py



运行环境

linux: CentOS Linux release 7.6.1810

python: Python 3.6.10

packages:

tensorflow==2.3.0
transformers==3.02
pandas==1.1.0
scikit-learn==0.22.2

由于数据量较大,训练时间长,建议在GPU下运行,或者到colab去跑。



多标签分类注意事项

​ 1.不要使用softmax

​ 2.使用sigmoid函数作为最后输出层

​ 3.使用binary_crossentropy 作为损失函数

​ 4.使用predict对测试集进行评估







参考:

https://towardsdatascience.com/building-a-multi-label-text-classifier-using-bert-and-tensorflow-f188e0ecdc5d

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/xiaoniu0991/article/details/108737333

智能推荐

HTML5 Web SQL 数据库_方式准则的定义-程序员宅基地

文章浏览阅读1k次。1、HTML5 Web SQL 数据库 Web SQL 数据库 API 并不是 HTML5 规范的一部分,但是它是一个独立的规范,引入了一组使用 SQL 操作客户端数据库的 APIs。如果你是一个 Web 后端程序员,应该很容易理解 SQL 的操作。Web SQL 数据库可以在最新版的 Safari, Chrome 和 Opera 浏览器中工作。2、核心方法 以下是规范中定义的三个_方式准则的定义

spring Boot 中使用线程池异步执行多个定时任务_springboot启动后自动开启多个线程程序-程序员宅基地

文章浏览阅读4.1k次,点赞2次,收藏6次。spring Boot 中使用线程池异步执行多个定时任务在启动类中添加注解@EnableScheduling配置自定义线程池在启动类中添加注解@EnableScheduling第一步添加注解,这样才会使定时任务启动配置自定义线程池@Configurationpublic class ScheduleConfiguration implements SchedulingConfigurer..._springboot启动后自动开启多个线程程序

Maven编译打包项目 mvn clean install报错ERROR_mvn clean install有errors-程序员宅基地

文章浏览阅读1.1k次。在项目的target文件夹下把之前"mvn clean package"生成的压缩包(我的是jar包)删掉重新执行"mvn clean package"再执行"mvn clean install"即可_mvn clean install有errors

navacate连接不上mysql_navicat连接mysql失败怎么办-程序员宅基地

文章浏览阅读974次。Navicat连接mysql数据库时,不断报1405错误,下面是针对这个的解决办法:MySQL服务器正在运行,停止它。如果是作为Windows服务运行的服务器,进入计算机管理--->服务和应用程序------>服务。如果服务器不是作为服务而运行的,可能需要使用任务管理器来强制停止它。创建1个文本文件(此处命名为mysql-init.txt),并将下述命令置于单一行中:SET PASSW..._nvarchar链接不上数据库

Python的requests参数及方法_python requests 参数-程序员宅基地

文章浏览阅读2.2k次。Python的requests模块是一个常用的HTTP库,用于发送HTTP请求和处理响应。_python requests 参数

近5年典型的的APT攻击事件_2010谷歌网络被极光黑客攻击-程序员宅基地

文章浏览阅读2.7w次,点赞7次,收藏50次。APT攻击APT攻击是近几年来出现的一种高级攻击,具有难检测、持续时间长和攻击目标明确等特征。本文中,整理了近年来比较典型的几个APT攻击,并其攻击过程做了分析(为了加深自己对APT攻击的理解和学习)Google极光攻击2010年的Google Aurora(极光)攻击是一个十分著名的APT攻击。Google的一名雇员点击即时消息中的一条恶意链接,引发了一系列事件导致这个搜_2010谷歌网络被极光黑客攻击

随便推点

微信小程序api视频课程-定时器-setTimeout的使用_微信小程序 settimeout 向上层传值-程序员宅基地

文章浏览阅读1.1k次。JS代码 /** * 生命周期函数--监听页面加载 */ onLoad: function (options) { setTimeout( function(){ wx.showToast({ title: '黄菊华老师', }) },2000 ) },说明该代码只执行一次..._微信小程序 settimeout 向上层传值

uploadify2.1.4如何能使按钮显示中文-程序员宅基地

文章浏览阅读48次。uploadify2.1.4如何能使按钮显示中文博客分类:uploadify网上关于这段话的搜索恐怕是太多了。方法多也试过了不知怎么,反正不行。最终自己想办法给解决了。当然首先还是要有fla源码。直接去管网就可以下载。[url]http://www.uploadify.com/wp-content/uploads/uploadify-v2.1.4...

戴尔服务器安装VMware ESXI6.7.0教程(U盘安装)_vmware-vcsa-all-6.7.0-8169922.iso-程序员宅基地

文章浏览阅读9.6k次,点赞5次,收藏36次。戴尔服务器安装VMware ESXI6.7.0教程(U盘安装)一、前期准备1、下载镜像下载esxi6.7镜像:VMware-VMvisor-Installer-6.7.0-8169922.x86_64.iso这里推荐到戴尔官网下载,Baidu搜索“戴尔驱动下载”,选择进入官网,根据提示输入服务器型号搜索适用于该型号服务器的所有驱动下一步选择具体类型的驱动选择一项下载即可待下载完成后打开软碟通(UItraISO),在“文件”选项中打开刚才下载好的镜像文件然后选择启动_vmware-vcsa-all-6.7.0-8169922.iso

百度语音技术永久免费的语音自动转字幕介绍 -程序员宅基地

文章浏览阅读2k次。百度语音技术永久免费的语音自动转字幕介绍基于百度语音技术,识别率97%无时长限制,无文件大小限制永久免费,简单,易用,速度快支持中文,英文,粤语永久免费的语音转字幕网站: http://thinktothings.com视频介绍 https://www.bilibili.com/video/av42750807 ...

Dyninst学习笔记-程序员宅基地

文章浏览阅读7.6k次,点赞2次,收藏9次。Instrumentation是一种直接修改程序二进制文件的方法。其可以用于程序的调试,优化,安全等等。对这个词一般的翻译是“插桩”,但这更多使用于软件测试领域。【找一些相关的例子】Dyninst可以动态或静态的修改程序的二进制代码。动态修改是在目标进程运行时插入代码(dynamic binary instrumentation)。静态修改则是直接向二进制文件插入代码(static b_dyninst

在服务器上部署asp网站,部署asp网站到云服务器-程序员宅基地

文章浏览阅读2.9k次。部署asp网站到云服务器 内容精选换一换通常情况下,需要结合客户的实际业务环境和具体需求进行业务改造评估,建议您进行服务咨询。这里仅描述一些通用的策略供您参考,主要分如下几方面进行考虑:业务迁移不管您的业务是否已经上线华为云,业务迁移的策略是一致的。建议您将时延敏感型,有快速批量就近部署需求的业务迁移至IEC;保留数据量大,且需要长期稳定运行的业务在中心云上。迁移方法请参见如何计算隔离独享计算资源..._nas asp网站