您可能已經(jīng)注意到,在回歸的情況下,從頭開始的實(shí)現(xiàn)和使用框架功能的簡(jiǎn)潔實(shí)現(xiàn)非常相似。分類也是如此。由于本書中的許多模型都處理分類,因此值得添加專門支持此設(shè)置的功能。本節(jié)為分類模型提供了一個(gè)基類,以簡(jiǎn)化以后的代碼。
import torch from d2l import torch as d2l
from mxnet import autograd, gluon, np, npx from d2l import mxnet as d2l npx.set_np()
from functools import partial import jax import optax from jax import numpy as jnp from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf from d2l import tensorflow as d2l
4.3.1. 類Classifier_
我們?cè)谙旅娑xClassifier類。在中,validation_step我們報(bào)告了驗(yàn)證批次的損失值和分類準(zhǔn)確度。我們?yōu)槊總€(gè)批次繪制一個(gè)更新num_val_batches 。這有利于在整個(gè)驗(yàn)證數(shù)據(jù)上生成平均損失和準(zhǔn)確性。如果最后一批包含的示例較少,則這些平均數(shù)并不完全正確,但我們忽略了這一微小差異以保持代碼簡(jiǎn)單。
class Classifier(d2l.Module): #@save """The base class of classification models.""" def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.
class Classifier(d2l.Module): #@save """The base class of classification models.""" def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.
We also redefine the training_step method for JAX since all models that will subclass Classifier later will have a loss that returns auxiliary data. This auxiliary data can be used for models with batch normalization (to be explained in Section 8.5), while in all other cases we will make the loss also return a placeholder (empty dictionary) to represent the auxiliary data.
class Classifier(d2l.Module): #@save """The base class of classification models.""" def training_step(self, params, batch, state): # Here value is a tuple since models with BatchNorm layers require # the loss to return auxiliary data value, grads = jax.value_and_grad( self.loss, has_aux=True)(params, batch[:-1], batch[-1], state) l, _ = value self.plot("loss", l, train=True) return value, grads def validation_step(self, params, batch, state): # Discard the second returned value. It is used for training models # with BatchNorm layers since loss also returns auxiliary data l, _ = self.loss(params, batch[:-1], batch[-1], state) self.plot('loss', l, train=False) self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state), train=False)
We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.
class Classifier(d2l.Module): #@save """The base class of classification models.""" def validation_step(self, batch): Y_hat = self(*batch[:-1]) self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)
默認(rèn)情況下,我們使用隨機(jī)梯度下降優(yōu)化器,在小批量上運(yùn)行,就像我們?cè)诰€性回歸的上下文中所做的那樣。
@d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=self.lr)
@d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): params = self.parameters() if isinstance(params, list): return d2l.SGD(params, self.lr) return gluon.Trainer(params, 'sgd', {'learning_rate': self.lr})
@d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): return optax.sgd(self.lr)
@d2l.add_to_class(d2l.Module) #@save def configure_optimizers(self): return tf.keras.optimizers.SGD(self.lr)
4.3.2. 準(zhǔn)確性
給定預(yù)測(cè)概率分布y_hat,每當(dāng)我們必須輸出硬預(yù)測(cè)時(shí),我們通常會(huì)選擇預(yù)測(cè)概率最高的類別。事實(shí)上,許多應(yīng)用程序需要我們做出選擇。例如,Gmail 必須將電子郵件分類為“主要”、“社交”、“更新”、“論壇”或“垃圾郵件”。它可能會(huì)在內(nèi)部估計(jì)概率,但最終它必須在類別中選擇一個(gè)。
當(dāng)預(yù)測(cè)與標(biāo)簽 class 一致時(shí)y,它們是正確的。分類準(zhǔn)確度是所有正確預(yù)測(cè)的分?jǐn)?shù)。盡管直接優(yōu)化精度可能很困難(不可微分),但它通常是我們最關(guān)心的性能指標(biāo)。它通常是基準(zhǔn)測(cè)試中的相關(guān)數(shù)量。因此,我們幾乎總是在訓(xùn)練分類器時(shí)報(bào)告它。
準(zhǔn)確度計(jì)算如下。首先,如果y_hat是一個(gè)矩陣,我們假設(shè)第二個(gè)維度存儲(chǔ)每個(gè)類別的預(yù)測(cè)分?jǐn)?shù)。我們使用argmax每行中最大條目的索引來(lái)獲取預(yù)測(cè)類。然后我們將預(yù)測(cè)的類別與真實(shí)的元素進(jìn)行比較y。由于相等運(yùn)算符== 對(duì)數(shù)據(jù)類型敏感,因此我們轉(zhuǎn)換 的y_hat數(shù)據(jù)類型以匹配 的數(shù)據(jù)類型y。結(jié)果是一個(gè)包含條目 0(假)和 1(真)的張量。求和得出正確預(yù)測(cè)的數(shù)量。
@d2l.add_to_class(Classifier) #@save def accuracy(self, Y_hat, Y, averaged=True): """Compute the number of correct predictions.""" Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) preds = Y_hat.argmax(axis=1).type(Y.dtype) compare = (preds == Y.reshape(-1)).type(torch.float32) return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier) #@save def accuracy(self, Y_hat, Y, averaged=True): """Compute the number of correct predictions.""" Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) preds = Y_hat.argmax(axis=1).astype(Y.dtype) compare = (preds == Y.reshape(-1)).astype(np.float32) return compare.mean() if averaged else compare @d2l.add_to_class(d2l.Module) #@save def get_scratch_params(self): params = [] for attr in dir(self): a = getattr(self, attr) if isinstance(a, np.ndarray): params.append(a) if isinstance(a, d2l.Module): params.extend(a.get_scratch_params()) return params @d2l.add_to_class(d2l.Module) #@save def parameters(self): params = self.collect_params() return params if isinstance(params, gluon.parameter.ParameterDict) and len( params.keys()) else self.get_scratch_params()
@d2l.add_to_class(Classifier) #@save @partial(jax.jit, static_argnums=(0, 5)) def accuracy(self, params, X, Y, state, averaged=True): """Compute the number of correct predictions.""" Y_hat = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, # BatchNorm Only *X) Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) preds = Y_hat.argmax(axis=1).astype(Y.dtype) compare = (preds == Y.reshape(-1)).astype(jnp.float32) return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier) #@save def accuracy(self, Y_hat, Y, averaged=True): """Compute the number of correct predictions.""" Y_hat = tf.reshape(Y_hat, (-1, Y_hat.shape[-1])) preds = tf.cast(tf.argmax(Y_hat, axis=1), Y.dtype) compare = tf.cast(preds == tf.reshape(Y, -1), tf.float32) return tf.reduce_mean(compare) if averaged else compare
4.3.3. 概括
分類是一個(gè)足夠普遍的問(wèn)題,它保證了它自己的便利功能。分類中最重要的是 分類器的準(zhǔn)確性。請(qǐng)注意,雖然我們通常主要關(guān)心準(zhǔn)確性,但出于統(tǒng)計(jì)和計(jì)算原因,我們訓(xùn)練分類器以優(yōu)化各種其他目標(biāo)。然而,無(wú)論在訓(xùn)練過(guò)程中哪個(gè)損失函數(shù)被最小化,有一個(gè)方便的方法來(lái)根據(jù)經(jīng)驗(yàn)評(píng)估我們的分類器的準(zhǔn)確性是有用的。
4.3.4. 練習(xí)
表示為L(zhǎng)v驗(yàn)證損失,讓Lvq是通過(guò)本節(jié)中的損失函數(shù)平均計(jì)算的快速而骯臟的估計(jì)。最后,表示為lvb最后一個(gè)小批量的損失。表達(dá)Lv按照Lvq, lvb,以及樣本和小批量大小。
表明快速而骯臟的估計(jì)Lvq是公正的。也就是說(shuō),表明E[Lv]=E[Lvq]. 為什么你還想使用Lv反而?
給定多類分類損失,表示為l(y,y′) 估計(jì)的懲罰y′當(dāng)我們看到y(tǒng)并給出一個(gè)概率p(y∣x), 制定最佳選擇規(guī)則y′. 提示:表達(dá)預(yù)期損失,使用 l和p(y∣x).
-
pytorch
+關(guān)注
關(guān)注
2文章
809瀏覽量
13786
發(fā)布評(píng)論請(qǐng)先 登錄
Pytorch模型訓(xùn)練實(shí)用PDF教程【中文】
pyhanlp文本分類與情感分析
將pytorch模型轉(zhuǎn)化為onxx模型的步驟有哪些
通過(guò)Cortex來(lái)非常方便的部署PyTorch模型
將Pytorch模型轉(zhuǎn)換為DeepViewRT模型時(shí)出錯(cuò)怎么解決?
textCNN論文與原理——短文本分類

結(jié)合BERT模型的中文文本分類算法

融合文本分類和摘要的多任務(wù)學(xué)習(xí)摘要模型

基于不同神經(jīng)網(wǎng)絡(luò)的文本分類方法研究對(duì)比
基于LSTM的表示學(xué)習(xí)-文本分類模型
PyTorch文本分類任務(wù)的基本流程
PyTorch教程4.3之基本分類模型

評(píng)論