最近,有一篇入門文章引發(fā)了不少關(guān)注。文章中詳細介紹了循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN),及其變體長短期記憶(LSTM)背后的原理。
具體內(nèi)容,從前饋網(wǎng)絡(luò)(Feedforward Networks)開始講起,先后講述了循環(huán)神經(jīng)網(wǎng)絡(luò)、時序反向傳播算法(BPTT)、LSTM等模型的原理與運作方式。
這篇文章來自Skymind,一家推動數(shù)據(jù)項目從原型到落地的公司。獲得了YCombinator、騰訊等的投資。
對于人工智能初學(xué)者來說,是一份非常不錯的入門資料。
循環(huán)網(wǎng)絡(luò),是一種人工神經(jīng)網(wǎng)絡(luò)(ANN),用來識別數(shù)據(jù)序列中的模式。
比如文本、基因組、筆記、口語或來自傳感器、股票市場和政府機構(gòu)的時間序列數(shù)據(jù)。
它的算法考慮了時間和順序,具有時間維度。
研究表明,RNN是最強大和最有用的神經(jīng)網(wǎng)絡(luò)之一,它甚至能夠適用于圖像處理。
把圖像分割成一系列的補丁,可以視為一個序列。
但是,想要理解循環(huán)網(wǎng)絡(luò),首先要必須了解前饋網(wǎng)絡(luò)的基本知識。
前饋網(wǎng)絡(luò)回顧
前饋網(wǎng)絡(luò)和循環(huán)網(wǎng)絡(luò)的命名,來自于它們在傳遞信息時,在網(wǎng)絡(luò)節(jié)點上執(zhí)行的一系列數(shù)學(xué)運算的方式。
前饋網(wǎng)絡(luò)直接向前遞送信息(不會再次接觸已經(jīng)經(jīng)過的節(jié)點),而循環(huán)網(wǎng)絡(luò)則是通過循環(huán)傳遞信息。
前饋網(wǎng)絡(luò)中的樣例,輸入網(wǎng)絡(luò)后被轉(zhuǎn)換成輸出;在監(jiān)督學(xué)習(xí)中,輸出將是一個標簽,一個應(yīng)用于輸入的名稱。
也就是說,前饋網(wǎng)絡(luò)將原始數(shù)據(jù)映射到類別,識別出信號的模式。例如,輸入圖像應(yīng)該被標記為“貓”還是“大象”。
前饋網(wǎng)絡(luò)根據(jù)標記的圖像進行訓(xùn)練,直到猜測圖像類別時產(chǎn)生的錯誤最小化。 通過一組經(jīng)過訓(xùn)練的參數(shù)(或者稱為權(quán)重,統(tǒng)稱為模型) ,網(wǎng)絡(luò)就可以對它從未見過的數(shù)據(jù)進行分類了。
一個訓(xùn)練好的前饋網(wǎng)絡(luò)可以應(yīng)用在任何隨機的照片數(shù)據(jù)集中,它識別的第一張照片,并不會影響它對第二張照片的預(yù)測。
看到一只貓的照片之后,不會導(dǎo)致網(wǎng)絡(luò)預(yù)下一張圖是大象。
也就是說,前饋網(wǎng)絡(luò)沒有時間順序的概念,它考慮的唯一輸入就是它所接觸到的當(dāng)前的輸入樣例。
循環(huán)網(wǎng)絡(luò)
與前饋網(wǎng)絡(luò)相比,循環(huán)網(wǎng)絡(luò)的輸入不僅包括當(dāng)前的輸入樣例,還包括之前的輸入信息。
下面是美國加州大學(xué)圣地亞哥分校教授Jeffrey Elman提出的一個早期的簡單循環(huán)網(wǎng)絡(luò)的示意圖。
圖底部的BTSXPE代表當(dāng)前時刻的輸入樣例,而CONTEXT UNIT代表前一時刻的輸出。
循環(huán)網(wǎng)絡(luò)在t-1個時間步的判定,會影響隨后在t時間步的判定。所以,循環(huán)網(wǎng)絡(luò)有兩個輸入源,現(xiàn)在和最近的過去,它們結(jié)合起來決定對新數(shù)據(jù)的反應(yīng),就像我們在生活中一樣。
循環(huán)網(wǎng)絡(luò)與前饋網(wǎng)絡(luò)的區(qū)別在于,循環(huán)網(wǎng)絡(luò)的反饋循環(huán)會連接到它們過去的判定,將自己的輸出作為輸入。
循環(huán)網(wǎng)絡(luò)是有記憶的。給神經(jīng)網(wǎng)絡(luò)增加記憶的目的在于:序列本身帶有信息,循環(huán)網(wǎng)絡(luò)用它來執(zhí)行前饋網(wǎng)絡(luò)不能執(zhí)行的任務(wù)。
這些連續(xù)的信息被保存在循環(huán)網(wǎng)絡(luò)的隱藏狀態(tài)中,這種隱藏狀態(tài)管理跨越多個時間步,并一層一層地向前傳遞,影響網(wǎng)絡(luò)對每一個新樣例的處理。
循環(huán)網(wǎng)絡(luò),需要尋找被許多時刻分開的各種事件之間的相關(guān)性,這些相關(guān)性被稱為“長距離依賴”,因為時間下游的事件依賴于之前的一個或多個事件,并且是這些事件的函數(shù)。
因此,你可以將RNN理解為是一種跨時間分享權(quán)重的方式。
正如人類的記憶在身體內(nèi)無形地循環(huán),影響我們的行為但不暴露全貌一樣,信息也在循環(huán)網(wǎng)絡(luò)的隱藏狀態(tài)中循環(huán)。
用數(shù)學(xué)的方式來描述記憶傳遞的過程是這樣的:
t代表時間步,ht代表第t個時間步的隱藏狀態(tài),是同一個時間步xt的輸入函數(shù)。W是權(quán)重函數(shù),用于修正xt。
U是隱藏狀態(tài)矩陣,也被稱為轉(zhuǎn)移矩陣,類似于馬爾可夫鏈。ht-1代表t的上一個時間步t-1的隱藏狀態(tài)。
權(quán)重矩陣,是決定當(dāng)前輸入和過去隱藏狀態(tài)的重要程度的過濾器。 它們產(chǎn)生的誤差會通過反向傳播返回,并用于調(diào)整相應(yīng)的權(quán)重,直到誤差不再降低。
權(quán)重輸入(Wxt)和隱藏狀態(tài)(Uht-1)的總和被函數(shù)φ壓縮,可能是邏輯S形函數(shù)或者是雙曲正切(tanh)函數(shù),視情況而定。
這是一個標準工具,用于將非常大或非常小的值壓縮到邏輯空間中,并使梯度可用于反向傳播。
因為這個反饋循環(huán)發(fā)生在序列中的每個時間步中,每個隱藏狀態(tài)不僅跟蹤前一個隱藏狀態(tài),只要記憶能夠持續(xù)存在,它會還包含h_t-1之前的所有的隱藏狀態(tài)。
給定一系列字母,循環(huán)網(wǎng)絡(luò)將使用第一個字符來幫助確定它對第二個字符的感知,比如,首字母是q,可能會導(dǎo)致它推斷下一個字母是u,而首字母是t,可能會導(dǎo)致它推斷下一個字母是h。
由于循環(huán)網(wǎng)絡(luò)跨越時間,用動畫來說明可能會更好。(可以將第一個垂直節(jié)點看作是一個前饋網(wǎng)絡(luò),隨著時間的推移,它會變成循環(huán)網(wǎng)絡(luò))。
在上圖中,每個x是一個輸入樣例,w是過濾輸入的權(quán)重,a是隱藏層的激活(加權(quán)輸入和先前隱藏狀態(tài)的和),b是隱藏層使用修正線性或sigmoid單元轉(zhuǎn)換或壓縮后的輸出。
時序反向傳播算法(BPTT)
循環(huán)網(wǎng)絡(luò)的目的是準確地對序列輸入進行分類。主要依靠誤差的反向傳播和梯度下降法來做到這一點。
前饋網(wǎng)絡(luò)中的反向傳播從最后的誤差開始,經(jīng)過每個隱藏層的輸出、權(quán)重和輸入反向移動,將一定比例的誤差分配給每個權(quán)重,方法是計算它們的偏導(dǎo)數(shù)?e/?w,或它們之間的變化率之間的關(guān)系。
隨后,這些偏導(dǎo)數(shù)會被用到梯度下降算法中,來調(diào)整權(quán)重減少誤差。
而循環(huán)網(wǎng)絡(luò)依賴于反向傳播的一種擴展,稱為時序反向傳播算法,即BPTT。
在這種情況下,時間通過一系列定義明確、有序的計算來表達,這些計算將一個時間步與下一個時間步聯(lián)系起來。
神經(jīng)網(wǎng)絡(luò),無論是循環(huán)的還是非循環(huán)的,都是簡單的嵌套復(fù)合函數(shù),比如f(g(h(x))。添加時間元素,只是擴展了我們用鏈式法則計算導(dǎo)數(shù)的函數(shù)序列。
截斷式BPTT
截斷式BPTT(Truncated BPTT)是完整BPTT的近似方法,是處理是長序列的首選。
在時間步較多的序列中,完整BPTT的每個參數(shù)更新的正向/反向運算成本變得非常高。
截斷式BPTT的缺點是,由于截斷,梯度反向移動的距離有限,因此網(wǎng)絡(luò)無法學(xué)習(xí)與完整BPTT一樣長的依賴。
梯度消失和梯度爆炸
和大多數(shù)神經(jīng)網(wǎng)絡(luò)一樣,循環(huán)網(wǎng)絡(luò)也有了一定的歷史。 到1990年代初,梯度消失問題成為影響網(wǎng)絡(luò)性能的主要障礙。
就像直線表示x的變化和y的變化一樣,梯度表示所有權(quán)重隨誤差變化的變化。如果我們不知道梯度,我們就不能在減少誤差的方向上調(diào)整權(quán)重,網(wǎng)絡(luò)也就會停止學(xué)習(xí)。
循環(huán)網(wǎng)絡(luò),在最終的輸入和之前許多時間步之間建立聯(lián)系時,也遇到了問題。因為很難知道一個遠距離的輸入有多么重要。
就像向前追溯曾曾曾曾曾……祖父母兄弟的數(shù)量一樣,會越來越多,越來越多。
這在一定程度上是因為,通過神經(jīng)網(wǎng)絡(luò)傳遞的信息要經(jīng)過多個乘法階段。
每個研究過復(fù)利的人都知道,任何數(shù)量循環(huán)乘以略大于一的量,都會變得不可估量的大(實際上,簡單的數(shù)學(xué)真理支撐著網(wǎng)絡(luò)效應(yīng)和社會不平等)。
反過來,乘以小于1的量,也會變得非常非常小。如果賭徒們每投入一美元,只能贏得97美分,那么他們很快就會破產(chǎn)。
由于深度神經(jīng)網(wǎng)絡(luò)的層和時間步通過乘法相互關(guān)聯(lián),導(dǎo)數(shù)很容易消失或爆炸。
梯度爆炸時,每一個權(quán)重就像諺語中的蝴蝶一樣,它拍打的翅膀會引起遠處的颶風(fēng)。
但是梯度爆炸解決起來相對容易,因為它們可以被截斷或壓縮。
梯度消失正好相反,是導(dǎo)數(shù)變得非常小,使計算機無法工作,網(wǎng)絡(luò)也無法學(xué)習(xí)。這是一個更難解決的問題。
下面你可以看到一遍又一遍應(yīng)用S形函數(shù)的效果。 數(shù)據(jù)曲線越來越平緩,直至在較長的距離上無法檢測到斜率。 這類似于通過許多層的梯度消失。
長短期記憶(LSTM)
在90年代中期,德國研究人員Sepp Hochreiter和Juergen Schmidhuber提出了一種具有長短期記憶單元( LSTM )的循環(huán)網(wǎng)絡(luò)變體,作為梯度消失問題的解決方案。
LSTM有助于保留可以通過時間和層進行反向傳播的誤差。
通過保留一個更為恒定的誤差,它們使循環(huán)網(wǎng)絡(luò)能夠在有許多時間步(超過1000步)的情況下繼續(xù)學(xué)習(xí),從而打開一個遠程鏈接因果關(guān)系的通道。
這是機器學(xué)習(xí)和人工智能面臨的主要挑戰(zhàn)之一,因為算法經(jīng)常遇到獎勵信號稀疏和延遲的環(huán)境。
LSTM將信息存放在循環(huán)網(wǎng)絡(luò)正常信息流之外的門控單元中。信息可以像計算機內(nèi)存中的數(shù)據(jù)一樣存儲、寫入單元,或者從單元中讀取。
單元通過打開和關(guān)閉的門來決定存儲什么,以及何時允許讀取、寫入和忘記。
但與計算機上的數(shù)字存儲器不同,這些門是模擬的,通過范圍在0~1之間的sigmoid函數(shù)的逐元素相乘來實現(xiàn)。
與數(shù)字信號相比,模擬信號的優(yōu)勢是可微分,因此適用于反向傳播。
這些門類似于神經(jīng)網(wǎng)絡(luò)的節(jié)點,會根據(jù)它們接收到的信號決定開關(guān),它們根據(jù)信息的強度和重要性來阻止或傳遞信息,然后用它們自己的權(quán)重過濾這些信息。
這些權(quán)重,就像調(diào)整輸入和隱藏狀態(tài)的權(quán)重一樣,可以在循環(huán)網(wǎng)絡(luò)學(xué)習(xí)過程中進行調(diào)整。
也就是說,記憶單元學(xué)習(xí)會通過猜測、反向傳播誤差和梯度下降法調(diào)整權(quán)重的迭代過程,來決定何時允許數(shù)據(jù)進入、離開或刪除。
下圖說明了數(shù)據(jù)如何通過記憶單元,以及門如何控制數(shù)據(jù)流動。
如果你剛剛接觸LSTM,不要著急,仔細研究一下。只需要幾分鐘,就能揭開其中的秘密。
從底部開始,三個箭頭顯示,信息由多個點流入記憶單元。 當(dāng)前輸入和過去單元狀態(tài)的組合不僅反饋到單元本身,而且反饋到它的三個門中的每一個,這將決定它們?nèi)绾翁幚磔斎搿?/p>
黑點是門本身,決定是否讓新的輸入進入、遺忘當(dāng)前的狀態(tài),還是讓這一狀態(tài)在當(dāng)前時間步影響網(wǎng)絡(luò)的輸出。
Sc是記憶單元的當(dāng)前狀態(tài),g_y_in是記憶單元的當(dāng)前輸入。
請記住,每個門都可以打開或關(guān)閉,它們會在每一步重新組合它們的打開和關(guān)閉狀態(tài)。記憶單元,在每個時間步都可以決定,是否遺忘、寫入、讀取它的狀態(tài),這些流都表示出來了。
大的、加粗的字母,給出了每個操作的結(jié)果。
下面是另一個示意圖,對比了簡單的循環(huán)網(wǎng)絡(luò)(左)和 LSTM 單元(右)。
值得注意的是,LSTM的記憶單元在輸入轉(zhuǎn)換中賦予加法和乘法不同的角色。
兩個圖中的中心加號,本質(zhì)上就是 LSTM 的秘密。
雖然這看起來非常非常簡單,但當(dāng)必須在深度上反向傳播時,這種變化有助于保持恒定的誤差。
LSTM不是將當(dāng)前狀態(tài)乘以新的輸入來確定后續(xù)的單元狀態(tài),而是將兩者相加,這就產(chǎn)生了差異。 (用于遺忘的門仍然依賴于乘法。)
不同的權(quán)重集對輸入信息進行篩選,決定是否輸入、輸出或遺忘。
不同的權(quán)重集對輸入信息進行過濾,決定是否輸出或遺忘。遺忘門被表示為一個線性恒等式函數(shù),因為如果門是打開的,那么記憶單元的當(dāng)前狀態(tài)就會被簡單地乘以1,從而向前傳播一個時間步。
此外,有一個簡單的竅門。將每個LSTM記憶單元遺忘門的偏差設(shè)定為1,可以提升網(wǎng)絡(luò)性能。(但另一方面,Sutskever建議將偏差設(shè)定為5。)
你可能會問,LSTM的目的是將遠距離事件與最終的輸出聯(lián)系起來,為什么它們會有一個遺忘門?
好吧,有時候遺忘是件好事。
如果分析一個文本語料庫,在到達一個文檔的末尾時,下一個文檔基本上跟它沒有關(guān)系,因此,在網(wǎng)絡(luò)攝取下一個文檔的第一個元素之前,應(yīng)該將記憶單元設(shè)置為零。
以分析一個文本語料庫為例,在到達文檔的末尾時,你可能會認為下一個文檔與這個文檔肯定沒有任何聯(lián)系,所以記憶單元在開始吸收下一個文檔的第一項元素前應(yīng)當(dāng)先歸零。
在下圖中,你可以看到在工作的門,直線表示關(guān)閉的門,空白圓圈代表打開的門。沿著隱藏層水平延伸的線條和圓圈是表示遺忘門。
需要注意的是,前饋網(wǎng)絡(luò)只是一對一,即將一個輸入映射到一個輸出。但循環(huán)網(wǎng)絡(luò)可以一對多,多對多,多對一。
涵蓋不同時間尺度和遠距離依賴
你可能還想知道,保護記憶單元不受新數(shù)據(jù)進入的輸入門和防止它影響 RNN 的某些輸出的輸出門的精確值是多少。你可以把 LSTM 看作是,允許一個神經(jīng)網(wǎng)絡(luò)同時在不同的時間尺度上運行。
讓我們以一個人的生命為例,想象一下我們在一個時間序列中收到了關(guān)于那個生命的各種數(shù)據(jù)流。
每個時間步的地理位置,對于下一個時間步來說都非常重要,因此時間尺度總是對最新信息開放的。
也許這個人是一個勤奮的公民,每兩年投票一次。在民主時代,我們會特別關(guān)注他們在選舉前后的所作所為。我們不想讓地理位置持續(xù)產(chǎn)生噪音影響我們的政治分析。
如果這個人也是一個勤奮的女兒,那么也許我們可以構(gòu)建一個家庭時間,學(xué)習(xí)每周日定期打電話的模式,每年假期前后,打電話的數(shù)量都會激增。這與政治周期或地理位置無關(guān)。
其他的數(shù)據(jù)也是這樣。音樂是多節(jié)奏的。文本中包含不同時間間隔的重復(fù)主題。股票市場和經(jīng)濟會有更長的波動周期。它們在不同的時間尺度上同時運行,LSTM可以捕捉到這些時間尺度。
門控循環(huán)單元(GRU)
門控循環(huán)單元( GRU )基本上是沒有輸出門的LSTM,因此在每個時間步,它都將內(nèi)容從其記憶單元完全寫入到較大的網(wǎng)絡(luò)中。
代碼示例
這里示例,是一個LSTM如何學(xué)習(xí)復(fù)制莎士比亞戲劇的評論,使用Deeplearning4j實現(xiàn)。在難以理解的地方,都有相應(yīng)的注釋。
傳送門:
https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/character/LSTMCharModellingExample.java
LSTM超參數(shù)調(diào)整
以下是手動優(yōu)化RNN超參數(shù)時需要注意的一些情況:
小心過擬合,神經(jīng)網(wǎng)絡(luò)基本在“記憶”訓(xùn)練數(shù)據(jù)時,就會發(fā)生過擬合。過擬合意味著你在訓(xùn)練數(shù)據(jù)上有很好的表現(xiàn),在其他數(shù)據(jù)集上基本無用。
正則化有好處:方法包括 l1、 l2和dropout等。
要有一個單獨的測試集,不要在這個測試集上訓(xùn)練網(wǎng)絡(luò)。
網(wǎng)絡(luò)越大,功能就越強,但也更容易過擬合。 不要試圖從10000個示例中學(xué)習(xí)一百萬個參數(shù),參數(shù)》樣例=麻煩。
數(shù)據(jù)越多越好,因為它有助于防止過度擬合。
訓(xùn)練要經(jīng)過多個epoch(算法遍歷訓(xùn)練數(shù)據(jù)集)。
每個epoch之后,評估測試集表現(xiàn),以了解何時停止(要提前停止)。
學(xué)習(xí)速率是最重要的超參數(shù)。
總體而言,堆疊層會有幫助。
對于LSTM,可以使用softsign(而不是softmax)函數(shù)替代雙曲正切函數(shù),它更快,更不容易飽和( 梯度大概為0 )。
更新器:RMSProp、AdaGrad或Nesterovs通常是不錯的選擇。AdaGrad也會降低學(xué)習(xí)率,這有時會有所幫助。
記住,要將數(shù)據(jù)標準化、MSE損失函數(shù)+恒等激活函數(shù)用于回歸、Xavier權(quán)重初始化。
評論