TensorFlow有了替代品,竟然還是谷歌自己做出來的?這其實是TensorFlow的一個簡化庫,名為JAX,可以支持部分TensorFlow的功能,但是比TensorFlow更加簡潔易用。
什么?TensorFlow 有了替代品?什么?竟然還是谷歌自己做出來的?先別慌,從各種意義上來說,這個所謂的 “替代品” 其實是 TensorFlow 的一個簡化庫,名為JAX,結(jié)合 Autograd 和 XLA,可以支持部分 TensorFlow 的功能,但是比 TensorFlow 更加簡潔易用。
雖然還不至于替代 TensorFlow,但已經(jīng)有 Reddit 網(wǎng)友對 JAX 寄予厚望,并表示“早就期待能有一個可以直接調(diào)用 Numpy API 接口的庫了!”,“希望它可以取代 TensorFlow!”。
JAX 結(jié)合了 Autograd 和 XLA,是專為高性能機器學(xué)習(xí)研究打造的產(chǎn)品。
有了新版本的Autograd,JAX 能夠自動對 Python 和 NumPy 的自帶函數(shù)求導(dǎo),支持循環(huán)、分支、遞歸、閉包函數(shù)求導(dǎo),而且可以求三階導(dǎo)數(shù)。它支持自動模式反向求導(dǎo)(也就是反向傳播)和正向求導(dǎo),且二者可以任意組合成任何順序。
JAX 的創(chuàng)新之處在于,它基于XLA在 GPU 和 TPU 上編譯和運行 NumPy 程序。默認(rèn)情況下,編譯是在底層進行的,庫調(diào)用能夠及時編譯和執(zhí)行。但是 JAX 還允許使用單一函數(shù) API jit將自己的 Python 函數(shù)及時編譯成經(jīng)過 XLA 優(yōu)化的內(nèi)核。編譯和自動求導(dǎo)可以任意組合,因此可以在不脫離 Python 環(huán)境的情況下實現(xiàn)復(fù)雜算法并獲得最優(yōu)性能。
JAX 最初由 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 發(fā)起,他們均任職于谷歌大腦團隊。在 GitHub 的說明文檔中,作者明確表示:JAX 目前還只是一個研究項目,不是谷歌的官方產(chǎn)品,因此可能會有一些 bug。從作者的 GitHub 簡介來看,這應(yīng)該是谷歌大腦正在嘗試的新項目,在同一個 GitHub 目錄下的開源項目還包括 8 月份在業(yè)內(nèi)引起熱議的強化學(xué)習(xí)框架 Dopamine。
以下是 JAX 的簡單使用示例。
GitHub 項目傳送門:https://github.com/google/JAX
有關(guān)具體的安裝和簡單的入門指導(dǎo)大家可以在 GitHub 中自行查看,在此不做過多贅述。
JAX 庫的實現(xiàn)原理
機器學(xué)習(xí)中的編程是關(guān)于函數(shù)的表達和轉(zhuǎn)換。轉(zhuǎn)換包括自動微分、加速器編譯和自動批處理。像 Python 這樣的高級語言非常適合表達函數(shù),但是通常使用者只能應(yīng)用它們。我們無法訪問它們的內(nèi)部結(jié)構(gòu),因此無法執(zhí)行轉(zhuǎn)換。
JAX 可以用于專門化高級Python+NumPy函數(shù),并將其轉(zhuǎn)換為可轉(zhuǎn)換的表示形式,然后再提升為 Python 函數(shù)。
JAX 通過跟蹤專門處理 Python 函數(shù)。跟蹤一個函數(shù)意味著:監(jiān)視應(yīng)用于其輸入,以產(chǎn)生其輸出的所有基本操作,并在有向無環(huán)圖 (DAG) 中記錄這些操作及其之間的數(shù)據(jù)流。為了執(zhí)行跟蹤,JAX 包裝了基本的操作,就像基本的數(shù)字內(nèi)核一樣,這樣一來,當(dāng)調(diào)用它們時,它們就會將自己添加到執(zhí)行的操作列表以及輸入和輸出中。為了跟蹤這些原語之間的數(shù)據(jù)流,跟蹤的值被包裝在 Tracer 類的實例中。
當(dāng) Python 函數(shù)被提供給 grad 或 jit 時,它被包裝起來以便跟蹤并返回。當(dāng)調(diào)用包裝的函數(shù)時,我們將提供的具體參數(shù)抽象到 AbstractValue 類的實例中,將它們框起來用于跟蹤跟蹤器類的實例,并對它們調(diào)用函數(shù)。
抽象參數(shù)表示一組可能的值,而不是特定的值:例如,jit 將 ndarray 參數(shù)抽象為抽象值,這些值表示具有相同形狀和數(shù)據(jù)類型的所有 ndarray。相反,grad 抽象 ndarray 參數(shù)來表示底層值的無窮小鄰域。通過在這些抽象值上跟蹤 Python 函數(shù),我們確保它足夠?qū)iT化,以便轉(zhuǎn)換是可處理的,并且它仍然足夠通用,以便轉(zhuǎn)換后的結(jié)果是有用的,并且可能是可重用的。然后將這些轉(zhuǎn)換后的函數(shù)提升回 Python 可調(diào)用函數(shù),這樣就可以根據(jù)需要跟蹤并再次轉(zhuǎn)換它們。
JAX 跟蹤的基本函數(shù)大多與 XLA HLO 1:1 對應(yīng),并在 lax.py 中定義。這種 1:1 的對應(yīng)關(guān)系使得到 XLA 的大多數(shù)轉(zhuǎn)換基本上都很簡單,并且確保我們只有一小組原語來覆蓋其他轉(zhuǎn)換,比如自動微分。 jax.numpy 層是用純 Python 編寫的,它只是用 LAX 函數(shù) (以及我們已經(jīng)編寫的其他 numpy 函數(shù)) 表示 numpy 函數(shù)。這使得 jax.numpy 易于延展。
當(dāng)你使用 jax.numpy 時,底層 LAX 原語是在后臺進行 jit 編譯的,允許你在加速器上執(zhí)行每個原語操作的同時編寫不受限制的 Python+ numpy 代碼。
但是 JAX 可以做更多的事情:你可以在越來越大的函數(shù)上使用jit來進行端到端編譯和優(yōu)化,而不僅僅是編譯和調(diào)度到一組固定的單個原語。例如,可以編譯整個網(wǎng)絡(luò),或者編譯整個梯度計算和優(yōu)化器更新步驟,而不僅僅是編譯和調(diào)度卷積運算。
折衷之處是,jit 函數(shù)必須滿足一些額外的專門化需求:因為我們希望編譯專門針對形狀和數(shù)據(jù)類型的跟蹤,但不是專門針對具體值的跟蹤,所以 jit 裝飾器下的 Python 代碼必須適用于抽象值。如果我們嘗試在一個抽象的 x 上求 x >0 的值,結(jié)果是一個抽象的值,表示集合 {True, False},所以 Python 分支就像 if x > 0 會引起報錯。
有關(guān)使用 jit 的更多要求,請參見:https://github.com/google/jax#whats-supported
好消息是,jit 是可選的:JAX 庫在后臺對單個操作和函數(shù)使用 jit,允許編寫不受限制的 Python+Numpy,同時仍然使用硬件加速器。但是,當(dāng)你希望最大化性能時,通常可以在自己的代碼中使用 jit 編譯和端到端優(yōu)化更大的函數(shù)。
后續(xù)計劃
目前項目小組還將對以下幾項做更多嘗試和更新:
完善說明文檔
支持 Cloud TPU
支持多 GPU 和多 TPU
支持完整的 NumPy 功能和部分 SciPy 功能
全面支持 vmap
加速
降低 XLA 函數(shù)調(diào)度開銷
線性代數(shù)例程(CPU 上的 MKL 和 GPU 上的 MAGMA)
高效自動微分原語cond和while
有關(guān) JAX 庫的介紹大致如此。
-
谷歌
+關(guān)注
關(guān)注
27文章
6230瀏覽量
107854 -
機器學(xué)習(xí)
+關(guān)注
關(guān)注
66文章
8499瀏覽量
134331 -
tensorflow
+關(guān)注
關(guān)注
13文章
330瀏覽量
61102
原文標(biāo)題:要替代 TensorFlow?谷歌開源機器學(xué)習(xí)庫 JAX
文章出處:【微信號:AI_era,微信公眾號:新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
請問有沒有關(guān)于SN74HC1G14的替代品?
超級電容是電池的替代品,你認(rèn)同嗎?

ADS8361輸入不接的時候,輸出端的時序竟然有波形出來,是哪里的問題?
愛普生停產(chǎn)產(chǎn)品/替代品

汽車應(yīng)用中有刷DC電機驅(qū)動的繼電器替代品

FCB-CV7520一體化機芯的卓越升級替代品——索尼FCB-EV9520L

tlc4502的替代品有哪些?
利用TINA仿真了一個10階10M巴特沃斯濾波器,做出來的電路,輸入信號會隨著頻率的變化而變化,為什么?
如何考慮將TI Smart DAC作為555定時器的替代品

評論