在我們之前發(fā)布的文章《一個(gè)新的 TensorFlow Lite 示例應(yīng)用:棋盤游戲》中,展示了如何使用 TensorFlow 和 TensorFlow Agents 來(lái)訓(xùn)練強(qiáng)化學(xué)習(xí) (RL) agent,使其玩一個(gè)簡(jiǎn)單的棋盤游戲“Plane Strike”。我們還將訓(xùn)練后的模型轉(zhuǎn)換為 TensorFlow Lite,然后將其部署到功能完備的 Android 應(yīng)用中。本文,我們將演示一種全新路徑:使用 Flax/JAX 訓(xùn)練相同的強(qiáng)化學(xué)習(xí) agent,然后將其部署到我們之前構(gòu)建的同一款 Android 應(yīng)用中。
簡(jiǎn)單回顧一下游戲規(guī)則:我們基于強(qiáng)化學(xué)習(xí)的 agent 需要根據(jù)真人玩家的棋盤位置預(yù)測(cè)擊打位置,以便能早于真人玩家完成游戲。如需進(jìn)一步了解游戲規(guī)則,請(qǐng)參閱我們之前發(fā)布的文章。
“Plane Strike”游戲演示
背景:JAX 和 TensorFlow
JAX 是一個(gè)與 NumPy 類似的內(nèi)容庫(kù),由 Google Research 部門專為實(shí)現(xiàn)高性能計(jì)算而開發(fā)。JAX 使用 XLA 針對(duì) GPU 和 TPU 優(yōu)化的程序進(jìn)行編譯。
JAX
https://github.com/google/jax
XLA
https://tensorflow.google.cn/xla
TPU
https://cloud.google.com/tpu
而 Flax 則是在 JAX 基礎(chǔ)上構(gòu)建的一款熱門神經(jīng)網(wǎng)絡(luò)庫(kù)。研究人員一直在使用 JAX/Flax 來(lái)訓(xùn)練包含數(shù)億萬(wàn)個(gè)參數(shù)的超大模型(如用于語(yǔ)言理解和生成的 PaLM,或者用于圖像生成的 Imagen),以便充分利用現(xiàn)代硬件。
Flax
https://github.com/google/flax
PaLM
https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html
Imagen
https://imagen.research.google/
如果您不熟悉 JAX 和 Flax,可以先從 JAX 101 教程和 Flax 入門示例開始。
JAX 101 教程
https://jax.readthedocs.io/en/latest/jax-101/index.html
Flax 入門示例
https://flax.readthedocs.io/en/latest/getting_started.html
2015 年底,TensorFlow 作為 Machine Learning (ML) 內(nèi)容庫(kù)問世,現(xiàn)已發(fā)展為一個(gè)豐富的生態(tài)系統(tǒng),其中包含用于實(shí)現(xiàn) ML 流水線生產(chǎn)化 (TFX)、數(shù)據(jù)可視化 (TensorBoard),和將 ML 模型部署到邊緣設(shè)備 (TensorFlow Lite) 的工具,以及在網(wǎng)絡(luò)瀏覽器上運(yùn)行的裝置,或能夠執(zhí)行 JavaScript (TensorFlow.js) 的任何裝置。
TFX
https://tensorflow.google.cn/tfx
TensorBoard
https://tensorboard.dev/
TensorFlow Lite
https://tensorflow.google.cn/lite
TensorFlow.js
https://tensorflow.google.cn/js
在 JAX 或 Flax 中開發(fā)的模型也可以利用這一豐富的生態(tài)系統(tǒng)。方法是首先將此類模型轉(zhuǎn)換為 TensorFlow SavedModel 格式,然后使用與它們?cè)?TensorFlow 中原生開發(fā)相同的工具。
SavedModel
https://tensorflow.google.cn/guide/saved_model
如果您已經(jīng)擁有經(jīng) JAX 訓(xùn)練的模型并希望立即進(jìn)行部署,我們整合了一份資源列表供您參考:
視頻 “使用 TensorFlow Serving 為 JAX 模型提供服務(wù)”,展示了如何使用 TensorFlow Serving 部署 JAX 模型。
https://youtu.be/I4dx7OI9FJQ?t=36
文章《借助 TensorFlow.js 在網(wǎng)絡(luò)上使用 JAX》,對(duì)如何將 JAX 模型轉(zhuǎn)換為 TFJS,并在網(wǎng)絡(luò)應(yīng)用中運(yùn)行進(jìn)行了詳細(xì)講解。
https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html
本篇文章演示了如何將 Flax/JAX 模型轉(zhuǎn)換為 TFLite,并在原生 Android 應(yīng)用中運(yùn)行該模型。
總而言之,無(wú)論您的部署目標(biāo)是服務(wù)器、網(wǎng)絡(luò)還是移動(dòng)設(shè)備,我們都會(huì)為您提供相應(yīng)的幫助。
使用 Flax/JAX 實(shí)現(xiàn)游戲 agent
將目光轉(zhuǎn)回到棋盤游戲。為了實(shí)現(xiàn)強(qiáng)化學(xué)習(xí) agent,我們將會(huì)利用與之前相同的 OpenAI gym 環(huán)境。這次,我們將使用 Flax/JAX 訓(xùn)練相同的策略梯度模型。回想一下,在數(shù)學(xué)層面上策略梯度的定義是:
OpenAI gym
https://github.com/tensorflow/examples/tree/master/lite/examples/reinforcement_learning/ml/tf_and_jax/gym_planestrike/gym_planestrike/envs
其中:
T:每段的時(shí)步數(shù),各段的時(shí)步數(shù)可能有所不同
st:時(shí)步上的狀態(tài) t
at:時(shí)步上的所選操作 t 指定狀態(tài)s
πθ:參數(shù)為 θ 的策略
R(*):在指定策略下,收集到的獎(jiǎng)勵(lì)
我們定義了一個(gè) 3 層 MLP 作為策略網(wǎng)絡(luò),該網(wǎng)絡(luò)可以預(yù)測(cè) agent 的下一個(gè)擊打位置。
class PolicyGradient(nn.Module): """Neural network to predict the next strike position.""" @nn.compact def __call__(self, x): dtype = jnp.float32 x = x.reshape((x.shape[0], -1)) x = nn.Dense( features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)( x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x) x = nn.relu(x) x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x) policy_probabilities = nn.softmax(x) return policy_probabilities
在我們訓(xùn)練循環(huán)的每次迭代中,我們都會(huì)使用神經(jīng)網(wǎng)絡(luò)玩一局游戲、收集軌跡信息(游戲棋盤位置、采取的操作和獎(jiǎng)勵(lì))、對(duì)獎(jiǎng)勵(lì)進(jìn)行折扣,然后使用相應(yīng)軌跡訓(xùn)練模型。
for i in tqdm(range(iterations)): predict_fn = functools.partial(run_inference, params) board_log, action_log, result_log = common.play_game(predict_fn) rewards = common.compute_rewards(result_log) optimizer, params, opt_state = train_step(optimizer, params, opt_state, board_log, action_log, rewards)
在 train_step() 方法中,我們首先會(huì)使用軌跡計(jì)算損失,然后使用 jax.grad() 計(jì)算梯度,最后,使用 Optax(用于 JAX 的梯度處理和優(yōu)化庫(kù))來(lái)更新模型參數(shù)。
Optax
https://github.com/deepmind/optax
def compute_loss(logits, labels, rewards): one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2) loss = -jnp.mean( jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards)) return loss def train_step(model_optimizer, params, opt_state, game_board_log, predicted_action_log, action_result_log): """Run one training step.""" def loss_fn(model_params): logits = run_inference(model_params, game_board_log) loss = compute_loss(logits, predicted_action_log, action_result_log) return loss def compute_grads(params): return jax.grad(loss_fn)(params) grads = compute_grads(params) updates, opt_state = model_optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return model_optimizer, params, opt_state @jax.jit def run_inference(model_params, board): logits = PolicyGradient().apply({'params': model_params}, board) return logits
這就是訓(xùn)練循環(huán)。如下圖所示,我們可以在 TensorBoard 中觀察訓(xùn)練進(jìn)度;其中,我們使代理指標(biāo)“game_length”(完成游戲所需的步驟數(shù))來(lái)跟蹤進(jìn)度:若 agent 變得更聰明,它便能以更少的步驟完成游戲。
將 Flax/JAX 模型轉(zhuǎn)換為
TensorFlow Lite 并與
Android 應(yīng)用集成
完成模型訓(xùn)練后,我們使用 jax2tf(一款 TensorFlow-JAX 互操作工具),將 JAX 模型轉(zhuǎn)換為 TensorFlow concrete function。最后一步是調(diào)用 TensorFlow Lite 轉(zhuǎn)換器來(lái)將 concrete function 轉(zhuǎn)換為 TFLite 模型。
jax2tf
https://github.com/google/jax/tree/main/jax/experimental/jax2tf
# Convert to tflite model model = PolicyGradient() jax_predict_fn = lambda input: model.apply({'params': params}, input) tf_predict = tf.function( jax2tf.convert(jax_predict_fn, enable_xla=False), input_signature=[ tf.TensorSpec( shape=[1, common.BOARD_SIZE, common.BOARD_SIZE], dtype=tf.float32, name='input') ], autograph=False, ) converter = tf.lite.TFLiteConverter.from_concrete_functions( [tf_predict.get_concrete_function()], tf_predict) tflite_model = converter.convert() # Save the model with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f: f.write(tflite_model)
經(jīng) JAX 轉(zhuǎn)換的 TFLite 模型與任何經(jīng) TensorFlow 訓(xùn)練的 TFLite 模型會(huì)有完全一致的行為。您可以使用 Netron 進(jìn)行可視化:
使用 Netron 對(duì) Flax/JAX 轉(zhuǎn)換的 TFLite 模型進(jìn)行可視化
我們可以使用與之前完全一樣的 Java 代碼來(lái)調(diào)用模型并獲取預(yù)測(cè)結(jié)果。
convertBoardStateToByteBuffer(board); tflite.run(boardData, outputProbArrays); float[] probArray = outputProbArrays[0]; int agentStrikePosition = -1; float maxProb = 0; for (int i = 0; i < probArray.length; i++) { int x = i / Constants.BOARD_SIZE; int y = i % Constants.BOARD_SIZE; if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) { agentStrikePosition = i; maxProb = probArray[i]; } }
總結(jié)
本文詳細(xì)介紹了如何使用 Flax/JAX 訓(xùn)練簡(jiǎn)單的強(qiáng)化學(xué)習(xí)模型、利用 jax2tf 將其轉(zhuǎn)換為 TensorFlow Lite,以及將轉(zhuǎn)換后的模型集成到 Android 應(yīng)用。
現(xiàn)在,您已經(jīng)了解了如何使用 Flax/JAX 構(gòu)建神經(jīng)網(wǎng)絡(luò)模型,以及如何利用強(qiáng)大的 TensorFlow 生態(tài)系統(tǒng),在幾乎任何您想要的位置部署模型。我們十分期待看到您使用 JAX 和 TensorFlow 構(gòu)建出色應(yīng)用!
審核編輯:劉清
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4779瀏覽量
101171 -
TPU
+關(guān)注
關(guān)注
0文章
143瀏覽量
20783 -
MLP
+關(guān)注
關(guān)注
0文章
57瀏覽量
4288
原文標(biāo)題:使用 JAX 構(gòu)建強(qiáng)化學(xué)習(xí) agent,并借助 TensorFlow Lite 將其部署到 Android 應(yīng)用中
文章出處:【微信號(hào):tensorflowers,微信公眾號(hào):Tensorflowers】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論