Tensorflow 中的卷積神經(jīng)網(wǎng)絡(luò) (CNN)

卷積神經(jīng)網(wǎng)絡(luò) (CNN) 通過從圖像中自動學(xué)習(xí)特征的空間層次結(jié)構(gòu),徹底改變了計算機(jī)視覺領(lǐng)域。在本文中,我們將探討 CNN 的基本構(gòu)建塊,并向您展示如何使用 TensorFlow 實現(xiàn) CNN 模型

CNN 的構(gòu)建塊

CNN 由各層組成,每個層在處理和提取輸入圖像中的特征時執(zhí)行特定任務(wù)。主要構(gòu)建塊是:

卷積神經(jīng)-

卷積神經(jīng)網(wǎng)絡(luò)架構(gòu)

1. 卷積層

它接收一個輸入特征圖(可以是圖像)并應(yīng)用一組過濾器(或內(nèi)核)來創(chuàng)建新的特征圖。這些濾鏡從圖像中捕獲不同的特征,例如邊緣、角落和紋理。卷積作由濾波器大小步幅填充等參數(shù)控制

import tensorflow as tf 

conv_layer = tf.keras.layers.Conv2D( 
    filters=32, kernel_size=(3, 3), strides=(1, 1), padding='valid', 
    activation='relu', kernel_initializer='glorot_uniform', 
)

2. 池化層

它用于對特征圖進(jìn)行下采樣,即在保留最重要的信息的同時減小它們的大小。池化作有兩種主要類型:

  • Max Pooling:從特征圖的區(qū)域獲取最大值。
  • Average Pooling:從特征圖的區(qū)域中獲取平均值。
import tensorflow as tf

max_pooling_layer = tf.keras.layers.MaxPool2D(
    pool_size=(2, 2), strides=None, padding='valid', data_format=None
)

avg_pooling_layer = tf.keras.layers.AveragePooling2D(
    pool_size=(2, 2), strides=None, padding='valid', data_format=None
)

注: 本文假定您正在處理圖像數(shù)據(jù)(2D 數(shù)據(jù))。如果您正在處理其他類型的數(shù)據(jù),請參閱 TensorFlow API 了解特定維度選項。

3. 全連接層

它將上一層中的每個神經(jīng)元連接到下一層中的每個神經(jīng)元。它在 CNN 的最后層用于進(jìn)行預(yù)測。它執(zhí)行線性變換,后跟非線性激活函數(shù)。

import tensorflow as tf 

fully_connected_layer = tf.keras.layers.Dense( 
    units=128, activation='relu', kernel_initializer='glorot_uniform', 
)

TensorFlow 中的 CNN 實現(xiàn)

現(xiàn)在我們已經(jīng)了解了構(gòu)建塊,讓我們看看如何在 TensorFlow 中使用這些層實現(xiàn) CNN 模型。我們將在 cifar 數(shù)據(jù)集上實現(xiàn)它。

1. 導(dǎo)入庫

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt

2. 加載和預(yù)處理數(shù)據(jù)集

我們將使用 CIFAR-10 數(shù)據(jù)集。 它是一個流行的基準(zhǔn)數(shù)據(jù)集,用于機(jī)器學(xué)習(xí)和計算機(jī)視覺任務(wù),特別是用于圖像分類。它包含 60,000 張 32×32 彩色圖像,分為 10 個類,每個類 6,000 張圖像。

  • 歸一化:圖像中的像素值范圍為 0 到 255。我們通過除以 255 來將圖像縮放到 0 到 1 的范圍來標(biāo)準(zhǔn)化圖像。這有助于在訓(xùn)練期間實現(xiàn)模型收斂。
  • to_categorical()):將整數(shù)標(biāo)簽轉(zhuǎn)換為獨熱編碼格式,其中每個標(biāo)簽都表示為指示類的二進(jìn)制向量。
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_labels = to_categorical(train_labels, 10)
test_labels = to_categorical(test_labels, 10)

3. 定義 CNN 模型

  • 模型。Sequential():初始化層的線性堆棧,其中每個層只有一個輸入和一個輸出。
  • Conv2D:添加一個具有 32 個過濾器的卷積層,每個過濾器的大小為 (3, 3)。
  • MaxPool2D:添加最大池化層,以從上一個卷積層對特征圖進(jìn)行下采樣。
  • 展平:將卷積層的輸出展平為完全連接(密集)層所需的 1D 向量。
  • 密集:具有 128 個單元和 ReLU 激活的全連接層。

4. 編譯模型

  • Adam 優(yōu)化器用于基于梯度的優(yōu)化。它根據(jù)梯度的第一和第二矩調(diào)整學(xué)習(xí)率。
  • 分類交叉熵用作多類分類問題的損失函數(shù)。
  • metrics=['accuracy']:指定我們希望在訓(xùn)練和評估期間跟蹤準(zhǔn)確性。
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

5. 訓(xùn)練模型

  • 該模型將在整個數(shù)據(jù)集上訓(xùn)練 10 次迭代。
  • 在更新權(quán)重之前,該模型將一次處理 64 張圖像。
  • 測試集用于在每個 epoch 之后進(jìn)行驗證,以跟蹤模型在未見過的數(shù)據(jù)上的性能。
history = model.fit(train_images, train_labels, epochs=10, batch_size=64, validation_data=(test_images, test_labels))

輸出:

屏幕截圖-2025-02-28-133847

6. Evaluating the Model

test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc * 100:.2f}%")

Output:

Test accuracy: 70.05%

測試準(zhǔn)確率為 70%,這對于簡單的 CNN 模型來說是很好的,我們可以根據(jù)我們的任務(wù)優(yōu)化模型來進(jìn)一步提高其準(zhǔn)確率。

7. 繪制訓(xùn)練歷史

plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.show()
下載-

紀(jì)元與準(zhǔn)確性

從圖表中,我們可以觀察到訓(xùn)練準(zhǔn)確率穩(wěn)步提高,這表明模型正在隨著時間的推移而學(xué)習(xí)和改進(jìn)。然而,驗證準(zhǔn)確性顯示出一些波動,尤其是在穩(wěn)定之前的早期 epoch 中。這表明該模型可以很好地泛化到看不見的驗證數(shù)據(jù),盡管仍有改進(jìn)的余地,特別是在縮小訓(xùn)練和驗證準(zhǔn)確性之間的差距方面。


登錄后免費查看全文
立即登錄
App下載
技術(shù)鄰APP
工程師必備
  • 項目客服
  • 培訓(xùn)客服
  • 平臺客服

TOP

1