Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/nn transformer block #367

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft

Fix/nn transformer block #367

wants to merge 30 commits into from

Conversation

cxzhang4
Copy link
Collaborator

A sketch of the FT-Transformer graph.

)
mlr3pipelines::mlr_pipeops$add("transformer_layer", PipeOpTorchTransformerLayer)

# TODO: remove deofault values from here
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since defaults should be handled by the PipeOp wrapper class


# TODO: remove layer_idx and ask about how we want to handle this condition
# layer_idx = -1
if (!is_first_layer || !prenormalization || first_prenormalization) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should match the official implementation now, but it is still confusing that !prenormalization is here...

d_embedding = 32

# # TODO: access x[, -1] first
# # TODO: sometimes there is no normalization, i.e. nn_identity instead of nn_layer_norm, figure out how to handle this
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could parameterize like the pre-implemented nn_ft_head, i.e. with an activation parameter

self$last_layer_query_idx = param_vals$last_layer_query_idx
}
),
private = list(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirm the ordering of the dimensions of the tensors. It may be the case that here, the first dimension is the sequence dimension (NOT the batch dimension)

@sebffischer
Copy link
Member

sebffischer commented Mar 21, 2025

#' @title Custom Function
#' @inherit torch::nnf_linear description
#' @section nn_module:
#' Calls [`torch::nn_linear()`] when trained where the parameter `in_features` is inferred as the second
#' to last dimension of the input tensor.
#' @section Parameters:
#' * `out_features` :: `integer(1)`\cr
#'   The output features of the linear layer.
#' * `bias` :: `logical(1)`\cr
#'   Whether to use a bias.
#'   Default is `TRUE`.
#'
#' @templateVar id nn_linear
#' @template pipeop_torch_channels_default
#' @templateVar param_vals out_features = 10
#' @template pipeop_torch
#' @template pipeop_torch_example
#'
#'
#' @export
PipeOpTorchFn = R6Class("PipeOpTorchFn",
  inherit = PipeOpTorch,
  public = list(
    #' @description Creates a new instance of this [R6][R6::R6Class] class.
    #' @template params_pipelines
    initialize = function(id = "nn_fn", param_vals = list()) {
      param_set = ps(fn = p_uty(...))
      super$initialize(
        id = id,
        param_set = param_set,
        param_vals = param_vals,
        module_generator = nn_linear
      )
    }
  ),
  private = list(
    .shapes_out = function(shapes_in, param_vals, task) {
      # Implement this.
      # 1. Generate a tensor of shape shapes_in (fill NA with something)
      # 2. Apply function private$.f
      # 3. Meausre shapes and fill dimensions with NA again

      # Should also be possible to implement shapes_out properly

      # Also take inspiration from pipeop_preproc_torch
    },
    .make_module = function(shapes_in, param_vals, task) {
      self$param_set$values$fn
    },
    .fn = NULL
  )
)

#' @include aaa.R
register_po("nn_fn", PipeOpTorchFn)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants