1 DeepMove
1.1 构造函数

1.2 初始化权重

1.3 forward

1.4 predict
def predict(self, batch):
score = self.forward(batch)
if self.evaluate_method == 'sample':
# build pos_neg_inedx
pos_neg_index = torch.cat((batch['target'].unsqueeze(1), batch['neg_loc']), dim=1)
score = torch.gather(score, 1, pos_neg_index)
return score
- 如果评估方法是
'sample',则执行以下步骤:- 构建正负样本索引 (
pos_neg_index): 使用torch.cat函数将批次中的目标位置 (batch['target']) 与负样本位置 (batch['neg_loc']) 结合。这里,目标位置通过unsqueeze(1)方法添加一个维度以匹配负样本位置的维度,使其成为batch_size x (1 + num_negatives)的形状。 - 选择得分: 使用
torch.gather方法根据pos_neg_index从得分张量中选择相关的得分。这一步骤的目的是从模型输出的所有可能位置的得分中,仅提取出与正样本和负样本对应的得分。
- 构建正负样本索引 (
1.5 calculate_loss
def calculate_loss(self, batch):
criterion = nn.NLLLoss().to(self.device)
scores = self.forward(batch)
return criterion(scores, batch['target'])
调用 criterion(scores, batch['target']) 来计算模型输出得分和批次中的目标标签 (batch['target']) 之间的损失。





![Siemens-NXUG二次开发-创建块(长方体)特征、圆柱特征、圆锥或圆台特征、球体特征、管道特征[Python UF][20240504]](https://img-blog.csdnimg.cn/direct/37a83b7c8228408e9a4408e9d64428bf.png#pic_center)













