女人自慰AV免费观看内涵网,日韩国产剧情在线观看网址,神马电影网特片网,最新一级电影欧美,在线观看亚洲欧美日韩,黄色视频在线播放免费观看,ABO涨奶期羡澄,第一导航fulione,美女主播操b

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

降低Transformer復(fù)雜度O(N^2)的方法匯總

新機(jī)器視覺(jué) ? 來(lái)源:極市平臺(tái) ? 2023-12-04 15:31 ? 次閱讀

導(dǎo)讀

文章總結(jié)了降低Transformer模型復(fù)雜度的方法,包括Softmax Attention的計(jì)算復(fù)雜度、稀疏Attention方法等。

Transformer最重要的特性是Global Interaction,也就是說(shuō)對(duì)于任意兩個(gè)位置的token(不論它們離的有多遠(yuǎn)),它們之間都能直接進(jìn)行信息交互。這個(gè)特性解決了傳統(tǒng)序列建模中長(zhǎng)依賴的問(wèn)題。

但Transformer也有一個(gè)典型問(wèn)題:它的計(jì)算復(fù)雜度和空間復(fù)雜度均為 , 其中 為序列長(zhǎng)度。

因此實(shí)際應(yīng)用中很難將Transformer應(yīng)用到長(zhǎng)序列任務(wù)上,如包數(shù)萬(wàn)個(gè)token的論文閱讀、書(shū)籍閱讀等任務(wù)。

解決Transformer計(jì)算復(fù)雜度的方法多種多樣。本文介紹其中最主流、最常見(jiàn)的一些方法。

Note:

為簡(jiǎn)化,本文不單獨(dú)討論multi-head的情況。大多數(shù)方法都可以平移到到multi-head中。

本文主要討論Transformer的Decoder。通常Encoder和Decoder的唯一區(qū)別是Encoder中當(dāng)前token可以attend到左邊和右邊的其它token,而Decoder中當(dāng)前token只能attend到左邊token。所以本文介紹的這些方法都可以輕易地?cái)U(kuò)展到Encoder中。

1. Transformer的計(jì)算復(fù)雜度

首先來(lái)詳細(xì)說(shuō)明為什么Transformer的計(jì)算復(fù)雜度是 。將Transformer中標(biāo)準(zhǔn)的Attention稱為Softmax Attention。令 為長(zhǎng)度為 的序列, 其維度為 , 。 可看作Softmax Attention的輸入。

Softmax Attention首先使用線性變換將輸入 變換為Query、Key和Value:

(1)

(2)

(3)

其中 和 都是待訓(xùn)練的參數(shù)矩陣; 是 和 的維度; 是 的維度。由此可得 的shape分別為:

(4)

(5)

(6)

在常見(jiàn)的Transformer中, 通常 。因此為了簡(jiǎn)化符號(hào), 我們假設(shè)后文中 , 并且只用符號(hào) (Dimension)。

有了Q、K、V, Softmax Attention(SA)的計(jì)算如下:

(7)

容易看到,Softmax Attention的計(jì)算主要包含兩次矩陣乘法操作。

首先回憶一下矩陣乘法的計(jì)算復(fù)雜度。對(duì)于矩陣 和 , 它們的矩陣乘法共需要 次乘法運(yùn)算??梢阅脟?guó)內(nèi)線性代數(shù)教材使用最多的計(jì)算方法來(lái)理解:為了計(jì)算這兩個(gè)矩陣的乘積, 需要拿矩陣 的每一行去與矩陣 的每一列做點(diǎn)積。因此總共需要 次點(diǎn)積。每次點(diǎn)積包含 次乘法和 次加法。考慮到加法復(fù)雜度遠(yuǎn)小于乘法, 所以總的計(jì)算復(fù)雜度就是 。

這個(gè) 可以使用兩種方法理解:

第一種理解方法, 因?yàn)榧臃◤?fù)雜度遠(yuǎn)小于乘法, 所以忽略加法, 那么 計(jì)算復(fù)雜度中的base operator指的是乘法操作。

第二種理解方法, 因?yàn)?與 的量級(jí)一致, 所以 計(jì)算復(fù)雜度中的base operator 指的是乘加操作 (乘法和加法) 。

回到Transformer的復(fù)雜度問(wèn)題上,前面提到Softmax Attention的計(jì)算主要包含兩次矩陣乘法操作。

第一次矩陣乘法是 , 結(jié)合上文關(guān)于矩陣乘法復(fù)雜度的結(jié)論和這兩個(gè)矩陣的大?。ü?(4)和公式(5)),可知 的復(fù)雜度為 。

第二次矩陣乘法是 sof tmax 的結(jié)果與 的乘積。sof tmax 輸出的矩陣大小為 , 矩陣 的大小為 (公式(6), 前文假設(shè)了 ), 所以這一次矩陣乘法的復(fù)雜度為 。

因?yàn)檫@兩次矩陣乘法是順序執(zhí)行的, 所以總的復(fù)雜度為它們各自復(fù)雜度之和。因?yàn)檫@兩個(gè)復(fù)雜度相等, 相加只是引入了一個(gè)常數(shù)項(xiàng), 所以可以忽略, 因此Softmax Attention總的復(fù)雜度就為

當(dāng)我們只關(guān)心復(fù)雜度與序列長(zhǎng)度 之間的關(guān)系時(shí), 可以忽略 并將其寫(xiě)為 。

這就是通常說(shuō)的Transformer計(jì)算復(fù)雜度隨序列長(zhǎng)度呈二次方增長(zhǎng)的由來(lái)。容易看到,Transformer的空間復(fù)雜隨序列長(zhǎng)度也呈二次方增長(zhǎng),即空間復(fù)雜度也為 。

這一節(jié)最后,我們用一幅簡(jiǎn)單的圖來(lái)說(shuō)明Softmax Attention中參與每個(gè)token的Attention Score計(jì)算的其它token的位置(只考慮Decoder)。該圖主要是為了與后文的一些其它復(fù)雜方法作對(duì)比。

e1e4cb64-91f0-11ee-939d-92fbcf53809c.jpg

圖1 Softmax Attention中參與每個(gè)token的Attention Score計(jì)算的其它token的位置

這幅圖按如下方法理解:行和列都表示位置;藍(lán)色表示當(dāng)前token,綠色表示參與當(dāng)前token計(jì)算的其它token的位置。

例如,圖中有12行,可以看作該示例中序列長(zhǎng)度為12。以第二行為例,它表示對(duì)于第二個(gè)位置的token(藍(lán)色位置,當(dāng)前token),只有第一個(gè)位置的token會(huì)參與它Attention Score的計(jì)算。這其實(shí)就是Transformer中Decoder采用的方式:只能看當(dāng)前token左邊的token。

為了簡(jiǎn)化表述,后文會(huì)使用如下方式來(lái)表述:第二行中,第二個(gè)token只能attend到第一個(gè)token。

同理,在第三行中,第三個(gè)token可以attend到第一個(gè)和第二個(gè)token。

以此類推。

同時(shí),也會(huì)采用被動(dòng)表述。例如,在第二行中,第一個(gè)token被attended到第二個(gè)token。此時(shí),第一個(gè)token也可以被稱為attended token。

2. Sparse Attention

再看一次圖1中的Softmax Attention,容易看到對(duì)于每一個(gè)token,它都會(huì)attend到它前面的所有token。所以通常說(shuō)Softmax Attention是密集的(dense)。

與密集相對(duì)的就是稀疏(Sparse)了。Sparse Attention的主要思路是減少每個(gè)token需要attend的token數(shù)量。

比如,Softmax Attention對(duì)于每個(gè)token都要attend它之前的所有token。那么為了減少計(jì)算量,能不能只去attend之前的部分token?

2.1 Factorized Self-Attention (Sparse Transformer)

Paper:Generating Long Sequences with Sparse Transformers (2019)

Key Contribution:提出了兩種稀疏Attention方法:Strided Attention和Fixed Attention。這二者均可將Transformer的 復(fù)雜度降低至 。

