多标签分类任务
总览¶
在网络世界中,有效过滤和管理恶意评论是一项挑战。Toxic Comment Classification Challenge 项目致力于利用深度学习技术解决这一难题。该项目源于 Kaggle 竞赛,旨在开发能够精准识别在线对话中有毒评论的算法。项目利用深度学习技术和预训练的词向量,对评论进行多类别分类,例如识别威胁、淫秽色情、侮辱和基于身份的仇恨言论等,从而帮助减少网络环境中的不良影响。
数据¶
代码¶
Python | |
---|---|
数据集加载¶
Python | |
---|---|
使用 train_test_split
方法划分出训练集和测试集。
Python | |
---|---|
数据预处理¶
Python | |
---|---|
Python | |
---|---|
Python | |
---|---|
Python | |
---|---|
Python | |
---|---|
train_and_test | |
---|---|
多标签分类和单标签分类在数据集格式上大同小异
- 在单标签分类中,单个样本的
labels
为一个整数 - 在多标签分类中,单个样本的
labels
为类别个数长度个0或1组成列表。其中1代表该样本属于该类,否则不属于该类。
预训练模型¶
Python | |
---|---|
Python | |
---|---|
如果任务类型为多任务标签,那么需要加载预训练模型时指定 problem_type
为 multi_label_classification
。
Note
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized:
['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
- 将网络参数都冻结,
- 将
classifier.requires_grad
设置为True
只允许分类头参与梯度更新。
评价指标¶
Python | |
---|---|
训练¶
Python | |
---|---|
Python | |
---|---|
Python | |
---|---|
Python | |
---|---|
模型训练结果¶
Epoch | Training Loss | Validation Loss | Accuracy | F1 |
---|---|---|---|---|
1 | 0.046600 | 0.042239 | 0.984312 | 0.882476 |
2 | 0.036300 | 0.041644 | 0.984067 | 0.881441 |
3 | 0.029900 | 0.043375 | 0.983816 | 0.883072 |
4 | 0.024800 | 0.048329 | 0.983503 | 0.882226 |
5 | 0.021100 | 0.049663 | 0.983153 | 0.881117 |
推理¶
Python | |
---|---|
Python | |
---|---|
经过后续的数据集格式整理,提交评测后: