一旦我們選擇了一個(gè)架構(gòu)并設(shè)置了我們的超參數(shù),我們就進(jìn)入訓(xùn)練循環(huán),我們的目標(biāo)是找到最小化損失函數(shù)的參數(shù)值。訓(xùn)練后,我們將需要這些參數(shù)來進(jìn)行未來的預(yù)測。此外,我們有時(shí)會希望提取參數(shù)以在其他上下文中重用它們,將我們的模型保存到磁盤以便它可以在其他軟件中執(zhí)行,或者進(jìn)行檢查以期獲得科學(xué)理解。
大多數(shù)時(shí)候,我們將能夠忽略參數(shù)聲明和操作的具體細(xì)節(jié),依靠深度學(xué)習(xí)框架來完成繁重的工作。然而,當(dāng)我們遠(yuǎn)離具有標(biāo)準(zhǔn)層的堆疊架構(gòu)時(shí),我們有時(shí)需要陷入聲明和操作參數(shù)的困境。在本節(jié)中,我們將介紹以下內(nèi)容:
-
訪問用于調(diào)試、診斷和可視化的參數(shù)。
-
跨不同模型組件共享參數(shù)。
import tensorflow as tf
我們首先關(guān)注具有一個(gè)隱藏層的 MLP。
torch.Size([2, 1])
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize() # Use the default initialization method
X = np.random.uniform(size=(2, 4))
net(X).shape
(2, 1)
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 1)
6.2.1. 參數(shù)訪問
讓我們從如何從您已知的模型中訪問參數(shù)開始。
當(dāng)通過類定義模型時(shí)Sequential
,我們可以首先通過索引模型來訪問任何層,就好像它是一個(gè)列表一樣。每個(gè)層的參數(shù)都方便地位于其屬性中。
When a model is defined via the Sequential
class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.
Flax and JAX decouple the model and the parameters as you might have observed in the models defined previously. When a model is defined via the Sequential
class, we first need to initialize the network to generate the parameters dictionary. We can access any layer’s parameters through the keys of this dictionary.
When a model is defined via the Sequential
class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.
我們可以如下檢查第二個(gè)全連接層的參數(shù)。
OrderedDict([('weight',
tensor([[-0.2523, 0.2104, 0.2189, -0.0395, -0.0590, 0.3360, -0.0205, -0.1507]])),
('bias', tensor([0.0694]))])
dense1_ (
Parameter dense1_weight (shape=(1, 8), dtype=float32)
Parameter dense1_bias (shape=(1,), dtype=float32)
)
FrozenDict({
kernel: Array([[-0.20739523],
[ 0.16546965],
[-0.03713543],
[-0.04860032],
[-0.2102929 ],
[ 0.163712 ],
[ 0.27240783],
[-0.4046879 ]], dtype=float32),
bias: Array([0.], dtype=float32),
})
我們可以看到這個(gè)全連接層包含兩個(gè)參數(shù),分別對應(yīng)于該層的權(quán)重和偏差。
6.2.1.1. 目標(biāo)參數(shù)
請注意,每個(gè)參數(shù)都表示為參數(shù)類的一個(gè)實(shí)例。要對參數(shù)做任何有用的事情,我們首先需要訪問基礎(chǔ)數(shù)值。做這件事有很多種方法。有些更簡單,有些則更通用。以下代碼從返回參數(shù)類實(shí)例的第二個(gè)神經(jīng)網(wǎng)絡(luò)層中提取偏差,并進(jìn)一步訪問該參數(shù)的值。
(torch.nn.parameter.Parameter, tensor([0.0694]))
參數(shù)是復(fù)雜的對象,包含值、梯度和附加信息。這就是為什么我們需要顯式請求該值。
除了值之外,每個(gè)參數(shù)還允許我們訪問梯度。因?yàn)槲覀冞€沒有為這個(gè)網(wǎng)絡(luò)調(diào)用反向傳播,所以它處于初始狀態(tài)。
True
(mxnet.gluon.parameter.Parameter, array([0.]))
Parameters are complex objects, containing values, gradients, and additional information. That is why we need to request the value explicitly.
In addition to the value, each parameter also allows us to access the gradient. Because we have not invoked backpropagation for this network yet, it is in its initial state.
array([[0., 0., 0., 0., 0., 0., 0., 0.]])
(jaxlib.xla_extension.Array, Array([0.], dtype=float32))
Unlike the other frameworks, JAX does not keep a track of the gradients over the neural network parameters, instead the parameters and the network are decoupled. It allows the user to express their computation as a Python function, and use the grad
transformation for the same purpose.
(tensorflow.python.ops.resource_variable_ops.ResourceVariable,
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>)
評論