Factorized Self-Attention的一個(gè)基礎(chǔ)假設(shè)是:在Softmax Attention中,真正為目標(biāo)token提供信息的attended token非常少。

換言之,該假設(shè)意味著:對(duì)于Softmax Attention,在經(jīng)softmax得到的Attention Weights中,其中大部分的值都趨于0,只有少數(shù)值明顯大于0。因此Attention Weight比較稀疏。

論文作者將Transformer用到了圖像自回歸任務(wù)中來(lái)表明他們假設(shè)的合理性,如圖2.1.1所示(圖2.1.1不容易懂,看后文解釋)。

e1e9e694-91f0-11ee-939d-92fbcf53809c.jpg

圖2.1.1 Softmax Attention中Weight Vector稀疏性示意圖

解釋一下圖2.1.1。作者們用了128層的Transformer在CIFRA-10上做自回歸訓(xùn)練。自回歸訓(xùn)練是逐行逐像素來(lái)做的。

以圖a)中左上方的紅色汽車(chē)圖為例,圖中黑色區(qū)域(下方)是mask。模型下一步需要去預(yù)測(cè)mask中的第一個(gè)點(diǎn)。所謂第一個(gè)點(diǎn),就是逐行看,看到的一個(gè)mask黑色點(diǎn)。圖中白色區(qū)域是Attention Weights。可以看到,有效的Attention Weights幾乎全部集中在當(dāng)前待預(yù)測(cè)點(diǎn)周?chē)?。所以此時(shí)的Attention Weights很像卷積的局部性。同時(shí)它也很稀疏,因?yàn)锳ttention Weights在較遠(yuǎn)的位置幾乎全為0。

圖2中a、b、c、d是來(lái)自不同網(wǎng)絡(luò)層的Attention Weights??梢钥吹?,雖然Attention Weights表現(xiàn)出的空間規(guī)律有所差異,但它們總體上都很稀疏:只有極少部分的位置被有效attend(Attention Weights明顯大于0,即圖中白色區(qū)域)。

基于這種稀疏性,作者們提出了兩種Attention方法。

注:針對(duì)這篇paper的Attention方法,本文不列具體公式。這是因?yàn)椋@些方法其實(shí)都非常簡(jiǎn)單,但公式反而繁瑣、不直觀。

第一種方法稱為Strided Attention。它又由兩種Attention機(jī)制構(gòu)成,我們把它們分別記為SA1和SA2(原文沒(méi)有這種命名法,這里只是為了指代方便):

SA1: 每個(gè)token只能Attend它左邊相鄰的L個(gè)token。

SA2:每個(gè)token只能Attend它左邊部分token,這些attened token用如下方法選出:從自己開(kāi)始往左邊數(shù),每隔L就會(huì)有一個(gè)token可以attend(參見(jiàn)圖2.1.3,比較直觀)。

為便于理解,請(qǐng)參見(jiàn)圖2.1.2和圖2.1.3,我們假設(shè)L=3。

e1f54584-91f0-11ee-939d-92fbcf53809c.jpg

圖2.1.2 Strided Attention的SA1。圖中每個(gè)token只能attend到它左邊相鄰的L個(gè)token,圖中L=3。

e1fbe48e-91f0-11ee-939d-92fbcf53809c.jpg

圖2.1.3 Strided Attention的SA2,圖中L=3。

圖2.1.2中的SA1很容易理解,每一個(gè)當(dāng)前token(每一行的藍(lán)色區(qū)域)只能attend到它左邊的L個(gè)token,圖中L=3。圖2.1.3中的SA2稍微復(fù)雜一點(diǎn),從自己開(kāi)始往左邊數(shù),每隔L就會(huì)有一個(gè)token可以attend。比如圖2.1.3中最后一行,從當(dāng)前token(藍(lán)色區(qū)域)開(kāi)始往左邊數(shù),相隔L個(gè)空格(3個(gè)空格)處遇到第一個(gè)綠色方塊可以attend(最后一行,第8列),然后再往左數(shù)L個(gè)(3個(gè))空格,遇到第二個(gè)綠色方塊可以attend(最后一行,第4列),以此類推。

Strided Attention的SA1方法和SA2方法的本質(zhì)是在選擇哪些token可以attend。

然后我們來(lái)看這兩種Attention方法怎么用在Transformer結(jié)構(gòu)中。有三種方法:

交替使用。在第1個(gè)Transformer Block中使用SA1,然后在第2個(gè)Transformer Block中使用SA2,然后在第3個(gè)Transformer Block中又使用SA1,在第4個(gè)Transformer Block中又使用SA2,以此類推。這種方法能work的原因是:雖然SA1只能看左邊的L個(gè)相鄰位置,但可以認(rèn)為在SA1中,每個(gè)token聚合了它左邊L個(gè)token的信息。因此在SA2,雖然它是跳著L個(gè)位置看的,但整體感受野等價(jià)于整個(gè)序列(因?yàn)槊總€(gè)attended token聚合了其左邊L個(gè)token的信息)。

聯(lián)合使用。將SA1選擇的attended token和SA2選擇的attended token合在一起使用。這個(gè)方法很簡(jiǎn)單,就是在計(jì)算Attention時(shí),首先用SA1去選擇一些token,再用SA2去選擇一些token,然后計(jì)算Attention時(shí)只使用選擇出的token參與計(jì)算即可。

多頭使用。類似Transformer采用的多頭機(jī)制,這里每個(gè)頭可以使用SA1、SA2或Transformer中的Softmax Attention。

然后來(lái)看 的選擇。只要將 的值設(shè)為 , 那么容易看到整個(gè)Strided Attention的計(jì)算復(fù)雜度就是 。雖然這個(gè)做法很不自然, 但是它確實(shí)能實(shí)現(xiàn) 的復(fù)雜度。

至此,我們介紹完了Strided Attention。

作者們提出的第二種Attention稱為Fixed Attention。Fixed Attention也有兩種機(jī)制,將它們分別稱為FA1和FA2。為了便于理解,需要把這兩種機(jī)制畫(huà)到一個(gè)圖里,如圖2.1.4所示。

e206a478-91f0-11ee-939d-92fbcf53809c.jpg

圖2.1.4 Fixed Attention中的FA1(綠色)和FA2(橙色),L=3

先看FA2,如圖中橙色區(qū)域。橙色區(qū)域的位置是固定的,即從左往右數(shù),每隔L個(gè)位置,選中一個(gè)token。

理解了FA2,F(xiàn)A1的選擇方式就會(huì)容易理解了。對(duì)于每個(gè)當(dāng)前token(藍(lán)色),往它左邊遍歷(綠色),直到遇到第一個(gè)FA2選中的token(橙色)。

Fixed Attention的使用方法和上文介紹的Strided Attention的三種方法一致(交替使用、聯(lián)合使用、多頭使用),不再贅述。

作者們的結(jié)論:Strided Attention適用于圖像、音頻;Fixed Attention適用于文本。

理由如下:Strided Attention在attended token的位置上做了強(qiáng)假設(shè):哪些位置的token應(yīng)該被attened,與當(dāng)前token位置強(qiáng)相關(guān)。作者們認(rèn)為這種適合圖像、音頻這類數(shù)據(jù)。而在文本上這類假設(shè)不成立。所以在Fixed Attention中,哪些位置的token應(yīng)該被attened,與當(dāng)前token位置無(wú)關(guān)。

講的再簡(jiǎn)單點(diǎn),圖像、音頻的局部信息很重要;而文本全局信息更重要。

總結(jié):paper對(duì)新手不友好,簡(jiǎn)單的事情用了公式來(lái)解釋,非常繁瑣。希望本文能比原文容易理解一點(diǎn)。

2.2 Blockwise Self-Attention

Paper:Blockwise Self-Attention for Long Document Understanding (2019)

Key Contribution:通過(guò)分塊來(lái)降低Softmax Attention的計(jì)算復(fù)雜度,方法簡(jiǎn)單,且實(shí)驗(yàn)效果較好。

前文提到了Transformer的時(shí)間復(fù)雜度和空間復(fù)雜度都為。Blockwise Self-Attention這篇Paper對(duì)空間復(fù)雜度做了更細(xì)致的分析。

一個(gè)模型的Memory Usage主要來(lái)自三部分:Model Memory、Optimizer Memory、Activation Memory。按照Transformer模型通常使用的Adam類優(yōu)化器來(lái)看,Optimizer Memory是Model Memory的三倍。這是因?yàn)镺ptimizer Memory需要為每個(gè)參數(shù)存儲(chǔ)梯度、first momentum和second momentum。

Model Memory和Optimizer Memory可以直接計(jì)算出來(lái)。比如對(duì)于Model Memory,可以直接通過(guò)模型大小與參數(shù)類型(如FP16、FP32、INT8)來(lái)推算出精確值。同理,Optimizer Memory也可以精確計(jì)算出。而Activation Memory則與具體實(shí)現(xiàn)相關(guān)。所以在Paper中,作者們用PyTorch的內(nèi)存分析工具來(lái)看訓(xùn)練時(shí)總的內(nèi)存開(kāi)銷(xiāo),然后減去Model Memory和Optimizer Memory,以此來(lái)估算Activation Memory。

作者們以BERT-base為例,分析了Model Memory、Optimizer Memory、Activation Memory三者的占比,其中Activation Memory獨(dú)占87.6%,屬于內(nèi)存開(kāi)銷(xiāo)最大的部分。畫(huà)一幅圖來(lái)總結(jié)上面提到的內(nèi)容(注意圖中的memory usage的比例是針對(duì)BERT-base而言的):

e212fc8c-91f0-11ee-939d-92fbcf53809c.jpg

圖2.2.1 BERT-base中內(nèi)存分布示意圖

我們說(shuō)的空間復(fù)雜度 主要指的就是Activation Memory這一部分。因?yàn)镸odel Memory和Optimizer Memory是線性復(fù)雜度 。

Blockwise Self-Attention的核心思想非常簡(jiǎn)單:將一個(gè)長(zhǎng)度為N的序列,平均分成n個(gè)短序列。當(dāng)原始序列長(zhǎng)度N無(wú)法被n除盡時(shí),對(duì)原始序列進(jìn)行padding,使它能被除盡。舉一個(gè)例子來(lái)說(shuō)明Blockwise Self-Attention的計(jì)算過(guò)程。

假設(shè)序列長(zhǎng)度 , 每個(gè)token的維度為 。在Transformer中, Q、K、V三個(gè)矩陣的大小都為 。在Blockwise Self-Attention中, 假設(shè)分塊數(shù) , 那么每個(gè)分塊中的序列長(zhǎng)度為 。所以輸入序列 可以劃分為 個(gè)子序列: 、 , 它們的大小都為 。同理可以把 ( 同理) 劃分成 個(gè)子矩陣:、 , 它們的大小也都為 。在計(jì)算Self-Attention時(shí), 每個(gè) 會(huì)去選擇一個(gè) 和 來(lái)計(jì)算:(2.2.1)

在只有一個(gè)Attention頭的情況下, 選擇 和 的方法是:shifting one position。很簡(jiǎn)單, 選擇 和 選擇 和 選擇 和 。換言之, 始終選下一個(gè) 和 ; 當(dāng) 是最后一個(gè)block時(shí), 選擇 和 。這個(gè)過(guò)程可以用取余數(shù)的符號(hào)寫(xiě)出來(lái), 但看著太繁瑣, 所以文字描述了。

多頭Attention情況下稍微麻煩一點(diǎn)。我們記序列 為單頭Attention情況下每個(gè) 對(duì)應(yīng)的 和 的編號(hào) :(2.2.2)

仍以上面的示例為例, 在單頭情況下, 的值為:(2.2.3)

它表示, 對(duì)應(yīng)的 的值是2, 對(duì)應(yīng)的 的值是 對(duì)應(yīng)的 的值是1。

在多頭情況下, 第 個(gè)頭的 定義如下:(2.2.4)

例如, 按照上述示例, 第一個(gè)頭的 為:(2.2.5)

第二個(gè)頭的 為:(2.2.6)

因?yàn)榉謮K數(shù) , 所以需要取余數(shù)(注意下標(biāo)從1開(kāi)始, 所以余 0 時(shí)替換為 即可),得到最終的結(jié)果:(2.2.7)

過(guò)程其實(shí)很簡(jiǎn)單, 只是寫(xiě)出來(lái)稍微麻煩一點(diǎn)。

最后來(lái)分析復(fù)雜度。由本文第一部分分析Transformer復(fù)雜度的結(jié)論可知, 公式(2.2.1)中的復(fù)雜度為 。因?yàn)閷?duì)每一個(gè)分塊, 都需要用公式 (2.2.1) 進(jìn)行計(jì)算, 所以總復(fù)雜度為:(2.2.8)

這既是計(jì)算復(fù)雜度,也是空間復(fù)雜度。

在原文中, 通常選為2。注意, 在大 計(jì)法中一般會(huì)忽略掉常數(shù)項(xiàng)。所以在這種意義下, Blockwise Self-Attention的復(fù)雜度仍為 。

但是大 計(jì)法的主要目的是理論分析, 并不為實(shí)際工程優(yōu)化。所以即使在大 意義下復(fù)雜度沒(méi)有變, 但它實(shí)際計(jì)算量仍然減少了。沒(méi)有改變的 仍然意味著,Blockwise Self-Attention不能擴(kuò)展到太大的 上,這就是大 計(jì)法的作用。

具體來(lái)看, 當(dāng) 時(shí), RoBERTa的訓(xùn)練時(shí)間由原來(lái)的9.7天減少至7.5天。

總結(jié):相比于Sparse Transformer中的Factorized Self-Attention,Blockwise Self-Attention更簡(jiǎn)單,且從效果上來(lái)看,優(yōu)于Factorized Self-Attention。

2.3 Longformer

paper:Longformer: The Long-Document Transformer (2020)

Key Contribution:設(shè)計(jì)了多種不同的Local Attention和Global Attention方法。

首先重新看一下Factorized Self-Attention (2.1小節(jié))中的兩種Attention方法:Strided Attention和Fixed Attention。在Strided Attention中,又有兩種Attention機(jī)制,在前文中我們把它們分別稱為SA1和SA2(參考圖2.1.2和2.1.3)。SA1的作用是Local Interaction,而SA2的作用是Global Interation。類似的,在Fixed Attention中(參考圖2.1.4),F(xiàn)A1的作用是Local Interaction,而FA2的作用是Global Interation。

在Factorized Self-Attention中,它主要依靠?jī)深怉ttention的組合使用來(lái)實(shí)現(xiàn)長(zhǎng)距離依賴,例如SA1+SA2(或FA1+FA2)。

Longformer的核心idea和Factorized Self-Attention很像,只是Longformer中的部分Attention只有Local Interaction,沒(méi)有Global Interaction。

Longformer一共提出了三種Attention,分別是SlidingWindow basedAttention(SW-Attention)、DilatedSlidingWindow basedAttention(DSW-Attention)和GlobalAttention(G-Attention)。下面分別介紹。

先看SlidingWindow basedAttention(SW-Attention),它其實(shí)和Strided Attention中的SA1完全一樣。為了方便大家查看,重新把Strided Attention的SA1圖copy一份到此處。

e1f54584-91f0-11ee-939d-92fbcf53809c.jpg

圖2.3.1 SW-Attention示意圖,它和Strided Attention中的SA1完全一樣。圖中L=3

SW-Attention只Attend它左邊的L個(gè)token。在SW-Attention中,L被稱為“窗口大小”,而在Strided Attention中,L被稱為“步長(zhǎng)(Stride)”,它們本質(zhì)一樣。

實(shí)際上我們可以在Transformer中只使用SW-Attention來(lái)構(gòu)建具有Global Interaction的網(wǎng)絡(luò)。其方法很簡(jiǎn)單,只需要堆疊多個(gè)SW-Attention網(wǎng)絡(luò)層即可,就如同CNN增大感受野的方式。假設(shè)窗口大小為K,一個(gè)M層的SW-Attention結(jié)構(gòu)中最頂層的“感受野”大小為KM,如圖2.3.2所示。

