在线观看www成人影院-在线观看www日本免费网站-在线观看www视频-在线观看操-欧美18在线-欧美1级

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

如何將Flax/JAX模型轉(zhuǎn)換為TFLite并在原生Android應(yīng)用中運(yùn)行呢

Tensorflowers ? 來(lái)源:TensorFlow ? 作者:TensorFlow ? 2022-11-02 10:13 ? 次閱讀

在我們之前發(fā)布的文章《一個(gè)新的 TensorFlow Lite 示例應(yīng)用:棋盤游戲》中,展示了如何使用 TensorFlow 和 TensorFlow Agents 來(lái)訓(xùn)練強(qiáng)化學(xué)習(xí) (RL) agent,使其玩一個(gè)簡(jiǎn)單的棋盤游戲“Plane Strike”。我們還將訓(xùn)練后的模型轉(zhuǎn)換為 TensorFlow Lite,然后將其部署到功能完備的 Android 應(yīng)用中。本文,我們將演示一種全新路徑:使用 Flax/JAX 訓(xùn)練相同的強(qiáng)化學(xué)習(xí) agent,然后將其部署到我們之前構(gòu)建的同一款 Android 應(yīng)用中。

簡(jiǎn)單回顧一下游戲規(guī)則:我們基于強(qiáng)化學(xué)習(xí)的 agent 需要根據(jù)真人玩家的棋盤位置預(yù)測(cè)擊打位置,以便能早于真人玩家完成游戲。如需進(jìn)一步了解游戲規(guī)則,請(qǐng)參閱我們之前發(fā)布的文章。

23754442-59d4-11ed-a3b6-dac502259ad0.gif

“Plane Strike”游戲演示

背景:JAX 和 TensorFlow

JAX 是一個(gè)與 NumPy 類似的內(nèi)容庫(kù),由 Google Research 部門專為實(shí)現(xiàn)高性能計(jì)算而開發(fā)。JAX 使用 XLA 針對(duì) GPU 和 TPU 優(yōu)化的程序進(jìn)行編譯。

JAX

https://github.com/google/jax

XLA

https://tensorflow.google.cn/xla

TPU

https://cloud.google.com/tpu

而 Flax 則是在 JAX 基礎(chǔ)上構(gòu)建的一款熱門神經(jīng)網(wǎng)絡(luò)庫(kù)。研究人員一直在使用 JAX/Flax 來(lái)訓(xùn)練包含數(shù)億萬(wàn)個(gè)參數(shù)的超大模型(如用于語(yǔ)言理解和生成的 PaLM,或者用于圖像生成的 Imagen),以便充分利用現(xiàn)代硬件。

Flax

https://github.com/google/flax

PaLM

https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html

Imagen

https://imagen.research.google/

如果您不熟悉 JAX 和 Flax,可以先從 JAX 101 教程和 Flax 入門示例開始。

JAX 101 教程

https://jax.readthedocs.io/en/latest/jax-101/index.html

Flax 入門示例

https://flax.readthedocs.io/en/latest/getting_started.html

2015 年底,TensorFlow 作為 Machine Learning (ML) 內(nèi)容庫(kù)問世,現(xiàn)已發(fā)展為一個(gè)豐富的生態(tài)系統(tǒng),其中包含用于實(shí)現(xiàn) ML 流水線生產(chǎn)化 (TFX)、數(shù)據(jù)可視化 (TensorBoard),和將 ML 模型部署到邊緣設(shè)備 (TensorFlow Lite) 的工具,以及在網(wǎng)絡(luò)瀏覽器上運(yùn)行的裝置,或能夠執(zhí)行 JavaScript (TensorFlow.js) 的任何裝置。

TFX

https://tensorflow.google.cn/tfx

TensorBoard

https://tensorboard.dev/

TensorFlow Lite

https://tensorflow.google.cn/lite

TensorFlow.js

https://tensorflow.google.cn/js

在 JAX 或 Flax 中開發(fā)的模型也可以利用這一豐富的生態(tài)系統(tǒng)。方法是首先將此類模型轉(zhuǎn)換為 TensorFlow SavedModel 格式,然后使用與它們?cè)?TensorFlow 中原生開發(fā)相同的工具。

SavedModel

https://tensorflow.google.cn/guide/saved_model

如果您已經(jīng)擁有經(jīng) JAX 訓(xùn)練的模型并希望立即進(jìn)行部署,我們整合了一份資源列表供您參考:

視頻 “使用 TensorFlow Serving 為 JAX 模型提供服務(wù)”,展示了如何使用 TensorFlow Serving 部署 JAX 模型。

https://youtu.be/I4dx7OI9FJQ?t=36

文章《借助 TensorFlow.js 在網(wǎng)絡(luò)上使用 JAX》,對(duì)如何將 JAX 模型轉(zhuǎn)換為 TFJS,并在網(wǎng)絡(luò)應(yīng)用中運(yùn)行進(jìn)行了詳細(xì)講解。

https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html

本篇文章演示了如何將 Flax/JAX 模型轉(zhuǎn)換為 TFLite,并在原生 Android 應(yīng)用中運(yùn)行該模型。

總而言之,無(wú)論您的部署目標(biāo)是服務(wù)器、網(wǎng)絡(luò)還是移動(dòng)設(shè)備,我們都會(huì)為您提供相應(yīng)的幫助。

使用 Flax/JAX 實(shí)現(xiàn)游戲 agent

將目光轉(zhuǎn)回到棋盤游戲。為了實(shí)現(xiàn)強(qiáng)化學(xué)習(xí) agent,我們將會(huì)利用與之前相同的 OpenAI gym 環(huán)境。這次,我們將使用 Flax/JAX 訓(xùn)練相同的策略梯度模型。回想一下,在數(shù)學(xué)層面上策略梯度的定義是:

OpenAI gym

https://github.com/tensorflow/examples/tree/master/lite/examples/reinforcement_learning/ml/tf_and_jax/gym_planestrike/gym_planestrike/envs

23e88678-59d4-11ed-a3b6-dac502259ad0.png

其中:

T:每段的時(shí)步數(shù),各段的時(shí)步數(shù)可能有所不同

st:時(shí)步上的狀態(tài) t

at:時(shí)步上的所選操作 t 指定狀態(tài)s

πθ:參數(shù)為 θ 的策略

R(*):在指定策略下,收集到的獎(jiǎng)勵(lì)

我們定義了一個(gè) 3 層 MLP 作為策略網(wǎng)絡(luò),該網(wǎng)絡(luò)可以預(yù)測(cè) agent 的下一個(gè)擊打位置。

class PolicyGradient(nn.Module):
  """Neural network to predict the next strike position."""


@nn.compact
  def __call__(self, x):
    dtype = jnp.float32
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(
        features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)(
           x)
    x = nn.relu(x)
    x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x)
    x = nn.relu(x)
    x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x)
    policy_probabilities = nn.softmax(x)
    return policy_probabilities

在我們訓(xùn)練循環(huán)的每次迭代中,我們都會(huì)使用神經(jīng)網(wǎng)絡(luò)玩一局游戲、收集軌跡信息(游戲棋盤位置、采取的操作和獎(jiǎng)勵(lì))、對(duì)獎(jiǎng)勵(lì)進(jìn)行折扣,然后使用相應(yīng)軌跡訓(xùn)練模型。

for i in tqdm(range(iterations)):
   predict_fn = functools.partial(run_inference, params)
   board_log, action_log, result_log = common.play_game(predict_fn)
   rewards = common.compute_rewards(result_log)
   optimizer, params, opt_state = train_step(optimizer, params, opt_state,
                                             board_log, action_log, rewards)

在 train_step() 方法中,我們首先會(huì)使用軌跡計(jì)算損失,然后使用 jax.grad() 計(jì)算梯度,最后,使用 Optax(用于 JAX 的梯度處理和優(yōu)化庫(kù))來(lái)更新模型參數(shù)。

Optax

https://github.com/deepmind/optax

def compute_loss(logits, labels, rewards):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2)
  loss = -jnp.mean(
      jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards))
  return loss


def train_step(model_optimizer, params, opt_state, game_board_log,
              predicted_action_log, action_result_log):
"""Run one training step."""

  def loss_fn(model_params):
    logits = run_inference(model_params, game_board_log)
    loss = compute_loss(logits, predicted_action_log, action_result_log)
    return loss

  def compute_grads(params):
    return jax.grad(loss_fn)(params)

  grads = compute_grads(params)
  updates, opt_state = model_optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return model_optimizer, params, opt_state


@jax.jit
def run_inference(model_params, board):
  logits = PolicyGradient().apply({'params': model_params}, board)
  return logits

這就是訓(xùn)練循環(huán)。如下圖所示,我們可以在 TensorBoard 中觀察訓(xùn)練進(jìn)度;其中,我們使代理指標(biāo)“game_length”(完成游戲所需的步驟數(shù))來(lái)跟蹤進(jìn)度:若 agent 變得更聰明,它便能以更少的步驟完成游戲。

23f8d758-59d4-11ed-a3b6-dac502259ad0.png

將 Flax/JAX 模型轉(zhuǎn)換為

TensorFlow Lite 并與

Android 應(yīng)用集成

完成模型訓(xùn)練后,我們使用 jax2tf(一款 TensorFlow-JAX 互操作工具),將 JAX 模型轉(zhuǎn)換為 TensorFlow concrete function。最后一步是調(diào)用 TensorFlow Lite 轉(zhuǎn)換器來(lái)將 concrete function 轉(zhuǎn)換為 TFLite 模型。

jax2tf

https://github.com/google/jax/tree/main/jax/experimental/jax2tf

# Convert to tflite model
 model = PolicyGradient()
 jax_predict_fn = lambda input: model.apply({'params': params}, input)


 tf_predict = tf.function(
     jax2tf.convert(jax_predict_fn, enable_xla=False),
     input_signature=[
         tf.TensorSpec(
             shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],
             dtype=tf.float32,
             name='input')
     ],
     autograph=False,
 )


 converter = tf.lite.TFLiteConverter.from_concrete_functions(
     [tf_predict.get_concrete_function()], tf_predict)


 tflite_model = converter.convert()


 # Save the model
 with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:
   f.write(tflite_model)

經(jīng) JAX 轉(zhuǎn)換的 TFLite 模型與任何經(jīng) TensorFlow 訓(xùn)練的 TFLite 模型會(huì)有完全一致的行為。您可以使用 Netron 進(jìn)行可視化:

242392fe-59d4-11ed-a3b6-dac502259ad0.png

使用 Netron 對(duì) Flax/JAX 轉(zhuǎn)換的 TFLite 模型進(jìn)行可視化

我們可以使用與之前完全一樣的 Java 代碼來(lái)調(diào)用模型并獲取預(yù)測(cè)結(jié)果。

convertBoardStateToByteBuffer(board);
tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i < probArray.length; i++) {
  int x = i / Constants.BOARD_SIZE;
  int y = i % Constants.BOARD_SIZE;
  if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
    agentStrikePosition = i;
    maxProb = probArray[i];
  }
}

總結(jié)

本文詳細(xì)介紹了如何使用 Flax/JAX 訓(xùn)練簡(jiǎn)單的強(qiáng)化學(xué)習(xí)模型、利用 jax2tf 將其轉(zhuǎn)換為 TensorFlow Lite,以及將轉(zhuǎn)換后的模型集成到 Android 應(yīng)用。

現(xiàn)在,您已經(jīng)了解了如何使用 Flax/JAX 構(gòu)建神經(jīng)網(wǎng)絡(luò)模型,以及如何利用強(qiáng)大的 TensorFlow 生態(tài)系統(tǒng),在幾乎任何您想要的位置部署模型。我們十分期待看到您使用 JAX 和 TensorFlow 構(gòu)建出色應(yīng)用!





