越来越多的手机和平板电脑成为许多人的主要计算设备。这些设备上强大的传感器(包括摄像头、麦克风和GPS),加上它们经常被携带的事实,意味着它们可以访问前所未有的大量数据,其中大部分本质上是私人的。根据这些数据学习的模型持有承诺通过支持更智能的应用程序来大大提高可用性,但数据的敏感性意味着将其存储在集中位置存在风险和责任。
本文的主要贡献是
将来自移动设备的分散数据的训练问题(联邦学习)确定为一个重要的研究方向;
选择可以应用于该设置的简单实用的算法FedAvg;
对所提出的方法进行广泛的实证评估。
更具体地说,本文介绍了FedAvg算法,它将每个客户端上的局部随机梯度下降(SGD)与执行模型平均的服务器相结合。本文对该算法进行了广泛的实验,证明了它对不平衡和非IID数据分布具有鲁棒性,并且可以将在分散数据上训练深度网络所需的通信轮次减少几个数量级。
举例:
对于机器学习问题,对于样本((x_i,y_i))的损失为(f_i(w)),那么全局损失定义为:
在联邦学习问题中,假设有(K)个客户端,第(k)个客户端的数据集为(P_k),数据集大小(n_k=|P_k|)。那么对于客户端(k),该客户端数据的损失函数为:
[F_k(w)=frac{1}{n_k}sumlimits_{iin P_k}f_i(w) ]
全局的损失函数定义为客户端损失的加权平均:
[f(w)=overset{K}{sum}limits_{k=1}frac{n_k}{n}F_k(w) ]
对于数据集中到中心的情况,由于数据量较大,通信成本相对较小,计算成本较大。
通信成本指客户端与中央服务器之间传输数据所需的成本。联邦学习中,会受到移动设备带宽限制,同时客户端通常仅在有电源和有WiFi等情况下愿意参与优化,因此通信成本较大。而设备数据量小、手机有GPU等特性使得计算成本较小。
为了减小通信成本,方法:
以往工作没有考虑不平衡和非独立同分布数据,以及客户端数量少。
根据当前的模型(w_t)计算梯度(g_k=nabla F_k(w_t))。由于:
[nabla f(w_t)=nabla[overset{K}{sum}limits_{k=1}frac{n_k}{n}F_k(w_t)]=overset{K}{sum}limits_{k=1}frac{n_k}{n}g_k ]
那么中心服务器聚合梯度并进行更新的结果为:
[w_{t+1}leftarrow w_t-etanabla f(w_t)=w_t-etaoverset{K}{sum}limits_{k=1}frac{n_k}{n}g_k ]
上式也等价于客户端先在本地做一次梯度更新,中心服务器再对模型进行加权平均:
[w^k_{t+1}leftarrow w_t-eta g_k ]
[w_{t+1}leftarrow overset{K}{sum}limits_{k=1}frac{n_k}{n}w^k_{t+1} ]
写成上述第二种形式后,可以在做平均之前,多次迭代本地更新:
[w^kleftarrow w^k-etanabla F_k(w^k) ]
每个客户端可以多次计算上式得到本地在第(t)轮的最终模型,最后中心服务器将这些本地模型进行聚合得到(w^{t+1})。
这就是FedAvg的思想,该算法主要有三个超参数:
当(B=infty,E=1)时,FedAvg和FedSGD等价
这里还定义了每轮的本地更新次数:(u_k=Efrac{n_k}{B}),由该公式也可以算出,FedSGD每轮本地更新次数为1。
完整的伪代码:
至此我们可以简单比较FedSGD和FedAvg:
算法 | local | server |
---|---|---|
FedSGD | 计算本轮梯度 | 收集local的梯度,加权平均后作为server要下降的梯度 |
FedAvg | 多次梯度下降,得到本轮的本地模型 | 收集local的模型,加权平均后作为本轮得到的模型 |
聚合参数(theta):以(theta w+(1-theta)w^{'})对两个模型进行聚合,得到最终模型。
左图是使用两个初始模型(w,w^{'})训练不同数据得到的损失,右图是两模型使用同一个(w)初始化训练不同数据,可以看出右边损失较小,且当(theta=0.5)效果最好。因此在联邦学习实验中,每个客户端需要共享相同的初始化模型。
选取大小适中的数据集,以便研究超参数。
第一个任务是MNIST数字识别,使用两个模型:
多层感知机。2个隐藏层,每个隐藏层有200个单元,使用ReLU激活。
199210个参数:图像为(28times 28),转为一维后是784。第一层(784*200+偏置200),第二层(200*200+偏置200),第三层(200*10+偏置10)
(32*5*5)卷积+(2*2)最大池化+(64*5*5)卷积+(2*2)最大池化+512单元全连接+ReLU+Softmax
数据集划分:
分出来的数据集有iid和非iid,但都是平衡的。
第二个任务是字符预测,使用LSTM,读取一行字符预测下一个字符。
数据集是莎士比亚全集,每个说话角色为一个客户端,共1146个。每个客户端,前80%的行是训练集,后20%行是测试集。
数据集划分:
学习率设置在(10^{frac{1}{3}})到(10^{frac{1}{6}})区间。
(C)控制并行量,因此先改变(C)。
实验记录了MLP达到97测试集准确率和CNN达到99测试准确率所需要的通信轮数。
使用小批量,当(C=0.1)时效果就已经较好。为平衡计算效率和收敛速度,之后实验固定(C=0.1)。
在FedAvg算法部分,我们已经指出,每轮本地更新次数为(u_k=Efrac{n_k}{B})。在实验中设置独立同分布的更新次数为期望更新次数,即(u=Efrac{n}{kB})。
首先,对于两种任务,增加(B)都减少了通信轮数。
对于MNIST任务,iid效果比非iid更显著。实际生活我们设备上的数字也不会是规律性的,因此这种情况是该方法鲁棒的论证。
对于莎士比亚数据集,非iid效果很好,而这代表了我们在现实生活的数据分布(不同的人说话数量相差很大)。推测是某些客户端有较大的数据集,使本地训练更具有价值。
可以看出,FedAvg不仅减少通信轮数,还提高了测试精度(蓝色实线是FedSGD)。推测是模型平均会产生类似dropout正则化的收益。
对于非常大的本地迭代次数,FedAvg可能会停滞或发散。这一结果表明,对于某些模型,尤其是在收敛的后期阶段,减少每轮的本地计算量(即减小E或增大B)可能是有益的,就像衰减学习率一样。
数据集包含50000个训练数据和10000个测试数据,将其平均划分给100个客户端,每个客户端包含500个训练数据和100个测试数据。
使用的模型为两个卷积层+两个全连接层+一个线性变换层。
图像会经过裁剪为(24*24)、左右反转、调整对比度、亮度等预处理。
单机的SGD对比10个客户端的FedSGD和FedAvg:
现有的模型,对CIFAR分类任务的测试精度已经很高,但这里只要达到80%左右即可,原因是本文的目标是评估FedAvg方法,而非提高CIFAR测试精度。
不同学习率的影响:
为了证明方法在现实世界问题上有效,还在大规模的预测下一个单词任务上进行了实验。
训练数据集由来自大型社交网络的1000万个公开帖子组成。按作者对帖子进行了分组,总共有超过500,000名客户。文中将每个客户端的数据集限制为最多5000个单词,并对10000个作者的数据进行了测试。
参与评论
手机查看
返回顶部