e21c8a90-91f0-11ee-939d-92fbcf53809c.jpg

圖2.3.2 基于SW-Attention構(gòu)建的Transformer的Global Interaction示意圖

圖中綠色方塊表示當(dāng)前token;藍(lán)色線表示信息流;每一層是上一層的輸入;L假設(shè)為2。在第一層中,第一個(gè)token和第二個(gè)token的信息會(huì)流入第二層中的第三個(gè)token。而在第二層中,第二個(gè)token和第三個(gè)token的信息會(huì)流入下一層中的第四個(gè)token,以此類推。在最頂層(第五層),雖然當(dāng)前token的信息只來(lái)自上一層的第四個(gè)token和第五個(gè)token,但從信息流的角度來(lái)看,它也隱含包含第一層中第一個(gè)和第二個(gè)token的信息。

可以看到,通過(guò)堆疊SW-Attention,Transformer也可以像CNN一樣增加感受野。但是很容易想到,這種非?!伴g接”的方法不會(huì)有太好效果,就像CNN對(duì)長(zhǎng)依賴建模的能力比較差一樣。

再來(lái)看DilatedSlidingWindow basedAttention(DSW-Attention),它其實(shí)和Strided Attention中的SA2完全一樣。為了方便大家查看,重新把Strided Attention的SA2圖copy一份到此處。

e220ba66-91f0-11ee-939d-92fbcf53809c.jpg

圖2.3.3 DSW-Attention示意圖,它和Strided Attention中的SA2完全一樣。圖中Dilation=3。

DSW-Attention是“空洞”版的SW-Attention,就像空洞卷積和卷積之間的關(guān)系。簡(jiǎn)單來(lái)說(shuō),被attended的token不再像SW-Attention中是連續(xù)排列的,而是按等間距排列(間距稱為“空洞率”,在圖2.3.3中為3)。

與SW-Attention類似,通過(guò)堆疊多個(gè)DSW-Attention也能增大網(wǎng)絡(luò)的感受野,從而實(shí)現(xiàn)Global Interaction。

最后再來(lái)看GlobalAttention(G-Attention)。G-Attention是SW-Attention的改進(jìn)版,它的主要改動(dòng)是:在SW-Attention基礎(chǔ)上,增加了部分固定位置,使得這些位置的token需要 1)attend到其它所有token;2)被其它位置tokenattend到。如圖2.3.4所示。

e2280ea6-91f0-11ee-939d-92fbcf53809c.jpg

圖2.3.4 G-Attention示意圖,L=3

圖中綠色token是SW-Attention會(huì)attend到的token。橙色token是在G-Attention中額外選中的token。以第五行的當(dāng)前token為例(橙色),因?yàn)樗潜活~外選中的token,所以它會(huì)attend它左邊的所有token。圖中用黃色標(biāo)出了相對(duì)于SW-Attention之外的額外被attended的token。此外,其它所有token也需要attend到第五個(gè)token,參見(jiàn)圖中最后四行中的靠左黃色列。

圖中第7行類似,大家可以自行對(duì)照?qǐng)D腦補(bǔ)一下這個(gè)過(guò)程。

在G-Attention中,哪些位置會(huì)被額外選中與具體下游任務(wù)相關(guān)。例如,在分類任務(wù)中,[CLS] token會(huì)被額外選中(Longformer一文中以RoBERTa為基礎(chǔ),將其中的Attention改為本文提到的Attention中的一種或多種);在問(wèn)答任務(wù)中,所有問(wèn)題的token都會(huì)被額外選中。

此外,G-Attention中有兩份不同的QKV,一份用于計(jì)算由SW-Attention選中的token(圖2.3.4中的綠色token),另一份用于計(jì)算由G-Attention額外選中的token(圖2.3.4中的黃色token)。

上述提到的三種Attention的復(fù)雜度都為 , 因?yàn)槟男﹖oken會(huì)被attend與序列長(zhǎng)度 無(wú)關(guān)。

2.4 Local attention and Memory-compressed attention

Paper: Generating wikipedia by summarizing long sequences (2018)

Key Contribution: 提出了Local Attention和Memory-compressed attention。Local Attention的計(jì)算復(fù)雜度隨序列長(zhǎng)度增長(zhǎng)呈線性增長(zhǎng);Memory-compressed attention可以將計(jì)算復(fù)雜度減少固定常數(shù)倍(超參控制)。

2.4.1 Local Attention

前文中的2.3節(jié)也有一個(gè)Local Attention,但與此處的Local Attention方法不同。

此處Local Attention的核心思想是使用一個(gè)固定的分塊大小n對(duì)輸入序列進(jìn)行分塊,并限制self-attention的計(jì)算只能在各個(gè)分塊內(nèi)單獨(dú)進(jìn)行,如圖2.4.1所示。

e2357a50-91f0-11ee-939d-92fbcf53809c.jpg

圖2.4.1 Local Attention的模式圖。圖中假設(shè)序列長(zhǎng)度N=12,分塊大小n=3。

在圖2.4.1中,每個(gè)位置的token只能attend到與它同顏色的其它token。例如圖中第五行(紅色標(biāo)注行),它表示在Decoder結(jié)構(gòu)中,對(duì)于輸入序列中的第5個(gè)token的attention模式:第五行的灰色區(qū)域表示mask,這些mask表示Decoder結(jié)構(gòu)中不能看到當(dāng)前token之后的信息;前五個(gè)token根據(jù)顏色進(jìn)行分塊,每個(gè)token只能attend到同分塊(同顏色)中的其它token,所以對(duì)于當(dāng)前token而言(第五個(gè)token),它只能attend到第四個(gè)token和它自己(綠色部分)。

作為對(duì)比,標(biāo)準(zhǔn)的self-attention的模式圖如下:

e239905e-91f0-11ee-939d-92fbcf53809c.jpg

圖2.4.2 標(biāo)準(zhǔn)self-attention的模式圖。圖中假設(shè)序列長(zhǎng)度N=12。

標(biāo)準(zhǔn)的Decoder結(jié)構(gòu)中,只有一個(gè)限制:所有token都不能attend到當(dāng)前序列之后的token。

Local Attention與2.2節(jié)介紹的Blockwise Self-Attention比較類似,其核心思想都是對(duì)輸入序列進(jìn)行分塊。Local Attention與Blockwise Self-Attention唯一的區(qū)別是:Local Attention將Self-attention的計(jì)算限制在組內(nèi);而B(niǎo)lockwise Self-Attention將Self-attention的計(jì)算限制在組間。

例如,考慮圖2.4.1中的最后一行,在最簡(jiǎn)單的情況下,Blockwise Self-Attention的attention模式為:每個(gè)分塊的token只能attend到下一個(gè)分塊(藍(lán)色token只能attend到綠色token;綠色token只能attend到橙色token;橙色token只能attend到黃色token;黃色token只能attend到藍(lán)色token)。

下面分析一下Local Attention的復(fù)雜度。Local Attention通常選擇一個(gè)固定長(zhǎng)度的分塊大小n(例如 )。假設(shè)總的序列長(zhǎng)度為 , 那么分塊數(shù)量為 。每一個(gè)分塊的復(fù)雜度為 個(gè)分塊的總復(fù)雜度為 。因?yàn)?為常數(shù)項(xiàng), 所以Local Attention的復(fù)雜度隨序列長(zhǎng)度 呈線性增長(zhǎng) 。

但在2.3節(jié)中, 曾分析到Blockwise Self-Attention的復(fù)雜度是 。為何兩個(gè)如此相似的方法復(fù)雜度卻有顯著差異?

在Blockwise Self-Attention中, 的含義不是分塊大小, 而是分塊數(shù)量。所以每個(gè)分塊的大小就為 。那么每個(gè)分塊的attention計(jì)算復(fù)雜度就是 。又因一共有 個(gè)分塊, 所以總復(fù)雜度是 。

