1 导引
1.1 域泛化
域泛化(domain generalization, DG) [1][2]旨在从多个源域中学习一个能够泛化到未知目标域的模型。形式化地说,给定(K)个训练的源域数据集(mathcal{S}=left{mathcal{S}^k mid k=1, cdots, Kright}),其中第(k)个域的数据被表示为(mathcal{S}^k = left{left(x_i^k, y_i^kright)right}_{i=1}^{n^k})。这些源域的数据分布各不相同:(P_{X Y}^k neq P_{X Y}^l, 1 leq k neq l leq K)。域泛化的目标是从这(K)个源域的数据中学习一个具有强泛化能力的模型:(f: mathcal{X}rightarrow mathcal{Y}),使其在一个未知的测试数据集(mathcal{T})(即(mathcal{T})在训练过程中不可访问且(P_{X Y}^{mathcal{T}} neq P_{X Y}^i text { for } i in{1, cdots, K}))上具有最小的误差:
这里(mathbb{E})和(ell(cdot, cdot))分别为期望和损失函数。域泛化示意图如下图所示:
目前为了解决域泛化中的域漂移(domain shift) 问题,已经提出了许多方法,大致以分为下列三类:
-
数据操作(data manipulation) 这种方法旨在通过数据增强(data augmentation)或数据生成(data generation)方法来丰富数据的多样性,从而辅助学习更有泛化能力的表征。其中数据增强方法常利用数据变换、对抗数据增强(adversarial data augmentation)[3]等手段来增强数据;数据生成方法则通过Mixup(也即对数据进行两两线性插值)[4]等手段来生成一些辅助样本。
-
表征学习(representation learning) 这种方法旨在通过学习领域不变表征(domain-invariant representations),或者对领域共享(domain-shared)和领域特异(domain-specific)的特征进行特征解耦(feature disentangle),从而增强模型的泛化性能。该类方法我们在往期博客《寻找领域不变量:从生成模型到因果表征 》和《跨域推荐:嵌入映射、联合训练和解耦表征》中亦有详细的论述。其中领域不变表征的学习手段包括了对抗学习[5]、显式表征对齐(如优化分布间的MMD距离)[6]等等,而特征解耦则常常通过优化含有互信息(信息瓶颈的思想)或KL散度[7]的损失项来达成,其中大多数会利用VAE等生成模型。
-
学习策略(learning stategy) 这种方法包括了集成学习[8]、元学习[9]等学习范式。其中,以元学习为基础的方法则利用元学习自发地从构造的任务中学习元知识,这里的构造具体而言是指将源域数据集(mathcal{S})按照域为单位来拆分成元训练(meta-train)部分(bar{mathcal{S}})和元测试(meta-test)部分(breve{mathcal{S}})以便对分布漂移进行模拟,最终能够在目标域(mathcal{T})的final-test中取得良好的泛化表现。
1.2 联邦域泛化
然而,目前大多数域泛化方法需要将不同领域的数据进行集中收集。然而在现实场景下,由于隐私性的考虑,数据常常是分布式收集的。因此我们需要考虑联邦域泛化(federated domain generalization, FedDG) 方法。形式化的说,设(mathcal{S}=left{mathcal{S}^1, mathcal{S}^2, ldots, mathcal{S}^Kright})表示在联邦场景下的(K)个分布式的源域数据集,每个源域数据集包含数据和标签对(mathcal{S}^k=left{left(x_i^k, y_i^kright)right}_{i=1}^{n^k}),采样自域分布(P_{X Y}^k)。联邦域泛化的目标是利用(K)个分布式的源域学习模型(f_theta: mathcal{X} rightarrow mathcal{Y}),该模型能够泛化到未知的测试域(mathcal{T})。联邦域泛化的架构如下图所示:
这里需要注意的是,传统的域泛化方法常常要求直接对齐表征或操作数据,这在联邦场景下是违反数据隐私性的。此外对于跨域的联邦学习,由于客户端异构的数据分布/领域漂移(如不同的图像风格)所导致的模型偏差(bias),直接聚合本地模型的参数也会导致次优(sub-optimal)的全局模型,从而更难泛化到新的目标域。因此,许多传统域泛化方法在联邦场景下都不太可行,需要因地制宜进行修改,下面试举几例:
-
对于数据操作的方法,我们常常需要用其它领域的数据来对某个领域的数据进行增强(或进行新数据的插值生成),而这显然违反了数据隐私。目前论文的解决方案是不直接传数据,而传数据的统计量来对数据进行增强[10],这里的统计量指图片的style(即图片逐通道计算的均值和方差)等等。
-
对于表征学习的方法,也需要在对不同域的表征进行共享/对比的条件下获得领域不变表征(或对表征进行分解),而传送表征事实上也违反了数据隐私。目前论文采用的解决方案包括不显式对齐表征,而是使得所有领域的表征显式/隐式地对齐一个参考分布(reference distribution)[11][12],这个参考分布可以是高斯,也可以由GAN来自适应地生成。
-
基于学习策略的方法,如元学习也需要利用多个域的数据来构建meta-train和meta-test,并进行元更新(meta-update),而这也违反了数据隐私性。目前论文的解决方案是使用来自其它域的变换后数据来为当前域构造元学习数据集[13],这里的变换后数据指图像的幅度谱等等。
2 论文阅读
2.1 CVPR21《FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space》[13]
本篇论文是联邦域泛化的第一篇工作。这篇论文属于基于学习策略(采用元学习)的域泛化方法,并通过传图像的幅度谱(amplitude spectrum),而非图像数据本身来构建本地的元学习任务,从而保证联邦场景下的数据隐私性。本文方法的框架示意图如下:
这里(K)为领域/客户端的个数。该方法使图像的低级特征——幅度谱在不同客户端间共享,而使高级语义特征——相位谱留在本地。这里再不同客户端间共享的幅度谱就可以作为多领域/多源数据分布供本地元学习训练使用。
接下来我们看本地的元学习部分。元学习的基本思想是通过模拟训练/测试数据集的领域漂移来学得具有泛化性的模型参数。而在本文中,本地客户端的领域漂移来自不同分布的频率空间。具体而言,对每轮迭代,我们考虑本地的原输入图片(x_{i}^k)做为meta-train,它的训练搭档(mathcal{T}_i^{k})则由来自其它客户端的频域产生,做为meta-test来表示分布漂移。
设客户端(k)中的图片(x^k_i)由正向傅里叶变换(mathcal{F})得到的幅度谱为(mathcal{A}_i^k in mathbb{R}^{H times W times C}),相位谱为(mathcal{P}_i^k in mathbb{R}^{H times W times C})((C)为图片通道数)。本文欲在客户端之间交换低级分布也即幅度谱信息,因此需要先构建一个供所有客户端共享的distribution bank (mathcal{A} = [mathcal{A}^1, cdots, mathcal{A}^K]),这里(A^k = {{mathcal{A}^k_i}}^{n^k}_{i=1})包含了来自第(k)个客户端所有图片的幅度谱信息,可视为代表了(mathcal{X}^k)的分布。
之后,作者通过在频域进行连续插值的手段,将distribution bank中的多源分布信息送到本地客户端。如上图所示,对于第(k)个客户端的图片幅度谱(mathcal{A}_i^{k}),我们会将其与另外(K-1)个客户端的幅度谱进行插值,其中与第(l(lneq k))个外部客户端的图片幅度谱(mathcal{A}_j)插值的结果表示为:
这里(mathcal{M})是一个控制幅度谱内低频成分比例的二值掩码,(lambda)是插值率。然后以此通过反向傅里叶变换生成变换后的图片:
就这样,对于第(k)个客户端的输入图片(x^k_i),我们就得到了属于不同分布的(K-1)个变换后的图片数据(mathcal{T}^k_i = {x^{krightarrow l}_i}_{lneq k}),这些图片和(x^k_i)共享了相同的语义标签。
接下来在元学习的每轮迭代中,我们将原始数据(x^k_i)做为meta-train,并将其对应的(K-1)个由频域产生的新数据(mathcal{T}^k_i)做为meta-test来表示分布漂移,从而完成在当前客户端的inner-loop的参数更新。
具体而言,元学习范式可以被分解为两步:
第一步 模型参数(theta^k)在meta-train上通过segmentaion Dice loss (mathcal{L}_{seg})来更新:
这里参数(beta)表示内层更新的学习率。
第二步 在meta-test数据集(mathcal{T}^k_i)上使用元目标函数(meta objective)(mathcal{L}_{meta})对已更新的参数(hat{theta}^k)进行进一步元更新。
这里特别重要的是,第二步所要优化的目标函数由在第一部中所更新的参数(hat{theta}^k)计算,最终的优化结果覆盖掉原来的参数(theta^k)。
如果我们将一二步合在一起看,则可以视为通过下面目标函数来一起优化关于参数(theta^k)的内层目标函数和元目标函数:
最后,一旦本地训练完成,则来自所有客户端的本地参数(theta^k)会被服务器聚合并更新全局模型。
2.2 Arxiv21《Federated Learning with Domain Generalization 》[12]
本篇论文属于基于学习领域不变表征的域泛化方法,并通过使所有客户端的表征对齐一个由GAN自适应生成的参考分布,而非使客户端之间的表征互相对齐,来保证联邦场景下的数据隐私性。本文方法整体的架构如下图所示:
注意,这里所有客户端共享一个参考分布,而这通过共享同一个分布生成器(distribution generator)来实现。在训练过程一边使每个域(客户端)的数据分布会和参考分布对齐,一边最小化分布生成器的损失函数,使其产生的参考分布接近所有源数据分布的“中心”(这也就是”自适应“的体现)。一旦判别器很难区分从特征提取器中提取的特征和从分布生成器中所生成的特征,此时所提取的特征就被认为是跨多个源域不变的。这里的特征分布生成器的输入为噪声样本和标签的one-hot向量,它会按照一定的分布(即参考分布)生成特征。最后,作者还采用了随机投影层来使得判别器更难区分实际提取的特征和生成器生成的特征,使得对抗网络更稳定。在训练完成之后,参考分布和所有源域的数据分布会对齐,此时学得的特征表征被认为是通用(universal)的,能够泛化到未知的领域。
接下来我们来看GAN部分具体的细节。设(F(cdot))为特征提取器,(G(cdot))为分布生成器,(D(cdot))为判别器。设由特征提取器所提取的特征(mathbf{h} = F(mathbf{x}))(数据(mathbf{x})的生成分布为(p(mathbf{h}))),而由分布生成器所产生的特征为(mathbf{h}'= G(mathbf{z}))(噪声(mathbf{z})的生成分布为(p(mathbf{h}'))。我们设特征提取器所提取的特征为负例,生成器所生成的特征为正例。
于是,我们可以将判别器的优化目标定义为使将特征提取器所生成的特征(mathbf{h})判为正类的概率(D(mathbf{h}|mathbf{y}))更小,而使将生成器所生成的特征(mathbf{h}')判为正类的概率(D(mathbf{h}'|mathbf{y}))更大。
生成器尽量使判别器(D(cdot))将其生成特征(mathbf{h}')判别为正类的概率(Dleft(mathbf{h}^{prime} mid mathbf{y}right))更大,以求以假乱真:
特征提取器也需要尽量使得其所生成的特征(mathbf{h})能够以假乱真:
再加上图像分类本身的交叉熵损失(mathcal{L}_{err}),则总的损失定义为:
论文的最后,作者还对一个问题进行了探讨:关于这里的参考分布,我们为什么不用一个预先选好的确定的分布,要用一个自适应生成的分布呢?那是因为自适应生成的分布有一个重要的好处,那就是少对齐期间的失真(distortion)。作者对多个域/客户端的分布和参考分布进行了可视化,如下图所示:
(a)中为参考分布选择为固定的分布后,与各域特征对比的示意图,图(b)为参考分布选择为自适应生成的分布后,和各域特征对比的示意图。在这两幅图中,红色五角星表示参考分布的特征,除了五角星之外的每种形状代表一个域,每种颜色代表一个类别的样本。可以看到自适应生成的分布和多个源域数据分布的距离,相比固定参考分布和多个源域数据分布的距离更小,因此自适应生成的分布能够减少对齐期间提取特征表征的失真。而更好的失真也就意味着源域数据的关键信息被最大程度的保留,这让本文的方法所得到的表征拥有更好的泛化表现。
2.3 NIPS22 《FedSR: A Simple and Effective Domain Generalization Method for Federated Learning》[11]
本篇论文属于基于学习领域不变表征的域泛化方法,并通过使所有客户端的表征对齐一个高斯参考分布,而非使客户端之间的表征互相对齐,来保证联邦场景下的数据隐私性。本文的动机源于经典机器学习算法的思想,旨在学习一个“简单”(simple)的表征从而获得更好的泛化性能。
首先,作者以生成模型的视角,将表征(z)建模为从(p(z|x))中的采样,然后在此基础上定义领域(k)的分类目标函数以学得表征:
这里领域(k)的样本表征(z_j^{(i)})通过编码器+重参数化从(p(z|x_k^{(i)}))中采样产生。
接下来我们来看怎么使得表征更“简单”。本文采用了两个正则项,一个是关于表征的(L2)正则项来限制表征中所包含的信息;一个是在给定(y)的条件下,(x)与(z)的条件互信息(I(x, zmid y))(的上界)来使表征只学习重要的信息,而忽视诸如图片背景之类的伪相关性(spurious correlations)。
关于表征(z)的(L2)正则项定义如下:
于是,上式的微妙之处在于可以和领域不变表征联系起来,事实上我们有(mathcal{L}_k^{L 2 R}=mathbb{E}_{p_k(x)}left[mathbb{E}_{p(z mid x)}left[|z|_2^2right]right]=mathbb{E}_{p_k(x, z)}left[|z|_2^2right]=mathbb{E}_{p_k(z)}left[|z|_2^2right]=2 sigma^2 mathbb{E}_{p_k(z)}[-log q(z)]=2 sigma^2 Hleft(p_k(z), q(z)right)),这里(Hleft(p_k(z), q(z)right)=Hleft(p_k(z)right)+ D_{text{KL}} left[p_k(z) Vert q(z)right]),参考分布(q(z)=mathcal{N}left(0, sigma^2 Iright))。如果(H(p_i(z)))在训练中并未发生大的改变,那么最小化(l_k^{L2R})也就是在最小化(D_{text{KL}}[p_k(z) Vert q(z)]),也即在隐式地对齐一个参考的边缘分布(q(z)),而这就使得标准的边缘分布(p_k(z))是跨域不变的。注意该对齐是不需要显式地比较不同客户端分布的。
接下来我们来看条件互信息项。在信息瓶颈理论中,常对(x)和表征(z)之间的互信息项(I(x, z))进行最小化以对(z)中所包含的信息进行加以正则,但是这样的约束在实践中如果系数没调整好,就很可能过于严格了,毕竟它迫使表征不包含数据的信息。因此,在这篇论文中,作者选择最小化给定(y)时(x)和(z)之间的条件互信息。领域(k)的条件互信息被计算为:
直观地看,(bar{f}_k)和(I_k(x, zmid y))共同作用,迫使表征(z)仅仅拥有预测标签(y)使所包含的信息,而没有关于(x)的额外(即和标签无关的)信息。
然而,这个互信息项是难解(tractable)的,这是由于计算(p_k(z|y))很难计算(由于需要对(x)进行积分将其边缘化消掉)。因此,作者导出了一个上界来对齐进行最小化:
这里(r(z|y))可以是一个输入(y)输出分布(r(z|y))的神经网络,作者将其设置为高斯(mathcal{N}left(z ; mu_y, sigma_y^2right)),这里(u_y),(sigma^2_y)((y=1, 2, cdots, C))是需要优化的神经网络参数,(C)是类别数量。
事实上,该正则项和域泛化中的条件分布对齐亦有着理论上的联系,这是因为( mathcal{L}_k^{C M I}=mathbb{E}_{p_k(x, y)}[D_{text{KL}}[p(z mid x) Vert r(z mid y)]] geq mathbb{E}_{p_k(y)}left[D_{text{KL}}left[p_k(z mid y) Vert r(z mid y)right]right] )。因此,最小化(mathcal{L}_k^{CMI})我们必然就能够最小化(D_{text{KL}}left[p_k(z mid y) Vert r(z mid y)right])(因为(mathcal{L}^{CMI}_k)是其上界),使得(p_k(z|y))和(r(z|y))互相接近,即:(p_k(z|y)approx r(z|y))。因此,模型会尝试迫使(p_k(z mid y) approx p_l(z mid y)(approx r(z mid y)))(对任意客户端/领域(k, l))。这也就是说,我们是在做给定标签(y)时表征(z)的条件分布的隐式对齐,这在传统的领域泛化中是一种很常见与有效的技术,区别就是这里不需要显式地比较不同客户端的分布。
最后,每个客户端的总体目标函数可以表示为:
总结一下,这里(L2)范数正则项(mathcal{L}_k^{L2R})和给定标签时数据和表征的条件互信息(mathcal{L}_k^{CMI})(的上界)用于限制表征中所包含的信息。此外,(mathcal{L}_k^{L2R})将边缘分布(p_k(z))对齐到一个聚集在0周围的高斯分布,而(mathcal{L}_i^{CMI})则将条件分布(p_k(z|y))对齐到一个参考分布(在实验环节作者亦将其选择为高斯)。
2.4 WACV23 《Federated Domain Generalization for Image Recognition via Cross-客户端 Style Transfer》[10]
本篇论文属于基于数据操作的域泛化方法,并通过构造一个style bank供所有客户端共享(类似CVPR21那篇),以使客户端在不共享数据的条件下基于风格(style)来进行数据增强,从而保证联邦场景下的数据隐私性。本文方法整体的架构如下图所示:
如图所示,每个client的数据集都有自己的风格。且对于每个客户端而言,都会接受其余客户端的风格来进行数据增强。事实上,这样就可以使得分布式的客户端在不泄露数据的情况下拥有相似的数据分布 。在本方法中,所有客户端的本地模型都拥有一致的学习目标——那就是拟合来自于所有源域的styles,而这种一致性就避免了本地模型之间的模型偏差,从而避免了影响全局模型的效果。此外,本方法可和其它DG的方法结合使用,从而使得其它中心化的DG方法均能得到精度的提升。
关于本文采用的风格迁移模型,有下列要求:1、所有客户端共享的style不能够被用来对数据集进行重构,从而保证数据隐私性;2、用于风格迁移的方法需要是一个实时的任意风格迁移模型,以允许高效和直接的风格迁移。本文最终选择了AdaIN做为本地的风格迁移模型。整个跨客户端/领域风格迁移流程如下图所示:
可以看到,整个跨客户端/领域风格迁移流程分为了三个阶段:
1. Local style Computation
每个客户端需要计算它们的风格并上传到全局服务器。其中可选择单张图片风格(single image style)和整体领域风格(overall domain style )这两种风格来进行计算。
- 单张图片风格 单张图片风格是图片VGG特征的像素级逐通道(channel-wise)均值和方差。比如我们设在第(k)个客户端上,随机选取的图片索引为(i),其对应的VGG特征(F_k^{(i)}=Phi(I^{(i)}_k))(这里的(I^{(i)}_k)表示图像内容,(Phi)为VGG的编码器),单张图片风格可以被计算为:
如果单张图片风格被用于风格迁移,那么就需要将该客户端不同图片对应的多种风格都上传到服务器,从而避免单张图片的偏差并增加多样性。而这就需要建立本地图片的style bank (mathcal{S}_k^{single})并将其上传到服务器。这里作者随机选择(J)张图像的style加入了本地style bank:
- 整体领域风格 整体领域风格是领域层次的逐通道均值和方差,其中考虑了一个客户端中的所有图片。比如我们假设客户端(k)拥有(N_k)个训练图片和对应的VGG特征({F_k^{(1)}, F_k^{(2)}, ldots, F_{k}^{(N_k)}})。则该客户端的整体领域风格(S_k^{overall})为:
相比单张图片风格,整体领域风格的计算代价非常高。不过,由于每个客户端/领域只有一个领域风格(S_k^{overall}),选择上传整体领域风格到服务器的通信效率会更高。
2. Style Bank on Server
当服务器接收到来自各个客户端的风格时,它会将所有风格汇总为一个style bank (mathcal{B}) 并将其广播回所有客户端。在两种不同的风格共享模式下,style bank亦会有所不同。
- 单图像风格的style bank (mathcal{B})为:
- 整体领域风格的style bank (mathcal{B})为:
(mathcal{B}_{single})比(mathcal{B}_{overall})会消耗更多存储空间,因此后者会更加通信友好。
3. Local Style Transfer
当客户端(k)收到style bank (mathcal{B})后,本地数据会通过迁移(mathcal{B})中的风格来进行增强,而这就将其它领域的风格引入了当前客户端。作者设置了超参数(L in{1,2, ldots, K})做为增强级别,意为从style bank (mathcal{B})中随机选择(L)个域所对应的风格来对每个图片进行增强,因此(L)表明了增强数据集的多样性。设第(k)个客户端数据集大小为(N_k),则在进行跨客户端的领域迁移之后,增强后数据集的大小会变为(N_k times L)。其中对客户端(k)中的每张图片(I^{(i)}_k),其对应的每个被选中的域都会拥有一个style vector(S)被作为图像生成器(G)的输入。这里关于style vector的获取有个细节需要注意:假设我们选了域(k),如果迁移的是整体领域风格,则(S^{overall}_k)直接即可做为style vector;如果迁移的是单图片风格,则还会进一步从选中(mathcal{S}^{single}_k)中随机选择一个风格(S_k^{(i)})做为域(k)的style vector。对以上两种风格模式而言,如果一个域被选中,则其对应的风格化图片就会被直接加入增强后的数据集中。
参考
- [1] Wang J, Lan C, Liu C, et al. Generalizing to unseen domains: A survey on domain generalization[J]. IEEE Transactions on Knowledge and Data Engineering, 2022.
- [2] 王晋东,陈益强. 迁移学习导论(第2版)[M]. 电子工业出版社, 2022.
- [3] Volpi R, Namkoong H, Sener O, et al. Generalizing to unseen domains via adversarial data augmentation[C]. Advances in neural information processing systems, 2018, 31.
- [4] Zhou K, Yang Y, Qiao Y, et al. Domain generalization with mixstyle[C]. ICLR, 2021.
- [5] Li H, Pan S J, Wang S, et al. Domain generalization with adversarial feature learning[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 5400-5409.
- [6] Li Y, Gong M, Tian X, et al. Domain generalization via conditional invariant representations[C]//Proceedings of the AAAI conference on artificial intelligence. 2018, 32(1).
- [7] Ilse M, Tomczak J M, Louizos C, et al. Diva: Domain invariant variational autoencoders[C]//Medical Imaging with Deep Learning. PMLR, 2020: 322-348.
- [8] Qin X, Wang J, Chen Y, et al. Domain Generalization for Activity Recognition via Adaptive Feature Fusion[J]. ACM Transactions on Intelligent Systems and Technology, 2022, 14(1): 1-21.
- [9] Li D, Yang Y, Song Y Z, et al. Learning to generalize: Meta-learning for domain generalization[C]//Proceedings of the AAAI conference on artificial intelligence. 2018, 32(1).
- [10] Chen J, Jiang M, Dou Q, et al. Federated Domain Generalization for Image Recognition via Cross-Client Style Transfer[C]//Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision. 2023: 361-370.
- [11] Nguyen A T, Torr P, Lim S N. Fedsr: A simple and effective domain generalization method for federated learning[J]. Advances in Neural Information Processing Systems, 2022, 35: 38831-38843.
- [12] Zhang L, Lei X, Shi Y, et al. Federated learning with domain generalization[J]. arXiv preprint arXiv:2111.10487, 2021.
- [13] Liu Q, Chen C, Qin J, et al. Feddg: Federated domain generalization on medical image segmentation via episodic learning in continuous frequency space[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021: 1013-1023.
文章来源: 博客园
- 还没有人评论,欢迎说说您的想法!