Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saint, Saint_custom(Decoder에 interaction만 제공) #79

Open
MignonDeveloper opened this issue Jun 7, 2021 · 0 comments
Open

Saint, Saint_custom(Decoder에 interaction만 제공) #79

MignonDeveloper opened this issue Jun 7, 2021 · 0 comments
Labels
enhancement New feature or request

Comments

@MignonDeveloper
Copy link
Member

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.scale = nn.Parameter(torch.ones(1))

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(
            0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.scale * self.pe[:x.size(0), :]
        return self.dropout(x)

class Saint(nn.Module):
    
    def __init__(self, args):
        super(Saint, self).__init__()
        self.args = args
        self.device = args.device

        self.hidden_dim = self.args.hidden_dim
        self.dropout = self.args.drop_out
        
        ### Embedding 
        # ENCODER embedding
        # categorical features
        # ========================== nn.Embedding에 들어갈 self.embedding_dims의 개수를 수정해주세요 ===============================
        self.embedding_dims = [self.hidden_dim // 3] * self.args.n_cates # cate feature 수와 동일하게 해주세요
        assert len(self.embedding_dims) == self.args.n_cates
        self.enc_embedding_dims = sum(self.embedding_dims[:-1])
        self.dec_embedding_dims = sum(self.embedding_dims)
        # =========================================================================================================================
        
        # categorical feature embedding for both encoder & decoder
        self.cate_emb = nn.ModuleList([nn.Embedding(x, self.embedding_dims[idx]) for idx, x in enumerate(self.args.cate_embs)])

        # ENCODER embedding
        # encoder combination projection
        self.enc_cate_comb_proj = nn.Sequential(
            nn.Linear(self.enc_embedding_dims, self.hidden_dim // 2),
            nn.LayerNorm(self.hidden_dim // 2)
        )

        # continuous features
        self.enc_cont_comb_proj = nn.Sequential(
            nn.Linear(self.args.n_conts, self.hidden_dim // 2),
            nn.LayerNorm(self.hidden_dim // 2)
        )


        # DECODER embedding
        # decoder combination projection
        self.dec_cate_comb_proj = nn.Sequential(
            nn.Linear(self.dec_embedding_dims, self.hidden_dim // 2),
            nn.LayerNorm(self.hidden_dim // 2)
        )

        # continuous features
        self.dec_cont_comb_proj = nn.Sequential(
            nn.Linear(self.args.n_conts, self.hidden_dim // 2),
            nn.LayerNorm(self.hidden_dim // 2)
        )


        # Positional encoding
        self.pos_encoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len)
        self.pos_decoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len)
        

        self.transformer = nn.Transformer(
            d_model=self.hidden_dim, 
            nhead=self.args.n_heads,
            num_encoder_layers=self.args.n_layers, 
            num_decoder_layers=self.args.n_layers, 
            dim_feedforward=self.hidden_dim, 
            dropout=self.dropout, 
            activation='relu')

        self.fc = nn.Linear(self.hidden_dim, 1)
        self.activation = nn.Sigmoid()

        self.enc_mask = None
        self.dec_mask = None
        self.enc_dec_mask = None
    

    def get_mask(self, seq_len):
        mask = torch.from_numpy(np.triu(np.ones((seq_len, seq_len)), k=1))

        return mask.masked_fill(mask==1, float('-inf'))


    def forward(self, input):

        categorical, continuous, mask, __ = input
        batch_size = categorical[0].size(0)
        seq_len = categorical[0].size(1)

        # Encoder Embedding
        # categorical
        enc_x_cat = [emb_layer(categorical[i]) for i, emb_layer in enumerate(self.cate_emb[:-1])]
        enc_x_cat = torch.cat(enc_x_cat, -1)
        enc_x_cat = self.enc_cate_comb_proj(enc_x_cat)
        
        # continuous
        enc_x_cont = torch.cat([c.unsqueeze(-1) for c in continuous], -1).to(torch.float32)
        enc_x_cont = self.enc_cont_comb_proj(enc_x_cont)

        # concat
        embed_enc = torch.cat([enc_x_cat, enc_x_cont], -1)


        # DECODER Embedding
        # categorical
        dec_x_cat = [emb_layer(categorical[i]) for i, emb_layer in enumerate(self.cate_emb)]
        dec_x_cat = torch.cat(dec_x_cat, -1)
        dec_x_cat = self.dec_cate_comb_proj(dec_x_cat)

        # continuous
        dec_x_cont = torch.cat([c.unsqueeze(-1) for c in continuous], -1).to(torch.float32)
        dec_x_cont = self.dec_cont_comb_proj(dec_x_cont)

        # concat
        embed_dec = torch.cat([dec_x_cat, dec_x_cont], -1)
        
        
        # ATTENTION MASK 생성
        # encoder하고 decoder의 mask는 가로 세로 길이가 모두 동일하여
        # 사실 이렇게 3개로 나눌 필요가 없다
        if self.enc_mask is None or self.enc_mask.size(0) != seq_len:
            self.enc_mask = self.get_mask(seq_len).to(self.device)
            
        if self.dec_mask is None or self.dec_mask.size(0) != seq_len:
            self.dec_mask = self.get_mask(seq_len).to(self.device)
            
        if self.enc_dec_mask is None or self.enc_dec_mask.size(0) != seq_len:
            self.enc_dec_mask = self.get_mask(seq_len).to(self.device)
            
  
        embed_enc = embed_enc.permute(1, 0, 2)
        embed_dec = embed_dec.permute(1, 0, 2)
        
        # Positional encoding
        embed_enc = self.pos_encoder(embed_enc)
        embed_dec = self.pos_decoder(embed_dec)
        
        out = self.transformer(embed_enc, embed_dec,
                               src_mask=self.enc_mask,
                               tgt_mask=self.dec_mask,
                               memory_mask=self.enc_dec_mask)

        out = out.permute(1, 0, 2)
        out = out.contiguous().view(batch_size, -1, self.hidden_dim)
        out = self.fc(out)

        preds = self.activation(out).view(batch_size, -1)

        return preds


class Saint_custom(nn.Module):
    
    def __init__(self, args):
        super(Saint_custom, self).__init__()
        self.args = args
        self.device = args.device

        self.hidden_dim = self.args.hidden_dim
        # self.dropout = self.args.dropout
        self.dropout = 0.
        
        ### Embedding 
        # ENCODER embedding
        # categorical features
        # ========================== nn.Embedding에 들어갈 self.embedding_dims의 개수를 수정해주세요 ===============================
        self.embedding_dims = [self.hidden_dim // 3] * self.args.n_cates # cate feature 수와 동일하게 해주세요
        assert len(self.embedding_dims) == self.args.n_cates
        self.enc_embedding_dims = sum(self.embedding_dims[:-1])
        self.dec_embedding_dims = self.embedding_dims[-1]
        # =========================================================================================================================
        
        # categorical feature embedding for both encoder & decoder
        self.cate_emb = nn.ModuleList([nn.Embedding(x, self.embedding_dims[idx]) for idx, x in enumerate(self.args.cate_embs)])

        # ENCODER embedding
        # encoder combination projection
        self.enc_cate_comb_proj = nn.Sequential(
            nn.Linear(self.enc_embedding_dims, self.hidden_dim // 2),
            nn.LayerNorm(self.hidden_dim // 2)
        )

        # continuous features
        self.enc_cont_comb_proj = nn.Sequential(
            nn.Linear(self.args.n_conts, self.hidden_dim // 2),
            nn.LayerNorm(self.hidden_dim // 2)
        )


        # DECODER embedding
        # decoder combination projection
        self.dec_cate_comb_proj = nn.Sequential(
            nn.Linear(self.dec_embedding_dims, self.hidden_dim),
            nn.LayerNorm(self.hidden_dim)
        )


        # Positional encoding
        self.pos_encoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len)
        self.pos_decoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len)
        

        self.transformer = nn.Transformer(
            d_model=self.hidden_dim, 
            nhead=self.args.n_heads,
            num_encoder_layers=self.args.n_layers, 
            num_decoder_layers=self.args.n_layers, 
            dim_feedforward=self.hidden_dim, 
            dropout=self.dropout, 
            activation='relu')

        self.fc = nn.Linear(self.hidden_dim, 1)
        self.activation = nn.Sigmoid()

        self.enc_mask = None
        self.dec_mask = None
        self.enc_dec_mask = None
    

    def get_mask(self, seq_len):
        mask = torch.from_numpy(np.triu(np.ones((seq_len, seq_len)), k=1))

        return mask.masked_fill(mask==1, float('-inf'))


    def forward(self, input):

        categorical, continuous, mask, __ = input
        batch_size = categorical[0].size(0)
        seq_len = categorical[0].size(1)

        # Encoder Embedding (Exercise)
        # categorical
        enc_x_cat = [emb_layer(categorical[i]) for i, emb_layer in enumerate(self.cate_emb[:-1])]
        enc_x_cat = torch.cat(enc_x_cat, -1)
        enc_x_cat = self.enc_cate_comb_proj(enc_x_cat)
        
        # continuous
        enc_x_cont = torch.cat([c.unsqueeze(-1) for c in continuous], -1).to(torch.float32)
        enc_x_cont = self.enc_cont_comb_proj(enc_x_cont)

        # concat
        embed_enc = torch.cat([enc_x_cat, enc_x_cont], -1)


        # DECODER Embedding (Response = interaction)
        dec_interaction = self.cate_emb[-1](categorical[-1])
        dec_interaction = self.dec_cate_comb_proj(dec_interaction)

        embed_dec = dec_interaction
        
        
        # ATTENTION MASK 생성
        # encoder하고 decoder의 mask는 가로 세로 길이가 모두 동일하여
        # 사실 이렇게 3개로 나눌 필요가 없다
        if self.enc_mask is None or self.enc_mask.size(0) != seq_len:
            self.enc_mask = self.get_mask(seq_len).to(self.device)
            
        if self.dec_mask is None or self.dec_mask.size(0) != seq_len:
            self.dec_mask = self.get_mask(seq_len).to(self.device)
            
        if self.enc_dec_mask is None or self.enc_dec_mask.size(0) != seq_len:
            self.enc_dec_mask = self.get_mask(seq_len).to(self.device)
            
  
        embed_enc = embed_enc.permute(1, 0, 2)
        embed_dec = embed_dec.permute(1, 0, 2)
        
        # Positional encoding
        embed_enc = self.pos_encoder(embed_enc)
        embed_dec = self.pos_decoder(embed_dec)
        
        out = self.transformer(embed_enc, embed_dec,
                               src_mask=self.enc_mask,
                               tgt_mask=self.dec_mask,
                               memory_mask=self.enc_dec_mask)

        out = out.permute(1, 0, 2)
        out = out.contiguous().view(batch_size, -1, self.hidden_dim)
        out = self.fc(out)

        preds = self.activation(out).view(batch_size, -1)

        return preds
@MignonDeveloper MignonDeveloper added the enhancement New feature or request label Jun 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant