在我們對線性回歸的介紹中,我們介紹了各種組件,包括數據、模型、損失函數和優化算法。事實上,線性回歸是最簡單的機器學習模型之一。然而,訓練它使用許多與本書中其他模型所需的組件相同的組件。因此,在深入了解實現細節之前,有必要設計一些貫穿本書的 API。將深度學習中的組件視為對象,我們可以從為這些對象及其交互定義類開始。這種面向對象的實現設計將極大地簡化演示,您甚至可能想在您的項目中使用它。
受PyTorch Lightning等開源庫的啟發,在高層次上我們希望擁有三個類:(i)Module
包含模型、損失和優化方法;(ii)DataModule
提供用于訓練和驗證的數據加載器;(iii) 兩個類結合使用該類 Trainer
,這使我們能夠在各種硬件平臺上訓練模型。本書中的大部分代碼都改編自Module
and DataModule
。Trainer
只有在討論 GPU、CPU、并行訓練和優化算法時,我們才會涉及該類。
import time
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
import time
from dataclasses import field
from typing import Any
import jax
import numpy as np
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import time
import numpy as np
import tensorflow as tf
from d2l import torch as d2l
3.2.1. 公用事業
我們需要一些實用程序來簡化 Jupyter 筆記本中的面向對象編程。挑戰之一是類定義往往是相當長的代碼塊。筆記本電腦的可讀性需要簡短的代碼片段,穿插著解釋,這種要求與 Python 庫常見的編程風格不相容。第一個實用函數允許我們在創建類后將函數注冊為類中的方法。事實上,即使我們已經創建了類的實例,我們也可以這樣做!它允許我們將一個類的實現拆分成多個代碼塊。
def add_to_class(Class): #@save
"""Register functions as methods in created class."""
def wrapper(obj):
setattr(Class, obj.__name__, obj)
return wrapper
讓我們快速瀏覽一下如何使用它。我們計劃 A
用一個方法來實現一個類do
。我們可以先聲明類并創建一個實例,而不是在同一個代碼塊中A
同時 擁有兩者的代碼。do
A
a
do
接下來我們像往常一樣 定義方法,但不在 classA
的范圍內。相反,我們add_to_class
用類A
作為參數來裝飾這個方法。這樣做時,該方法能夠訪問 的成員變量,A
正如我們所期望的那樣,如果它已被定義為 的A
定義的一部分。讓我們看看當我們為實例調用它時會發生什么a
。
@add_to_class(A)
def do(self):
print('Class attribute "b" is', self.b)
a.do()
Class attribute "b" is 1
Class attribute "b" is 1
Class attribute "b" is 1
第二個是實用程序類,它將類 __init__
方法中的所有參數保存為類屬性。這使我們無需額外代碼即可隱式擴展構造函數調用簽名。
我們將其實施推遲到第 23.7 節。HyperParameters
要使用它,我們定義繼承自該方法并調用 save_hyperparameters
該方法的類__init__
。
self.a = 1 self.b = 2
There is no self.c = True
self.a = 1 self.b = 2
There is no self.c = True
self.a = 1 self.b = 2
There is no self.c = True
self.a = 1 self.b = 2
There is no self.c = True
最后一個實用程序允許我們在實驗進行時以交互方式繪制實驗進度。為了尊重更強大(和復雜)的TensorBoard,我們將其命名為ProgressBoard
。實現推遲到 第 23.7 節。現在,讓我們簡單地看看它的實際效果。
該方法在圖中 draw
繪制一個點,并在圖例中指定。可選的僅通過顯示來平滑線條(x, y)
label
every_n
1/n圖中的點。他們的價值是從平均n原始圖中的鄰居點。
class ProgressBoard(d2l.HyperParameters): #@save
"""The board that plots data points in animation."""
def __init__(self, xlabel=None, ylabel=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
fig=None, axes=None, figsize=(3.5, 2.5), display=True):
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImpleme
評論
查看更多