與我們大多數(shù)從頭開(kāi)始的實(shí)施一樣, 第 9.5 節(jié)旨在深入了解每個(gè)組件的工作原理。但是,當(dāng)您每天使用 RNN 或編寫(xiě)生產(chǎn)代碼時(shí),您會(huì)希望更多地依賴(lài)于減少實(shí)現(xiàn)時(shí)間(通過(guò)為通用模型和函數(shù)提供庫(kù)代碼)和計(jì)算時(shí)間(通過(guò)優(yōu)化這些庫(kù)實(shí)現(xiàn))。本節(jié)將向您展示如何使用深度學(xué)習(xí)框架提供的高級(jí) API 更有效地實(shí)現(xiàn)相同的語(yǔ)言模型。和以前一樣,我們首先加載時(shí)間機(jī)器數(shù)據(jù)集。
import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l
from mxnet import np, npx from mxnet.gluon import nn, rnn from d2l import mxnet as d2l npx.set_np()
from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf from d2l import tensorflow as d2l
9.6.1. 定義模型
我們使用由高級(jí) API 實(shí)現(xiàn)的 RNN 定義以下類(lèi)。
class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_inputs, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = nn.RNN(num_inputs, num_hiddens) def forward(self, inputs, H=None): return self.rnn(inputs, H)
Specifically, to initialize the hidden state, we invoke the member method begin_state. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.
class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = rnn.RNN(num_hiddens) def forward(self, inputs, H=None): if H is None: H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx) outputs, (H, ) = self.rnn(inputs, (H, )) return outputs, H
Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen API.
class RNN(nn.Module): #@save """The RNN model implemented with high-level APIs.""" num_hiddens: int @nn.compact def __call__(self, inputs, H=None): raise NotImplementedError
class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = tf.keras.layers.SimpleRNN( num_hiddens, return_sequences=True, return_state=True, time_major=True) def forward(self, inputs, H=None): outputs, H = self.rnn(inputs, H) return outputs, H
繼承自9.5 節(jié)RNNLMScratch中的類(lèi) ,下面的類(lèi)定義了一個(gè)完整的基于 RNN 的語(yǔ)言模型。請(qǐng)注意,我們需要?jiǎng)?chuàng)建一個(gè)單獨(dú)的全連接輸出層。RNNLM
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = nn.LazyLinear(self.vocab_size) def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = nn.Dense(self.vocab_size, flatten=False) self.initialize() def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" training: bool = True def setup(self): self.linear = nn.Dense(self.vocab_size) def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1) def forward(self, X, state=None): embs = self.one_hot(X) rnn_outputs, _ = self.rnn(embs, state, self.training) return self.output_layer(rnn_outputs)
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = tf.keras.layers.Dense(self.vocab_size) def output_layer(self, hiddens): return tf.transpose(self.linear(hiddens), (1, 0, 2))
9.6.2. 訓(xùn)練和預(yù)測(cè)
在訓(xùn)練模型之前,讓我們使用隨機(jī)權(quán)重初始化的模型進(jìn)行預(yù)測(cè)。鑒于我們還沒(méi)有訓(xùn)練網(wǎng)絡(luò),它會(huì)產(chǎn)生無(wú)意義的預(yù)測(cè)。
data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab)
'it hasgggggggggggggggggggg'
data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab)
'it hasxlxlxlxlxlxlxlxlxlxl'
data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab)
'it hasnvjdtagwbcsxvcjwuyby'
接下來(lái),我們利用高級(jí) API 訓(xùn)練我們的模型。
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1) trainer.fit(model, data)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1) trainer.fit(model, data)
with d2l.try_gpu(): trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1) trainer.fit(model, data)
與第 9.5 節(jié)相比,該模型實(shí)現(xiàn)了相當(dāng)?shù)睦Щ蠖龋捎趯?shí)現(xiàn)優(yōu)化,運(yùn)行速度更快。和以前一樣,我們可以在指定的前綴字符串之后生成預(yù)測(cè)標(biāo)記。
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has and the time trave '
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has and the thi baid th'
model.predict('it has', 20, data.vocab)
'it has our in the time tim'
9.6.3. 概括
深度學(xué)習(xí)框架中的高級(jí) API 提供標(biāo)準(zhǔn) RNN 的實(shí)現(xiàn)。這些庫(kù)可幫助您避免浪費(fèi)時(shí)間重新實(shí)現(xiàn)標(biāo)準(zhǔn)模型。此外,框架實(shí)施通常經(jīng)過(guò)高度優(yōu)化,與從頭開(kāi)始實(shí)施相比,可顯著提高(計(jì)算)性能。
9.6.4. 練習(xí)
您能否使用高級(jí) API 使 RNN 模型過(guò)擬合?
使用 RNN實(shí)現(xiàn)第 9.1 節(jié)的自回歸模型。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4814瀏覽量
103601 -
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13960
發(fā)布評(píng)論請(qǐng)先 登錄
PyTorch教程之從零開(kāi)始的遞歸神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)

PyTorch教程9.6之遞歸神經(jīng)網(wǎng)絡(luò)的簡(jiǎn)潔實(shí)現(xiàn)

PyTorch教程10.3之深度遞歸神經(jīng)網(wǎng)絡(luò)

PyTorch教程10.4之雙向遞歸神經(jīng)網(wǎng)絡(luò)

PyTorch教程16.2之情感分析:使用遞歸神經(jīng)網(wǎng)絡(luò)

使用PyTorch構(gòu)建神經(jīng)網(wǎng)絡(luò)
遞歸神經(jīng)網(wǎng)絡(luò)是循環(huán)神經(jīng)網(wǎng)絡(luò)嗎
遞歸神經(jīng)網(wǎng)絡(luò)主要應(yīng)用于哪種類(lèi)型數(shù)據(jù)
遞歸神經(jīng)網(wǎng)絡(luò)與循環(huán)神經(jīng)網(wǎng)絡(luò)一樣嗎
遞歸神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)形式主要分為
rnn是遞歸神經(jīng)網(wǎng)絡(luò)還是循環(huán)神經(jīng)網(wǎng)絡(luò)
PyTorch神經(jīng)網(wǎng)絡(luò)模型構(gòu)建過(guò)程
遞歸神經(jīng)網(wǎng)絡(luò)的實(shí)現(xiàn)方法
遞歸神經(jīng)網(wǎng)絡(luò)和循環(huán)神經(jīng)網(wǎng)絡(luò)的模型結(jié)構(gòu)

評(píng)論