在本工作中,來自阿德萊德大學、烏魯姆大學的研究者針對當前一致性學習出現的三個問題做了針對性的處理, 使得經典的 teacher-student 架構 (A.K.A Mean-Teacher) 在半監督圖像切割任務上得到了顯著的提升。
該研究已被計算機視覺頂會 CVPR 2022 大會接收,論文標題為《Perturbed and Strict Mean Teachers for Semi-supervised Semantic Segmentation》:
背景
語義分割是一項重要的像素級別分類任務。但是由于其非常依賴于數據的特性(data hungary), 模型的整體性能會因為數據集的大小而產生大幅度變化。同時, 相比于圖像級別的標注, 針對圖像切割的像素級標注會多花費十幾倍的時間。因此, 在近些年來半監督圖像切割得到了越來越多的關注。
半監督分割的任務依賴于一部分像素級標記圖像和無標簽圖像 (通常來說無標簽圖像個數大于等于有標簽個數),其中兩種類型的圖像都遵從相同的數據分布。該任務的挑戰之處在于如何從未標記的圖像中提取額外且有用的訓練信號,以使模型的訓練能夠加強自身的泛化能力。
在當前領域內有兩個比較火熱的研究方向, 分別是自監督訓練(self-training) 和 一致性學習 (consistency learning)。我們的項目主要基于后者來進行。
一致性學習的介紹
簡單來說, 一致性學習(consistency learning)過程可以分為 3 步來描述: 1)。 用不做數據增強的 “簡單” 圖像來給像素區域打上偽標簽, 2)。 用數據增強 (或擾動) 之后的 “復雜” 圖片進行 2 次預測, 和 3)。 用偽標簽的結果來懲罰增強之后的結果。
可是, 為什么要進行這 3 步呢? 先用簡單圖像打標簽, 復雜圖像學習的意義在哪?
從細節來說, 如上圖所示, 假設我們有一個像素的分類問題 (在此簡化為 2 分類, 左下的三角和右上的圓圈) 。我們假設中間虛線為真實分布, 藍色曲線為模型的判別邊界。
在這個例子中, 假設這個像素的標簽是圓圈, 并且由 1)。 得到的偽標簽結果是正確的 (y_tilde=Circ.)。在 2)。 中如果像素的增強或擾動可以讓預測成三角類, 那么隨著 3)步驟的懲罰, 模型的判別邊界會 (順著紅色箭頭) 挪向真實分布。由此, 模型的泛化能力得到加強。
由此得出, 在 1)。 中使用 “簡單” 的樣本更容易確保偽標簽的正確性, 在 2)。 時使用增強后的 “復雜” 樣本來確保預測掉在邊界的另一端來增強泛化能力。可是在實踐中,
1)。 沒有經受過增強的樣本也很可能被判斷錯 (hard samples), 導致模型在學習過程中打的偽標簽正確性下降。
2)。 隨著訓練的進行, 一般的圖像增強將不能讓模型做出錯誤判斷。這時, 一致性學習的效率會大幅度下降。
3)。 被廣泛實用的半監督 loss 例如 MSE, 在切割任務里不能給到足夠的力量來有效的推動判別邊界。而 Cross-entropy 很容易讓模型過擬合錯誤標簽, 造成認知偏差 (confirmation bias)。
針對這三個問題, 我們提出了:
1)。 新的基于一致性的半監督語義分割 MT 模型。通過新引入的 teacher 模型提高未標記訓練圖像的分割精度。同時, 用置信加權 CE 損失 (Conf-CE) 代替 MT 的 MSE 損失,從而實現更強的收斂性和整體上更好的訓練準確性。
2)。 一種結合輸入、特征和網絡擾動結合的數據增強方式,以提高模型的泛化能力。
3)。 一種新型的特征擾動,稱為 T-VAT。它基于 Teacher 模型的預測結果生成具有挑戰性的對抗性噪聲進一步加強了 student 模型的學習效率。
方法介紹
1)。 Dual-Teacher Architecture
我們的方法基于 Mean-Teacher, 其中 student 的模型基于反向傳播做正常訓練。在每個 iteration 結束后, student 模型內的參數以 expotional moving average (EMA)的方式轉移給 teacher 模型。
在我們的方法中, 我們使用了兩個 Teacher 模型。在做偽標簽時, 我們用兩個 teacher 預測的結果做一個 ensemble 來進一步增強偽標簽的穩定性。我們在每一個 epoch 的訓練內只更新其中一個 teacher 模型的參數, 來增加兩個 teacher 之間的 diversity。
由于雙 teacher 模型并沒有參加到反向傳播的運算中, 在每個 iteration 內他們只會消耗很小的運算成本來更新參數。
2)。 Semi-supervised Loss
在訓練中, teacher 模型的輸出經過 softmax 后的置信度代表著它對對應偽標簽的信心。置信度越高, 說明這個偽標簽潛在的準確率可能會更高。在我們的模型中, 我們首先對同一張圖兩個 teacher 的預測取平均值。然后通過最后的 confidence 作為權重, 對 student 模型的輸出做一個基于 cross-entropy 懲罰。同時, 我們會舍棄掉置信度過低的像素標簽, 因為他們是噪音的可能性會更大。
3)。 Teacher-based Virtual Adversarial Training (T-VAT)
Virtual Adversarial Training (VAT) 是半監督學習中常用的添加擾動的方式, 它以部分反向傳播的方式來尋找能最大化預測和偽標簽距離的噪音。
在我們的模型中, dual-teacher 的預測比學生的更加準確, 并且 (由于 EMA 的更新方式使其) 更加穩定。我們使用 teacher 模型替代 student 來尋找擾動性最強的對抗性噪音, 進而讓 student 的預測出錯的可能性加大, 最后達到增強一致性學習效率的目的。
4)。 訓練流程
i)。 supervised part: 我們用 strong-augmentation 后的圖片通過 cross-entropy 來訓練 student 模型。
ii)。 unsupervised part: 我們首先喂給 dual-teacher 模型們一個 weak-augmentation 的圖片, 并且用他們 ensemble 的結果生成標簽。之后我們用 strong-augmentation 后的圖片喂給 student 模型。在通過 encoder 之后, 我們用 dual-teachers 來通過 T-VAT 尋找具有最強擾動性的噪音并且注入到 (student encoded 之后的) 特征圖里, 并讓其 decoder 來做最終預測。
iii)。 我們通過 dual-teachers 的結果用 conf-ce 懲罰 student 的預測
iv)。 基于 student 模型的內部參數, 以 EMA 的方式更新一個 teacher 模型。
實驗
1)。 Compare with SOTAs.
Pascal VOC12 Dataset:
訓練 log 可視化鏈接: https://wandb.ai/pyedog1976/PS-MT(VOC12)?workspace=user-pyedog1976
該數據集包含超過 13,000 張圖像和 21 個類別。它提供了 1,464 張高質量標簽的圖像用于訓練,1,449 圖像用于驗證,1,456 圖像用于測試。我們 follow 以往的工作, 使了 10582 張低質量標簽來做擴展學習, 并且使用了和相同的 label id。
Low-quality Experiments
該實驗從整個數據集中隨機 sample 不同 ratio 的樣本來當作訓練集 (其中包含高質量和低質量兩種標簽), 旨在測試模型在有不同數量的標簽時所展示的泛化能力。
在此實驗中, 我們使用了 DeeplabV3 + 當作架構, 并且用 ResNet50 和 ResNet101 得到了所有 ratio 的 SOTA。
High-quality Experiments
該實驗從數據集提供的高質量標簽內隨機挑取不同 ratio 的標簽, 來測試模型在極少標簽下的泛化能力。我們的模型在不同的架構下 (e.g., Deeplabv3+ and PSPNet) 都取得了最好的結果。
Cityscapes Dataset
訓練 log 可視化鏈接: https://wandb.ai/pyedog1976/PS-MT(City)?workspace=user-pyedog1976
Cityscapes 是城市駕駛場景數據集,其中包含 2,975 張訓練圖像、500 張驗證圖像和 1,525 張測試圖像。數據集中的每張圖像的分辨率為 2,048 ×1,024,總共有 19 個類別。
在 2021 年之前, 大多數方法用 712x712 作為訓練的 resolution, 并且拿 Cross-entropy 當作 supervised 的 loss function。在最近, 越來越多的方式傾向于用大 resolution (800x800)當作輸入, OHEM 當作 supervised loss function。為了公平的對比之前的工作, 我們分別對兩種 setting 做了單獨的訓練并且都拿到了 SOTA 的結果。
2)。 Ablation Learnings.
我們使用 VOC 數據集中 1/8 的 ratio 來進行消融實驗。原本的 MT 我們依照之前的工作使用了 MSE 的 loss 方式。可以看到, conf-CE 帶來了接近 3 個點的巨大提升。在這之后, T-VAT (teacher-based virtual adversarial training)使 student 模型的一致性學習更有效率, 它對兩個架構帶來了接近 1% 的提升。最后, dual-teacher 的架構給兩個 backbone 分別帶來了 0.83% 和 0.84% 的提升。
同時我們對比了多種針對 feature 的擾動的方法, 依次分別為不使用 perturbation, 使用 uniform sample 的噪音, 使用原本的 VAT 和我們提出的 T-VAT。T-VAT 依然帶來了最好的結果。
3)。 Improvements over Supervised Baseline.
我們的方法相較于相同架構但只使用 label part 的數據集的結果有了巨大提升。以 Pascal VOC12 為例, 在 1/16 的比率中 (即 662 張標記圖像), 我們的方法分別 (在 ResNet50 和 ResNet101 中) 超過了基于全監督訓練的結果 6.01% 和 5.97%。在其他 ratio 上,我們的方法也顯示出一致的改進。
總結
在本文中,我們提出了一種新的基于一致性的半監督語義分割方法。在我們的貢獻中,我們引入了一個新的 MT 模型,它基于多個 teacher 和一個 student 模型,它顯示了對促進一致性學習的未標記圖像更準確的預測,使我們能夠使用比原始 MT 的 MSE 更嚴格的基于置信度的 CE 來增強一致性學習的效率。這種更準確的預測還使我們能夠使用網絡、特征和輸入圖像擾動的具有挑戰性的組合,從而顯示出更好的泛化性。
此外,我們提出了一種新的對抗性特征擾動 (T-VAT),進一步增強了我們模型的泛化性。
-
模型
+關注
關注
1文章
3340瀏覽量
49269 -
計算機視覺
+關注
關注
8文章
1701瀏覽量
46144 -
數據集
+關注
關注
4文章
1209瀏覽量
24848
原文標題:基于一致性的半監督語義分割方法:刷新多項SOTA,還有更好泛化性
文章出處:【微信號:CVSCHOOL,微信公眾號:OpenCV學堂】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
LTE基站一致性測試的類別
順序一致性和TSO一致性分別是什么?SC和TSO到底哪個好?
一致性規劃研究
CMP中Cache一致性協議的驗證
EMI一致性測試調試方法
![EMI<b class='flag-5'>一致性</b>測試調試<b class='flag-5'>方法</b>](https://file.elecfans.com/web2/M00/49/BA/pYYBAGKhvFWAXkWbAAAYurwEqNE739.jpg)
加速器一致性接口
Cache一致性協議優化研究
![Cache<b class='flag-5'>一致性</b>協議優化研究](https://file.elecfans.com/web2/M00/49/86/poYBAGKhwMOAGKC_AAAc6iUvydw085.jpg)
基于業務目標和業務場景的語義一致性驗證方法
搞定緩存一致性驗證,多核SoC設計就成功了一半
DDR一致性測試的操作步驟
深入理解數據備份的關鍵原則:應用一致性與崩潰一致性的區別
![深入理解數據備份的關鍵原則:應用<b class='flag-5'>一致性</b>與崩潰<b class='flag-5'>一致性</b>的區別](https://file1.elecfans.com/web2/M00/C4/A2/wKgaomXueUOAUC9kAAUkG4ifnAc542.png)
評論