斯坦福大學聯合谷歌大腦使用「兩步蒸餾方法」提升無分類器指導的采樣效率,在生成樣本質量和采樣速度上都有非常亮眼的表現。
去噪擴散概率模型(DDPM)在圖像生成、音頻合成、分子生成和似然估計領域都已經實現了 SOTA 性能。同時無分類器(classifier-free)指導進一步提升了擴散模型的樣本質量,并已被廣泛應用在包括 GLIDE、DALL·E 2 和 Imagen 在內的大規模擴散模型框架中。
然而,無分類器指導的一大關鍵局限是它的采樣效率低下,需要對兩個擴散模型評估數百次才能生成一個樣本。這一局限阻礙了無分類指導模型在真實世界設置中的應用。盡管已經針對擴散模型提出了蒸餾方法,但目前這些方法不適用無分類器指導擴散模型。
為了解決這一問題,近日斯坦福大學和谷歌大腦的研究者在論文《On Distillation of Guided Diffusion Models》中提出使用兩步蒸餾(two-step distillation)方法來提升無分類器指導的采樣效率。
在第一步中,他們引入單一學生模型來匹配兩個教師擴散模型的組合輸出;在第二步中,他們利用提出的方法逐漸地將從第一步學得的模型蒸餾為更少步驟的模型。
利用提出的方法,單個蒸餾模型能夠處理各種不同的指導強度,從而高效地對樣本質量和多樣性進行權衡。此外為了從他們的模型中采樣,研究者考慮了文獻中已有的確定性采樣器,并進一步提出了隨機采樣過程。
研究者在 ImageNet 64x64 和 CIFAR-10 上進行了實驗,結果表明提出的蒸餾模型只需 4 步就能生成在視覺上與教師模型媲美的樣本,并且在更廣泛的指導強度上只需 8 到 16 步就能實現與教師模型媲美的 FID/IS 分數,具體如下圖 1 所示。
此外,在 ImageNet 64x64 上的其他實驗結果也表明了,研究者提出的框架在風格遷移應用中也表現良好。
方法介紹
接下來本文討論了蒸餾無分類器指導擴散模型的方法( distilling a classifier-free guided diffusion model)。給定一個訓練好的指導模型,即教師模型之后本文分兩步完成。
第一步引入一個連續時間學生模型,該模型具有可學習參數η_1,以匹配教師模型在任意時間步 t∈[0,1] 處的輸出。給定一個優化范圍 [w_min, w_max],對學生模型進行優化:
其中,。為了合并指導權重 w,本文引入了一個 w - 條件模型,其中 w 作為學生模型的輸入。為了更好地捕捉特征,本文還對 w 應用傅里葉嵌入。此外,由于初始化在模型性能中起著關鍵作用,因此本文初始化學生模型的參數與教師模型相同。
在第二步中,本文將離散時間步(discrete time-step)考慮在內,并逐步將第一步中的蒸餾模型轉化為步數較短的學生模型
,其可學習參數為η_2,每次采樣步數減半。設 N 為采樣步數,給定 w ~ U[w_min, w_max] 和 t∈{1,…, N},然后根據 Salimans & Ho 等人提出的方法訓練學生模型。在將教師模型中的 2N 步蒸餾為學生模型中的 N 步之后,之后使用 N 步學生模型作為新的教師模型,這個過程不斷重復,直到將教師模型蒸餾為 N/2 步學生模型。
N 步可確定性和隨機采樣:一旦模型訓練完成,給定一個指定的 w ∈ [w_min, w_max],然后使用 DDIM 更新規則執行采樣。
實際上,本文也可以執行 N 步隨機采樣,使用兩倍于原始步長的確定性采樣步驟,然后使用原始步長向后執行一個隨機步驟 。對于,當 t > 1/N 時,本文使用以下更新規則
實驗
實驗評估了蒸餾方法的性能,本文主要關注模型在 ImageNet 64x64 和 CIFAR-10 上的結果。他們探索了指導權重的不同范圍,并觀察到所有范圍都具有可比性,因此實驗采用 [w_min, w_max] = [0, 4]。圖 2 和表 1 報告了在 ImageNet 64x64 上所有方法的性能。
本文還進行了如下實驗。具體來說,為了在兩個域 A 和 B 之間執行風格遷移,本文使用在域 A 上訓練的擴散模型對來自域 A 的圖像進行編碼,然后使用在域 B 上訓練的擴散模型進行解碼。由于編碼過程可以理解為反向 DDIM 采樣過程,本文在無分類器指導下對編碼器和解碼器進行蒸餾,并與下圖 3 中的 DDIM 編碼器和解碼器進行比較。
-
編碼器
+關注
關注
45文章
3669瀏覽量
135245 -
模型
+關注
關注
1文章
3309瀏覽量
49224 -
分類器
+關注
關注
0文章
152瀏覽量
13225
原文標題:采樣提速256倍,蒸餾擴散模型生成圖像質量媲美教師模型,只需4步
文章出處:【微信號:CVSCHOOL,微信公眾號:OpenCV學堂】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
基于擴散模型的圖像生成過程
![基于<b class='flag-5'>擴散</b><b class='flag-5'>模型</b>的圖像生成過程](https://file1.elecfans.com/web2/M00/8C/E2/wKgaomS0rtWABJl7AAAWIxBb_zY535.png)
傳感器的故障分類與診斷方法
深度學習:知識蒸餾的全過程
針對遙感圖像場景分類的多粒度特征蒸餾方法
![針對遙感圖像場景<b class='flag-5'>分類</b>的多粒度特征<b class='flag-5'>蒸餾</b><b class='flag-5'>方法</b>](https://file.elecfans.com/web1/M00/E4/F3/pIYBAGBJ4VCAHLfkAAJmmzSr0ts539.png)
如何改進和加速擴散模型采樣的方法2
![如何改進和加速<b class='flag-5'>擴散</b><b class='flag-5'>模型</b>采樣的<b class='flag-5'>方法</b>2](https://file.elecfans.com/web2/M00/41/D9/pYYBAGJ2FOKAdn0EAAFSsy-pVec759.png)
若干蒸餾方法之間的細節以及差異
如何度量知識蒸餾中不同數據增強方法的好壞?
蒸餾也能Step-by-Step:新方法讓小模型也能媲美2000倍體量大模型
![<b class='flag-5'>蒸餾</b>也能Step-by-Step:新<b class='flag-5'>方法</b>讓小<b class='flag-5'>模型</b>也能媲美2000倍體量大<b class='flag-5'>模型</b>](https://file1.elecfans.com/web2/M00/82/BF/wKgaomRhjOGAQwvHAAAhs0sFkR4437.png)
基于移動自回歸的時序擴散預測模型
![基于移動自回歸的時序<b class='flag-5'>擴散</b>預測<b class='flag-5'>模型</b>](https://file1.elecfans.com/web3/M00/04/B1/wKgZPGd3fv-Abpt4AABRrqe_ON8694.png)
評論