Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[🐛BUG] SASREC的实现和官方代码似乎存在一定的差异? #2135

Open
chungewoqu2 opened this issue Jan 20, 2025 · 0 comments
Open
Labels
bug Something isn't working

Comments

@chungewoqu2
Copy link

描述这个 bug
您好,我在运行recbole的sasrec模型时发现和预期结果有所出入。经过比对SASRec官方GITHUB页面提供的代码,发现recbole的SASREC似乎和SASREC官方的代码有一定出入。具体差异见下方的 如何复现 部分及 屏幕截图 部分

如何复现
复现这个 bug 的步骤:

  1. 您引入的额外 yaml 文件——无需额外引入,任意即可
  2. 您的代码

SASREC官方代码(model.py line 17):

self.seq, item_emb_table = embedding(self.input_seq,
                                                 vocab_size=itemnum + 1,
                                                 num_units=args.hidden_units,
                                                 zero_pad=True,
                                                 scale=True,
                                                 l2_reg=args.l2_emb,
                                                 scope="input_embeddings",
                                                 with_t=True,
                                                 reuse=reuse
                                                 )

其中embedding的代码如下(modules.py line 51):

def embedding(inputs, 
              vocab_size, 
              num_units, 
              zero_pad=True, 
              scale=True,
              l2_reg=0.0,
              scope="embedding", 
              with_t=False,
              reuse=None):
    with tf.variable_scope(scope, reuse=reuse):
        lookup_table = tf.get_variable('lookup_table',
                                       dtype=tf.float32,
                                       shape=[vocab_size, num_units],
                                       #initializer=tf.contrib.layers.xavier_initializer(),
                                       regularizer=tf.contrib.layers.l2_regularizer(l2_reg))
        if zero_pad:
            lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
                                      lookup_table[1:, :]), 0)
        outputs = tf.nn.embedding_lookup(lookup_table, inputs)
        if scale:
            outputs = outputs * (num_units ** 0.5) 
    if with_t: return outputs,lookup_table
    else: return outputs

可以看到,SASREC在处理item emb时会对item emb进行scale,即outputs = outputs * (num_units ** 0.5) 这一行,再和position embedding相加。
但是在RECBOLE中,似乎缺少这一部分内容.

见recbole/model/sequential_recommender/sasrec.py,在forward部分,第102行至105行为:

        item_emb = self.item_embedding(item_seq)
        input_emb = item_emb + position_embedding
        input_emb = self.LayerNorm(input_emb)
        input_emb = self.dropout(input_emb)

这部分缺少了上面提到的scale内容。

同时,在后续的multihead attention中,同样缺少了对query的normalize部分(SASREC的tensorflow实现中,key会对input进行normalize)(model.py line 54)。

          self.seq = multihead_attention(queries=normalize(self.seq),
                                         keys=self.seq,
                                         num_units=args.hidden_units,
                                         num_heads=args.num_heads,
                                         dropout_rate=args.dropout_rate,
                                         is_training=self.is_training,
                                         causality=True,
                                         scope="self_attention")

这部分由于recbole中涉及到较多引用,此处暂不进行recbole代码的详细展示。

预期
预期:SASREC官方代码得到的结果应当和Recbole中得到的结果一致。
但是我在修改SASREC官方页面提到的三方pytorch代码后得到的结果(此处修改主要是针对三方pytorch代码中使用了BCE loss,并在训练时未使用全量负样本),和recbole的结果不一致。pytorch的结果和recbole的结果分别如屏幕截图中的图1和图2.

通过将recbole代码修改至和tensorflow版代码逻辑一致(即修改了multihead attention的差异,以及item emb的scale差异),得到了新的结果,和pytorch版的差异较小,在可接受范围内(因为某些超参上有所出入)。见屏幕截图中的图3.

屏幕截图
图1——SASREC pytorch版结果:

Image

图2——Recbole 本身结果:

Image

图3——经过修改后,Recbole本身结果:

Image

图4——两次运行recbole的公用超参设置:

Image

链接
代码暂略。后续如有需要我修改后的recbole及sasrec pytorch版代码,会进行上传。

实验环境(请补全下列信息):

  • 操作系统: Linux & Windows
  • RecBole 版本:1.1.0
  • Python 版本:3.10
  • PyTorch 版本:2.0.1+cu118
  • cudatoolkit 版本:11.6
@chungewoqu2 chungewoqu2 added the bug Something isn't working label Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant