本文介绍了如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在做一个长文本分类任务,该任务在doc中有10000个以上的单词,我计划使用Bert作为段落编码器,然后将段落的嵌入内容逐步导入BiLSTM.网络如下:

I am doing a long text classification task, which has more than 10000 words in doc, I am planing to use Bert as a paragraph encoder, then feed the embeddings of paragraph to BiLSTM step by step.The network is as below:

伯特层:(max_paragraph_len,paragraph_embedding_size)

bert layer: (max_paragraph_len,paragraph_embedding_size)

lstm层:???

lstm layer: ???

输出层:(batch_size,classification_size)

output layer: (batch_size,classification_size)

如何用keras实施它?我正在使用keras的load_trained_model_from_checkpoint来加载bert模型

How to implement it with keras?I am using keras's load_trained_model_from_checkpoint to load bert model

bert_model = load_trained_model_from_checkpoint(
        config_path,
        model_path,
        training=False,
        use_adapter=True,
        trainable=['Encoder-{}-MultiHeadSelfAttention-Adapter'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-FeedForward-Adapter'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-MultiHeadSelfAttention-Norm'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-FeedForward-Norm'.format(i + 1) for i in range(layer_num)],
        )

推荐答案

我相信您可以检查以下文章.作者展示了如何加载预训练的BERT模型,如何将其嵌入Keras层以及如何将其用于定制的Deep Neural Network.首先安装google-research/bert的TensorFlow 2.0 Keras实现:

I believe you can check the following article. The author shows how to load a pre-trained BERT model, embed it into a Keras layer and use it into a customized Deep Neural Network.First install TensorFlow 2.0 Keras implementation of google-research/bert:

pip install bert-for-tf2

然后运行:

import bert
import os

def createBertLayer():
    global bert_layer

    bertDir = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")

    bert_params = bert.params_from_pretrained_ckpt(bertDir)

    bert_layer = bert.BertModelLayer.from_params(bert_params, name="bert")

    bert_layer.apply_adapter_freeze()

def loadBertCheckpoint():

    modelsFolder = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")
    checkpointName = os.path.join(modelsFolder, "bert_model.ckpt")

    bert.load_stock_weights(bert_layer, checkpointName)

这篇关于如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-29 08:46