審核編輯:劉清

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 神經(jīng)網(wǎng)絡(luò)

    關(guān)注

    42

    文章

    4779

    瀏覽量

    101171
  • TPU
    TPU
    +關(guān)注

    關(guān)注

    0

    文章

    143

    瀏覽量

    20783
  • MLP
    MLP
    +關(guān)注

    關(guān)注

    0

    文章

    57

    瀏覽量

    4288

原文標(biāo)題:使用 JAX 構(gòu)建強(qiáng)化學(xué)習(xí) agent,并借助 TensorFlow Lite 將其部署到 Android 應(yīng)用中

文章出處:【微信號(hào):tensorflowers,微信公眾號(hào):Tensorflowers】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    使用電腦上tensorflow創(chuàng)建的模型,轉(zhuǎn)換為tflite格式了,導(dǎo)入后進(jìn)度條反復(fù)出現(xiàn)0-100%變化,為什么?

    使用電腦上tensorflow創(chuàng)建的模型,轉(zhuǎn)換為tflite格式了,導(dǎo)入后,進(jìn)度條反復(fù)出現(xiàn)0-100%變化,卡了一個(gè)晚上了還沒分析好?
    發(fā)表于 03-19 06:20

    如何將采樣位移轉(zhuǎn)換為采樣速度

    我是新手,在Labview編程如何將采樣的位移轉(zhuǎn)換為速度?求圖解,謝謝
    發(fā)表于 04-25 14:56

    如何將秒數(shù)轉(zhuǎn)換為時(shí)間字符串?

    請(qǐng)問如何將數(shù)值型秒數(shù)轉(zhuǎn)換為時(shí)間字符串?比如3600s轉(zhuǎn)換為01:00:00
    發(fā)表于 03-30 13:15

    如何將傳統(tǒng)ANN轉(zhuǎn)換為SNN?

    SNN和ANN的區(qū)別是什么?如何將傳統(tǒng)ANN轉(zhuǎn)換為SNN?
    發(fā)表于 09-28 06:15

    如何將觸控芯片的IIC接口轉(zhuǎn)換為USB接口

    CH554是什么?CH554如何實(shí)現(xiàn)數(shù)據(jù)轉(zhuǎn)換如何將觸控芯片的IIC接口轉(zhuǎn)換為USB接口?
    發(fā)表于 02-24 07:54

    EIQ onnx模型轉(zhuǎn)換為tf-lite失敗怎么解決?

    問題: 而我們需要您幫助我們回答這些問題:a) Dose eIQ(版本 2.7.12)支持 onnx 模型轉(zhuǎn)換為 tflte 格式?(文件見附件)b) 找不到float16 的量化選項(xiàng),你知道
    發(fā)表于 03-31 08:03

    如何在MIMXRT1064評(píng)估套件上部署tflite模型?

    我有一個(gè)嬰兒哭聲檢測(cè) tflite (tensorflow lite) 文件,其中包含模型本身。我如何將模型部署到 MIMXRT1064-evk 以通過 MCUXpresso IDE
    發(fā)表于 04-06 06:24

    如何將DS_CNN_S.pb轉(zhuǎn)換為ds_cnn_s.tflite?

    用于圖像分類(eIQ tensflowlite 庫(kù))。從廣義上講,我正在尋找該腳本,您可能已經(jīng)使用該腳本 DS_CNN_S.pb 轉(zhuǎn)換為 ds_cnn_s.tflite我能夠查看兩個(gè)模型
    發(fā)表于 04-19 06:11

    Pytorch模型轉(zhuǎn)換為DeepViewRT模型時(shí)出錯(cuò)怎么解決?

    我最終可以在 i.MX 8M Plus 處理器上部署 .rtm 模型。 我遵循了 本指南,我 Pytorch 模型轉(zhuǎn)換為 ONNX 模型,
    發(fā)表于 06-09 06:42

    如何將Detectron2和Layout-LM模型轉(zhuǎn)換為OpenVINO中間表示(IR)和使用CPU插件進(jìn)行推斷?

    無(wú)法確定如何將 Detectron2* 和 Layout-LM* 模型轉(zhuǎn)換為OpenVINO?中間表示 (IR) 和使用 CPU 插件進(jìn)行推斷。
    發(fā)表于 08-15 06:23

    數(shù)學(xué)原理:如何將ADC代碼轉(zhuǎn)換為電壓(第1篇)

    許多初步了解模數(shù)轉(zhuǎn)換器(ADC)的人想知道如何將ADC代碼轉(zhuǎn)換為電壓?;蛘?,他們的問題是針對(duì)特定應(yīng)用,例如:如何將ADC代碼轉(zhuǎn)換回物理量,如
    發(fā)表于 04-18 03:30 ?4137次閱讀

    如何將Altera的SDC約束轉(zhuǎn)換為Xilinx XDC約束

    了解如何將Altera的SDC約束轉(zhuǎn)換為Xilinx XDC約束,以及需要更改或修改哪些約束以使Altera的約束適用于Vivado設(shè)計(jì)軟件。
    的頭像 發(fā)表于 11-27 07:17 ?5183次閱讀

    Android中使用TFLite c++部署

    之前的文章,我們跟大家介紹過如何使用NNAPI來(lái)加速TFLite-Android的inference(可參考使用NNAPI加速android-tflite的Mobilenet分類器...
    發(fā)表于 02-07 11:57 ?7次下載
    在<b class='flag-5'>Android</b>中使用<b class='flag-5'>TFLite</b> c++部署

    如何將簡(jiǎn)單的汽車轉(zhuǎn)換為無(wú)線遙控汽車

    電子發(fā)燒友網(wǎng)站提供《如何將簡(jiǎn)單的汽車轉(zhuǎn)換為無(wú)線遙控汽車.zip》資料免費(fèi)下載
    發(fā)表于 10-21 14:51 ?2次下載
    <b class='flag-5'>如何將</b>簡(jiǎn)單的汽車<b class='flag-5'>轉(zhuǎn)換為</b>無(wú)線遙控汽車

    如何將Android代碼轉(zhuǎn)換成JS代碼運(yùn)行

    Autojs這個(gè)工具,因?yàn)樗旧硎鞘褂玫腞hino引擎開發(fā)的,因此它可以把Android代碼轉(zhuǎn)換成JavaScript語(yǔ)法的代碼來(lái)運(yùn)行,Autojs提供了幾個(gè)相關(guān)的方法來(lái)輔助
    的頭像 發(fā)表于 03-03 14:05 ?2760次閱讀
    主站蜘蛛池模板: 色香视频在线 | 亚洲一级毛片免费在线观看 | 色老头·com 色老头成人免费综合视频 色老头久久久久 | 夜夜夜爽爽爽久久久 | 全黄h全肉边做边吃奶在线观看 | 色视频www在线播放国产人成 | 国产经典三级在线 | 久久综合婷婷 | 手机免费在线视频 | 四虎美女 | 欧美成人免费 | 夜夜操天天射 | 天天舔天天干天天操 | 黄色网在线播放 | 特黄大片aaaaa毛片 | 国产精品一区在线观看你懂的 | 日韩欧美一区二区三区视频 | 日日摸人人拍人人澡 | 女人张腿让男桶免费视频观看 | 久久综合偷偷噜噜噜色 | 中文字幕色婷婷在线精品中 | 久久精品国产99精品最新 | 青青热久免费精品视频在线观看 | 特级黄aaaaaaaaa毛片 | 亚洲国产成人久久午夜 | 日韩加勒比在线 | 欧美成人午夜精品免费福利 | 4438x成人网最大色成网站 | 色综合久久综合欧美综合图片 | 在线播放91灌醉迷j高跟美女 | 免费看黄色录像 | 亚洲欧美一区二区三区四区 | 一区二区3区免费视频 | 天天操天天爱天天干 | 91p0rn永久备用地址二 | 59日本人xxxxxxxxx69 | 日本一区二区免费看 | 亚欧乱色束缚一区二区三区 | 国外一级毛片 | 巨乳色最新网址 | 黄 色 录像成 人播放免费 |