We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=1000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) self.scale = nn.Parameter(torch.ones(1)) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange( 0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.scale * self.pe[:x.size(0), :] return self.dropout(x) class Saint(nn.Module): def __init__(self, args): super(Saint, self).__init__() self.args = args self.device = args.device self.hidden_dim = self.args.hidden_dim self.dropout = self.args.drop_out ### Embedding # ENCODER embedding # categorical features # ========================== nn.Embedding에 들어갈 self.embedding_dims의 개수를 수정해주세요 =============================== self.embedding_dims = [self.hidden_dim // 3] * self.args.n_cates # cate feature 수와 동일하게 해주세요 assert len(self.embedding_dims) == self.args.n_cates self.enc_embedding_dims = sum(self.embedding_dims[:-1]) self.dec_embedding_dims = sum(self.embedding_dims) # ========================================================================================================================= # categorical feature embedding for both encoder & decoder self.cate_emb = nn.ModuleList([nn.Embedding(x, self.embedding_dims[idx]) for idx, x in enumerate(self.args.cate_embs)]) # ENCODER embedding # encoder combination projection self.enc_cate_comb_proj = nn.Sequential( nn.Linear(self.enc_embedding_dims, self.hidden_dim // 2), nn.LayerNorm(self.hidden_dim // 2) ) # continuous features self.enc_cont_comb_proj = nn.Sequential( nn.Linear(self.args.n_conts, self.hidden_dim // 2), nn.LayerNorm(self.hidden_dim // 2) ) # DECODER embedding # decoder combination projection self.dec_cate_comb_proj = nn.Sequential( nn.Linear(self.dec_embedding_dims, self.hidden_dim // 2), nn.LayerNorm(self.hidden_dim // 2) ) # continuous features self.dec_cont_comb_proj = nn.Sequential( nn.Linear(self.args.n_conts, self.hidden_dim // 2), nn.LayerNorm(self.hidden_dim // 2) ) # Positional encoding self.pos_encoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len) self.pos_decoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len) self.transformer = nn.Transformer( d_model=self.hidden_dim, nhead=self.args.n_heads, num_encoder_layers=self.args.n_layers, num_decoder_layers=self.args.n_layers, dim_feedforward=self.hidden_dim, dropout=self.dropout, activation='relu') self.fc = nn.Linear(self.hidden_dim, 1) self.activation = nn.Sigmoid() self.enc_mask = None self.dec_mask = None self.enc_dec_mask = None def get_mask(self, seq_len): mask = torch.from_numpy(np.triu(np.ones((seq_len, seq_len)), k=1)) return mask.masked_fill(mask==1, float('-inf')) def forward(self, input): categorical, continuous, mask, __ = input batch_size = categorical[0].size(0) seq_len = categorical[0].size(1) # Encoder Embedding # categorical enc_x_cat = [emb_layer(categorical[i]) for i, emb_layer in enumerate(self.cate_emb[:-1])] enc_x_cat = torch.cat(enc_x_cat, -1) enc_x_cat = self.enc_cate_comb_proj(enc_x_cat) # continuous enc_x_cont = torch.cat([c.unsqueeze(-1) for c in continuous], -1).to(torch.float32) enc_x_cont = self.enc_cont_comb_proj(enc_x_cont) # concat embed_enc = torch.cat([enc_x_cat, enc_x_cont], -1) # DECODER Embedding # categorical dec_x_cat = [emb_layer(categorical[i]) for i, emb_layer in enumerate(self.cate_emb)] dec_x_cat = torch.cat(dec_x_cat, -1) dec_x_cat = self.dec_cate_comb_proj(dec_x_cat) # continuous dec_x_cont = torch.cat([c.unsqueeze(-1) for c in continuous], -1).to(torch.float32) dec_x_cont = self.dec_cont_comb_proj(dec_x_cont) # concat embed_dec = torch.cat([dec_x_cat, dec_x_cont], -1) # ATTENTION MASK 생성 # encoder하고 decoder의 mask는 가로 세로 길이가 모두 동일하여 # 사실 이렇게 3개로 나눌 필요가 없다 if self.enc_mask is None or self.enc_mask.size(0) != seq_len: self.enc_mask = self.get_mask(seq_len).to(self.device) if self.dec_mask is None or self.dec_mask.size(0) != seq_len: self.dec_mask = self.get_mask(seq_len).to(self.device) if self.enc_dec_mask is None or self.enc_dec_mask.size(0) != seq_len: self.enc_dec_mask = self.get_mask(seq_len).to(self.device) embed_enc = embed_enc.permute(1, 0, 2) embed_dec = embed_dec.permute(1, 0, 2) # Positional encoding embed_enc = self.pos_encoder(embed_enc) embed_dec = self.pos_decoder(embed_dec) out = self.transformer(embed_enc, embed_dec, src_mask=self.enc_mask, tgt_mask=self.dec_mask, memory_mask=self.enc_dec_mask) out = out.permute(1, 0, 2) out = out.contiguous().view(batch_size, -1, self.hidden_dim) out = self.fc(out) preds = self.activation(out).view(batch_size, -1) return preds class Saint_custom(nn.Module): def __init__(self, args): super(Saint_custom, self).__init__() self.args = args self.device = args.device self.hidden_dim = self.args.hidden_dim # self.dropout = self.args.dropout self.dropout = 0. ### Embedding # ENCODER embedding # categorical features # ========================== nn.Embedding에 들어갈 self.embedding_dims의 개수를 수정해주세요 =============================== self.embedding_dims = [self.hidden_dim // 3] * self.args.n_cates # cate feature 수와 동일하게 해주세요 assert len(self.embedding_dims) == self.args.n_cates self.enc_embedding_dims = sum(self.embedding_dims[:-1]) self.dec_embedding_dims = self.embedding_dims[-1] # ========================================================================================================================= # categorical feature embedding for both encoder & decoder self.cate_emb = nn.ModuleList([nn.Embedding(x, self.embedding_dims[idx]) for idx, x in enumerate(self.args.cate_embs)]) # ENCODER embedding # encoder combination projection self.enc_cate_comb_proj = nn.Sequential( nn.Linear(self.enc_embedding_dims, self.hidden_dim // 2), nn.LayerNorm(self.hidden_dim // 2) ) # continuous features self.enc_cont_comb_proj = nn.Sequential( nn.Linear(self.args.n_conts, self.hidden_dim // 2), nn.LayerNorm(self.hidden_dim // 2) ) # DECODER embedding # decoder combination projection self.dec_cate_comb_proj = nn.Sequential( nn.Linear(self.dec_embedding_dims, self.hidden_dim), nn.LayerNorm(self.hidden_dim) ) # Positional encoding self.pos_encoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len) self.pos_decoder = PositionalEncoding(self.hidden_dim, self.dropout, self.args.max_seq_len) self.transformer = nn.Transformer( d_model=self.hidden_dim, nhead=self.args.n_heads, num_encoder_layers=self.args.n_layers, num_decoder_layers=self.args.n_layers, dim_feedforward=self.hidden_dim, dropout=self.dropout, activation='relu') self.fc = nn.Linear(self.hidden_dim, 1) self.activation = nn.Sigmoid() self.enc_mask = None self.dec_mask = None self.enc_dec_mask = None def get_mask(self, seq_len): mask = torch.from_numpy(np.triu(np.ones((seq_len, seq_len)), k=1)) return mask.masked_fill(mask==1, float('-inf')) def forward(self, input): categorical, continuous, mask, __ = input batch_size = categorical[0].size(0) seq_len = categorical[0].size(1) # Encoder Embedding (Exercise) # categorical enc_x_cat = [emb_layer(categorical[i]) for i, emb_layer in enumerate(self.cate_emb[:-1])] enc_x_cat = torch.cat(enc_x_cat, -1) enc_x_cat = self.enc_cate_comb_proj(enc_x_cat) # continuous enc_x_cont = torch.cat([c.unsqueeze(-1) for c in continuous], -1).to(torch.float32) enc_x_cont = self.enc_cont_comb_proj(enc_x_cont) # concat embed_enc = torch.cat([enc_x_cat, enc_x_cont], -1) # DECODER Embedding (Response = interaction) dec_interaction = self.cate_emb[-1](categorical[-1]) dec_interaction = self.dec_cate_comb_proj(dec_interaction) embed_dec = dec_interaction # ATTENTION MASK 생성 # encoder하고 decoder의 mask는 가로 세로 길이가 모두 동일하여 # 사실 이렇게 3개로 나눌 필요가 없다 if self.enc_mask is None or self.enc_mask.size(0) != seq_len: self.enc_mask = self.get_mask(seq_len).to(self.device) if self.dec_mask is None or self.dec_mask.size(0) != seq_len: self.dec_mask = self.get_mask(seq_len).to(self.device) if self.enc_dec_mask is None or self.enc_dec_mask.size(0) != seq_len: self.enc_dec_mask = self.get_mask(seq_len).to(self.device) embed_enc = embed_enc.permute(1, 0, 2) embed_dec = embed_dec.permute(1, 0, 2) # Positional encoding embed_enc = self.pos_encoder(embed_enc) embed_dec = self.pos_decoder(embed_dec) out = self.transformer(embed_enc, embed_dec, src_mask=self.enc_mask, tgt_mask=self.dec_mask, memory_mask=self.enc_dec_mask) out = out.permute(1, 0, 2) out = out.contiguous().view(batch_size, -1, self.hidden_dim) out = self.fc(out) preds = self.activation(out).view(batch_size, -1) return preds
The text was updated successfully, but these errors were encountered:
No branches or pull requests
The text was updated successfully, but these errors were encountered: