腾讯QQ浏览器2021AI算法大赛,北大冠军团队经验分享,附详细代码( 四 )


过于激进的早停策略在比赛中仍然存在问题 。 如果使用贝叶斯优化只对全量验证数据建模 , 由于总体优化预算时间很少 , 早停会减少可用于建模的数据量 , 使得模型不能得到充分训练 。 为解决这一问题 , 我们引入插值方法 , 增加模型可训练数据 。
基于以上考量 , 最终我们的决赛算法在初赛贝叶斯优化算法的基础上 , 前期执行完整贝叶斯优化使模型得到较为充分的拟合 , 后期使用早停技术与插值法 , 加速超参数验证与搜索过程 。 下面将对早停模块做详细介绍 。
算法核心技术——早停模块介绍
早停方法
由于超参数配置之间的部分验证轮次均值大小关系与最终均值大小关系存在一定的相关性 , 我们受异步多阶段早停算法ASHA[5]的启发 , 设计了基于排名的早停算法:一个超参数如果到达需要判断早停的轮次 , 就计算其性能均值处于历史中同一轮次的超参数性能均值的排名 , 如果位于前1/eta , 则继续验证 , 否则执行早停 。
依据95%置信区间的含义 , 我们还设计了另一种早停方法 , 即使用置信区间判断当前超参数配置是否仍有验证价值 。 如果某一时刻 , 当前验证超参数的置信区间上界差于已完全验证的性能前10名配置的均值 , 则代表至少有95%的可能其最终均值差于前10名的配置 , 故进行早停 。 使用本地数据验证 , 以空间中前50名的配置对前1000名的配置使用该方法进行早停 , 早停准确率在99%以上 。
经过测试 , 结合贝叶斯优化时两种方法效果近似 , 我们最终选择使用基于排名的早停方法 。 无论是哪种方法 , 都需要设计执行早停的轮次 。 早停越早越激进 , 节省的验证时间越多 , 但是得到的数据置信度越低 , 后续执行插值时训练的模型就越不准确 。 为了权衡早停带来的时间收益和高精度验证带来的数据收益 , 我们选择只在第7轮(总共14轮)时判断每个配置是否应当早停 。 早停判断准则依据eta=2的ASHA算法 , 即如果当前配置均值性能处于已验证配置第7轮的后50% , 就进行早停 。
以下代码展示了基于排名的早停方法 。 首先统计各个早停轮次下已验证配置的性能并进行排序(比赛中我们使用早停轮次为第7轮) , 然后判断当前配置是否处于前1/eta(比赛中为前1/2) , 否则执行早停:
#基于排名的早停方法 , prune_eta=2 , prune_iters=[7]defprune_mean_rank(self,iteration_number,running_suggestions,suggestion_history):#统计早停阶段上已验证配置的性能并排序bracket=dict()forn_iterationinself.hps['prune_iters']:bracket[n_iteration]=list()forsuggestioninrunning_suggestions+suggestion_history:n_history=len(suggestion['reward'])forn_iterationinself.hps['prune_iters']:ifn_history>=n_iteration:bracket[n_iteration].append(suggestion['reward'][n_iteration-1]['value'])forn_iterationinself.hps['prune_iters']:bracket[n_iteration].sort(reverse=True)#maximize#依据当前配置性能排名 , 决定是否早停stop_list=[False]*len(running_suggestions)fori,suggestioninenumerate(running_suggestions):n_history=len(suggestion['reward'])ifn_history==CONFIDENCE_N_ITERATION:#当前配置已完整验证 , 无需早停print('fullobservation.pass',i)continueifn_historynotinself.hps['prune_iters']:#当前配置不处于需要早停的阶段print('n_history:%dnotinprune_iters:%s.pass%d.'%(n_history,self.hps['prune_iters'],i))continuerank=bracket[n_history].index(suggestion['reward'][-1]['value'])total_cnt=len(bracket[n_history])#判断当前配置性能是否处于前1/eta , 否则早停ifrank/total_cnt>=1/self.hps['prune_eta']:print('n_history:%d,rank:%d/%d,eta:1/%s.PRUNE%d!'%(n_history,rank,total_cnt,self.hps['prune_eta'],i))stop_list[i]=Trueelse:print('n_history:%d,rank:%d/%d,eta:1/%s.continue%d.'%(n_history,rank,total_cnt,self.hps['prune_eta'],i))returnstop_list