这篇论文的标题是《Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction》,作者是Yiqun Lin, Hualiang Wang, Jixiang Chen和Xiaomeng Li,来自香港科技大学以及HKUST深圳-香港协同创新研究院。
这篇论文主要探讨了一种新的锥束计算机断层扫描(CBCT)重建框架,称为DIF-Gaussian,旨在通过使用更少的投影来减少辐射剂量,同时提高重建图像的质量。
给的代码只是个框架,强行复现花费时间而且以我水平容易误人子弟,就简单的对照论文理解一下,大家有兴趣可以一起讨论
项目地址:
GitHub - xmed-lab/DIF-Gaussian: MICCAI 2024: Learning 3D Gaussians for Extremely Sparse-View Cone-Beam CT Reconstruction
数据预处理地址
https://github.com/xmed-lab/C2RV-CBCT/tree/main/data
1、 下载代码和数据预处理方法,数据放到data中
2、发现代码是不完整的,因此边补充边写
train.py
使其与不同版本的DDP兼容
    if args.dist:
        args.local_rank = int(os.environ["LOCAL_RANK"]) # Make it compatible with different versions of DDP
        torch.distributed.init_process_group(backend="nccl")
        torch.cuda.set_device(args.local_rank) 
加载cfg,项目只给出了一个default.yaml,复制一个改个名字
    cfg = load_config(args.cfg_path)
    if args.local_rank == 0:
        print(args)
        print(cfg)
        # save config
        save_dir = f'./logs/{args.name}'
        os.makedirs(save_dir, exist_ok=True)
        if os.path.exists(os.path.join(save_dir, 'config.yaml')):
            time_str = datetime.now().strftime('%d-%m-%Y_%H-%M-%S')
            shutil.copyfile(
                os.path.join(save_dir, 'config.yaml'), 
                os.path.join(save_dir, f'config_{time_str}.yaml')
            )
        shutil.copyfile(args.cfg_path, os.path.join(save_dir, 'config.yaml')) 
初始化训练数据集/加载器
    train_dst = CBCT_dataset_gs(
        dst_name=args.dst_name,
        cfg=cfg.dataset,
        split='train', 
        num_views=args.num_views, 
        npoint=args.num_points,
        out_res_scale=args.out_res_scale,
        random_views=args.random_views
    ) 
关键在于并没有数据,因此还得自己想办法
dataset:
  root_dir: ../../datasets
  gs_res: 12 # the resolution of GS points (12^3 points in total) 
进去看看数据集如何构建
class CBCT_dataset_gs(CBCT_dataset):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        gs_res = self.cfg.gs_res
        points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res
        self.points_gs = points_gs.reshape(3, -1).transpose(1, 0) # ~[0, 1]
    def __getitem__(self, index):
        data_dict = super().__getitem__(index)
        # projections of GS points (initial center xyz)
        points_gs = deepcopy(self.points_gs)
        points_gs_proj = self.project_points(points_gs, data_dict['angles'])
        data_dict.update({
            'points_gs': points_gs,          # [K, 3]
            'points_gs_proj': points_gs_proj # [M, K, 2]
        })
        return data_dict
 
np.mgrid是NumPy库中的一个函数,它返回一个由给定尺寸的数组创建的多维网格。这段代码points_gs = np.mgrid[:gs_res, :gs_res, :gs_res] / gs_res创建了一个3D网格,并且将这个网格的每个点归一化到[0, 1]区间。结果
points_gs是一个4D数组,其形状为(gs_res, gs_res, gs_res, 3),其中最后一个维度包含每个网格点的x、y、z坐标。
看getitem
points_gs_proj = self.project_points(points_gs, data_dict['angles'])
 
points_gs是一个3D网格的点,通常是用于表示3D空间中的一个体素化网格或者用于定义3D空间中的高斯分布的中心点。而points_gs_proj则是这些点在2D平面上的投影。
代码是不全的,后期再看看会不会更新
看LUNA16数据预处理的config 内有dataset的参数,其中的angle 为180
get返回一个3d 高斯网格,一个2d的投影
loader如下
    train_sampler = None
    if args.dist:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dst)
    train_loader = DataLoader(
        train_dst, 
        batch_size=args.batch_size, 
        sampler=train_sampler, 
        shuffle=(train_sampler is None),
        num_workers=0, # args.num_workers,
        pin_memory=True,
        worker_init_fn=worker_init_fn
    )
    # -- initialize evaluation dataset/loader
    eval_loader = DataLoader(
        CBCT_dataset_gs(
            dst_name=args.dst_name,
            cfg=cfg.dataset,
            split='eval',
            num_views=args.num_views,
            out_res_scale=0.5, # low-res for faster evaluation,
        ), 
        batch_size=1, 
        shuffle=False,
        pin_memory=True
    ) 
加载模型,模型放到后面看
    # -- initialize model
    model = DIF_Gaussian(cfg.model)
    if args.resume:
        print(f'resume model from epoch {args.resume}')
        ckpt = torch.load(
            os.path.join(f'./logs/{args.name}/ep_{args.resume}.pth'),
            map_location=torch.device('cpu')
        )
        model.load_state_dict(ckpt)
    
    model = model.cuda()
    if args.dist:
        model = nn.parallel.DistributedDataParallel(
            model, 
            find_unused_parameters=False,
            device_ids=[args.local_rank]
        ) 
优化器和优化器规划,损失只有一个MSE
    # -- initialize optimizer, lr scheduler, and loss function
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=args.lr, 
        momentum=0.98, 
        weight_decay=args.weight_decay
    )
    lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=1, 
        gamma=np.power(args.lr_decay, 1 / args.epoch)
    )
    loss_func = nn.MSELoss() 
