在本章前面的部分中,我們?yōu)?SNLI 數(shù)據(jù)集上的自然語言推理任務(如第 16.4 節(jié)所述)設計了一個基于注意力的架構(gòu)(第16.5節(jié))。現(xiàn)在我們通過微調(diào) BERT 重新審視這個任務。正如16.6 節(jié)所討論的 ,自然語言推理是一個序列級文本對分類問題,微調(diào) BERT 只需要一個額外的基于 MLP 的架構(gòu),如圖 16.7.1所示。
圖 16.7.1本節(jié)將預訓練的 BERT 提供給基于 MLP 的自然語言推理架構(gòu)。
在本節(jié)中,我們將下載預訓練的小型 BERT 版本,然后對其進行微調(diào)以在 SNLI 數(shù)據(jù)集上進行自然語言推理。
import json import multiprocessing import os import torch from torch import nn from d2l import torch as d2l
import json import multiprocessing import os from mxnet import gluon, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np()
16.7.1。加載預訓練的 BERT
我們已經(jīng)在第 15.9 節(jié)和第 15.10 節(jié)中解釋了如何在 WikiText-2 數(shù)據(jù)集上預訓練 BERT (請注意,原始 BERT 模型是在更大的語料庫上預訓練的)。如15.10 節(jié)所述,原始 BERT 模型有數(shù)億個參數(shù)。在下文中,我們提供了兩個版本的預訓練 BERT:“bert.base”與需要大量計算資源進行微調(diào)的原始 BERT 基礎模型差不多大,而“bert.small”是一個小版本方便演示。
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip', '225d66f04cae318b841a13d32af3acc165f253ac') d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip', 'c72329e68a732bef0452e4b96a1c341c8910f81f')
d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip', '7b3820b35da691042e5d34c0971ac3edbd80d3f4') d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip', 'a4e718a47137ccd1809c9107ab4f5edd317bae2c')
預訓練的 BERT 模型都包含一個定義詞匯集的“vocab.json”文件和一個預訓練參數(shù)的“pretrained.params”文件。我們實現(xiàn)以下load_pretrained_model 函數(shù)來加載預訓練的 BERT 參數(shù)。
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len, devices): data_dir = d2l.download_extract(pretrained_model) # Define an empty vocabulary to load the predefined vocabulary vocab = d2l.Vocab() vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json'))) vocab.token_to_idx = {token: idx for idx, token in enumerate( vocab.idx_to_token)} bert = d2l.BERTModel( len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4, num_blks=2, dropout=0.2, max_len=max_len) # Load pretrained BERT parameters bert.load_state_dict(torch.load(os.path.join(data_dir, 'pretrained.params'))) return bert, vocab
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len, devices): data_dir = d2l.download_extract(pretrained_model) # Define an empty vocabulary to load the predefined vocabulary vocab = d2l.Vocab() vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json'))) vocab.token_to_idx = {token: idx for idx, token in enumerate( vocab.idx_to_token)} bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads, num_blks, dropout, max_len) # Load pretrained BERT parameters bert.load_parameters(os.path.join(data_dir, 'pretrained.params'), ctx=devices) return bert, vocab
為了便于在大多數(shù)機器上進行演示,我們將在本節(jié)中加載和微調(diào)預訓練 BERT 的小型版本(“bert.small”)。在練習中,我們將展示如何微調(diào)更大的“bert.base”以顯著提高測試準確性。
devices = d2l.try_all_gpus() bert, vocab = load_pretrained_model( 'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4, num_blks=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...
devices = d2l.try_all_gpus() bert, vocab = load_pretrained_model( 'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4, num_blks=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip...
16.7.2。微調(diào) BERT 的數(shù)據(jù)集
對于 SNLI 數(shù)據(jù)集上的下游任務自然語言推理,我們定義了一個自定義的數(shù)據(jù)集類SNLIBERTDataset。在每個示例中,前提和假設形成一對文本序列,并被打包到一個 BERT 輸入序列中,如圖 16.6.2所示。回想第 15.8.4 節(jié) ,段 ID 用于區(qū)分 BERT 輸入序列中的前提和假設。對于 BERT 輸入序列 ( max_len) 的預定義最大長度,輸入文本對中較長者的最后一個標記會不斷被刪除,直到max_len滿足為止。為了加速生成用于微調(diào) BERT 的 SNLI 數(shù)據(jù)集,我們使用 4 個工作進程并行生成訓練或測試示例。
class SNLIBERTDataset(torch.utils.data.Dataset): def __init__(self, dataset, max_len, vocab=None): all_premise_hypothesis_tokens = [[ p_tokens, h_tokens] for p_tokens, h_tokens in zip( *[d2l.tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])] self.labels = torch.tensor(dataset[2]) self.vocab = vocab self.max_len = max_len (self.all_token_ids, self.all_segments, self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens) print('read ' + str(len(self.all_token_ids)) + ' examples') def _preprocess(self, all_premise_hypothesis_tokens): pool = multiprocessing.Pool(4) # Use 4 worker processes out = pool.map(self._mp_worker, all_premise_hypothesis_tokens) all_token_ids = [ token_ids for token_ids, segments, valid_len in out] all_segments = [segments for token_ids, segments, valid_len in out] valid_lens = [valid_len for token_ids, segments, valid_len in out] return (torch.tensor(all_token_ids, dtype=torch.long), torch.tensor(all_segments, dtype=torch.long), torch.tensor(valid_lens)) def _mp_worker(self, premise_hypothesis_tokens): p_tokens, h_tokens = premise_hypothesis_tokens self._truncate_pair_of_tokens(p_tokens, h_tokens) tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens) token_ids = self.vocab[tokens] + [self.vocab['']] * (self.max_len - len(tokens)) segments = segments + [0] * (self.max_len - len(segments)) valid_len = len(tokens) return token_ids, segments, valid_len def _truncate_pair_of_tokens(self, p_tokens, h_tokens): # Reserve slots for '', '', and '' tokens for the BERT # input while len(p_tokens) + len(h_tokens) > self.max_len - 3: if len(p_tokens) > len(h_tokens): p_tokens.pop() else: h_tokens.pop() def __getitem__(self, idx): return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx]), self.labels[idx] def __len__(self): return len(self.all_token_ids)
class SNLIBERTDataset(gluon.data.Dataset): def __init__(self, dataset, max_len, vocab=None): all_premise_hypothesis_tokens = [[ p_tokens, h_tokens] for p_tokens, h_tokens in zip( *[d2l.tokenize([s.lower() for s in sentences]) for sentences in dataset[:2]])] self.labels = np.array(dataset[2]) self.vocab = vocab self.max_len = max_len (self.all_token_ids, self.all_segments, self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens) print('read ' + str(len(self.all_token_ids)) + ' examples') def _preprocess(self, all_premise_hypothesis_tokens): pool = multiprocessing.Pool(4) # Use 4 worker processes out = pool.map(self._mp_worker, all_premise_hypothesis_tokens) all_token_ids = [ token_ids for token_ids, segments, valid_len in out] all_segments = [segments for token_ids, segments, valid_len in out] valid_lens = [valid_len for token_ids, segments, valid_len in out] return (np.array(all_token_ids, dtype='int32'), np.array(all_segments, dtype='int32'), np.array(valid_lens)) def _mp_worker(self, premise_hypothesis_tokens): p_tokens, h_tokens = premise_hypothesis_tokens self._truncate_pair_of_tokens(p_tokens, h_tokens) tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens) token_ids = self.vocab[tokens] + [self.vocab['']] * (self.max_len - len(tokens)) segments = segments + [0] * (self.max_len - len(segments)) valid_len = len(tokens) return token_ids, segments, valid_len def _truncate_pair_of_tokens(self, p_tokens, h_tokens): # Reserve slots for '', '', and '' tokens for the BERT # input while len(p_tokens) + len(h_tokens) > self.max_len - 3: if len(p_tokens) > len(h_tokens): p_tokens.pop() else: h_tokens.pop() def __getitem__(self, idx): return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx]), self.labels[idx] def __len__(self): return len(self.all_token_ids)
下載 SNLI 數(shù)據(jù)集后,我們通過實例化SNLIBERTDataset類來生成訓練和測試示例。此類示例將在自然語言推理的訓練和測試期間以小批量讀取。
# Reduce `batch_size` if there is an out of memory error. In the original BERT # model, `max_len` = 512 batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers() data_dir = d2l.download_extract('SNLI') train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab) test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab) train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(test_set, batch_size, num_workers=num_workers)
read 549367 examples read 9824 examples
# Reduce `batch_size` if there is an out of memory error. In the original BERT # model, `max_len` = 512 batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers() data_dir = d2l.download_extract('SNLI') train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab) test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab) train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers) test_iter = gluon.data.DataLoader(test_set, batch_size, num_workers=num_workers)
read 549367 examples read 9824 examples
16.7.3。微調(diào) BERT
如圖16.6.2所示,為自然語言推理微調(diào) BERT 只需要一個額外的 MLP,該 MLP 由兩個完全連接的層組成(參見下一類中的self.hidden和)。該 MLP 將特殊“”標記的 BERT 表示形式(對前提和假設的信息進行編碼)轉(zhuǎn)換為自然語言推理的三個輸出:蘊含、矛盾和中性。self.outputBERTClassifier
class BERTClassifier(nn.Module): def __init__(self, bert): super(BERTClassifier, self).__init__() self.encoder = bert.encoder self.hidden = bert.hidden self.output = nn.LazyLinear(3) def forward(self, inputs): tokens_X, segments_X, valid_lens_x = inputs encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x) return self.output(self.hidden(encoded_X[:, 0, :]))
class BERTClassifier(nn.Block): def __init__(self, bert): super(BERTClassifier, self).__init__() self.encoder = bert.encoder self.hidden = bert.hidden self.output = nn.Dense(3) def forward(self, inputs): tokens_X, segments_X, valid_lens_x = inputs encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x) return self.output(self.hidden(encoded_X[:, 0, :]))
接下來,預訓練的 BERT 模型bert被輸入到 下游應用程序的BERTClassifier實例中。net在 BERT 微調(diào)的常見實現(xiàn)中,只會net.output從頭學習附加 MLP ( ) 輸出層的參數(shù)。net.encoder預訓練的 BERT 編碼器 ( ) 和附加 MLP 的隱藏層 ( )的所有參數(shù)都net.hidden將被微調(diào)。
net = BERTClassifier(bert)
net = BERTClassifier(bert) net.output.initialize(ctx=devices)
回想一下15.8 節(jié)中類MaskLM和 NextSentencePred類在它們使用的 MLP 中都有參數(shù)。這些參數(shù)是預訓練 BERT 模型中參數(shù)bert的一部分,因此也是net. 然而,這些參數(shù)僅用于計算預訓練期間的掩碼語言建模損失和下一句預測損失。MaskLM這兩個損失函數(shù)與微調(diào)下游應用程序無關(guān),因此在微調(diào) BERT 時,在和中使用的 MLP 的參數(shù)NextSentencePred不會更新(失效)。
為了允許具有陳舊梯度的參數(shù),在的函數(shù) ignore_stale_grad=True中設置了標志 。我們使用此函數(shù)使用SNLI 的訓練集 ( ) 和測試集 ( )來訓練和評估模型。由于計算資源有限,訓練和測試的準確性可以進一步提高:我們將其討論留在練習中。stepd2l.train_batch_ch13nettrain_itertest_iter
lr, num_epochs = 1e-4, 5 trainer = torch.optim.Adam(net.parameters(), lr=lr) loss = nn.CrossEntropyLoss(reduction='none') net(next(iter(train_iter))[0]) d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
loss 0.519, train acc 0.791, test acc 0.782 9226.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]
lr, num_epochs = 1e-4, 5 trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr}) loss = gluon.loss.SoftmaxCrossEntropyLoss() d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices, d2l.split_batch_multi_inputs)
loss 0.477, train acc 0.810, test acc 0.785 4626.9 examples/sec on [gpu(0), gpu(1)]
16.7.4。概括
我們可以為下游應用微調(diào)預訓練的 BERT 模型,例如 SNLI 數(shù)據(jù)集上的自然語言推理。
在微調(diào)期間,BERT 模型成為下游應用模型的一部分。僅與預訓練損失相關(guān)的參數(shù)在微調(diào)期間不會更新。
16.7.5。練習
如果您的計算資源允許,微調(diào)一個更大的預訓練 BERT 模型,該模型與原始 BERT 基礎模型差不多大。將函數(shù)中的參數(shù)設置load_pretrained_model為:將“bert.small”替換為“bert.base”,將 、 、 和 的值分別增加到 num_hiddens=256768、3072、12ffn_num_hiddens=512和num_heads=412 num_blks=2。通過增加微調(diào)周期(并可能調(diào)整其他超參數(shù)),您能否獲得高于 0.86 的測試精度?
如何根據(jù)長度比截斷一對序列?比較這對截斷方法和類中使用的方法 SNLIBERTDataset。他們的優(yōu)缺點是什么?
-
自然語言
+關(guān)注
關(guān)注
1文章
291瀏覽量
13587 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13742
發(fā)布評論請先 登錄
如何開始使用PyTorch進行自然語言處理
python自然語言
自然語言處理怎么最快入門_自然語言處理知識了解
自然語言入門之ESIM

PyTorch教程16.4之自然語言推理和數(shù)據(jù)集

PyTorch教程16.5之自然語言推理:使用注意力

PyTorch教程16.6之針對序列級和令牌級應用程序微調(diào)BERT

PyTorch教程16.7之自然語言推理:微調(diào)BERT

PyTorch教程-16.4。自然語言推理和數(shù)據(jù)集
PyTorch教程-16.5。自然語言推理:使用注意力
PyTorch教程-16.6. 針對序列級和令牌級應用程序微調(diào) BERT
自然語言處理的概念和應用 自然語言處理屬于人工智能嗎
ChatGPT是一個好的因果推理器嗎?

評論