图像自回归生成(Auto-regressive image generation)实战学习(六)
相关项目下载链接基于 Transformer 的自回归图像生成模型完整的链路是1、先用 Patch AutoEncoder BSQ 量化器把原始图像压缩为离散的 token 序列每个整数 token 对应原图的一个小图像 patch2、训练这个自回归 Transformer 模型学习 token 之间的空间共现规律3、通过generate方法生成全新的token序列4、用 BSQ 量化器把 token 序列解码回可保存的 png 图片。本节内容主要介绍如何通过generate方法生成全新的 token 序列。定义主模型主模型对应的代码在autoregressive.py在上一节中我们并没有定义generation方法的具体实现本节对其逻辑进行补全。为了兼容补全后的generation方法还需要对前向传播算法进行维度匹配调整。补全generation方法torch.no_grad()defgenerate(self,B:int1,h:int20,w:int30,deviceNone)-torch.Tensor:ifdeviceisNone:deviceself.embedding.weight.device gen_seqtorch.zeros((B,h,w),dtypetorch.long,devicedevice)total_lenh*wforkinrange(total_len):# 把 1D 索引 k 转回 2D 坐标 (i,j)ik//w# 行号jk%w# 列号logits,_self.forward(gen_seq)next_token_logitslogits[:,i,j,:]/0.9next_tokentorch.multinomial(F.softmax(next_token_logits,dim-1),num_samples1).squeeze(1)gen_seq[:,i,j]next_tokenreturngen_seq调整前向传播算法defforward(self,x:torch.Tensor)-tuple[torch.Tensor,dict[str,torch.Tensor]]:# 对训练和推理进行维度匹配ifx.dim()4:xx.squeeze(1)B,h,wx.shape Lh*w# 展平成序列x_flatx.reshape(B,L)# 嵌入 位置编码token_embself.embedding(x_flat)pos_idxtorch.arange(L,devicex.device)pos_embself.pos_emb(pos_idx)x_embtoken_embpos_emb# 自回归右移关键x_embF.pad(x_emb,(0,0,1,0))[:,:-1]# 因果掩码maskself._generate_causal_mask(L,x.device)trans_outself.transformer(x_emb,maskmask)# 输出logitsself.fc_out(trans_out)logits_2dlogits.reshape(B,h,w,self.n_tokens)returnlogits_2d,{}模块测评下面进行图像生成的功能测试mkdir test python-m homework.generation checkpoints/BSQPatchAutoEncoder.pth checkpoints/AutoregressiveModel.pth8test所得的解码后的PNG图片如下所示将代码打包为压缩文件python bundle.py homework20260412进行评分自测python-m grader20260412.zip最终测试得分如下可选的优化方向更优的量化器更小的图像块、更高的码率缩小 patch 尺寸你当前patch_size5可改为 3或2。更小的图像块意味着更细的图像粒度大幅减少单 patch 的信息损失生成的图像细节更丰富、块效应更少。提升码本码率你当前codebook_bits10仅 1024 个码本可提升到 12或14。码本容量越大量化的精度越高单个 token 能表达的图像信息越丰富生成的画面连贯性更强。辅助优化提升 Patch AutoEncoder 的重建能力比如增加卷积层、调整 latent_dim降低量化器的基础重建 MSE从根源上提升 token 的质量。更大的 Transformer 模型参数量增加 Transformer 深度把 Encoder 层数从 2 层提升到 4/6 层更深的网络能拟合更复杂的 token 序列分布。提升隐层维度把d_latent从 128 提升到 256/512注意nhead必须能整除d_latent更高的维度能承载更丰富的图像语义信息。更优的训练策略增加训练轮次可提升到 10/20/50 轮配合学习率衰减策略让模型充分学习 patch 的空间分布和长距离依赖关系。优化学习率策略在 AdamW 优化器中加入「warmup 预热 余弦退火衰减」避免训练初期梯度爆炸同时让模型在训练后期更精细地拟合分布大幅提升生成效果。
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2511838.html
如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!