开始训练
    # -- training starts
    for epoch in range(start_epoch, args.epoch + 1):
        if args.dist:
            train_loader.sampler.set_epoch(epoch)
        loss_list = []
        model.train()
        optimizer.zero_grad() 
一个epoch,外部看没有花里胡哨的损失,一个损失做到底
        for k, item in enumerate(train_loader):
            item = convert_cuda(item)
            pred = model(item)
            loss = loss_func(pred['points_pred'], item['points_gt'])
            loss_list.append(loss.item())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad() 
评估和优化
        if args.local_rank == 0:
            if epoch % 10 == 0:
                loss = np.mean(loss_list)
                print('epoch: {}, loss: {:.4}'.format(epoch, loss))
            
            if epoch % 100 == 0 or (epoch >= (args.epoch - 100) and epoch % 10 == 0):
                if isinstance(model, torch.nn.DataParallel) or isinstance(model, torch.nn.parallel.DistributedDataParallel):
                    model_state = model.module.state_dict()
                else:
                    model_state = model.state_dict()
                torch.save(
                    model_state,
                    os.path.join(save_dir, f'ep_{epoch}.pth')
                )
            if epoch % 50 == 0 or (epoch >= (args.epoch - 100) and epoch % 20 == 0):
                metrics, _ = eval_one_epoch(
                    model, 
                    eval_loader, 
                    args.eval_npoint,
                    ignore_msg=True,
                )
                msg = f' --- epoch {epoch}'
                for dst_name in metrics.keys():
                    msg += f', {dst_name}'
                    met = metrics[dst_name]
                    for key, val in met.items():
                        msg += ', {}: {:.4}'.format(key, val)
                print(msg)
        
        if lr_scheduler is not None:
            lr_scheduler.step()
 
model .py
看看初始化定义了什么
class DIF_Gaussian(Recon_base):
    def __init__(self, cfg):
        super().__init__(cfg)
    def init(self):
        self.init_encoder()
        
        # gaussians-related modules
        mid_ch = self.image_encoder.out_ch
        ds_ch = self.image_encoder.ds_ch
        self.gs_feats_mlp = MLP_1d([ds_ch, ds_ch // 4, mid_ch], use_bn=True, last_bn=True, last_act=False)
        self.gs_params_mlp = MLP_1d([ds_ch, ds_ch // 4, 3 + 4 + 3], use_bn=True, last_bn=False, last_act=False) # 3d: offsets, 4d: rotation, 3d: scaling
        self.gs_act = nn.LeakyReLU(inplace=True)
        self.init_decoder(mid_ch * 2)
        self.registered_point_keys = ['points', 'points_proj'] 
初始化编码器:self.init_encoder()
定义高斯特征和参数mlp:self.gs_feats_mlp;self.gs_params_mlp,选用线性激活self.gs_act
初始化解码器
虽然没写完全,但是不难想象编码器和解码器的都是unet里面的
看向里面的点forward ,获取点的预测值
1多视图像素对齐功能+最大池
2gaussian-based插值函数
3逐点地预测
class PointDecoder(nn.Module):
    def __init__(self, channels, residual=True, use_bn=True):
        super().__init__()
        self.residual = residual
        self.mlps = nn.ModuleList()
        for i in range(len(channels) - 1):
            modules = []
            if i == 0 or not self.residual:
                modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
            else:
                modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))
            if i != len(channels) - 1:
                if use_bn:
                    modules.append(nn.BatchNorm1d(channels[i + 1]))
                modules.append(nn.LeakyReLU(inplace=True))
            self.mlps.append(nn.Sequential(*modules))
    def forward(self, x):
        x_ = x
        for i, m in enumerate(self.mlps):
            if i != 0 and self.residual:
                x_ = torch.cat([x_, x], dim=1)
            x_ = m(x_)
        return x_
 
query_view_feats:应该是对应这个公式
def query_view_feats(view_feats, points_proj, fusion='max'):
    # view_feats: [B, M, C, H, W]
    # points_proj: [B, M, N, 2]
    # output: [B, C, N, M]
    n_view = view_feats.shape[1]
    p_feats_list = []
    for i in range(n_view):
        feat = view_feats[:, i, ...] # B, C, W, H
        p = points_proj[:, i, ...] # B, N, 2
        p_feats = index_2d(feat, p) # B, C, N
        p_feats_list.append(p_feats)
    p_feats = torch.stack(p_feats_list, dim=-1) # B, C, N, M
    if fusion == 'max':
        p_feats = F.max_pool2d(p_feats, (1, p_feats.shape[-1]))
        p_feats = p_feats.squeeze(-1) # [B, C, K]
    elif fusion is not None:
        raise NotImplementedError
    return p_feats
 

插值如下
![]()
下面有一个点decoder
class PointDecoder(nn.Module):
    def __init__(self, channels, residual=True, use_bn=True):
        super().__init__()
        self.residual = residual
        self.mlps = nn.ModuleList()
        for i in range(len(channels) - 1):
            modules = []
            if i == 0 or not self.residual:
                modules.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=1))
            else:
                modules.append(nn.Conv1d(channels[i] + channels[0], channels[i + 1], kernel_size=1))
            if i != len(channels) - 1:
                if use_bn:
                    modules.append(nn.BatchNorm1d(channels[i + 1]))
                modules.append(nn.LeakyReLU(inplace=True))
            self.mlps.append(nn.Sequential(*modules))
    def forward(self, x):
        x_ = x
        for i, m in enumerate(self.mlps):
            if i != 0 and self.residual:
                x_ = torch.cat([x_, x], dim=1)
            x_ = m(x_)
        return x_
 
用了残差网络进行预测


















