在深度學(xué)習(xí)中,我們經(jīng)常使用 CNN 或 RNN 對序列進行編碼。現(xiàn)在考慮到注意力機制,想象一下將一系列標(biāo)記輸入注意力機制,這樣在每個步驟中,每個標(biāo)記都有自己的查詢、鍵和值。在這里,當(dāng)在下一層計算令牌表示的值時,令牌可以(通過其查詢向量)參與每個其他令牌(基于它們的鍵向量進行匹配)。使用完整的查詢鍵兼容性分數(shù)集,我們可以通過在其他標(biāo)記上構(gòu)建適當(dāng)?shù)募訖?quán)和來為每個標(biāo)記計算表示。因為每個標(biāo)記都關(guān)注另一個標(biāo)記(不同于解碼器步驟關(guān)注編碼器步驟的情況),這種架構(gòu)通常被描述為自注意力模型 (Lin等。, 2017 年, Vaswani等人。, 2017 ),以及其他地方描述的內(nèi)部注意力模型 ( Cheng et al. , 2016 , Parikh et al. , 2016 , Paulus et al. , 2017 )。在本節(jié)中,我們將討論使用自注意力的序列編碼,包括使用序列順序的附加信息。
11.6.1。自注意力
給定一系列輸入標(biāo)記 x1,…,xn任何地方 xi∈Rd(1≤i≤n), 它的self-attention輸出一個相同長度的序列 y1,…,yn, 在哪里
根據(jù) (11.1.1)中attention pooling的定義。使用多頭注意力,以下代碼片段計算具有形狀(批量大小、時間步數(shù)或標(biāo)記中的序列長度, d). 輸出張量具有相同的形狀。
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, valid_lens = 2, 4, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
11.6.2。比較 CNN、RNN 和自注意力
讓我們比較一下映射一系列的架構(gòu)n標(biāo)記到另一個等長序列,其中每個輸入或輸出標(biāo)記由一個d維向量。具體來說,我們將考慮 CNN、RNN 和自注意力。我們將比較它們的計算復(fù)雜度、順序操作和最大路徑長度。請注意,順序操作會阻止并行計算,而序列位置的任意組合之間的較短路徑可以更容易地學(xué)習(xí)序列內(nèi)的遠程依賴關(guān)系 (Hochreiter等人,2001 年)。
圖 11.6.1比較 CNN(省略填充標(biāo)記)、RNN 和自注意力架構(gòu)。
考慮一個卷積層,其內(nèi)核大小為k. 我們將在后面的章節(jié)中提供有關(guān)使用 CNN 進行序列處理的更多詳細信息。現(xiàn)在,我們只需要知道,因為序列長度是n,輸入和輸出通道的數(shù)量都是 d, 卷積層的計算復(fù)雜度為 O(knd2). 如圖11.6.1 所示,CNN 是分層的,因此有O(1) 順序操作和最大路徑長度是
評論