這里之所以把這兩個(gè)復(fù)雜度拿出來(lái)對(duì)比,是想說(shuō)明:小心對(duì)待復(fù)雜性分析中的常量。不同視角可能會(huì)導(dǎo)致不同的分析結(jié)果。

復(fù)雜度唯一能體現(xiàn)的僅僅是:計(jì)算量與變量之間的關(guān)系。

在上面的例子中,我們關(guān)心的變量是序列長(zhǎng)度N,所以直接忽略了常數(shù)項(xiàng)n。但如果我們要比較這兩個(gè)復(fù)雜度所對(duì)應(yīng)的計(jì)算量時(shí),常數(shù)項(xiàng)不能輕易忽略。

2.4.2 Memory-compressed Attention

在通常的基于Transformer的模型中,我們使用不同的線性變換來(lái)將輸入序列x映射為Q、K、V。這三者的尺寸通常一樣(維度一樣,長(zhǎng)度也一樣)。

Memory-compressed Attention的思路是使用額外的卷積來(lái)降低K和V的序列長(zhǎng)度,這樣整體Self-Attention的計(jì)算量就降低了。這樣的卷積很容易實(shí)現(xiàn)。假設(shè)輸入序列長(zhǎng)度為N,維度為D,且K和V的尺寸都為[N, D]。我們只需要使用一個(gè)步長(zhǎng)大于1的卷積,讓它沿著序列長(zhǎng)度維度進(jìn)行滑動(dòng)即可,如下圖所示。

e23f37a2-91f0-11ee-939d-92fbcf53809c.jpg

圖2.4.3 Query與Key的矩陣乘積示意圖。

圖中上部分表示標(biāo)準(zhǔn)Query與Key的計(jì)算示意圖。在Memory-compressed Attention中, 首先使用一個(gè)沿著序列長(zhǎng)度維度滑動(dòng)的卷積對(duì)該維度進(jìn)行下采樣,得到一個(gè)更小的Key的矩陣,如圖中下部分所示。假設(shè)下采樣后的序列長(zhǎng)度為 , 可知此時(shí)矩陣乘法的復(fù)雜度為 , 而標(biāo)準(zhǔn)Q、K計(jì)算的復(fù)雜度為 。一般來(lái)說(shuō), L在量級(jí)上并不會(huì)與N有明顯差異, 所以Memory-compressed Attention雖然能降低計(jì)算量, 但并不能顯著降低復(fù)雜度。

2.5 Reformer

paper: Reformer: the efficient Transformer

Key contribution: 1) 提出了LSH-attention, 能夠?qū)ransformer的復(fù)雜度由 降低至 ;2) 將Transformer中的跳躍連接改為了“可逆跳躍連接", 這樣在網(wǎng)絡(luò)的前向過(guò)程中不用為后續(xù)的梯度計(jì)算存儲(chǔ)激活值, 能夠極大降低訓(xùn)練過(guò)程的存儲(chǔ)開(kāi)銷(xiāo)。

從最原始的研究動(dòng)機(jī)來(lái)看,Reformer主要考慮的是:降低基于Transformer的模型在訓(xùn)練階段的存儲(chǔ)開(kāi)銷(xiāo)。

神經(jīng)網(wǎng)絡(luò)在訓(xùn)練過(guò)程中,最大的存儲(chǔ)開(kāi)銷(xiāo)主要來(lái)自兩方面。一是網(wǎng)絡(luò)參數(shù)本身的存儲(chǔ)開(kāi)銷(xiāo);二是整個(gè)前向過(guò)程中產(chǎn)生的激活值。由存儲(chǔ)激活值而導(dǎo)致的開(kāi)銷(xiāo)只會(huì)在訓(xùn)練階段產(chǎn)生,因?yàn)橛?xùn)練中為了計(jì)算每一層的梯度,需要用到當(dāng)前層的激活值。而在推理階段,因?yàn)椴恍枰偻ㄟ^(guò)梯度信息來(lái)更新網(wǎng)絡(luò),所以自然也就不用存儲(chǔ)每一層的激活值了。

基于這兩個(gè)部分,先來(lái)看一下基于標(biāo)準(zhǔn)Self-attention的單層Transformer的所涉及的存儲(chǔ)開(kāi)銷(xiāo):

以當(dāng)時(shí)(2020年)最大的單層Transformer為例,它的參數(shù)量是0.5B。每個(gè)參數(shù)32位,也就是4Byte,所以總的內(nèi)存開(kāi)銷(xiāo)就是2GB。

假設(shè)輸入序列長(zhǎng)度為 , embedding大小為 , batch size為 8 , 那么單個(gè)self-attention激活值所占的存儲(chǔ)開(kāi)銷(xiāo)是 。同理, 每個(gè)激活值也是4Byte, 所以總的內(nèi)存開(kāi)銷(xiāo)也是 。

上述兩點(diǎn)涉及到的存儲(chǔ)開(kāi)銷(xiāo)并不大,加起來(lái)一共也才4GB。但實(shí)際上除了這兩點(diǎn),就單層Transformer而言,它還包含另外兩點(diǎn)最大的開(kāi)銷(xiāo):

在標(biāo)準(zhǔn)Transformer結(jié)構(gòu)中,除了self-attention部分,在其后還有兩個(gè)全連層。兩個(gè)全連層的激活值的數(shù)量加起來(lái)通常遠(yuǎn)大于self-attention的激活值。例如,在標(biāo)準(zhǔn)Transformer中,第一個(gè)全連層的激活值數(shù)量是self-attention的四倍,第二個(gè)全連層與self-attention相同。那么兩個(gè)全連層總的激活值就是self-attention的五倍。按照上面第二點(diǎn)中的計(jì)算方法,兩個(gè)全連層總的存儲(chǔ)開(kāi)銷(xiāo)就是10GB。

self-attention的計(jì)算中包含 的矩陣乘法計(jì)算, 它的計(jì)算復(fù)雜度和空間復(fù)雜度都是 。例如, 當(dāng)輸入序列長(zhǎng)度為 時(shí), 的輸出矩陣大小為 , 內(nèi)存消耗約 16GB。

上述還僅僅是單層網(wǎng)絡(luò)的開(kāi)銷(xiāo)(注:上述也并沒(méi)有計(jì)算完單層所有的激活開(kāi)銷(xiāo)),對(duì)于一個(gè)N層的Transformer,這個(gè)開(kāi)銷(xiāo)還得乘以N倍。Reformer主要采用了兩種方法來(lái)降低整體存儲(chǔ)開(kāi)銷(xiāo),分別是LSH-attention和“可逆跳躍連接”。

2.5.1 Locality-Sensitive Hashing Attention(LSH-attention)

e24813d6-91f0-11ee-939d-92fbcf53809c.png

圖2.5.1 在Self-attention中(左圖),當(dāng)前token可以attend到它之前(包括自身)的所有token;LSH-attention中(右圖),當(dāng)前token只attend到部分“重要”的token。

在標(biāo)準(zhǔn)的self-attention中(以decoder為例),每一個(gè)位置的token可以attend到它之前的所有token(包括它自己)。但實(shí)際上,因?yàn)閟oftmax主要由較大的那些值所主導(dǎo),所以由softmax輸出的weight vector中可能會(huì)比較稀疏。也就是說(shuō),很多位置的權(quán)重很小,只有少部分位置的權(quán)重較大。因此,一個(gè)自然的想法是不是只要找到那些產(chǎn)生較大權(quán)重的token即可,而不用讓所有token都參與計(jì)算?這個(gè)想法的示意如圖2.5.1所示。

在self-attention的計(jì)算中, 與當(dāng)前Query越相似的Key, 它們點(diǎn)乘的值也會(huì)更大, 從而產(chǎn)生的權(quán)重也更大。為了后續(xù)描述方便, 對(duì)于某個(gè)token , 它對(duì)應(yīng)的Query記為 。如果某個(gè)token 所對(duì)應(yīng)的Key 能與 產(chǎn)生較大的點(diǎn)乘結(jié)果, 我們就說(shuō)“token 對(duì)于token 是重要的”。

