1、分布式多任务学习(Multi-task Learning, MTL)简介

我们在上一篇文章《分布式多任务学习论文阅读(一)多任务学习速览》(链接:https://www.cnblogs.com/orion-orion/p/15481054.html)中提到,实现多任务学习的一种传统的(非神经网络的)方法为增加一个正则项[1][2][3]

[begin{aligned} underset{textbf{W}}{min} & sum_{t=1}^{T} [frac{1}{m_t}sum_{i=1}^{m_t}L(y_{ti}, f(bm{x}_{ti}; bm{w}_t))]+lambda g(textbf{W})\ & = sum_{t=1}^{T} mathcal{l}_t(bm{w}_t)+lambda g(textbf{W})\ & = f(textbf{W}) + lambda g(textbf{W}) end{aligned} tag{1} ]

其中(g(textbf{W}))编码了任务的相关性(多任务学习的假定)并结合了(T)个任务;(lambda)是一个正则化参数,用于控制有多少知识在任务间共享。在许多论文中,都假设了损失函数(f(textbf{W}))是凸的,且是(Ltext{-Lipschitz})可导的(对(L>0)),然而正则项(g(textbf{W}))虽然常常不满足凸性(比如采用矩阵的核范数),但是我们认为其实接近凸的,因此对于式((1))可以采用近端梯度算法(proximal gradient methods)[4]来求解(在标准近端梯度法中,默认(g(mathbf{W}))是不可微的凸函数)。

不过现实情况比较复杂。当任务数量很大时,多任务学习的计算复杂度很高,这可能要求我们用多CPU/多GPU对学习算法进行加速;又或者(也是更为常见的情况,尤其在联邦学习中),当数据量很大时,数据经常会分片存储在不同的计算机甚至是不同的计算中心。比如,学习任务经常会涉及到不同的样本集(用于学习不同的任务):(mathcal{D}_1,..., mathcal{D}_T),这些样本集进场会存储在不同的地方。比如如果我想用不同医院的病例样本集进行多任务学习,那么不同医院的数据肯定各自存储在不同地方。不管是出于网络传输带宽考虑还是数据隐私考虑,想要将所有场所的数据集中在一起然后跑最优化算法显然不太现实(即使数据已经脱敏,大规模转移病人数据仍然是个很有争议的问题)。

以上的两种需求要求我们尽量使(T)个任务的梯度的计算分摊到(T)个不同的工作节点(worker)上。这样对MTL设计分布式算法就显得非常重要了,分布式算法旨在尽量将耗时的计算放在各分界点本地进行,然后再通过网络传输到中心节点。现实中的训练数据会非常庞大且位于不同的数据中心,我们不能将数据收集起来再训练 ,必须要将数据存放在各节点就地使用。

实际上由于正则项的存在和损失函数的复杂性,我们需要非常仔细地设计分布式多任务学习算法,在保证任务得到划分的同时而尽量不影响优化算法最终的收敛。

2、MTL的同步(synchronized)分布式数值优化算法

我们将会从MTL的单机数值优化方法开始,逐步说明分布式数值优化的必要性并介绍它的一种主要实现手段——同步分布式优化算法。
我们先来看单机数值优化,由于(g(mathbf{W}))正则项的不光滑性,MTL的目标函数常采用基于近端梯度的一阶优化方法进行求解,包括FISTA[5](近端梯度下降法的一个典型变种), SpaRSA[6]以及最近提出的二阶优化方法PNOPT[7]。下面我们简要回顾一下在这些方法中涉及到的两个关键计算步骤:

(1) 梯度计算(gradient computing) 设第(k)迭代步的参数矩阵为(mathbf{W}^{(k)}),目标函数光滑部分(f(mathbf{W}^{(k)}))的梯度由每个任务的损失函数单独计算梯度后拼接而得:

[nabla f(mathbf{W}^{(k)}) = (nabla mathcal{l_1}(bm{w}_{1}^{(k)}), ..., nabla mathcal{l}_T(bm{w}_T^{(k)})) tag{2} ]

(2) 近端映射(proximal mapping) 在梯度更新后,我们会计算

[hat{mathbf{W}} = mathbf{W}^{(k)} - eta nabla f(mathbf{W}^{(k)}) tag{3} ]

此处(eta)是迭代步长。不过请注意,此处只为完成了(f(mathbf{W}) + g(mathbf{W})) 中可微部分(f(mathbf{W}))的求导,(hat{mathbf{W}})尚不能做为我们下一步的搜索点,下一步的搜索点会经过近端映射(text{Prox}(hat{mathbf{W}}; eta, lambda, g))获得,该近端映射等价于求解下列优化问题:

[begin{aligned} mathbf{W}^{(k+1)} & = text{Prox}(hat{mathbf{W}};eta, lambda, g) \ & = underset{mathbf{W}}{text{argmin}} left( frac{1}{2 eta} ||mathbf{W} - hat{mathbf{W}}||_F^2 + lambda g(mathbf{W}) right) end{aligned} tag{4} ]

这样,我们就得到了下一步的搜索点(mathbf{W}^{(k+1)})

这里我们不失一般性,我们假设我们的数据集(mathcal{D}_1,..., mathcal{D}_T)分散存储在一个用星形网络连接的计算机集群中。每个单独的计算机系统我们称之为节点(node),工作节点(worker)或主体(agent)。第(t)个节点对于任务(t)的数据(mathcal{D}_t)拥有完全访问权,并能够进行数值计算(比如计算第(t)个任务的梯度(nabla mathcal{l_t}(bm{w}_t))。我们假定有一个中心节点(central server)能够收集所有任务节点(task agents)的数据,并进行近端映射操作。

我们接下来看如何分布式并行。因为(T)个任务的独立性,可以让第(t)个任务节点存储(bm{w}_t^{(k)}),然后负责计算梯度(nabla mathcal{l_t}(bm{w}_t^{(k)})),这样就很容易地并行化了。然后我们收集每个任务的梯度向量(nabla mathcal{l_t}(bm{w}_t^{(k)}))到中心节点并拼接得到(nabla f(mathbf{W}^{(k)})),然后计算(hat{mathbf{W}}),最后经过近端映射操作得到(mathbf{W}^{(k+1)})。然后再将(mathbf{W}^{(k+1)})拆分为(bm{w}_1^{(k+1)},...,bm{w}_T^{(k+1)})分别发送到(T)个任务节点,进行下一轮的迭代。整个分布式并行算法如下图所示:

同步迭代框架

因为必须要所有任务节点的梯度计算并收集完毕后,主节点才能进行下一步操作,所以上面这种方法被称为同步的(synchronized)。同步方法的最大弊端就是如果有一个或多个任务节点网络传输带宽过低,或者直接down掉,其他任务节点都会停下来等待(因为拿不到下一轮的数据)。因为多数一阶优化算法都需要经过很多轮迭代才能够收敛到一个特定的精度,在同步数值优化算法中的等待会造成不能容忍的算法运行时间和运算资源的极大浪费。

2、MTL的异步(asynchronized)分布式数值优化算法

上面提到的同步数值优化算法可能让一些读者想到MapReduce计算架构,这种架构很少用于迭代算法。比如我们在深度学习的训练中多采用参数服务器(Parameter Server)架构,这是一种异步数值优化的架构。在多任务学习的领域,也有学者提出了异步数值优化算法,接下来我们以《Asynchronous Multi-Task Learning》[8](IM Baytas等,2016)这篇论文为例,来介绍MTL的异步数值优化算法。

在本篇论文的异步数值优化算法中,中心节点只要收到了来自一个任务节点的已经算好的梯度,就会马上对模型的参数矩阵(mathbf{W})进行更新,而不用等待其他任务节点完成计算。中心节点和任务节点都会在内存中维护一份(mathbf{W})的拷贝,任务节点之间的拷贝可能会各不相同。AMTL(Asynchronized Multi-task learning)的收敛率分析可以参照另外两篇介绍ARock[9]计算框架介绍Tmac[10]计算框架的论文(这两篇论文对Krasnosel’skii-Mann
(KM) 迭代方法进行了改造,增加了异步并行坐标更新(asynchronous parallel coordinate update)的特性)。我们称一个任务节点被激活,当它进行(梯度)计算并与中心节点通信以进行更新。《Asynchronous Multi-Task Learning》这篇论文也提出了一个异步并行的框架,该框架基于以下关于激活率(activation rate)的假设:

假设1: 所有任务节点服从独立的泊松过程并且有相同的激活率。

该假设可以得到一个有用的结论,如果不同的任务节点的激活率不同,我们理论上可以调整迭代步长(eta)来调整迭代结果:如果任务节点的激活率很大,那么该任务节点被激活的可能性就会很大,从而我们应该降低
该任务节点对应的迭代步长(eta)(注意:因为是异步算法,每个任务节点都有其对应的迭代步长)。该论文提出了的一个动态迭代步长策略,具体细节在此略过。


接下来的推导会用到基于算子做优化的思想,我们这里做一下简要介绍。
很多优化问题可以转换为一个求映射的零点问题,即求一个(bm{x})使得映射(A(bm{x})=0)满足:

[A(bm{x}) = 0 tag{5} ]

比如我们求无约束优化问题,其最优解等价于求解梯度等于0,这里(A)就为求梯度;对于约束优化问题,我们可以转化为一个无约束对偶问题,这时的(A)就是求对偶问题的梯度。
而求解问题((5))的方法就是不动点迭代,也就是找到一个特定的算子(T),迭代地寻找解:

[bm{x}^{k+1} = T(bm{x}^k) tag{6} ]

问题((5))的解是算子(T)的稳定点(bm{x}^*),满足(bm{x}^* = T(bm{x}^*))
接下来我们讨论算子(T)该怎么定。目前已经提出了一下几种最常见和实用的算子:

(1) 前向算子:(T = I - eta A)

考虑凸问题(underset{bm{x}}{text{min}}f(bm{x})),该问题可转化为求解(nabla f(bm{x})=0),令(A(bm{x}) = nabla f(bm{x})),我们应用前向算子去求解该方程

[bm{x}^{k+1} = T(bm{x}^k) = x^k - eta nabla f(bm{x}^k) ]

聪明的你应该发现,这就是梯度迭代。

(2) 后向算子: (T = (I + eta A)^{-1})

还是该问题,我们运用该算子有:

[bm{x}^{k+1} = T(bm{x}^k) = ( I + eta nabla f)^{-1}(bm{x}^k) ]

这对应的就是我们近端迭代步骤:

[bm{x}^{k+1} = underset{bm{x}}{text{argmin}} left( frac{1}{2 eta} ||bm{x} - bm{x}^{k}||^2 + f(bm{x}^k) right) ]

(3) 前向后向分类分裂:(T =(I + eta B)^{-1}(I - eta A))(即前面这两个算子的结合)

考虑优化问题(underset{bm{x}}{text{min}} f(bm{x}) + g(bm{x})),其中(f(bm{x}))光滑,梯度为(f(bm{x})),而(g(bm{x}))不光滑,次梯度为(partial g(bm{x}))。该优化问题等价于找到(bm{x})满足(0 in nabla f(bm{x}) + partial f(bm{x}))。我们令(A(bm{x}) = nabla (bm{x}))(B(bm{x}) = partial g(bm{x})),这要就可以运用前向后向算子:

[bm{x}^{k+1} = T(bm{x}^k) = ( I + eta partial g)^{-1}(I - eta nabla f)(bm{x}^k) ]

由前面的讨论知,(bm{x}^{k + 1/2} = (I - eta nabla f)(bm{x}^k))是一个梯度迭代,(bm{x}^{k+1} = (I + eta partial g)^{-1}(bm{x}^{k+1/2}))是一个近端迭代。所以结合起来就能得到:

[bm{x}^{k+1} = underset{bm{x}}{text{argmin}} left( frac{1}{2 eta} ||bm{x} - (bm{x}^{k}-etanabla f(bm{x}^k) )||^2 + g(bm{x}^k) right) ]

这就是近端梯度迭代算法。
关于算子做优化我们就介绍到这里,其他详见算子优化相关书籍[13]和知乎文章[14])。


该论文就使用的是前向-后向算子分裂方法[11][12]来求解目标函数((1))。前向后向迭代如下:

[mathbf{W}^{k+1} = (I + eta lambda partial g)^{-1}(I - eta nabla f)(mathbf{W}^k) tag{6} ]

该迭代对(eta in (0, 2/L))会收敛到解。前面我们提到(nabla f(mathbf{W}))是可分的,比如可以写成(nabla f(mathbf{W}) = (nabla l_1(bm{w}_1),..., nabla l_T(bm{w}_T))),且此处的前向算子(I- eta nabla f)也是可分的。不过这里的后向算子((I + eta lambda partial g)^{-1})是不可分的,将导致后面无法并行。因此我们不能直接在前向-后向迭代上应用论文[9]提到的KM坐标更新法。不过,如果我们转换前向和后向的顺序,我们可以得到下列的后向-前向迭代:

[mathbf{V}^{k+1} = (I - eta nabla f)(I + eta lambda partial g)^{-1}(mathbf{V}^k) tag{7} ]

这里我们使用一个辅助矩阵(mathbf{V} in mathbb{R}^{d times T})来在更新中替代(mathbf{W}),这是因为前向-后向迭代和后向-前向迭代中的更新变量是不一样的。因此,由(mathbf{V}^*)得到(mathbf{W}^*)需要一个额外的后向迭代步骤。之后我们就可以在后向-前向迭代的基础上,按照论文[9]提出的坐标下降策略,从({1,2,...,T})中随机采样一个任务索引(t),来对任务块(bm{v}_t)(代指和任务(t)有关的变量)进行坐标更新了。更新步骤如下所示:

[bm{v}_t^{k+1} = (I - eta nabla mathcal{l}_t)((I + eta lambda partial g)^{-1}(mathbf{V}^k))_t tag{8} ]

这里(bm{v}_t in mathbb{R}^d)是任务(t)(bm{w}_t)相应的辅助变量。注意想要更新一个任务块(bm{v}_t)只需要在一个任务块上进行一个完整的后向步骤和前向步骤。再给出整体的AMTL算法之前,我们先给出算法的迭代步骤。该算法的迭代步遵循的为在论文[9]中讨论的基于坐标更新的KM迭代。KM迭代是一种代替不动点迭代(bm{x}^{k+1}= T(bm{x}^{k}))的方法,它的形式为:(bm{x}^{k+1} = bm{x}^k + eta_r (T bm{x}^k - bm{x}^k))。而论文[9]则更进一步,采用其坐标下降形式:

[bm{x}^{k+1}_i = bm{x}^k_i + eta_r (T bm{x}^k - bm{x}^k_i) tag{9} ]

这里(bm{eta})是随机向量。(i)为从({1,2....n})中随机采样的特征索引(采到(1,2,..,n)的概率可以不同,不过我们一般取均等概率),(n)(bm{x})的维度。至于该迭代式的推导过程具体可参照论文[9]
在这个问题中,我们设(text{Prox}_{eta lambda g}(hat{bm{v}}^k))为后向映射,我们有:

[begin{aligned} bm{v}_t^{k+1} &= bm{v}_t^k + eta_r left( (I - eta nabla mathcal{l}_t)((I + eta lambda partial g)^{-1}(mathbf{V}^k))_t - bm{v}_t^k right) \ & = bm{v}_t^k + eta_r left( left(text{Prox}_{eta lambda g}({mathbf{V}}^k)right)_t -eta nabla_{l_t}left( left( text{Prox}_{eta lambda g}(mathbf{V}^k) right)_t right) -bm{v}^k_t right) end{aligned} tag{10} ]

这就是本篇论文提到的迭代式。注意,(text{Prox}_{eta lambda g}(mathbf{V}^k))需要在主节点计算完毕,然后将(left( text{Prox}_{eta lambda g}(mathbf{V}^k) right)_t)发送给任务节点(t)

值得一提的是,机器学习中一般根据子问题的复杂性来选择前向-后向迭代和后向-前向的迭代。如果数据集((bm{x}_t, y_t))很大,此时后向步骤相比前向步骤更容易计算,我们会使用后向-前向迭代来进行坐标更新。具体到MTL的应用中,后向迭代步骤由式((4))的近端映射给出,而这通常有解析解(比如迹范数的奇异值的软阈值),适合先进行。另一方面,式((2))的梯度计算是典型的耗时瓶颈(尤其是数据集很大时),适合后进行。因此后向-前向迭代为分布式MTL提供了一个更为高效的优化框架。最后,我们注意到后向前向迭代当(eta in (0, 2/L))时是一个non-expansive 算子,因为前向和后向步骤都是non-expansive的。

整个异步多任务学习框架如下图所示:

异步多任务学习伪代码

在本篇AMTL论文中,任务节点并不共享内存(注意,在文献[9]中,各任务节点可访问一个共享内存),任务节点间不能通信,但它们都各自与主节点连接并能与之通信。任务节点和主节点之间的通信只有向量(bm{v}_t),相较各任务节点上存储的本地数据(D_t)很小。每个任务节点负责计算前向步骤;而主节点负责计算后向步骤,一旦更新的梯度从任务节点传来,就进行近端映射(近端映射也能够在多轮梯度更新后才进行,取决于梯度更新的速度)。因为每个任务节点只需要和该任务节点相关的任务块,故本篇论文进一步减少了任务节点和主节点之间的通信代价。

异步更新示意图

上面这幅图进一步描述了AMTL中的异步更新机制,包括中心节点和任务节点分别执行后向和前向步骤的顺序。在(t_1)时刻,任务节点2从中心节点接收了已完成近端映射的参数((text{Prox}_{eta lambda g}(textbf{V}^k))_2),之后在任务节点2上的前向(梯度)计算步骤就会马上启动。在伪代码步骤第6行所示的任务梯度下降更新完成后,任务2的参数(bm{v}_2)会被送回中心节点。当中心节点收到参数后,它会开始对整个(即包括所有任务的)参数矩阵进行近端映射。

然而,这个算法却会有潜在的不一致性(inconsistency)问题。如图所示,当(t_2)(t_3)之间,即任务节点2在执行计算时,任务节点4已经将其计算好的参数发送到中心节点并触发了近端映射。因此,中心节点的参数矩阵在任务节点2计算梯度时,就因响应任务节点4而更新。之后当任务节点2将算好的参数送到中心节点时,近端映射只能在不一致的数据上进行计算(数据由来自任务节点2的参数和之前已更新的参数混合而成)。 同样,任务节点4在(t_3)时刻收到参数并完成计算后,此时中心节点的参数已经被更新(因为任务节点2在(t_4和t_5)之间已触发近端映射),后面也会产生同样的问题。

为什么会有这种不一致性呢?这是因为AMTL中的数据读取是没有加内存锁的。因此,对于异步坐标更新模式,从中心节点读参数向量时会有不一致性的问题。这种由于后向迭代步骤产生的不一致性已经被论文考虑在了收敛率分析中,具体细节大家可以参见论文。
最后,这篇论文在异构网络环境下的版本代码已开源在Github上(链接: https://github.com/illidanlab/AMTL ),大家可前往学习。

参考文献

  • [1] Evgeniou T, Pontil M. Regularized multi--task learning[C]//Proceedings of the tenth ACM SIGKDD international conference on Knowledge discovery and data mining. 2004: 109-117.
  • [2] Zhou J, Chen J, Ye J. Malsar: Multi-task learning via structural regularization[J]. Arizona State University, 2011, 21.
  • [3] Zhou J, Chen J, Ye J. Clustered multi-task learning via alternating structure optimization[J]. Advances in neural information processing systems, 2011, 2011: 702.
  • [4] Ji S, Ye J. An accelerated gradient method for trace norm minimization[C]//Proceedings of the 26th annual international conference on machine learning. 2009: 457-464.
  • [5] A. Beck and M. Teboulle, “A fast iterative shrinkage-thresholding algorithm for linear inverse problems,” SIAM Journal on Imaging Sciences, vol. 2, no. 1, pp. 183–202, 2009.
  • [6] S. J. Wright, R. D. Nowak, and M. A. Figueiredo, “Sparse reconstruction by separable approximation,” IEEE Transactions on Signal Processing, vol. 57, no. 7, pp. 2479–2493, 2009.
  • [7] J. D. Lee, Y. Sun, and M. A. Saunders, “Proximal newton-type methods for minimizing composite functions,” SIAM Journal on Optimization, vol. 24, no. 3, pp. 1420–1443, 2014.
  • [8] Baytas I M, Yan M, Jain A K, et al. Asynchronous multi-task learning[C]//2016 IEEE 16th International Conference on Data Mining (ICDM). IEEE, 2016: 11-20.
  • [9] Z. Peng, Y. Xu, M. Yan, and W. Yin, “ARock: An algorithmic framework for asynchronous parallel coordinate updates,” SIAM Journal on Scientific Computing, vol. 38, no. 5, pp. A2851–A2879, 2016.
  • [10] B. Edmunds, Z. Peng, and W. Yin, “Tmac: A toolbox of modern asyncparallel, coordinate, splitting, and stochastic methods,” UCLA CAM Report 16-38, 2016.
  • [11] P. L. Combettes and V. R. Wajs, “Signal recovery by proximal forwardbackward splitting,” Multiscale Modeling & Simulation, vol. 4, no. 4, pp. 1168–1200, 2005.
  • [12] Z. Peng, T. Wu, Y. Xu, M. Yan, and W. Yin, “Coordinate-friendly structures, algorithms and applications,” Annals of Mathematical Sciences and Applications, vol. 1, pp. 57–119, 2016.
  • [13] Bauschke H H, Combettes P L. Convex analysis and monotone operator theory in Hilbert spaces[M]. New York: Springer, 2011.
  • [14] https://zhuanlan.zhihu.com/p/150605754
  • [15] 杨强等. 迁移学习[M].机械工业出版社, 2020.
内容来源于网络如有侵权请私信删除

文章来源: 博客园

原文链接: https://www.cnblogs.com/orion-orion/p/15487700.html

你还没有登录,请先登录注册
  • 还没有人评论,欢迎说说您的想法!