??
GAN網絡是近兩年深度學習領域的新秀,火的不行,本文旨在淺顯理解傳統GAN,分享學習心得。現有GAN網絡大多數代碼實現使用Python、torch等語言,這里,后面用matlab搭建一個簡單的GAN網絡,便于理解GAN原理。
GAN的鼻祖之作是2014年NIPS一篇文章:Generative Adversarial Net(https://arxiv.org/abs/1406.2661),可以細細品味。
分享一個目前各類GAN的一個論文整理集合
https://deephunt.in/the-gan-zoo-79597dc8c347
再分享一個目前各類GAN的一個代碼整理集合
https://github.com/zhangqianhui/AdversarialNetsPapers
▌開始
我們知道GAN的思想是是一種二人零和博弈思想(two-player game),博弈雙方的利益之和是一個常數,比如兩個人掰手腕,假設總的空間是一定的,你的力氣大一點,那你就得到的空間多一點,相應的我的空間就少一點,相反我力氣大我就得到的多一點,但有一點是確定的就是,我兩的總空間是一定的,這就是二人博弈,但是呢總利益是一定的。
引申到GAN里面就是可以看成,GAN中有兩個這樣的博弈者,一個人名字是生成模型(G),另一個人名字是判別模型(D)。他們各自有各自的功能。
相同點是:
這兩個模型都可以看成是一個黑匣子,接受輸入然后有一個輸出,類似一個函數,一個輸入輸出映射。
不同點是:
生成模型功能:比作是一個樣本生成器,輸入一個噪聲/樣本,然后把它包裝成一個逼真的樣本,也就是輸出。
判別模型:比作一個二分類器(如同0-1分類器),來判斷輸入的樣本是真是假。(就是輸出值大于0.5還是小于0.5)
直接上一張個人覺得解釋的好的圖說明:
在之前,我們首先明白在使用GAN的時候的2個問題
我們有什么?比如上面的這個圖,我們有的只是真實采集而來的人臉樣本數據集,僅此而已,而且很關鍵的一點是我們連人臉數據集的類標簽都沒有,也就是我們不知道那個人臉對應的是誰。
我們要得到什么?至于要得到什么,不同的任務得到的東西不一樣,我們只說最原始的GAN目的,那就是我們想通過輸入一個噪聲,模擬得到一個人臉圖像,這個圖像可以非常逼真以至于以假亂真。
好了再來理解下GAN的兩個模型要做什么。
首先判別模型,就是圖中右半部分的網絡,直觀來看就是一個簡單的神經網絡結構,輸入就是一副圖像,輸出就是一個概率值,用于判斷真假使用(概率值大于0.5那就是真,小于0.5那就是假),真假也不過是人們定義的概率而已。
其次是生成模型,生成模型要做什么呢,同樣也可以看成是一個神經網絡模型,輸入是一組隨機數Z,輸出是一個圖像,不再是一個數值而已。從圖中可以看到,會存在兩個數據集,一個是真實數據集,這好說,另一個是假的數據集,那這個數據集就是有生成網絡造出來的數據集。好了根據這個圖我們再來理解一下GAN的目標是要干什么:
判別網絡的目的:就是能判別出來屬于的一張圖它是來自真實樣本集還是假樣本集。假如輸入的是真樣本,網絡輸出就接近1,輸入的是假樣本,網絡輸出接近0,那么很完美,達到了很好判別的目的。
生成網絡的目的:生成網絡是造樣本的,它的目的就是使得自己造樣本的能力盡可能強,強到什么程度呢,你判別網絡沒法判斷我是真樣本還是假樣本。
有了這個理解我們再來看看為什么叫做對抗網絡了。判別網絡說,我很強,來一個樣本我就知道它是來自真樣本集還是假樣本集。生成網絡就不服了,說我也很強,我生成一個假樣本,雖然我生成網絡知道是假的,但是你判別網絡不知道呀,我包裝的非常逼真,以至于判別網絡無法判斷真假,那么用輸出數值來解釋就是,生成網絡生成的假樣本進去了判別網絡以后,判別網絡給出的結果是一個接近0.5的值,極限情況就是0.5,也就是說判別不出來了,這就是納什平衡了。
由這個分析可以發現,生成網絡與判別網絡的目的正好是相反的,一個說我能判別的好,一個說我讓你判別不好。所以叫做對抗,叫做博弈。那么最后的結果到底是誰贏呢?這就要歸結到設計者,也就是我們希望誰贏了。作為設計者的我們,我們的目的是要得到以假亂真的樣本,那么很自然的我們希望生成樣本贏了,也就是希望生成樣本很真,判別網絡能力不足以區分真假樣本位置。
▌再理解
知道了GAN大概的目的與設計思路,那么一個很自然的問題來了就是我們該如何用數學方法解決這么一個對抗問題。這就涉及到如何訓練這樣一個生成對抗網絡模型了,還是先上一個圖,用圖來解釋最直接:
需要注意的是生成模型與對抗模型可以說是完全獨立的兩個模型,好比就是完全獨立的兩個神經網絡模型,他們之間沒有什么聯系。
好了那么訓練這樣的兩個模型的大方法就是:單獨交替迭代訓練。
什么意思?因為是2個網絡,不好一起訓練,所以才去交替迭代訓練,我們一一來看。
假設現在生成網絡模型已經有了(當然可能不是最好的生成網絡),那么給一堆隨機數組,就會得到一堆假的樣本集(因為不是最終的生成模型,那么現在生成網絡可能就處于劣勢,導致生成的樣本就不咋地,可能很容易就被判別網絡判別出來了說這貨是假冒的),但是先不管這個,假設我們現在有了這樣的假樣本集,真樣本集一直都有,現在我們人為的定義真假樣本集的標簽,因為我們希望真樣本集的輸出盡可能為1,假樣本集為0,很明顯這里我們就已經默認真樣本集所有的類標簽都為1,而假樣本集的所有類標簽都為0.
有人會說,在真樣本集里面的人臉中,可能張三人臉和李四人臉不一樣呀,對于這個問題我們需要理解的是,我們現在的任務是什么,我們是想分樣本真假,而不是分真樣本中那個是張三label、那個是李四label。況且我們也知道,原始真樣本的label我們是不知道的。回過頭來,我們現在有了真樣本集以及它們的label(都是1)、假樣本集以及它們的label(都是0),這樣單就判別網絡來說,此時問題就變成了一個再簡單不過的有監督的二分類問題了,直接送到神經網絡模型中訓練就完事了。假設訓練完了,下面我們來看生成網絡。
對于生成網絡,想想我們的目的,是生成盡可能逼真的樣本。那么原始的生成網絡生成的樣本你怎么知道它真不真呢?就是送到判別網絡中,所以在訓練生成網絡的時候,我們需要聯合判別網絡一起才能達到訓練的目的。什么意思?就是如果我們單單只用生成網絡,那么想想我們怎么去訓練?誤差來源在哪里?細想一下沒有,但是如果我們把剛才的判別網絡串接在生成網絡的后面,這樣我們就知道真假了,也就有了誤差了。所以對于生成網絡的訓練其實是對生成-判別網絡串接的訓練,就像圖中顯示的那樣。好了那么現在來分析一下樣本,原始的噪聲數組Z我們有,也就是生成了假樣本我們有,此時很關鍵的一點來了,我們要把這些假樣本的標簽都設置為1,也就是認為這些假樣本在生成網絡訓練的時候是真樣本。
那么為什么要這樣呢?我們想想,是不是這樣才能起到迷惑判別器的目的,也才能使得生成的假樣本逐漸逼近為正樣本。好了,重新順一下思路,現在對于生成網絡的訓練,我們有了樣本集(只有假樣本集,沒有真樣本集),有了對應的label(全為1),是不是就可以訓練了?有人會問,這樣只有一類樣本,訓練啥呀?誰說一類樣本就不能訓練了?只要有誤差就行。還有人說,你這樣一訓練,判別網絡的網絡參數不是也跟著變嗎?沒錯,這很關鍵,所以在訓練這個串接的網絡的時候,一個很重要的操作就是不要判別網絡的參數發生變化,也就是不讓它參數發生更新,只是把誤差一直傳,傳到生成網絡那塊后更新生成網絡的參數。這樣就完成了生成網絡的訓練了。
在完成生成網絡訓練好,那么我們是不是可以根據目前新的生成網絡再對先前的那些噪聲Z生成新的假樣本了,沒錯,并且訓練后的假樣本應該是更真了才對。然后又有了新的真假樣本集(其實是新的假樣本集),這樣又可以重復上述過程了。我們把這個過程稱作為單獨交替訓練。我們可以實現定義一個迭代次數,交替迭代到一定次數后停止即可。這個時候我們再去看一看噪聲Z生成的假樣本會發現,原來它已經很真了。
看完了這個過程是不是感覺GAN的設計真的很巧妙,個人覺得最值得稱贊的地方可能在于這種假樣本在訓練過程中的真假變換,這也是博弈得以進行的關鍵之處。
▌進一步
文字的描述相信已經讓大多數的人知道了這個過程,下面我們來看看原文中幾個重要的數學公式描述,首先我們直接上原始論文中的目標公式吧:
上述這個公式說白了就是一個最大最小優化問題,其實對應的也就是上述的兩個優化過程。有人說如果不看別的,能達看到這個公式就拍案叫絕的地步,那就是機器學習的頂級專家,哈哈,真是前路漫漫。同時也說明這個簡單的公式意義重大。
這個公式既然是最大最小的優化,那就不是一步完成的,其實對比我們的分析過程也是這樣的,這里現優化D,然后在取優化G,本質上是兩個優化問題,把拆解就如同下面兩個公式:
優化D:
優化G:
可以看到,優化D的時候,也就是判別網絡,其實沒有生成網絡什么事,后面的G(z)這里就相當于已經得到的假樣本。優化D的公式的第一項,使的真樣本x輸入的時候,得到的結果越大越好,可以理解,因為需要真樣本的預測結果越接近于1越好嘛。對于假樣本,需要優化是的其結果越小越好,也就是D(G(z))越小越好,因為它的標簽為0。但是呢第一項是越大,第二項是越小,這不矛盾了,所以呢把第二項改成1-D(G(z)),這樣就是越大越好,兩者合起來就是越大越好。 那么同樣在優化G的時候,這個時候沒有真樣本什么事,所以把第一項直接卻掉了。這個時候只有假樣本,但是我們說這個時候是希望假樣本的標簽是1的,所以是D(G(z))越大越好,但是呢為了統一成1-D(G(z))的形式,那么只能是最小化1-D(G(z)),本質上沒有區別,只是為了形式的統一。之后這兩個優化模型可以合并起來寫,就變成了最開始的那個最大最小目標函數了。
所以回過頭來我們來看這個最大最小目標函數,里面包含了判別模型的優化,包含了生成模型的以假亂真的優化,完美的闡釋了這樣一個優美的理論。
▌再進一步
有人說GAN強大之處在于可以自動的學習原始真實樣本集的數據分布,不管這個分布多么的復雜,只要訓練的足夠好就可以學出來。針對這一點,感覺有必要好好理解一下為什么別人會這么說。
我們知道,傳統的機器學習方法,我們一般都會定義一個什么模型讓數據去學習。比如說假設我們知道原始數據屬于高斯分布呀,只是不知道高斯分布的參數,這個時候我們定義高斯分布,然后利用數據去學習高斯分布的參數得到我們最終的模型。再比如說我們定義一個分類器,比如SVM,然后強行讓數據進行東變西變,進行各種高維映射,最后可以變成一個簡單的分布,SVM可以很輕易的進行二分類分開,其實SVM已經放松了這種映射關系了,但是也是給了一個模型,這個模型就是核映射(什么徑向基函數等等),說白了其實也好像是你事先知道讓數據該怎么映射一樣,只是核映射的參數可以學習罷了。
所有的這些方法都在直接或者間接的告訴數據你該怎么映射一樣,只是不同的映射方法能力不一樣。那么我們再來看看GAN,生成模型最后可以通過噪聲生成一個完整的真實數據(比如人臉),說明生成模型已經掌握了從隨機噪聲到人臉數據的分布規律了,有了這個規律,想生成人臉還不容易。然而這個規律我們開始知道嗎?顯然不知道,如果讓你說從隨機噪聲到人臉應該服從什么分布,你不可能知道。這是一層層映射之后組合起來的非常復雜的分布映射規律。然而GAN的機制可以學習到,也就是說GAN學習到了真實樣本集的數據分布。
再拿原論文中的一張圖來解釋:
這張圖表明的是GAN的生成網絡如何一步步從均勻分布學習到正太分布的。原始數據x服從正太分布,這個過程你也沒告訴生成網絡說你得用正太分布來學習,但是生成網絡學習到了。假設你改一下x的分布,不管什么分布,生成網絡可能也能學到。這就是GAN可以自動學習真實數據的分布的強大之處。
還有人說GAN強大之處在于可以自動的定義潛在損失函數。什么意思呢,這應該說的是判別網絡可以自動學習到一個好的判別方法,其實就是等效的理解為可以學習到好的損失函數,來比較好或者不好的判別出來結果。雖然大的loss函數還是我們人為定義的,基本上對于多數GAN也都這么定義就可以了,但是判別網絡潛在學習到的損失函數隱藏在網絡之中,不同的問題這個函數就不一樣,所以說可以自動學習這個潛在的損失函數。
▌開始做小實驗
本節主要實驗一下如何通過隨機數組生成mnist圖像。mnist手寫體數據庫應該都熟悉的。這里簡單的使用matlab來實現,方便看到整個實現過程。這里用到了一個工具箱DeepLearnToolbox,關于該工具箱的一些其他使用說明:
DeepLearnToolbox
https://github.com/rasmusbergpalm/DeepLearnToolbox
其他使用說明
https://blog.csdn.net/dark_scope/article/details/9447967
網絡結構很簡單,就定義成下面這樣子:
將上述工具箱添加到路徑,然后運行下面代碼:
clcclear%% 構造真實訓練樣本 60000個樣本 1*784維(28*28展開)load mnist_uint8;train_x = double(train_x(1:60000,:)) / 255;% 真實樣本認為為標簽 [1 0];
生成樣本為[0 1];train_y = double(ones(size(train_x,1),1));% normalizetrain_x = mapminmax(train_x, 0, 1);rand('state',0)%% 構造模擬訓練樣本 60000個樣本 1*100維test_x = normrnd(0,1,[60000,100]); % 0-255的整數test_x = mapminmax(test_x, 0, 1);test_y =
double(zeros(size(test_x,1),1));test_y_rel = double(ones(size(test_x,1),1));%%nn_G_t = nnsetup([100 784]);nn_G_t.activation_function = 'sigm';nn_G_t.output = 'sigm';nn_D = nnsetup([784 100 1]);nn_D.weightPenaltyL2 = 1e-4; % L2 weight decaynn_D.dropoutFraction = 0.5; % Dropout fraction nn_D.learningRate = 0.01;
% Sigm require a lower learning ratenn_D.activation_function = 'sigm';nn_D.output = 'sigm';% nn_D.weightPenaltyL2 = 1e-4; % L2 weight decaynn_G = nnsetup([100 784 100 1]);nn_G.weightPenaltyL2 = 1e-4; % L2 weight decaynn_G.dropoutFraction = 0.5; % Dropout fraction nn_G.learningRate = 0.01;
% Sigm require a lower learning ratenn_G.activation_function = 'sigm';nn_G.output = 'sigm';% nn_G.weightPenaltyL2 = 1e-4; % L2 weight decayopts.numepochs = 1;
% Number of full sweeps through dataopts.batchsize = 100;
% Take a mean gradient step over this many samples%%num = 1000;ticfor each = 1:1500 %----------計算G的輸出:假樣本------------------- for i = 1:length(nn_G_t.W) %共享網絡參數
nn_G_t.W{i} = nn_G.W{i}; end G_output = nn_G_out(nn_G_t, test_x); %-----------訓練D------------------------------ index = randperm(60000); train_data_D = [train_x(index(1:num),:);G_output(index(1:num),:)];
train_y_D = [train_y(index(1:num),:);test_y(index(1:num),:)]; nn_D = nntrain(nn_D, train_data_D, train_y_D, opts);%訓練D %-----------訓練G------------------------------- for i = 1:length(nn_D.W) %共享訓練的D的網絡參數
nn_G.W{length(nn_G.W)-i+1} = nn_D.W{length(nn_D.W)-i+1}; end %訓練G:此時假樣本標簽為1,認為是真樣本
nn_G = nntrain(nn_G, test_x(index(1:num),:), test_y_rel(index(1:num),:), opts);endtocfor i = 1:length(nn_G_t.W)
nn_G_t.W{i} = nn_G.W{i};endfin_output = nn_G_out(nn_G_t, test_x);
函數nn_G_out為:
function output = nn_G_out(nn, x)
nn.testing = 1;
nn = nnff(nn, x, zeros(size(x,1), nn.size(end)));
nn.testing = 0;
output = nn.a{end};end
看一下這個及其簡單的函數,其實最值得注意的就是中間那個交替訓練的過程,這里我分了三步列出來:
重新計算假樣本(假樣本每次是需要更新的,產生越來越像的樣本)
訓練D網絡,一個二分類的神經網絡;
訓練G網絡,一個串聯起來的長網絡,也是一個二分類的神經網絡(不過只有假樣本來訓練),同時D部分參數在下一次的時候不能變了。
就這樣調一調參數,最終輸出在fin_output里面,多運行幾次顯示不同運行次數下的結果:
可以看到的是結果還是有點像模像樣的。
▌實驗總結
運行上述簡單的網絡我發現幾個問題:
網絡存在著不收斂問題;網絡不穩定;網絡難訓練;讀過原論文其實作者也提到過這些問題,包括GAN剛出來的時候,很多人也在致力于解決這些問題,當你實驗自己碰到的時候,還是很有意思的。那么這些問題怎么體現的呢,舉個例子,可能某一次你會發現訓練的誤差很小,在下一代訓練時,馬上又出現極限性的上升的很厲害,過幾代又發現訓練誤差很小,震蕩太嚴重。
其次網絡需要調才能出像樣的結果。交替迭代次數的不同結果也不一樣。比如每一代訓練中,D網絡訓練2回,G網絡訓練一回,結果就不一樣。
這是簡單的無條件GAN,所以每一代訓練完后,只能出現一個結果,那就是0-9中的某一個數。要想在一代訓練中出現好幾種結果,就需要使用到條件GAN了。
▌最后
現在的GAN已經到了五花八門的時候了,各種GAN應用也很多,理解底層原理再慢慢往上層擴展。GAN還是一個很厲害的東西,它使得現有問題從有監督學習慢慢過渡到無監督學習,而無監督學習才是自然界中普遍存在的,因為很多時候沒有辦法拿到監督信息的。要不Yann Lecun贊嘆GAN是機器學習近十年來最有意思的想法。
-
神經網絡
+關注
關注
42文章
4797瀏覽量
102364 -
GaN
+關注
關注
19文章
2138瀏覽量
75818 -
機器學習
+關注
關注
66文章
8481瀏覽量
133864
原文標題:一文詳解生成對抗網絡(GAN)的原理,通俗易懂
文章出處:【微信號:AI_Thinker,微信公眾號:人工智能頭條】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
STM32時鐘系統學習心得
嵌入式基礎學習心得
嵌入式學習心得

評論