這里我們并沒(méi)有定義“較大的點(diǎn)乘結(jié)果”究竟是多大。這可以認(rèn)為是具體策略問(wèn)題, 只要使用的策略能夠區(qū)分"較大"和"不較大"就行。

LSH-attention的核心思路是, 對(duì)于當(dāng)前token , 找到對(duì)它“重要的”所有token集合 , 并限制 在self-attention的計(jì)算中只能attend到集合 中的token。簡(jiǎn)而言之, 對(duì)于當(dāng)前token , 我們希望知道哪些 與 的點(diǎn)乘會(huì)比較大。這些 對(duì)于的token 就構(gòu)成了集合 。

尋找集合 中token的本質(zhì)是相似度問(wèn)題。計(jì)算兩個(gè)向量相似度最簡(jiǎn)單的辦法是計(jì)算它們的余弦相似度。但這不可取, 因?yàn)槲覀兊哪康木褪菫榱吮苊庥?jì)算和一些可能導(dǎo)致低權(quán)重的token的點(diǎn)乘來(lái)降低計(jì)算量。但如果為了找到它們又需要計(jì)算余弦相似度, 而余弦相似度的計(jì)算又包含點(diǎn)乘, 那么后續(xù)節(jié)約的計(jì)算實(shí)際上預(yù)先發(fā)生了, 所以這樣沒(méi)意義。

真正要解決的問(wèn)題是找到一種高效的計(jì)算方法來(lái)判斷兩個(gè)向量是否相似。LSH-attention采用的方法是Locality sensitive hashing(局部敏感哈希)。

一個(gè)“局部敏感”的哈希算法指的是非常相似的向量具有相同的哈希值。LSH-attention使用的方法如圖2.5.2所示。

e25335fe-91f0-11ee-939d-92fbcf53809c.jpg

圖2.5.2 Locality sensitive hashing示意圖。圖來(lái)自原論文。

解釋一下圖2.5.2。先看圖的上半部分。假設(shè)有相距較遠(yuǎn)的兩個(gè)點(diǎn)x和y,首先把它們投影到一個(gè)圓上(高維空間中對(duì)應(yīng)超球面)。然后用一個(gè)隨機(jī)的旋轉(zhuǎn)將圓上的兩個(gè)投影點(diǎn)進(jìn)行旋轉(zhuǎn),并記錄它們落在的區(qū)域編號(hào)。圖中區(qū)域由四個(gè)不同顏色的三角形區(qū)域構(gòu)成,從右沿著逆時(shí)針?lè)较蚓幪?hào)為0、1、2、3。所以在第一次旋轉(zhuǎn)后(對(duì)應(yīng)于圖中Random Rotation 0),x落在的區(qū)域0,y落在區(qū)域3。然后再次隨機(jī)旋轉(zhuǎn),并記錄第二次旋轉(zhuǎn)后落在的區(qū)域。圖中一共進(jìn)行了三次旋轉(zhuǎn),x分別落在區(qū)域0、2、1,因此它的哈希值就是021。而y三次分別落在區(qū)域3、2、0,所以它的哈希值是320。兩個(gè)哈希值021和320不同,那么認(rèn)為x和y不相似。

圖2.5.2中的下半部分中用了兩個(gè)更接近的點(diǎn)作為示例,不再展開(kāi)解釋了。LSH方法直覺(jué)上非常簡(jiǎn)單,它也有一些高效的實(shí)現(xiàn)方法。這里簡(jiǎn)單提一個(gè)要點(diǎn):判斷一個(gè)點(diǎn)落在哪個(gè)區(qū)域可以通過(guò)argmax操作實(shí)現(xiàn)(這實(shí)際上也同時(shí)隱含地確定了空間劃分方法,但解釋起來(lái)相對(duì)麻煩,故此文不展開(kāi))。

在二維平面中,如果一個(gè)點(diǎn)的坐標(biāo)是[x, y] (與上面例子中的x、y無(wú)關(guān)),我們可以把它擴(kuò)展成一個(gè)四維向量[x, y, -x, -y]。然后對(duì)這個(gè)向量使用argmax,也就是最大值對(duì)應(yīng)的索引。這個(gè)索引編號(hào)就是點(diǎn)[x, y]對(duì)應(yīng)的區(qū)域。

要證明這點(diǎn)只需要注意空間的劃分是依靠y=x和y=-x這兩條線實(shí)現(xiàn)的即可完成。

基于LSH,整個(gè)LSH-attention的計(jì)算可由下圖描述。

e25a634c-91f0-11ee-939d-92fbcf53809c.jpg

圖2.5.3 LSH-attention計(jì)算示意圖

圖2.5.3從上至下來(lái)解釋。

圖中第一行。在LSH-attention中,Query和Key是相同的,這和標(biāo)準(zhǔn)self-attention有所區(qū)別。

圖中第二行。使用LSH Hashing將token進(jìn)行分組,具有相同Hash值的token被分為同一組(相同顏色表示)。

圖中第三行。按照分組對(duì)token進(jìn)行重排序。同組中的token按照它們?cè)谠夹蛄兄械奈恢眠M(jìn)行排序:越靠后的排在越后面。

圖中第四行。按照固定長(zhǎng)度對(duì)重排后的序列進(jìn)行分塊。分塊的目的主要是為了并行化。

圖中第五行。每個(gè)token只能attend到同組(同顏色)之前的token。如果某個(gè)組被分成了多塊,那么后一塊中的token只能attend到前一個(gè)塊中同組的token(如果一個(gè)組被分成了三個(gè)塊,最后一個(gè)塊中的token不能attend到第一個(gè)塊中的token,即使它們是同組的)。

LSH-attention的優(yōu)勢(shì)在于它降低了每一個(gè)token可以attend到的token數(shù)量。原論文中沒(méi)有詳細(xì)分析為什么LSH-attention的復(fù)雜度是 。從LSH-attention的形式來(lái)看, 它的復(fù)雜度介于 和 之間。

LSH-attention主要解決了前文中提到標(biāo)準(zhǔn)self-attention開(kāi)銷(xiāo)中的第四點(diǎn): self-attention的計(jì)算中包含 的矩陣乘法計(jì)算, 它的計(jì)算復(fù)雜度和空間復(fù)雜度都是 。例如, 當(dāng)輸入序列長(zhǎng)度為 時(shí), 的輸出矩陣大小為 , 內(nèi)存消耗約 。

2.5.2 Reversible Transformer

因?yàn)榛诜聪騻鞑サ奶荻扔?jì)算需要用到網(wǎng)絡(luò)前向過(guò)程產(chǎn)生的激活值,所以在訓(xùn)練過(guò)程中必須將這些激活值存儲(chǔ)起來(lái)。對(duì)于較大的模型而言,這些激活值造成的存儲(chǔ)開(kāi)銷(xiāo)相當(dāng)巨大。

一種樸素的解決方案是利用checkpoint。在每次反向計(jì)算過(guò)程中,當(dāng)需要層i的激活值時(shí),使用上一次的checkpoint進(jìn)行一次前向計(jì)算,直到層i,然后取激活值。雖然基于checkpoint方法能在存儲(chǔ)不足時(shí)讓模型跑起來(lái),但增加了太多額外計(jì)算量。

另一種方法是讓網(wǎng)絡(luò)變得“可逆”。也就是說(shuō),我們可以由后一層的激活值來(lái)推出前一層的激活值?;谶@種方法的一個(gè)經(jīng)典工作是RevNet,它讓ResNet變得可逆。

Reversible Transformer基本照搬了RevNet的思想。在整個(gè)前向過(guò)程中,網(wǎng)絡(luò)始終處理兩個(gè)序列 和 :(2.5.1) (2.5.2) FeedForward

輸出 和 構(gòu)成下一層的輸入。對(duì)于網(wǎng)絡(luò)輸入層, 和 可由兩個(gè)線性層變換得到。對(duì)于任意一層, 當(dāng)知道它的輸出 和 時(shí), 利用公式 (2.5.2) 可以恢復(fù)出 :(2.5.3) FeedForward

代價(jià)是需要重新計(jì)算一次 FeedForward 。

