联邦学习:算法详解与系统实现
上QQ阅读APP看书,第一时间看更新

5.2.1 基于BFGS的二阶优化方法

在本节中,我们基于二阶优化方法BFGS求解目标函数的局部极小值。区别于常用的机器学习中的分布式优化算法,联邦学习在训练时的时间和空间开销更大,主要原因在于:①多方通信交互时需要同态加密,单次乘法/加法的时间消耗可达普通操作的百倍以上,在同态加密下的运算成为整个系统的时间瓶颈;②同态加密下的数据量呈指数级别增长,给网络传输带来了挑战。因此,我们希望对传统的二阶优化方法进行改进,力求使尽量多的计算可以在不加密的情况下利用本地数据在本地完成。由于各个参与方数据集的重叠模式不同,很多样本所对应的ID有时只存在于一个参与方中,对于这部分数据相关的计算往往也不需要与其他参与方交互。因此,在计算中我们可以将这些样本的相关计算集中在本地完成。

1. BFGS算法

BFGS是一种高效的拟牛顿算法,利用目标函数的一阶和二阶信息来确定搜索方向。与牛顿法不同的是,它不直接计算Hessian矩阵,而是对它的逆矩阵进行近似。令近似的Hessian逆矩阵为H,在BFGS中H的迭代公式为:

其中,

2. H和g在多个参与方上的计算

如上文所述,为了进一步减小计算开销,我们希望将每个样本的一阶、二阶导数分为需要共同参与计算和仅需在本地完成两部分。因此,我们首先将参与方的每个样本按照重叠与否分成私有样本非私有样本

定义5.3(私有样本和非私有样本) 对任意参与方来说,若其某一样本与其他参与方均无重叠,则称此样本为这个参与方的私有样本。参与方j的私有样本记为,则非私有样本则定义为

(1)对于私有样本的计算:

令全空间S的维数为m,参与方pi的样本子空间Si的维数为mi,X为全体样本的集合,该集合的大小为n。方便起见,将损失函数l(f(x),y)=(y-f(x))2(在逻辑回归中为逻辑函数)简写为l(f),定义Wi为第i个参与方样本空间所对应的参数,按照定义5.2,则损失函数可以写成如下形式:

令Z=Wix,对于上式求梯度,应用链式法则我们可以得到:

注意到,在上式中,加号左边的项在计算时不需要其他参与方的任何信息,可以完全在本地完成计算并通过Uj映射到全空间S。

进一步求Hessian矩阵,

同理可知,在计算二阶导数时,私有样本的部分也可以完全在本地完成计算。在使用拟牛顿法近似L对W的Hessian矩阵时,可在每个参与方本地直接对进行近似,然后将近似结果映射到全空间。这样既降低了计算复杂度,又减小了近似稀疏高维矩阵时产生的误差。

(2)对于非私有样本:x∈Xj\

与私有样本的计算不同,非私有样本相关的计算无法由单一参与方完成,需要涉及数据交互。因此,为了保护参与方数据的安全性,我们需要遵照安全多方计算中的方法对计算过程加以保护。本节先给出明文的计算公式。在5.2.2节中,我们会详细介绍一种基于秘密共享和同态加密的计算协议。

■梯度计算

对于∀x∈Xj,令f(j)(x,W,α)为x在参与方j上的预测值,简写为f(j)(x)。

首先计算损失函数对训练参数wi的导数:

计算∂f/∂wi需要将来自各个参与方的预测值f(j)相加得到全局预测值f:

计算出f后接着基于可以得到∂L/∂wi

由于每种重叠模式都对应一个不同的α,因此参数α的个数取决于参与方数据集重叠模式的种类。给定一种重叠模式o,令符合这种重叠模式的样本集合为Xo,则L对αo的导数为:

由上式可总结出梯度的计算过程为:参与方在本地计算,然后多方聚合得到;然后再计算;最后再一次聚合各方结果得到一阶导数

■Hessian矩阵的近似

根据式(5.3),由于计算H时需要计算w、g在每一次迭代中的变化s和d,且由于s、d都无法由单独的参与方计算得到,因此需要在得到每个参与方在本轮的g、w之后再收集处理求得最终结果。