在前面幾節(jié)中,我們了解了概率論和隨機(jī)變量。為了將這一理論付諸實(shí)踐,讓我們介紹一下樸素貝葉斯分類器。這只使用概率基礎(chǔ)知識來讓我們執(zhí)行數(shù)字分類。
學(xué)習(xí)就是做假設(shè)。如果我們想要對以前從未見過的新數(shù)據(jù)示例進(jìn)行分類,我們必須對哪些數(shù)據(jù)示例彼此相似做出一些假設(shè)。樸素貝葉斯分類器是一種流行且非常清晰的算法,它假設(shè)所有特征彼此獨(dú)立以簡化計算。在本節(jié)中,我們將應(yīng)用此模型來識別圖像中的字符。
%matplotlib inline
import math
import tensorflow as tf
from d2l import tensorflow as d2l
d2l.use_svg_display()
22.9.1。光學(xué)字符識別
MNIST ( LeCun et al. , 1998 )是廣泛使用的數(shù)據(jù)集之一。它包含 60,000 張用于訓(xùn)練的圖像和 10,000 張用于驗(yàn)證的圖像。每個圖像包含一個從 0 到 9 的手寫數(shù)字。任務(wù)是將每個圖像分類為相應(yīng)的數(shù)字。
GluonMNIST
在模塊中提供了一個類data.vision
來自動從 Internet 檢索數(shù)據(jù)集。隨后,Gluon 將使用已經(jīng)下載的本地副本。train
我們通過將參數(shù)的值分別設(shè)置為True
或來指定我們是請求訓(xùn)練集還是測試集False
。每個圖像都是一個灰度圖像,寬度和高度都是28具有形狀(28,28,1). 我們使用自定義轉(zhuǎn)換來刪除最后一個通道維度。此外,數(shù)據(jù)集用無符號表示每個像素8位整數(shù)。我們將它們量化為二進(jìn)制特征以簡化問題。
data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
lambda x: torch.floor(x * 255 / 128).squeeze(dim=0)
])
mnist_train = torchvision.datasets.MNIST(
root='./temp', train=True, transform=data_transform, download=True)
mnist_test = torchvision.datasets.MNIST(
root='./temp', train=False, transform=data_transform, download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./temp/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00
0%| | 0/28881 [00:00
Extracting ./temp/MNIST/raw/train-labels-idx1-ubyte.gz to ./temp/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./temp/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00
Extracting ./temp/MNIST/raw/t10k-images-idx3-ubyte.gz to ./temp/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./temp/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00
Extracting ./temp/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./temp/MNIST/raw
((train_images, train_labels), (
test_images, test_labels)) = tf.keras.datasets.mnist.load_data()
# Original pixel values of MNIST range from 0-255 (as the digits are stored as
# uint8). For this section, pixel values that are greater than 128 (in the
# original image) are converted to 1 and values that are less than 128 are
# converted to 0. See section 18.9.2 and 18.9.3 for why
train_images = tf.floor(tf.constant(train_images / 128, dtype = tf.float32))
test_images = tf.floor(tf.constant(test_images / 128, dtype = tf.float32))
train_labels = tf.constant(train_labels, dtype = tf.int32)
test_labels = tf.constant(test_labels, dtype = tf.int32)
我們可以訪問一個特定的示例,其中包含圖像和相應(yīng)的標(biāo)簽。
我們的示例存儲在此處的變量中image
,對應(yīng)于高度和寬度為28像素。
我們的代碼將每個圖像的標(biāo)簽存儲為標(biāo)量。它的類型是 32位整數(shù)。
label,
評論