當(dāng)恢復(fù)出 后, 可以用公式(2.5.1)再恢復(fù)出 :(2.5.4)

代價(jià)是需要重新計(jì)算一次 。

如果整個(gè)網(wǎng)絡(luò)使用的激活函數(shù)也是可逆的,那么在前向過(guò)程中不需要存儲(chǔ)任何激活值。

Reformer的論文中沒(méi)有講用的激活函數(shù)是什么。在一些開(kāi)源實(shí)現(xiàn)中有使用Gelu的,也有使用ReLU的。它們都不是可逆的激活函數(shù)。

所以對(duì)于這些激活函數(shù)而言,它們之前的輸入仍需要存儲(chǔ),因?yàn)閱慰考せ詈瘮?shù)的輸出無(wú)法恢復(fù)出輸入。

Reversible Transformer可以解決前文提到的第二點(diǎn)和第三點(diǎn):

假設(shè)輸入序列長(zhǎng)度為 , embedding大小為 , batch size為 8 , 那么單個(gè)self-attention激活值所占的存儲(chǔ)開(kāi)銷(xiāo)是 。同理, 每個(gè)激活值也是4Byte, 所以總的內(nèi)存開(kāi)銷(xiāo)也是2GB。

在標(biāo)準(zhǔn)Transformer結(jié)構(gòu)中,除了self-attention部分,在其后還有兩個(gè)全連層。兩個(gè)全連層的激活值的數(shù)量加起來(lái)通常遠(yuǎn)大于self-attention的激活值。例如,在標(biāo)準(zhǔn)Transformer中,第一個(gè)全連層的激活值數(shù)量是self-attention的四倍,第二個(gè)全連層與self-attention相同。那么兩個(gè)全連層總的激活值就是self-attention的五倍。按照上面第二點(diǎn)中的計(jì)算方法,兩個(gè)全連層總的存儲(chǔ)開(kāi)銷(xiāo)就是10GB。

2.6 Adaptive Attention

paper:Adaptive Attention Span in Transformers

Key contribution:提出了一種對(duì)不同attention head自適應(yīng)選擇attention長(zhǎng)度的方法。

在標(biāo)準(zhǔn)self-attention中,不同attention head的attention模式完全一樣,即每一個(gè)token能attend到它之前的token(包含自己)。Adaptive Attention的假設(shè)是:不同head可以具有不同的attention模式,比如有的head可能更關(guān)注較近的token,有些head可能會(huì)更注重遠(yuǎn)距離依賴,所以可以通過(guò)學(xué)習(xí)來(lái)讓不同head自適應(yīng)調(diào)整head可以attend到的token長(zhǎng)度。

這個(gè)思路與2.5.1節(jié)中介紹的LSH-attention的相似點(diǎn)是都在嘗試選擇部分token來(lái)attend,以減少參與計(jì)算的總token數(shù)。區(qū)別在于,adaptive attention選擇的是一個(gè)連續(xù)的子序列,而LSH-attention沒(méi)有這個(gè)要求,如圖2.6.1所示。

e2652a7a-91f0-11ee-939d-92fbcf53809c.png

圖2.6.1 LSH-attention v.s. Adaptive-attention

可以把a(bǔ)daptive attention理解為是在選一個(gè)“距離”:最遠(yuǎn)可attend到的token離當(dāng)前token的距離。只要這個(gè)距離確定了,那么可以被attend到的token就被確定了。

Adaptive attention的實(shí)現(xiàn)方法是為每一個(gè)attend到的token再加一個(gè)soft mask。下面詳細(xì)介紹。

在標(biāo)準(zhǔn)的self-attention中, 記當(dāng)前token為 , 它的Query和Key分別記為 和 。對(duì)于某個(gè)目標(biāo)token , 它對(duì) 的權(quán)重由如下公式計(jì)算得到:

Adaptive attention中,在計(jì)算公式(2.6.1) 中的權(quán)重時(shí),會(huì)為每一個(gè)位置再加上一個(gè)soft mask:

簡(jiǎn)單解釋一下。 是一個(gè)mask函數(shù), 它的輸入是“距離”, 輸出是一個(gè)0到1的值。例如公式 (2.6.2) 的分子中, 是當(dāng)前token與目標(biāo)token之間的距離。mask函數(shù) 根據(jù)這個(gè)距離計(jì)算出對(duì)應(yīng)的mask值。

我們說(shuō)這個(gè)mask是soft的, 是因?yàn)樗妮敵霾⒉皇?或1, 而是0到1, 這與平時(shí)的hard mask有所區(qū)別。

在Adaptive attention中, mask函數(shù) 定義為:

其中 是超參, 是需要學(xué)習(xí)的參數(shù)。公式很不優(yōu)雅, 可以借助mask函數(shù) 的圖像來(lái)理解:

e26a2eee-91f0-11ee-939d-92fbcf53809c.jpg

圖2.6.1 mask函數(shù)的圖像

模型會(huì)自動(dòng)學(xué)習(xí)到一個(gè)合適的 , 這個(gè) 可理解為“有效token的距離”。與當(dāng)前token距離 的所有token的soft mask值都為 1 , 表明它們都是有效的token。

超參數(shù) 表示一個(gè)“soft距離"。在距離 至 的范圍內(nèi), soft mask由1線性衰減到0, 表示它們的重要性逐漸降低。超過(guò) 距離的token的soft mask值為 0 , 表示它們?yōu)闊o(wú)效token。

對(duì)于multi-head attention中的每個(gè)head, 都需要單獨(dú)為它訓(xùn)練一個(gè) 。所以不同head可以attend 到的token距離也就不同。

Adaptive attention的最核心思路就是這樣。下面再介紹兩個(gè)其它細(xì)節(jié)。

首先,Adaptive attention采用的是相對(duì)位置編碼的方法,所以公式(2.6.2)需要更改為:

(2.6.3)

其中 表示與當(dāng)前token距離為 的位置編碼, 它是相對(duì)的, 且直接靠學(xué)習(xí)得到。相對(duì)位置編碼雖然簡(jiǎn)單, 但在后續(xù)的很多其它改進(jìn)版的Transformer結(jié)構(gòu)中應(yīng)用非常廣泛。

其次, Adaptive attention還提出了一種更復(fù)雜的參數(shù)化 的方法。在公式 (2.6.3) 中, 直接作為一個(gè)可學(xué)習(xí)參數(shù)參與訓(xùn)練, 由模型直接優(yōu)化。對(duì) 的一種新的參數(shù)化方法如下:

其中 表示當(dāng)前token; 表示最遠(yuǎn)可attend到的token距離, 可設(shè)置為一個(gè)期望的值, 或直接讓它等于 和 是可學(xué)習(xí)的參數(shù); 是sigmoid函數(shù)。

公式 (2.6.4) 的含義是為每一個(gè)當(dāng)前token單獨(dú)計(jì)算它的 , 而不僅僅是為每一個(gè)head計(jì)算一個(gè)

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 矩陣
    +關(guān)注

    關(guān)注

    0

    文章

    429

    瀏覽量

    35006
  • 線性
    +關(guān)注

    關(guān)注

    0

    文章

    200

    瀏覽量

    25505
  • Transformer
    +關(guān)注

    關(guān)注

    0

    文章

    148

    瀏覽量

    6376

原文標(biāo)題:降低Transformer復(fù)雜度O(N^2)的方法匯總

