阅读谷歌工作:优化器SAM的论文和代码

根据论文《When Vision Transformers Outperform ResNets》,可以知道,优化器SAM对过拟合抑制,对泛化性的提交很有帮助。很让人有所启发。对我以往的工作有很多rethinking。所以阅读了SAM的论文和代码。

论文:https://openreview.net/pdf?id=6Tm1mposlrM。发表于iclr 2021。短短时间引用26次 代码:https://github.com/google-research/sam/blob/dae9904c4cf3a57a304f7b04cecffe371679c702/sam_jax/training_utils/flax_training.py#L466

论文里的核心部分是:

先前向->反向,求一个\epsilon。计算公式为:



假装以w+\epsilon为参数,再前向->反向,求一个\delta w,计算公式如下:



回到w点,\delta w更新w即可。

算法框架为:



主要实现的代码如下:

第一次前向+反向:
(_, (inner_state, _)), grad = jax.value_and_grad( lambda m: forward_and_loss(m, true_gradient=True), has_aux=True)(model)
noised_model = jax.tree_multimap(lambda a, b: a + rho * b, model, grad)
第二次前向+反向:

(_, (_, logits)), grad = jax.value_and_grad( forward_and_loss, has_aux=True)(noised_model)
#clip一下梯度且更新
grad = clip_by_global_norm(grad) new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
# Compute some norms to log on tensorboard.
gradient_norm = jnp.sqrt(sum( [jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(grad)]))
param_norm = jnp.sqrt(sum( [jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves( new_optimizer.target)]))



-----------------------------------

大家好,我来自fast lab。我开始不定时公开写作。这些写作主要通过两个渠道公布:一是FAST LAB官方网站;一是印象识堂(微信可访问)。欢迎大家订阅。谢谢!

FAST Lab的官方网址为:https://wanggrun.github.io/projects/fast

除此外,还可以关注我的小伙伴王广润:https://wanggrun.github.io/

王广聪: https://wanggcong.github.io/

石阳:https://www.linkedin.com/in/%E9%98%B3-%E7%9F%B3-381b521a4/

有时候这些网站打不开,请耐心多点几次。

多谢大家关注。

返回博客目录Return to all Blogs
返回主页Return to homepage