输入到 BartForConditionalGeneration 类的各个参数是什么意思?
decoder_input_ids
是必须要以 <s> 开头的。这个参数可以自己生成然后传入到模型中,也可以交由代码自己生成(一般会根据label右移一位再补0)
-
case 1: 直接传入

此时的decoder_input_ids 如下:

-
case 2: 由labels 右移一位生成


decoder_start_token_id的值为2(一般需要指定),对应的token是</s>。最后返回shifted_input_ids作为decoder_input_ids

-
需要注意
labels的起始是没有<s>token的。
细心的读者会发现这两种方法得到的 decoder_input_ids 是不同的(就是因为这个 decoder_start_token_id 值的不同)。
为啥下面两种方法计算的loss值不相同?
就是因为上述说的 这个 decoder_input_ids 值的原因,以及add_special_tokens参数的原因。





















