阅读代码:OpenAI的DALL.E源代码

DALL.E是一个极其重要的工作。其博客为:https://openai.com/blog/dall-e/
其论文为:https://arxiv.org/abs/2102.12092
其代码在:https://github.com/openai/DALL-E

DALL.E的tokenize可以视为一个极其简单的AE:

首先进行编码,使用编码网络enc()。这个网络就是若干个卷积层,大约经历了8倍下采样。注意,编码后的输出,并非变成一个向量,而是一组大约为cx32x32的feature maps。
z_logits = enc(x)

然后,使用一个argmax()函数,把上述编码的cx32x32的feature maps变成一个1x32x32的map了。
z = torch.argmax(z_logits, axis=1)

然后把上述1x32x32的feature map又用one-hot函数变回cx32x32的feature map。
z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()

上述两步是只在测试阶段,而在训练阶段怎么做呢?两步并做一步,直接用一个GumbelSoftmax就可以得到这个cx32x32的map啦。

然后解码回原图
x_stats = dec(z).float()

原本直接全名用sigmoid就好啦。这里加了一个带阀值的unmap_pixels()函数来调一下。

x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
x_rec = T.ToPILImage(mode='RGB')(x_rec[0])

这个unmap_pixels()函数很直观,定义为:torch.clamp((x - 0.1) / (1 - 2 * 0.1), 0, 1)

display_markdown('Reconstructed image:')
display(x_rec)



-----------------------------------

大家好,我来自fast lab。我开始不定时公开写作。这些写作主要通过两个渠道公布:一是FAST LAB官方网站;一是印象识堂(微信可访问)。欢迎大家订阅。谢谢!

FAST Lab的官方网址为:https://wanggrun.github.io/projects/fast

除此外,还可以关注我的小伙伴王广润:https://wanggrun.github.io/

王广聪: https://wanggcong.github.io/

石阳:https://www.linkedin.com/in/%E9%98%B3-%E7%9F%B3-381b521a4/

有时候这些网站打不开,请耐心多点几次。

多谢大家关注。

返回博客目录Return to all Blogs
返回主页Return to homepage