文章出處:【微信號(hào):vision263com,微信公眾號(hào):新機(jī)器視覺(jué)】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦
    熱點(diǎn)推薦

    業(yè)務(wù)復(fù)雜度治理方法論--十年系統(tǒng)設(shè)計(jì)經(jīng)驗(yàn)總結(jié)

    一、復(fù)雜度綜述 1、什么是復(fù)雜度 軟件設(shè)計(jì)的核心在于降低復(fù)雜性。 --《軟件設(shè)計(jì)的哲學(xué)》 業(yè)界對(duì)于復(fù)雜度并沒(méi)有統(tǒng)一的定義, 斯坦福教授Joh
    的頭像 發(fā)表于 09-05 14:11 ?1230次閱讀
    業(yè)務(wù)<b class='flag-5'>復(fù)雜度</b>治理<b class='flag-5'>方法</b>論--十年系統(tǒng)設(shè)計(jì)經(jīng)驗(yàn)總結(jié)

    時(shí)間復(fù)雜度O(n^2) 的排序算法

    作者:京東保險(xiǎn) 王奕龍 對(duì)于小規(guī)模數(shù)據(jù),我們可以選用時(shí)間復(fù)雜度O(n2) 的排序算法。因?yàn)闀r(shí)間復(fù)雜度并不代表實(shí)際代碼的執(zhí)行時(shí)間,它省去了低階、系數(shù)和常數(shù),僅代表的增長(zhǎng)趨勢(shì),所以在小
    的頭像 發(fā)表于 10-19 16:31 ?1605次閱讀
    時(shí)間<b class='flag-5'>復(fù)雜度</b>為 <b class='flag-5'>O</b>(<b class='flag-5'>n</b>^<b class='flag-5'>2</b>) 的排序算法

    JEM軟件復(fù)雜度的增加情況

    這篇文檔展示了幾個(gè)機(jī)構(gòu)關(guān)于JEM軟件復(fù)雜度的增加情況的看法,特別提出來(lái)創(chuàng)立一個(gè)新的Ad-hoc組,研究降低軟件一般性復(fù)雜度的可能方法。
    發(fā)表于 07-19 08:25

    如何降低LMS算法的計(jì)算復(fù)雜度,加快程序在DSP上運(yùn)行的速度,實(shí)現(xiàn)DSP?

    基于線性預(yù)測(cè)的FIR自適應(yīng)語(yǔ)音濾波器的系統(tǒng)結(jié)構(gòu)由那幾部分組成?如何降低LMS算法的計(jì)算復(fù)雜度,加快程序在DSP上運(yùn)行的速度,實(shí)現(xiàn)DSP?
    發(fā)表于 04-12 06:27

    時(shí)間復(fù)雜度是指什么

    原理->微機(jī)原理->軟件工程,編譯原理,數(shù)據(jù)庫(kù)數(shù)據(jù)結(jié)構(gòu)1.時(shí)間復(fù)雜度時(shí)間復(fù)雜度是指執(zhí)行算法所需要的計(jì)算工作量,因?yàn)檎麄€(gè)算法的執(zhí)行時(shí)間與基本操作重復(fù)執(zhí)行的...
    發(fā)表于 07-22 10:01

    各種排序算法的時(shí)間空間復(fù)雜度、穩(wěn)定性

    各種排序算法的時(shí)間空間復(fù)雜度、穩(wěn)定性一、排序算法分類:二、排序算法比較:注:1、歸并排序可以通過(guò)手搖算法將空間復(fù)雜度降到O(1),但是時(shí)間復(fù)雜度會(huì)提高。
    發(fā)表于 12-21 07:48

    降低高條件數(shù)信道下的球形譯碼算法復(fù)雜度方法

    MIMO 系統(tǒng)中,球形譯碼可以在保證接近ML 檢測(cè)性能的前提下大大降低檢測(cè)復(fù)雜度。但當(dāng)信道矩陣條件數(shù)很高時(shí),球形譯碼的復(fù)雜度仍然會(huì)很高。在分析了這一現(xiàn)象的原因后,本文提出
    發(fā)表于 11-21 13:52 ?8次下載

    圖像復(fù)雜度對(duì)信息隱藏性能影響分析

    針對(duì)信息隱藏中載體圖像的差異性,提出一種圖像復(fù)雜度評(píng)價(jià)方法,綜合考慮圖像的壓縮特性以及圖像紋理能量作為圖像復(fù)雜度指標(biāo),并基于閾值劃分準(zhǔn)則對(duì)栽體圖像進(jìn)行復(fù)雜度分類,以幾種經(jīng)典的基于直方圖
    發(fā)表于 11-14 09:57 ?5次下載

    Transformer復(fù)雜度和高效設(shè)計(jì)及Transformer的應(yīng)用

    幫助。 本文涉及25篇Transformer相關(guān)的文章,對(duì)原文感興趣的讀者可以關(guān)注公眾號(hào)回復(fù): ACL2021Transformers,下載本文所涉及的所有文章~本文主要內(nèi)容: 前言 ACL 2021中
    的頭像 發(fā)表于 09-01 09:27 ?6812次閱讀
    <b class='flag-5'>Transformer</b>的<b class='flag-5'>復(fù)雜度</b>和高效設(shè)計(jì)及<b class='flag-5'>Transformer</b>的應(yīng)用

    如何求遞歸算法的時(shí)間復(fù)雜度

    那么我通過(guò)一道簡(jiǎn)單的面試題,模擬面試的場(chǎng)景,來(lái)帶大家逐步分析遞歸算法的時(shí)間復(fù)雜度,最后找出最優(yōu)解,來(lái)看看同樣是遞歸,怎么就寫(xiě)成了O(n)的代碼。
    的頭像 發(fā)表于 07-13 11:30 ?2450次閱讀

    算法時(shí)空復(fù)雜度分析實(shí)用指南2

    類似的,想想之前說(shuō)的數(shù)據(jù)結(jié)構(gòu)擴(kuò)容的場(chǎng)景,也許`N`次操作中的某一次操作恰好觸發(fā)了擴(kuò)容,導(dǎo)致時(shí)間復(fù)雜度提高,但總的時(shí)間復(fù)雜度依然保持在`O(N
    的頭像 發(fā)表于 04-12 14:38 ?662次閱讀
    算法時(shí)空<b class='flag-5'>復(fù)雜度</b>分析實(shí)用指南<b class='flag-5'>2</b>

    算法時(shí)空復(fù)雜度分析實(shí)用指南(上)

    本文會(huì)篇幅較長(zhǎng),會(huì)涵蓋如下幾點(diǎn): 1、Big O 表示法的幾個(gè)基本特點(diǎn)。 2、非遞歸算法中的時(shí)間復(fù)雜度分析。 3、數(shù)據(jù)結(jié)構(gòu) API 的效率衡量方法(攤還分析)。
    的頭像 發(fā)表于 04-19 10:34 ?1026次閱讀
    算法時(shí)空<b class='flag-5'>復(fù)雜度</b>分析實(shí)用指南(上)

    算法時(shí)空復(fù)雜度分析實(shí)用指南(下)

    Big O 表示法的幾個(gè)基本特點(diǎn)。 2、非遞歸算法中的時(shí)間復(fù)雜度分析。 3、數(shù)據(jù)結(jié)構(gòu) API 的效率衡量方法(攤還分析)。 4、遞歸算法的時(shí)間/空間
    的頭像 發(fā)表于 04-19 10:35 ?886次閱讀
    算法時(shí)空<b class='flag-5'>復(fù)雜度</b>分析實(shí)用指南(下)

    如何計(jì)算時(shí)間復(fù)雜度

    來(lái)完成,那么該算法的用處就不會(huì)太大。同樣如果該算法需要若干個(gè)GB的內(nèi)存,那么在大部分機(jī)器上都無(wú)法使用。 一個(gè)算法的評(píng)價(jià)主要從時(shí)間復(fù)雜度和空間復(fù)雜度來(lái)考慮。 而時(shí)間復(fù)雜度是一個(gè)函數(shù),定性描述該算法的運(yùn)行時(shí)間,通常用大
    的頭像 發(fā)表于 10-13 11:19 ?3614次閱讀
    如何計(jì)算時(shí)間<b class='flag-5'>復(fù)雜度</b>

    如何降低SigmaDSP音頻系統(tǒng)復(fù)雜度的情形

    電子發(fā)燒友網(wǎng)站提供《如何降低SigmaDSP音頻系統(tǒng)復(fù)雜度的情形.pdf》資料免費(fèi)下載
    發(fā)表于 11-29 11:13 ?0次下載
    如何<b class='flag-5'>降低</b>SigmaDSP音頻系統(tǒng)<b class='flag-5'>復(fù)雜度</b>的情形