生成式對抗網絡 (GAN) |簡介

?

生成對抗網絡 (GAN) 由 Ian Goodfellow 于 2014 年首次提出。GAN 是一類功能強大的神經網絡,用于無監督學習。GAN 可以創造任何東西,無論你提供給他們什么,因為它是 Learn-Generate-Improve。要首先了解 GAN,您必須對
卷積神經網絡知之甚少。如果將圖像饋送到 CNN,CNN 經過訓練可以根據圖像的標簽對圖像進行分類,它會逐個像素分析圖像并通過 CNN 隱藏層中存在的節點,作為輸出,它會告訴圖像是關于什么的或它在圖像中看到什么。例如:如果 CNN 經過訓練對狗和貓進行分類,并且圖像被提供給該 CNN,它可以判斷該圖像中是狗還是貓。因此,它也可以稱為分類算法。GAN 有何不同?GAN 可以分為兩部分,即 Generator 和 Discriminator鑒別器–GANs 的這一部分可以被認為類似于 CNN 的作用。判別器是一個卷積神經網絡,由許多隱藏層和一個輸出層組成,這里的主要區別是 GAN 的輸出層只能有兩個輸出,這與 CNN 不同,CNN 可以有相對于它訓練的標簽數量的輸出。判別器的輸出可以是 1 或 0,因為為此任務專門選擇了激活函數,如果輸出為 1,則提供的數據是真實的,如果輸出為 0,則將其稱為假數據。Discriminator 在真實數據上進行訓練,因此它學會識別實際數據的外觀以及數據應該將哪些特征歸類為真實數據。 發電機–從名稱本身,我們可以理解它是一種生成算法。Generator 是一個逆卷積神經網絡,它的作用與 CNN 完全相反,因為在 CNN 中,實際圖像作為輸入給出,分類標簽預期作為輸出,但在 Generator 中,隨機噪聲(具有一些確切值的向量)作為該逆 CNN 的輸入,實際圖像預期作為輸出。簡單來說,它利用自己的想象力從一段數據中生成數據。 如上圖所示,一個隨機值向量作為 Inverse-CNN 的輸入,在通過隱藏層和激活函數后,接收圖像作為輸出。Generator 和 Discriminator 一起工作:正如我們已經討論過的,Discriminator 是在實際數據上訓練的,以分類給定的數據是否真實,因此 Discriminator 的工作是分辨什么是真實的,什么是假的。現在生成器開始從隨機輸入生成數據,然后將生成的數據作為輸入傳遞給判別器,現在判別器分析數據并檢查它被歸類為真實數的接近程度,如果生成的數據不包含足夠的特征而被判別器歸類為真實數,那么這些數據和與之相關的權重將使用反向傳播發送回生成器, 這樣它就可以重新調整與數據關聯的權重并創建比前一個更好的新數據。此新生成的數據再次傳遞給 Discriminator 并繼續。只要 Discriminator 每次數據都不斷將生成的數據分類為假數據,這個過程就會不斷重復被歸類為假數據,并且隨著每一次反向傳播,數據的質量會越來越好,并且總有一天 Generator 變得如此準確,以至于很難區分真實數據和 Generator 生成的數據。 簡單來說,Discriminator 是一個訓練有素的人,他可以分辨什么是真的,什么是假的,而 Generator 正試圖欺騙 Discriminator,讓他相信生成的數據是真實的,每一次失敗的嘗試,Generator 都會學習和改進自己以產生更真實的數據。它也可以說是 Generator 和 Discriminator 之間的競爭。

?

生成式對抗網絡 (GAN) |簡介的圖2 編輯

?

生成式對抗網絡 (GAN) |簡介的圖4 編輯

?

生成式對抗網絡 (GAN) |簡介的圖6 編輯

生成器和鑒別器的示例代碼:

1. 構建生成器

一個。What input to pass input to first layer of generator in initial stage :random_normal_dimensions這是一個超參數,用于定義您希望將向量中的多少個隨機數輸入生成器,作為生成圖像的起點。

灣。接下來需要注意的一點是,這里我們使用了 “selu” 激活函數而不是 “relu”,因為 “relu” 在對數據進行分類時具有去除噪聲的作用,可以防止負值抵消正值,但在 GAN 中,我們不想刪除數據。

  • Python3 語言
# You'll pass the random_normal_dimensions to the first dense layer of the generator
random_normal_dimensions = 32
 
### START CODE HERE ###
generator = keras.models.Sequential([
    keras.layers.Dense(7 * 7 * 128, input_shape=[random_normal_dimensions]),
    keras.layers.Reshape([7, 7, 128]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="SAME",
                                 activation="selu"),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding="SAME",
                                 activation="tanh")
     
     
])
### END CODE HERE ###
生成式對抗網絡 (GAN) |簡介的圖7

2. 構建判別器:

  • Python3 語言
### START CODE HERE ###
discriminator = keras.models.Sequential([
    keras.layers.Conv2D(64, kernel_size=5, strides=2, padding="SAME",
                        activation=keras.layers.LeakyReLU(0.2),
                        input_shape=[28, 28, 1]),
    keras.layers.Dropout(0.4),
    keras.layers.Conv2D(128, kernel_size=5, strides=2, padding="SAME",
                        activation=keras.layers.LeakyReLU(0.2)),
    keras.layers.Dropout(0.4),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation="sigmoid") 
     
     
])
### END CODE HERE ###
生成式對抗網絡 (GAN) |簡介的圖8

3. 編譯判別器:

這里我們用 binary_crossentropy loss 和 rmsprop 優化器編譯鑒別器。
將判別器設置為不訓練其權重(設置其 “trainable” 字段)。

  • Python3 語言
### START CODE HERE ###
discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
### END CODE HERE ###
生成式對抗網絡 (GAN) |簡介的圖9

4. 構建和編譯 GAN 模型 :
為 GAN 構建順序模型,傳遞包含生成器和判別器的列表。
使用二進制交叉熵損失和 rmsprop 優化器編譯模型。

  • Python3 語言
### START CODE HERE ###
gan = keras.models.Sequential([generator, discriminator])
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")
### END CODE HERE ###
生成式對抗網絡 (GAN) |簡介的圖10

5. 訓練 GAN:
第 1 階段

real_batch_size:獲取輸入批次的批次大小(它是張量的第 0 維)
noise:使用 tf.random.normal 生成噪聲。形狀為批量大小 x random_normal_dimension
個假圖像:使用您剛剛創建的生成器。傳入噪聲并產生假圖像。
mixed_images:將假圖與真圖拼接起來。
將軸設置為 0。
discriminator_labels:設置為 0。用于真實圖像和 1.用于虛假圖像。
將判別器設置為 trainable。
使用判別器的 train_on_batch() 方法對混合圖像和判別器標簽進行訓練。

第 2 階段

雜色:生成維度為 batch_size x 的隨機法線值
random_normal_dimensions Use real_batch_size。
Generator_labels:設置為 1。將假圖像標記為真實
生成器將生成標記為真實圖像的假圖像,并試圖欺騙判別器。
將判別器設置為 NOT be trainable。
在噪聲和生成器標簽上訓練 GAN。

?

登錄后免費查看全文
立即登錄
App下載
技術鄰APP
工程師必備
  • 項目客服
  • 培訓客服
  • 平臺客服

TOP

1
1