歡迎光臨
每天分享高質量文章

變分自編碼器VAE:一步到位的聚類方案


作者丨蘇劍林

單位丨廣州火焰信息科技有限公司

研究方向丨NLP,神經網絡

個人主頁丨kexue.fm

由於 VAE 中既有編碼器又有解碼器(生成器),同時隱變數分佈又被近似編碼為標準正態分佈,因此 VAE 既是一個生成模型,又是一個特征提取器。


在圖像領域中,由於 VAE 生成的圖片偏模糊,因此大家通常更關心 VAE 作為圖像特征提取器的作用。提取特征都是為了下一步的任務準備的,而下一步的任務可能有很多,比如分類、聚類等。本文來關心“聚類”這個任務。


一般來說,用 AE 或者 VAE 做聚類都是分步來進行的,即先訓練一個普通的 VAE,然後得到原始資料的隱變數,接著對隱變數做一個 K-Means 或 GMM 之類的。但是這樣的思路的整體感顯然不夠,而且聚類方法的選擇也讓我們糾結。


本文介紹基於 VAE 的一個“一步到位”聚類思路,它同時允許我們完成無監督地完成聚類和條件生成。

理論


一般框架

回顧 VAE 的 loss(如果沒印象請參考再談變分自編碼器VAE:從貝葉斯觀點出發):

通常來說,我們會假設 q(z) 是標準正態分佈,p(z|x),q(x|z) 是條件正態分佈,然後代入計算,就得到了普通的 VAE 的 loss。


然而,也沒有誰規定隱變數一定是連續變數吧?這裡我們就將隱變數定為 (z,y),其中 z 是一個連續變數,代表編碼向量;y 是離散的變數,代表類別。直接把 (1) 中的 z 替換為 (z,y),就得到:



這就是用來做聚類的 VAE 的 loss 了。


分步假設


啥?就完事了?呃,是的,如果只考慮一般化的框架,(2) 確實就完事了。 


不過落實到實踐中,(2) 可以有很多不同的實踐方案,這裡介紹比較簡單的一種。首先我們要明確,在 (2 )中,我們只知道 p̃(x)(通過一批資料給出的經驗分佈),其他都是沒有明確下來的。於是為了求解 (2),我們需要設定一些形式。一種選取方案為:


代入 (2) 得到:


其實 (4) 式還是相當直觀的,它分佈描述了編碼和生成過程:


1. 從原始資料中採樣到 x,然後通過 p(z|x) 可以得到編碼特征 z,然後通過分類器 p(y|z) 對編碼特征進行分類,從而得到類別;


2. 從分佈 q(y) 中選取一個類別 y,然後從分佈 q(z|y) 中選取一個隨機隱變數 z,再通過生成器 q(x|z) 解碼為原始樣本。


具體模型


(4) 式其實已經很具體了,我們只需要沿用以往 VAE 的做法:p(z|x) 一般假設為均值為 μ(x) 方差為的正態分佈,q(x|z) 一般假設為均值為 G(z) 方差為常數的正態分佈(等價於用 MSE 作為 loss),q(z|y) 可以假設為均值為 μy 方差為 1 的正態分佈,至於剩下的 q(y),p(y|z),q(y) 可以假設為均勻分佈(它就是個常數),也就是希望每個類大致均衡,而 p(y|z) 是對隱變數的分類器,隨便用個 softmax 的網絡就可以擬合了。 


最後,可以形象地將 (4) 改寫為:



其中 z∼p(z|x) 是重引數操作,而方括號中的三項 loss,各有各的含義:


1. −log q(x|z) 希望重構誤差越小越好,也就是 z 儘量保留完整的信息;


2.希望 z 能儘量對齊某個類別的“專屬”的正態分佈,就是這一步起到聚類的作用;


3. KL(p(y|z)‖q(y)) 希望每個類的分佈儘量均衡,不會發生兩個幾乎重合的情況(坍縮為一個類)。當然,有時候可能不需要這個先驗要求,那就可以去掉這一項。

實驗

實驗代碼自然是 Keras 完成的了,在 MNIST 和 Fashion-MNIST 上做了實驗,表現都還可以。實驗環境:Keras 2.2 + TensorFlow 1.8 + Python 2.7。


代碼實現


代碼位於:

https://github.com/bojone/vae/blob/master/vae_keras_cluster.py 

其實註釋應該比較清楚了,而且相比普通的 VAE 改動不大。可能稍微有難度的是這個怎麼實現。因為 y 是離散的,所以事實上這就是一個矩陣乘法(相乘然後對某個公共變數求和,就是矩陣乘法的一般形式),用 K.batch_dot 實現。 


其他的話,讀者應該先弄清楚普通的 VAE 實現過程,然後再看本文的內容和代碼,不然估計是一臉懵的。


MNIST


這裡是 MNIST 的實驗結果圖示,包括類內樣本圖示和按類採樣圖示。最後還簡單估算了一下,以每一類對應的數目最多的那個真實標簽為類標簽的話,最終的 test 準確率大約有 84.5%,對比這篇文章 Unsupervised Deep Embedding for Clustering Analysis [1] 的結果(最高也是 84% 左右),感覺應該很不錯了。 


聚類圖示


按類採樣



Fashion-MNIST


這裡是 Fashion-MNIST [2] 的實驗結果圖示,包括類內樣本圖示和按類採樣圖示,最終的 test 準確率大約有 60.6%。 


聚類圖示



按類採樣



總結

文章簡單地實現了一下基於 VAE 的聚類演算法,演算法的特點就是一步到位,結合“編碼”、“聚類”和“生成”三個任務同時完成,思想是對 VAE 的 loss 的一般化。


感覺還有一定的提升空間,比如式 (4) 只是式 (2) 的一個例子,還可以考慮更加一般的情況。代碼中的 encoder 和 decoder 也都沒有經過仔細調優,僅僅是驗證想法所用。


參考文獻

[1]. Unsupervised Deep Embedding for Clustering Analysis Junyuan Xie, Ross Girshick, and Ali Farhadi in International Conference on Machine Learning (ICML), 2016.

[2]. https://github.com/zalandoresearch/fashion-mnist


點擊以下標題查看更多相關文章: 


#投 稿 通 道#

 讓你的論文被更多人看到 


如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢? 答案就是:你不認識的人。


總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。 


PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是最新論文解讀,也可以是學習心得技術乾貨。我們的目的只有一個,讓知識真正流動起來。

來稿標準:

• 稿件確系個人原創作品,來稿需註明作者個人信息(姓名+學校/工作單位+學歷/職位+研究方向) 

• 如果文章並非首發,請在投稿時提醒並附上所有已發佈鏈接 

• PaperWeekly 預設每篇文章都是首發,均會添加“原創”標誌


? 投稿郵箱:

• 投稿郵箱:[email protected] 

• 所有文章配圖,請單獨在附件中發送 

• 請留下即時聯繫方式(微信或手機),以便我們在編輯發佈時和作者溝通



?


現在,在「知乎」也能找到我們了

進入知乎首頁搜索「PaperWeekly」

點擊「關註」訂閱我們的專欄吧

關於PaperWeekly


PaperWeekly 是一個推薦、解讀、討論、報道人工智慧前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號後臺點擊「交流群」,小助手將把你帶入 PaperWeekly 的交流群里。

▽ 點擊 | 閱讀原文 | 查看作者博客

赞(0)

分享創造快樂