diff --git a/examples/sampling/graphbolt/temporal_link_prediction.py b/examples/sampling/graphbolt/temporal_link_prediction.py index 06d6863727b2..caf7e8aaa0c4 100644 --- a/examples/sampling/graphbolt/temporal_link_prediction.py +++ b/examples/sampling/graphbolt/temporal_link_prediction.py @@ -30,6 +30,7 @@ │ └───> Test set evaluation """ + import argparse import os import time @@ -282,7 +283,7 @@ def main(args): print("Loading data") # TODO: Add the datasets to built-in. dataset_path = download_datasets(args.dataset) - dataset = gb.OnDiskDataset(dataset_path).load() + dataset = gb.OnDiskDataset(dataset_path, force_preprocess=True).load() # Move the dataset to the selected storage. graph = dataset.graph.to(args.storage_device) diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index 2ba51d84a693..206d0090e9ad 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -24,22 +24,18 @@ constexpr bool is_labor(SamplerType S) { return S == SamplerType::LABOR || S == SamplerType::LABOR_DEPENDENT; } -template -struct SamplerArgs; +template struct SamplerArgs; -template <> -struct SamplerArgs {}; +template <> struct SamplerArgs {}; -template <> -struct SamplerArgs { - const torch::Tensor& indices; +template <> struct SamplerArgs { + const torch::Tensor &indices; single_seed random_seed; int64_t num_nodes; }; -template <> -struct SamplerArgs { - const torch::Tensor& indices; +template <> struct SamplerArgs { + const torch::Tensor &indices; continuous_seed random_seed; int64_t num_nodes; }; @@ -60,7 +56,7 @@ struct SamplerArgs { * id of each edge. */ class FusedCSCSamplingGraph : public torch::CustomClassHolder { - public: +public: using NodeTypeToIDMap = torch::Dict; using EdgeTypeToIDMap = torch::Dict; using NodeAttrMap = torch::Dict; @@ -85,13 +81,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * */ FusedCSCSamplingGraph( - const torch::Tensor& indptr, const torch::Tensor& indices, - const torch::optional& node_type_offset = torch::nullopt, - const torch::optional& type_per_edge = torch::nullopt, - const torch::optional& node_type_to_id = torch::nullopt, - const torch::optional& edge_type_to_id = torch::nullopt, - const torch::optional& node_attributes = torch::nullopt, - const torch::optional& edge_attributes = torch::nullopt); + const torch::Tensor &indptr, const torch::Tensor &indices, + const torch::optional &node_type_offset = torch::nullopt, + const torch::optional &type_per_edge = torch::nullopt, + const torch::optional &node_type_to_id = torch::nullopt, + const torch::optional &edge_type_to_id = torch::nullopt, + const torch::optional &node_attributes = torch::nullopt, + const torch::optional &edge_attributes = torch::nullopt); /** * @brief Create a fused CSC graph from tensors of CSC format. @@ -110,14 +106,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * * @return FusedCSCSamplingGraph */ - static c10::intrusive_ptr Create( - const torch::Tensor& indptr, const torch::Tensor& indices, - const torch::optional& node_type_offset, - const torch::optional& type_per_edge, - const torch::optional& node_type_to_id, - const torch::optional& edge_type_to_id, - const torch::optional& node_attributes, - const torch::optional& edge_attributes); + static c10::intrusive_ptr + Create(const torch::Tensor &indptr, const torch::Tensor &indices, + const torch::optional &node_type_offset, + const torch::optional &type_per_edge, + const torch::optional &node_type_to_id, + const torch::optional &edge_type_to_id, + const torch::optional &node_attributes, + const torch::optional &edge_attributes); /** @brief Get the number of nodes. */ int64_t NumNodes() const { return indptr_.size(0) - 1; } @@ -173,15 +169,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * If the input name is empty, return nullopt. Otherwise, return the node * attribute tensor by name. */ - inline torch::optional NodeAttribute( - torch::optional name) const { + inline torch::optional + NodeAttribute(torch::optional name) const { if (!name.has_value()) { return torch::nullopt; } - TORCH_CHECK( - node_attributes_.has_value() && - node_attributes_.value().contains(name.value()), - "Node attribute ", name.value(), " does not exist."); + TORCH_CHECK(node_attributes_.has_value() && + node_attributes_.value().contains(name.value()), + "Node attribute ", name.value(), " does not exist."); return torch::optional( node_attributes_.value().at(name.value())); } @@ -192,34 +187,33 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * If the input name is empty, return nullopt. Otherwise, return the edge * attribute tensor by name. */ - inline torch::optional EdgeAttribute( - torch::optional name) const { + inline torch::optional + EdgeAttribute(torch::optional name) const { if (!name.has_value()) { return torch::nullopt; } - TORCH_CHECK( - edge_attributes_.has_value() && - edge_attributes_.value().contains(name.value()), - "Edge attribute ", name.value(), " does not exist."); + TORCH_CHECK(edge_attributes_.has_value() && + edge_attributes_.value().contains(name.value()), + "Edge attribute ", name.value(), " does not exist."); return torch::optional( edge_attributes_.value().at(name.value())); } /** @brief Set the csc index pointer tensor. */ - inline void SetCSCIndptr(const torch::Tensor& indptr) { indptr_ = indptr; } + inline void SetCSCIndptr(const torch::Tensor &indptr) { indptr_ = indptr; } /** @brief Set the index tensor. */ - inline void SetIndices(const torch::Tensor& indices) { indices_ = indices; } + inline void SetIndices(const torch::Tensor &indices) { indices_ = indices; } /** @brief Set the node type offset tensor for a heterogeneous graph. */ - inline void SetNodeTypeOffset( - const torch::optional& node_type_offset) { + inline void + SetNodeTypeOffset(const torch::optional &node_type_offset) { node_type_offset_ = node_type_offset; } /** @brief Set the edge type tensor for a heterogeneous graph. */ - inline void SetTypePerEdge( - const torch::optional& type_per_edge) { + inline void + SetTypePerEdge(const torch::optional &type_per_edge) { type_per_edge_ = type_per_edge; } @@ -227,8 +221,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * @brief Set the node type to id map for a heterogeneous graph. * @note The map is a dictionary mapping node type names to type IDs. */ - inline void SetNodeTypeToID( - const torch::optional& node_type_to_id) { + inline void + SetNodeTypeToID(const torch::optional &node_type_to_id) { node_type_to_id_ = node_type_to_id; } @@ -236,20 +230,20 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * @brief Set the edge type to id map for a heterogeneous graph. * @note The map is a dictionary mapping edge type names to type IDs. */ - inline void SetEdgeTypeToID( - const torch::optional& edge_type_to_id) { + inline void + SetEdgeTypeToID(const torch::optional &edge_type_to_id) { edge_type_to_id_ = edge_type_to_id; } /** @brief Set the node attributes dictionary. */ - inline void SetNodeAttributes( - const torch::optional& node_attributes) { + inline void + SetNodeAttributes(const torch::optional &node_attributes) { node_attributes_ = node_attributes; } /** @brief Set the edge attributes dictionary. */ - inline void SetEdgeAttributes( - const torch::optional& edge_attributes) { + inline void + SetEdgeAttributes(const torch::optional &edge_attributes) { edge_attributes_ = edge_attributes; } @@ -263,28 +257,28 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * @brief Load graph from stream. * @param archive Input stream for deserializing. */ - void Load(torch::serialize::InputArchive& archive); + void Load(torch::serialize::InputArchive &archive); /** * @brief Save graph to stream. * @param archive Output stream for serializing. */ - void Save(torch::serialize::OutputArchive& archive) const; + void Save(torch::serialize::OutputArchive &archive) const; /** * @brief Pickle method for deserializing. * @param state The state of serialized FusedCSCSamplingGraph. */ void SetState( - const torch::Dict>& - state); + const torch::Dict> + &state); /** * @brief Pickle method for serializing. * @returns The state of this FusedCSCSamplingGraph. */ - torch::Dict> GetState() - const; + torch::Dict> + GetState() const; /** * @brief Return the subgraph induced on the inbound edges of the given nodes. @@ -292,8 +286,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * * @return FusedSampledSubgraph. */ - c10::intrusive_ptr InSubgraph( - const torch::Tensor& nodes) const; + c10::intrusive_ptr + InSubgraph(const torch::Tensor &nodes) const; /** * @brief Sample neighboring edges of the given nodes and return the induced @@ -335,13 +329,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * @return An intrusive pointer to a FusedSampledSubgraph object containing * the sampled graph's information. */ - c10::intrusive_ptr SampleNeighbors( - torch::optional seeds, - torch::optional> seed_offsets, - const std::vector& fanouts, bool replace, bool layer, - bool return_eids, torch::optional probs_or_mask, - torch::optional random_seed, - double seed2_contribution) const; + c10::intrusive_ptr + SampleNeighbors(torch::optional seeds, + torch::optional> seed_offsets, + const std::vector &fanouts, bool replace, bool layer, + bool return_eids, + torch::optional probs_or_mask, + torch::optional random_seed, + double seed2_contribution) const; /** * @brief Sample neighboring edges of the given nodes with a temporal @@ -373,13 +368,15 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * the sampled graph's information. * */ - c10::intrusive_ptr TemporalSampleNeighbors( - const torch::Tensor& input_nodes, - const torch::Tensor& input_nodes_timestamp, - const std::vector& fanouts, bool replace, bool layer, - bool return_eids, torch::optional probs_or_mask, - torch::optional node_timestamp_attr_name, - torch::optional edge_timestamp_attr_name) const; + c10::intrusive_ptr + TemporalSampleNeighbors(const torch::Tensor &input_nodes, + const torch::Tensor &input_nodes_timestamp, + const std::vector &fanouts, bool replace, + bool layer, bool return_eids, + torch::optional probs_or_mask, + torch::optional node_timestamp_attr_name, + torch::optional edge_timestamp_attr_name, + torch::optional time_window) const; /** * @brief Copy the graph to shared memory. @@ -387,8 +384,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * * @return A new FusedCSCSamplingGraph object on shared memory. */ - c10::intrusive_ptr CopyToSharedMemory( - const std::string& shared_memory_name); + c10::intrusive_ptr + CopyToSharedMemory(const std::string &shared_memory_name); /** * @brief Load the graph from shared memory. @@ -396,8 +393,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * * @return A new FusedCSCSamplingGraph object on shared memory. */ - static c10::intrusive_ptr LoadFromSharedMemory( - const std::string& shared_memory_name); + static c10::intrusive_ptr + LoadFromSharedMemory(const std::string &shared_memory_name); /** * @brief Hold the shared memory objects of the the tensor metadata and data. @@ -410,16 +407,16 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * @param tensor_metadata_shm The shared memory objects of tensor metadata. * @param tensor_data_shm The shared memory objects of tensor data. */ - void HoldSharedMemoryObject( - SharedMemoryPtr tensor_metadata_shm, SharedMemoryPtr tensor_data_shm); + void HoldSharedMemoryObject(SharedMemoryPtr tensor_metadata_shm, + SharedMemoryPtr tensor_data_shm); - private: +private: template - c10::intrusive_ptr SampleNeighborsImpl( - const torch::Tensor& seeds, - torch::optional>& seed_offsets, - const std::vector& fanouts, bool return_eids, - NumPickFn num_pick_fn, PickFn pick_fn) const; + c10::intrusive_ptr + SampleNeighborsImpl(const torch::Tensor &seeds, + torch::optional> &seed_offsets, + const std::vector &fanouts, bool return_eids, + NumPickFn num_pick_fn, PickFn pick_fn) const; /** @brief CSC format index pointer array. */ torch::Tensor indptr_; @@ -505,34 +502,38 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * numbers. */ template -void NumPick( - int64_t fanout, bool replace, - const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors, PickedNumType* num_picked_ptr); - -int64_t TemporalNumPick( - torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, - bool replace, const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, int64_t seed_offset, - int64_t offset, int64_t num_neighbors); +void NumPick(int64_t fanout, bool replace, + const torch::optional &probs_or_mask, + int64_t offset, int64_t num_neighbors, + PickedNumType *num_picked_ptr); + +int64_t TemporalNumPick(torch::Tensor seed_timestamp, torch::Tensor csc_indics, + int64_t fanout, bool replace, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + int64_t seed_offset, int64_t offset, + int64_t num_neighbors); template -void NumPickByEtype( - bool with_seed_offsets, const std::vector& fanouts, bool replace, - const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index, - const std::vector& etype_id_to_num_picked_offset); - -int64_t TemporalNumPickByEtype( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, - const std::vector& fanouts, bool replace, - const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, int64_t seed_offset, - int64_t offset, int64_t num_neighbors); +void NumPickByEtype(bool with_seed_offsets, const std::vector &fanouts, + bool replace, const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + int64_t offset, int64_t num_neighbors, + PickedNumType *num_picked_ptr, int64_t seed_index, + const std::vector &etype_id_to_num_picked_offset); + +int64_t +TemporalNumPickByEtype(torch::Tensor seed_timestamp, torch::Tensor csc_indices, + const std::vector &fanouts, bool replace, + const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + int64_t seed_offset, int64_t offset, + int64_t num_neighbors); /** * @brief Picks a specified number of neighbors for a node, starting from the @@ -569,28 +570,29 @@ int64_t TemporalNumPickByEtype( * should be put. Enough memory space should be allocated in advance. */ template -int64_t Pick( - int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, - const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - SamplerArgs args, PickedType* picked_data_ptr); +int64_t Pick(int64_t offset, int64_t num_neighbors, int64_t fanout, + bool replace, const torch::TensorOptions &options, + const torch::optional &probs_or_mask, + SamplerArgs args, + PickedType *picked_data_ptr); template -std::enable_if_t Pick( - int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, - const torch::TensorOptions& options, - const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr); +std::enable_if_t +Pick(int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, + const torch::TensorOptions &options, + const torch::optional &probs_or_mask, SamplerArgs args, + PickedType *picked_data_ptr); template -int64_t TemporalPick( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, - int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout, - bool replace, const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, - PickedType* picked_data_ptr); +int64_t TemporalPick(torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t seed_offset, int64_t offset, int64_t num_neighbors, + int64_t fanout, bool replace, + const torch::TensorOptions &options, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + PickedType *picked_data_ptr); /** * @brief Picks a specified number of neighbors for a node per edge type, @@ -626,36 +628,35 @@ int64_t TemporalPick( * etype_id to the offset of its pick numbers in the tensor. */ template -int64_t PickByEtype( - bool with_seed_offsets, int64_t offset, int64_t num_neighbors, - const std::vector& fanouts, bool replace, - const torch::TensorOptions& options, const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr, int64_t seed_offset, - PickedType* subgraph_indptr_ptr, - const std::vector& etype_id_to_num_picked_offset); +int64_t PickByEtype(bool with_seed_offsets, int64_t offset, + int64_t num_neighbors, const std::vector &fanouts, + bool replace, const torch::TensorOptions &options, + const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + SamplerArgs args, PickedType *picked_data_ptr, + int64_t seed_offset, PickedType *subgraph_indptr_ptr, + const std::vector &etype_id_to_num_picked_offset); template int64_t TemporalPickByEtype( torch::Tensor seed_timestamp, torch::Tensor csc_indices, int64_t seed_offset, int64_t offset, int64_t num_neighbors, - const std::vector& fanouts, bool replace, - const torch::TensorOptions& options, const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, - PickedType* picked_data_ptr); - -template < - bool NonUniform, bool Replace, typename ProbsType, SamplerType S, - typename PickedType, int StackSize = 1024> -std::enable_if_t LaborPick( - int64_t offset, int64_t num_neighbors, int64_t fanout, - const torch::TensorOptions& options, - const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr); - -} // namespace sampling -} // namespace graphbolt - -#endif // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ + const std::vector &fanouts, bool replace, + const torch::TensorOptions &options, const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, PickedType *picked_data_ptr); + +template +std::enable_if_t +LaborPick(int64_t offset, int64_t num_neighbors, int64_t fanout, + const torch::TensorOptions &options, + const torch::optional &probs_or_mask, + SamplerArgs args, PickedType *picked_data_ptr); + +} // namespace sampling +} // namespace graphbolt + +#endif // GRAPHBOLT_CSC_SAMPLING_GRAPH_H_ diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index a36404632103..14833c9f2918 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -25,30 +25,30 @@ #include "./utils.h" namespace { -torch::optional> TensorizeDict( - const torch::optional>& dict) { +torch::optional> +TensorizeDict(const torch::optional> &dict) { if (!dict.has_value()) { return torch::nullopt; } torch::Dict result; - for (const auto& pair : dict.value()) { + for (const auto &pair : dict.value()) { result.insert(pair.key(), torch::tensor(pair.value(), torch::kInt64)); } return result; } torch::optional> DetensorizeDict( - const torch::optional>& dict) { + const torch::optional> &dict) { if (!dict.has_value()) { return torch::nullopt; } torch::Dict result; - for (const auto& pair : dict.value()) { + for (const auto &pair : dict.value()) { result.insert(pair.key(), pair.value().item()); } return result; } -} // namespace +} // namespace namespace graphbolt { namespace sampling { @@ -56,20 +56,16 @@ namespace sampling { static const int kPickleVersion = 6199; FusedCSCSamplingGraph::FusedCSCSamplingGraph( - const torch::Tensor& indptr, const torch::Tensor& indices, - const torch::optional& node_type_offset, - const torch::optional& type_per_edge, - const torch::optional& node_type_to_id, - const torch::optional& edge_type_to_id, - const torch::optional& node_attributes, - const torch::optional& edge_attributes) - : indptr_(indptr), - indices_(indices), - node_type_offset_(node_type_offset), - type_per_edge_(type_per_edge), - node_type_to_id_(node_type_to_id), - edge_type_to_id_(edge_type_to_id), - node_attributes_(node_attributes), + const torch::Tensor &indptr, const torch::Tensor &indices, + const torch::optional &node_type_offset, + const torch::optional &type_per_edge, + const torch::optional &node_type_to_id, + const torch::optional &edge_type_to_id, + const torch::optional &node_attributes, + const torch::optional &edge_attributes) + : indptr_(indptr), indices_(indices), node_type_offset_(node_type_offset), + type_per_edge_(type_per_edge), node_type_to_id_(node_type_to_id), + edge_type_to_id_(edge_type_to_id), node_attributes_(node_attributes), edge_attributes_(edge_attributes) { TORCH_CHECK(indptr.dim() == 1); TORCH_CHECK(indices.dim() == 1); @@ -77,20 +73,19 @@ FusedCSCSamplingGraph::FusedCSCSamplingGraph( } c10::intrusive_ptr FusedCSCSamplingGraph::Create( - const torch::Tensor& indptr, const torch::Tensor& indices, - const torch::optional& node_type_offset, - const torch::optional& type_per_edge, - const torch::optional& node_type_to_id, - const torch::optional& edge_type_to_id, - const torch::optional& node_attributes, - const torch::optional& edge_attributes) { + const torch::Tensor &indptr, const torch::Tensor &indices, + const torch::optional &node_type_offset, + const torch::optional &type_per_edge, + const torch::optional &node_type_to_id, + const torch::optional &edge_type_to_id, + const torch::optional &node_attributes, + const torch::optional &edge_attributes) { if (node_type_offset.has_value()) { - auto& offset = node_type_offset.value(); + auto &offset = node_type_offset.value(); TORCH_CHECK(offset.dim() == 1); TORCH_CHECK(node_type_to_id.has_value()); - TORCH_CHECK( - offset.size(0) == - static_cast(node_type_to_id.value().size() + 1)); + TORCH_CHECK(offset.size(0) == + static_cast(node_type_to_id.value().size() + 1)); } if (type_per_edge.has_value()) { TORCH_CHECK(type_per_edge.value().dim() == 1); @@ -98,22 +93,21 @@ c10::intrusive_ptr FusedCSCSamplingGraph::Create( TORCH_CHECK(edge_type_to_id.has_value()); } if (node_attributes.has_value()) { - for (const auto& pair : node_attributes.value()) { - TORCH_CHECK( - pair.value().size(0) == indptr.size(0) - 1, - "Expected node_attribute.size(0) and num_nodes to be equal, " - "but node_attribute.size(0) was ", - pair.value().size(0), ", and num_nodes was ", indptr.size(0) - 1, - "."); + for (const auto &pair : node_attributes.value()) { + TORCH_CHECK(pair.value().size(0) == indptr.size(0) - 1, + "Expected node_attribute.size(0) and num_nodes to be equal, " + "but node_attribute.size(0) was ", + pair.value().size(0), ", and num_nodes was ", + indptr.size(0) - 1, "."); } } if (edge_attributes.has_value()) { - for (const auto& pair : edge_attributes.value()) { - TORCH_CHECK( - pair.value().size(0) == indices.size(0), - "Expected edge_attribute.size(0) and num_edges to be equal, " - "but edge_attribute.size(0) was ", - pair.value().size(0), ", and num_edges was ", indices.size(0), "."); + for (const auto &pair : edge_attributes.value()) { + TORCH_CHECK(pair.value().size(0) == indices.size(0), + "Expected edge_attribute.size(0) and num_edges to be equal, " + "but edge_attribute.size(0) was ", + pair.value().size(0), ", and num_edges was ", indices.size(0), + "."); } } return c10::make_intrusive( @@ -121,107 +115,101 @@ c10::intrusive_ptr FusedCSCSamplingGraph::Create( edge_type_to_id, node_attributes, edge_attributes); } -void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive& archive) { +void FusedCSCSamplingGraph::Load(torch::serialize::InputArchive &archive) { const int64_t magic_num = read_from_archive(archive, "FusedCSCSamplingGraph/magic_num"); - TORCH_CHECK( - magic_num == kCSCSamplingGraphSerializeMagic, - "Magic numbers mismatch when loading FusedCSCSamplingGraph."); + TORCH_CHECK(magic_num == kCSCSamplingGraphSerializeMagic, + "Magic numbers mismatch when loading FusedCSCSamplingGraph."); indptr_ = read_from_archive(archive, "FusedCSCSamplingGraph/indptr"); - indices_ = read_from_archive( - archive, "FusedCSCSamplingGraph/indices"); - if (read_from_archive( - archive, "FusedCSCSamplingGraph/has_node_type_offset")) { + indices_ = read_from_archive(archive, + "FusedCSCSamplingGraph/indices"); + if (read_from_archive(archive, + "FusedCSCSamplingGraph/has_node_type_offset")) { node_type_offset_ = read_from_archive( archive, "FusedCSCSamplingGraph/node_type_offset"); } - if (read_from_archive( - archive, "FusedCSCSamplingGraph/has_type_per_edge")) { + if (read_from_archive(archive, + "FusedCSCSamplingGraph/has_type_per_edge")) { type_per_edge_ = read_from_archive( archive, "FusedCSCSamplingGraph/type_per_edge"); } - if (read_from_archive( - archive, "FusedCSCSamplingGraph/has_node_type_to_id")) { + if (read_from_archive(archive, + "FusedCSCSamplingGraph/has_node_type_to_id")) { node_type_to_id_ = read_from_archive( archive, "FusedCSCSamplingGraph/node_type_to_id"); } - if (read_from_archive( - archive, "FusedCSCSamplingGraph/has_edge_type_to_id")) { + if (read_from_archive(archive, + "FusedCSCSamplingGraph/has_edge_type_to_id")) { edge_type_to_id_ = read_from_archive( archive, "FusedCSCSamplingGraph/edge_type_to_id"); } - if (read_from_archive( - archive, "FusedCSCSamplingGraph/has_node_attributes")) { + if (read_from_archive(archive, + "FusedCSCSamplingGraph/has_node_attributes")) { node_attributes_ = read_from_archive( archive, "FusedCSCSamplingGraph/node_attributes"); } - if (read_from_archive( - archive, "FusedCSCSamplingGraph/has_edge_attributes")) { + if (read_from_archive(archive, + "FusedCSCSamplingGraph/has_edge_attributes")) { edge_attributes_ = read_from_archive( archive, "FusedCSCSamplingGraph/edge_attributes"); } } void FusedCSCSamplingGraph::Save( - torch::serialize::OutputArchive& archive) const { - archive.write( - "FusedCSCSamplingGraph/magic_num", kCSCSamplingGraphSerializeMagic); + torch::serialize::OutputArchive &archive) const { + archive.write("FusedCSCSamplingGraph/magic_num", + kCSCSamplingGraphSerializeMagic); archive.write("FusedCSCSamplingGraph/indptr", indptr_); archive.write("FusedCSCSamplingGraph/indices", indices_); - archive.write( - "FusedCSCSamplingGraph/has_node_type_offset", - node_type_offset_.has_value()); + archive.write("FusedCSCSamplingGraph/has_node_type_offset", + node_type_offset_.has_value()); if (node_type_offset_) { - archive.write( - "FusedCSCSamplingGraph/node_type_offset", node_type_offset_.value()); + archive.write("FusedCSCSamplingGraph/node_type_offset", + node_type_offset_.value()); } - archive.write( - "FusedCSCSamplingGraph/has_type_per_edge", type_per_edge_.has_value()); + archive.write("FusedCSCSamplingGraph/has_type_per_edge", + type_per_edge_.has_value()); if (type_per_edge_) { - archive.write( - "FusedCSCSamplingGraph/type_per_edge", type_per_edge_.value()); + archive.write("FusedCSCSamplingGraph/type_per_edge", + type_per_edge_.value()); } - archive.write( - "FusedCSCSamplingGraph/has_node_type_to_id", - node_type_to_id_.has_value()); + archive.write("FusedCSCSamplingGraph/has_node_type_to_id", + node_type_to_id_.has_value()); if (node_type_to_id_) { - archive.write( - "FusedCSCSamplingGraph/node_type_to_id", node_type_to_id_.value()); + archive.write("FusedCSCSamplingGraph/node_type_to_id", + node_type_to_id_.value()); } - archive.write( - "FusedCSCSamplingGraph/has_edge_type_to_id", - edge_type_to_id_.has_value()); + archive.write("FusedCSCSamplingGraph/has_edge_type_to_id", + edge_type_to_id_.has_value()); if (edge_type_to_id_) { - archive.write( - "FusedCSCSamplingGraph/edge_type_to_id", edge_type_to_id_.value()); + archive.write("FusedCSCSamplingGraph/edge_type_to_id", + edge_type_to_id_.value()); } - archive.write( - "FusedCSCSamplingGraph/has_node_attributes", - node_attributes_.has_value()); + archive.write("FusedCSCSamplingGraph/has_node_attributes", + node_attributes_.has_value()); if (node_attributes_) { - archive.write( - "FusedCSCSamplingGraph/node_attributes", node_attributes_.value()); + archive.write("FusedCSCSamplingGraph/node_attributes", + node_attributes_.value()); } - archive.write( - "FusedCSCSamplingGraph/has_edge_attributes", - edge_attributes_.has_value()); + archive.write("FusedCSCSamplingGraph/has_edge_attributes", + edge_attributes_.has_value()); if (edge_attributes_) { - archive.write( - "FusedCSCSamplingGraph/edge_attributes", edge_attributes_.value()); + archive.write("FusedCSCSamplingGraph/edge_attributes", + edge_attributes_.value()); } } void FusedCSCSamplingGraph::SetState( - const torch::Dict>& - state) { + const torch::Dict> + &state) { // State is a dict of dicts. The tensor-type attributes are stored in the dict // with key "independent_tensors". The dict-type attributes (edge_attributes) // are stored directly with the their name as the key. - const auto& independent_tensors = state.at("independent_tensors"); + const auto &independent_tensors = state.at("independent_tensors"); TORCH_CHECK( independent_tensors.at("version_number") .equal(torch::tensor({kPickleVersion})), @@ -283,8 +271,8 @@ FusedCSCSamplingGraph::GetState() const { return state; } -c10::intrusive_ptr FusedCSCSamplingGraph::InSubgraph( - const torch::Tensor& nodes) const { +c10::intrusive_ptr +FusedCSCSamplingGraph::InSubgraph(const torch::Tensor &nodes) const { if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr_) && utils::is_accessible_from_gpu(indices_) && (!type_per_edge_.has_value() || @@ -353,51 +341,49 @@ c10::intrusive_ptr FusedCSCSamplingGraph::InSubgraph( * num_neighbors (number of neighbors) as params and returns the pick number of * the given node. */ -auto GetNumPickFn( - const std::vector& fanouts, bool replace, - const torch::optional& type_per_edge, - const torch::optional& probs_or_mask, - bool with_seed_offsets) { +auto GetNumPickFn(const std::vector &fanouts, bool replace, + const torch::optional &type_per_edge, + const torch::optional &probs_or_mask, + bool with_seed_offsets) { // If fanouts.size() > 1, returns the total number of all edge types of the // given node. return [&fanouts, replace, &probs_or_mask, &type_per_edge, with_seed_offsets]( int64_t offset, int64_t num_neighbors, auto num_picked_ptr, int64_t seed_index, - const std::vector& etype_id_to_num_picked_offset) { + const std::vector &etype_id_to_num_picked_offset) { if (fanouts.size() > 1) { - NumPickByEtype( - with_seed_offsets, fanouts, replace, type_per_edge.value(), - probs_or_mask, offset, num_neighbors, num_picked_ptr, seed_index, - etype_id_to_num_picked_offset); + NumPickByEtype(with_seed_offsets, fanouts, replace, type_per_edge.value(), + probs_or_mask, offset, num_neighbors, num_picked_ptr, + seed_index, etype_id_to_num_picked_offset); } else { - NumPick( - fanouts[0], replace, probs_or_mask, offset, num_neighbors, - num_picked_ptr + seed_index); + NumPick(fanouts[0], replace, probs_or_mask, offset, num_neighbors, + num_picked_ptr + seed_index); } }; } -auto GetTemporalNumPickFn( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, - const std::vector& fanouts, bool replace, - const torch::optional& type_per_edge, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp) { +auto GetTemporalNumPickFn(torch::Tensor seed_timestamp, + torch::Tensor csc_indices, + const std::vector &fanouts, bool replace, + const torch::optional &type_per_edge, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window) { // If fanouts.size() > 1, returns the total number of all edge types of the // given node. return [&seed_timestamp, &csc_indices, &fanouts, replace, &probs_or_mask, - &type_per_edge, &node_timestamp, &edge_timestamp]( + &type_per_edge, &node_timestamp, &edge_timestamp, &time_window]( int64_t seed_offset, int64_t offset, int64_t num_neighbors) { if (fanouts.size() > 1) { return TemporalNumPickByEtype( seed_timestamp, csc_indices, fanouts, replace, type_per_edge.value(), - probs_or_mask, node_timestamp, edge_timestamp, seed_offset, offset, - num_neighbors); + probs_or_mask, node_timestamp, edge_timestamp, time_window, + seed_offset, offset, num_neighbors); } else { - return TemporalNumPick( - seed_timestamp, csc_indices, fanouts[0], replace, probs_or_mask, - node_timestamp, edge_timestamp, seed_offset, offset, num_neighbors); + return TemporalNumPick(seed_timestamp, csc_indices, fanouts[0], replace, + probs_or_mask, node_timestamp, edge_timestamp, + time_window, seed_offset, offset, num_neighbors); } }; } @@ -426,25 +412,24 @@ auto GetTemporalNumPickFn( * the picked neighbors at the address specified by picked_data_ptr. */ template -auto GetPickFn( - const std::vector& fanouts, bool replace, - const torch::TensorOptions& options, - const torch::optional& type_per_edge, - const torch::optional& probs_or_mask, bool with_seed_offsets, - SamplerArgs args) { +auto GetPickFn(const std::vector &fanouts, bool replace, + const torch::TensorOptions &options, + const torch::optional &type_per_edge, + const torch::optional &probs_or_mask, + bool with_seed_offsets, SamplerArgs args) { return [&fanouts, replace, &options, &type_per_edge, &probs_or_mask, args, with_seed_offsets]( int64_t offset, int64_t num_neighbors, auto picked_data_ptr, int64_t seed_offset, auto subgraph_indptr_ptr, - const std::vector& etype_id_to_num_picked_offset) { + const std::vector &etype_id_to_num_picked_offset) { // If fanouts.size() > 1, perform sampling for each edge type of each // node; otherwise just sample once for each node with no regard of edge // types. if (fanouts.size() > 1) { - return PickByEtype( - with_seed_offsets, offset, num_neighbors, fanouts, replace, options, - type_per_edge.value(), probs_or_mask, args, picked_data_ptr, - seed_offset, subgraph_indptr_ptr, etype_id_to_num_picked_offset); + return PickByEtype(with_seed_offsets, offset, num_neighbors, fanouts, + replace, options, type_per_edge.value(), probs_or_mask, + args, picked_data_ptr, seed_offset, + subgraph_indptr_ptr, etype_id_to_num_picked_offset); } else { int64_t num_sampled = Pick( offset, num_neighbors, fanouts[0], replace, options, probs_or_mask, @@ -458,46 +443,46 @@ auto GetPickFn( } template -auto GetTemporalPickFn( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, - const std::vector& fanouts, bool replace, - const torch::TensorOptions& options, - const torch::optional& type_per_edge, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, SamplerArgs args) { - return - [&seed_timestamp, &csc_indices, &fanouts, replace, &options, - &type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp, args]( - int64_t seed_offset, int64_t offset, int64_t num_neighbors, - auto picked_data_ptr) { - // If fanouts.size() > 1, perform sampling for each edge type of each - // node; otherwise just sample once for each node with no regard of edge - // types. - if (fanouts.size() > 1) { - return TemporalPickByEtype( - seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, - fanouts, replace, options, type_per_edge.value(), probs_or_mask, - node_timestamp, edge_timestamp, args, picked_data_ptr); - } else { - int64_t num_sampled = TemporalPick( - seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, - fanouts[0], replace, options, probs_or_mask, node_timestamp, - edge_timestamp, args, picked_data_ptr); - if (type_per_edge.has_value()) { - std::sort(picked_data_ptr, picked_data_ptr + num_sampled); - } - return num_sampled; - } - }; +auto GetTemporalPickFn(torch::Tensor seed_timestamp, torch::Tensor csc_indices, + const std::vector &fanouts, bool replace, + const torch::TensorOptions &options, + const torch::optional &type_per_edge, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + torch::optional time_window, + SamplerArgs args) { + return [&seed_timestamp, &csc_indices, &fanouts, replace, &options, + &type_per_edge, &probs_or_mask, &node_timestamp, &edge_timestamp, + &time_window, args](int64_t seed_offset, int64_t offset, + int64_t num_neighbors, auto picked_data_ptr) { + // If fanouts.size() > 1, perform sampling for each edge type of each + // node; otherwise just sample once for each node with no regard of edge + // types. + if (fanouts.size() > 1) { + return TemporalPickByEtype( + seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, + fanouts, replace, options, type_per_edge.value(), probs_or_mask, + node_timestamp, edge_timestamp, time_window, args, picked_data_ptr); + } else { + int64_t num_sampled = TemporalPick( + seed_timestamp, csc_indices, seed_offset, offset, num_neighbors, + fanouts[0], replace, options, probs_or_mask, node_timestamp, + edge_timestamp, time_window, args, picked_data_ptr); + if (type_per_edge.has_value()) { + std::sort(picked_data_ptr, picked_data_ptr + num_sampled); + } + return num_sampled; + } + }; } template c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighborsImpl( - const torch::Tensor& seeds, - torch::optional>& seed_offsets, - const std::vector& fanouts, bool return_eids, + const torch::Tensor &seeds, + torch::optional> &seed_offsets, + const std::vector &fanouts, bool return_eids, NumPickFn num_pick_fn, PickFn pick_fn) const { const int64_t num_seeds = seeds.size(0); const auto indptr_options = indptr_.options(); @@ -533,7 +518,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( // number tensor. std::vector etype_id_to_num_picked_offset(num_etypes + 1); if (hetero_with_seed_offsets) { - for (auto& etype_and_id : edge_type_to_id_.value()) { + for (auto &etype_and_id : edge_type_to_id_.value()) { auto etype = etype_and_id.key(); auto id = etype_and_id.value(); auto [src_type, dst_type] = utils::parse_src_dst_ntype_from_etype(etype); @@ -544,10 +529,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( seed_offsets->at(dst_ntype_id + 1) - seed_offsets->at(dst_ntype_id) + 1; } - std::partial_sum( - etype_id_to_num_picked_offset.begin(), - etype_id_to_num_picked_offset.end(), - etype_id_to_num_picked_offset.begin()); + std::partial_sum(etype_id_to_num_picked_offset.begin(), + etype_id_to_num_picked_offset.end(), + etype_id_to_num_picked_offset.begin()); } else { etype_id_to_dst_ntype_id[0] = 0; etype_id_to_num_picked_offset[1] = num_seeds + 1; @@ -593,9 +577,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( } else { const auto seed_type_id = (hetero_with_seed_offsets) - ? std::upper_bound( - seed_offsets->begin(), - seed_offsets->end(), i) - + ? std::upper_bound(seed_offsets->begin(), + seed_offsets->end(), i) - seed_offsets->begin() - 1 : 0; // `seed_index` indicates the index of the current @@ -605,10 +588,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( (hetero_with_seed_offsets) ? i - seed_offsets->at(seed_type_id) : i; - num_pick_fn( - offset, num_neighbors, - num_picked_neighbors_data_ptr + 1, seed_index, - etype_id_to_num_picked_offset); + num_pick_fn(offset, num_neighbors, + num_picked_neighbors_data_ptr + 1, + seed_index, etype_id_to_num_picked_offset); } } }); @@ -625,10 +607,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( torch::empty({num_etypes + 1}, indptr_options); const auto num_picked_offset_data_ptr = num_picked_offset_tensor.data_ptr(); - std::copy( - etype_id_to_num_picked_offset.begin(), - etype_id_to_num_picked_offset.end(), - num_picked_offset_data_ptr); + std::copy(etype_id_to_num_picked_offset.begin(), + etype_id_to_num_picked_offset.end(), + num_picked_offset_data_ptr); torch::Tensor substract_offset = torch::empty({num_etypes}, indptr_options); const auto substract_offset_data_ptr = @@ -638,9 +619,9 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( substract_offset_data_ptr[i] = subgraph_indptr_data_ptr [etype_id_to_num_picked_offset[i]]; } - subgraph_indptr_substract = ops::ExpandIndptr( - num_picked_offset_tensor, indptr_.scalar_type(), - substract_offset); + subgraph_indptr_substract = + ops::ExpandIndptr(num_picked_offset_tensor, + indptr_.scalar_type(), substract_offset); } // When doing non-temporal hetero sampling, we generate an @@ -681,9 +662,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( auto picked_number = 0; const auto seed_type_id = (hetero_with_seed_offsets) - ? std::upper_bound( - seed_offsets->begin(), seed_offsets->end(), - i) - + ? std::upper_bound(seed_offsets->begin(), + seed_offsets->end(), i) - seed_offsets->begin() - 1 : 0; const auto seed_index = @@ -696,19 +676,19 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( picked_number = num_picked_neighbors_data_ptr[i + 1]; auto picked_offset = subgraph_indptr_data_ptr[i]; if (picked_number > 0) { - auto actual_picked_count = pick_fn( - i, offset, num_neighbors, - picked_eids_data_ptr + picked_offset); + auto actual_picked_count = + pick_fn(i, offset, num_neighbors, + picked_eids_data_ptr + picked_offset); TORCH_CHECK( actual_picked_count == picked_number, "Actual picked count doesn't match the calculated" " pick number."); } } else { - picked_number = pick_fn( - offset, num_neighbors, picked_eids_data_ptr, - seed_index, subgraph_indptr_data_ptr, - etype_id_to_num_picked_offset); + picked_number = + pick_fn(offset, num_neighbors, picked_eids_data_ptr, + seed_index, subgraph_indptr_data_ptr, + etype_id_to_num_picked_offset); if (!hetero_with_seed_offsets) { TORCH_CHECK( num_picked_neighbors_data_ptr[i + 1] == @@ -790,7 +770,8 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( })); torch::optional subgraph_reverse_edge_ids = torch::nullopt; - if (return_eids) subgraph_reverse_edge_ids = std::move(picked_eids); + if (return_eids) + subgraph_reverse_edge_ids = std::move(picked_eids); if (subgraph_indptr_substract.has_value()) { subgraph_indptr -= subgraph_indptr_substract.value(); @@ -804,7 +785,7 @@ FusedCSCSamplingGraph::SampleNeighborsImpl( c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( torch::optional seeds, torch::optional> seed_offsets, - const std::vector& fanouts, bool replace, bool layer, + const std::vector &fanouts, bool replace, bool layer, bool return_eids, torch::optional probs_or_mask, torch::optional random_seed, double seed2_contribution) const { @@ -856,17 +837,15 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( NumNodes()}; return SampleNeighborsImpl( seeds.value(), seed_offsets, fanouts, return_eids, - GetNumPickFn( - fanouts, replace, type_per_edge_, probs_or_mask, - with_seed_offsets), - GetPickFn( - fanouts, replace, indptr_.options(), type_per_edge_, - probs_or_mask, with_seed_offsets, args)); + GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask, + with_seed_offsets), + GetPickFn(fanouts, replace, indptr_.options(), type_per_edge_, + probs_or_mask, with_seed_offsets, args)); } else { auto args = [&] { if (random_seed.has_value() && random_seed->numel() == 1) { - return SamplerArgs{ - indices_, random_seed.value(), NumNodes()}; + return SamplerArgs{indices_, random_seed.value(), + NumNodes()}; } else { return SamplerArgs{ indices_, @@ -877,33 +856,31 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( }(); return SampleNeighborsImpl( seeds.value(), seed_offsets, fanouts, return_eids, - GetNumPickFn( - fanouts, replace, type_per_edge_, probs_or_mask, - with_seed_offsets), - GetPickFn( - fanouts, replace, indptr_.options(), type_per_edge_, - probs_or_mask, with_seed_offsets, args)); + GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask, + with_seed_offsets), + GetPickFn(fanouts, replace, indptr_.options(), type_per_edge_, + probs_or_mask, with_seed_offsets, args)); } } else { SamplerArgs args; return SampleNeighborsImpl( seeds.value(), seed_offsets, fanouts, return_eids, - GetNumPickFn( - fanouts, replace, type_per_edge_, probs_or_mask, with_seed_offsets), - GetPickFn( - fanouts, replace, indptr_.options(), type_per_edge_, probs_or_mask, - with_seed_offsets, args)); + GetNumPickFn(fanouts, replace, type_per_edge_, probs_or_mask, + with_seed_offsets), + GetPickFn(fanouts, replace, indptr_.options(), type_per_edge_, + probs_or_mask, with_seed_offsets, args)); } } c10::intrusive_ptr FusedCSCSamplingGraph::TemporalSampleNeighbors( - const torch::Tensor& input_nodes, - const torch::Tensor& input_nodes_timestamp, - const std::vector& fanouts, bool replace, bool layer, + const torch::Tensor &input_nodes, + const torch::Tensor &input_nodes_timestamp, + const std::vector &fanouts, bool replace, bool layer, bool return_eids, torch::optional probs_or_mask, torch::optional node_timestamp_attr_name, - torch::optional edge_timestamp_attr_name) const { + torch::optional edge_timestamp_attr_name, + torch::optional time_window) const { torch::optional> seed_offsets = torch::nullopt; // 1. Get probs_or_mask. if (probs_or_mask.has_value()) { @@ -919,6 +896,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( auto node_timestamp = this->NodeAttribute(node_timestamp_attr_name); // 3. Get the timestamp attribute for edges of the graph auto edge_timestamp = this->EdgeAttribute(edge_timestamp_attr_name); + // 4. Call SampleNeighborsImpl if (layer) { const int64_t random_seed = RandomEngine::ThreadLocal()->RandInt( @@ -926,29 +904,29 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors( SamplerArgs args{indices_, random_seed, NumNodes()}; return SampleNeighborsImpl( input_nodes, seed_offsets, fanouts, return_eids, - GetTemporalNumPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp), - GetTemporalPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp, - edge_timestamp, args)); + GetTemporalNumPickFn(input_nodes_timestamp, this->indices_, fanouts, + replace, type_per_edge_, probs_or_mask, + node_timestamp, edge_timestamp, time_window), + GetTemporalPickFn(input_nodes_timestamp, this->indices_, fanouts, + replace, indptr_.options(), type_per_edge_, + probs_or_mask, node_timestamp, edge_timestamp, + time_window, args)); } else { SamplerArgs args; return SampleNeighborsImpl( input_nodes, seed_offsets, fanouts, return_eids, - GetTemporalNumPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - type_per_edge_, probs_or_mask, node_timestamp, edge_timestamp), - GetTemporalPickFn( - input_nodes_timestamp, this->indices_, fanouts, replace, - indptr_.options(), type_per_edge_, probs_or_mask, node_timestamp, - edge_timestamp, args)); + GetTemporalNumPickFn(input_nodes_timestamp, this->indices_, fanouts, + replace, type_per_edge_, probs_or_mask, + node_timestamp, edge_timestamp, time_window), + GetTemporalPickFn(input_nodes_timestamp, this->indices_, fanouts, + replace, indptr_.options(), type_per_edge_, + probs_or_mask, node_timestamp, edge_timestamp, + time_window, args)); } } static c10::intrusive_ptr -BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) { +BuildGraphFromSharedMemoryHelper(SharedMemoryHelper &&helper) { helper.InitializeRead(); auto indptr = helper.ReadTorchTensor(); auto indices = helper.ReadTorchTensor(); @@ -962,14 +940,14 @@ BuildGraphFromSharedMemoryHelper(SharedMemoryHelper&& helper) { indptr.value(), indices.value(), node_type_offset, type_per_edge, node_type_to_id, edge_type_to_id, node_attributes, edge_attributes); auto shared_memory = helper.ReleaseSharedMemory(); - graph->HoldSharedMemoryObject( - std::move(shared_memory.first), std::move(shared_memory.second)); + graph->HoldSharedMemoryObject(std::move(shared_memory.first), + std::move(shared_memory.second)); return graph; } c10::intrusive_ptr FusedCSCSamplingGraph::CopyToSharedMemory( - const std::string& shared_memory_name) { + const std::string &shared_memory_name) { SharedMemoryHelper helper(shared_memory_name); helper.WriteTorchTensor(indptr_); helper.WriteTorchTensor(indices_); @@ -985,7 +963,7 @@ FusedCSCSamplingGraph::CopyToSharedMemory( c10::intrusive_ptr FusedCSCSamplingGraph::LoadFromSharedMemory( - const std::string& shared_memory_name) { + const std::string &shared_memory_name) { SharedMemoryHelper helper(shared_memory_name); return BuildGraphFromSharedMemoryHelper(std::move(helper)); } @@ -997,19 +975,19 @@ void FusedCSCSamplingGraph::HoldSharedMemoryObject( } template -void NumPick( - int64_t fanout, bool replace, - const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors, PickedNumType* picked_num_ptr) { +void NumPick(int64_t fanout, bool replace, + const torch::optional &probs_or_mask, + int64_t offset, int64_t num_neighbors, + PickedNumType *picked_num_ptr) { int64_t num_valid_neighbors = num_neighbors; if (probs_or_mask.has_value() && num_neighbors > 0) { // Subtract the count of zeros in probs_or_mask. AT_DISPATCH_ALL_TYPES( probs_or_mask.value().scalar_type(), "CountZero", ([&] { - scalar_t* probs_data_ptr = probs_or_mask.value().data_ptr(); - num_valid_neighbors -= std::count( - probs_data_ptr + offset, probs_data_ptr + offset + num_neighbors, - 0); + scalar_t *probs_data_ptr = probs_or_mask.value().data_ptr(); + num_valid_neighbors -= + std::count(probs_data_ptr + offset, + probs_data_ptr + offset + num_neighbors, 0); })); } if (num_valid_neighbors == 0 || fanout == -1) { @@ -1019,21 +997,26 @@ void NumPick( } } -torch::Tensor TemporalMask( - int64_t seed_timestamp, torch::Tensor csc_indices, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, - std::pair edge_range) { +torch::Tensor TemporalMask(int64_t seed_timestamp, torch::Tensor csc_indices, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + std::pair edge_range) { auto [l, r] = edge_range; torch::Tensor mask = torch::ones({r - l}, torch::kBool); if (node_timestamp.has_value()) { auto neighbor_timestamp = node_timestamp.value().index_select(0, csc_indices.slice(0, l, r)); mask &= neighbor_timestamp < seed_timestamp; + if (time_window.has_value()) + mask &= neighbor_timestamp > seed_timestamp - time_window.value(); } if (edge_timestamp.has_value()) { - mask &= edge_timestamp.value().slice(0, l, r) < seed_timestamp; + auto edge_ts = edge_timestamp.value().slice(0, l, r); + mask &= edge_ts < seed_timestamp; + if (time_window.has_value()) + mask &= edge_ts > seed_timestamp - time_window.value(); } if (probs_or_mask.has_value()) { mask &= probs_or_mask.value().slice(0, l, r) != 0; @@ -1047,11 +1030,13 @@ torch::Tensor TemporalMask( * the timestamp of the neighbors. It is successful if the number of sampled * neighbors in kTriedThreshold trials is equal to the fanout. */ -std::pair> FastTemporalPick( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, int64_t fanout, - bool replace, const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, int64_t seed_offset, - int64_t offset, int64_t num_neighbors) { +std::pair> +FastTemporalPick(torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t fanout, bool replace, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + int64_t seed_offset, int64_t offset, int64_t num_neighbors) { constexpr int64_t kTriedThreshold = 1000; auto timestamp = utils::GetValueByIndex(seed_timestamp, seed_offset); std::vector sampled_edges; @@ -1072,16 +1057,23 @@ std::pair> FastTemporalPick( csc_indices.scalar_type(), "CheckNodeTimeStamp", ([&] { int64_t neighbor_id = utils::GetValueByIndex(csc_indices, edge_id); - if (utils::GetValueByIndex( - node_timestamp.value(), neighbor_id) >= timestamp) + auto neighbor_ts = utils::GetValueByIndex( + node_timestamp.value(), neighbor_id); + if (neighbor_ts >= timestamp || + (time_window.has_value() && + neighbor_ts <= (timestamp - time_window.value()))) flag = false; })); - if (!flag) continue; + if (!flag) + continue; } - if (edge_timestamp.has_value() && - utils::GetValueByIndex(edge_timestamp.value(), edge_id) >= - timestamp) { - continue; + if (edge_timestamp.has_value()) { + auto edge_ts = + utils::GetValueByIndex(edge_timestamp.value(), edge_id); + if (edge_ts >= timestamp || + (time_window.has_value() && + edge_ts <= (timestamp - time_window.value()))) + continue; } if (!replace) { sampled_edge_set.insert(edge_id); @@ -1095,12 +1087,14 @@ std::pair> FastTemporalPick( return {true, sampled_edges}; } -int64_t TemporalNumPick( - torch::Tensor seed_timestamp, torch::Tensor csc_indics, int64_t fanout, - bool replace, const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, int64_t seed_offset, - int64_t offset, int64_t num_neighbors) { +int64_t TemporalNumPick(torch::Tensor seed_timestamp, torch::Tensor csc_indics, + int64_t fanout, bool replace, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + int64_t seed_offset, int64_t offset, + int64_t num_neighbors) { constexpr int64_t kFastPathThreshold = 1000; if (num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) { // TODO: Currently we use the fast path both in TemporalNumPick and @@ -1108,39 +1102,39 @@ int64_t TemporalNumPick( // sampled edges in TemporalPick to avoid sampling twice. auto [success, sampled_edges] = FastTemporalPick( seed_timestamp, csc_indics, fanout, replace, node_timestamp, - edge_timestamp, seed_offset, offset, num_neighbors); - if (success) return sampled_edges.size(); + edge_timestamp, time_window, seed_offset, offset, num_neighbors); + if (success) + return sampled_edges.size(); } - auto mask = TemporalMask( - utils::GetValueByIndex(seed_timestamp, seed_offset), csc_indics, - probs_or_mask, node_timestamp, edge_timestamp, - {offset, offset + num_neighbors}); + auto mask = + TemporalMask(utils::GetValueByIndex(seed_timestamp, seed_offset), + csc_indics, probs_or_mask, node_timestamp, edge_timestamp, + time_window, {offset, offset + num_neighbors}); int64_t num_valid_neighbors = utils::GetValueByIndex(mask.sum(), 0); - if (num_valid_neighbors == 0 || fanout == -1) return num_valid_neighbors; + if (num_valid_neighbors == 0 || fanout == -1) + return num_valid_neighbors; return replace ? fanout : std::min(fanout, num_valid_neighbors); } template -void NumPickByEtype( - bool with_seed_offsets, const std::vector& fanouts, bool replace, - const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, int64_t offset, - int64_t num_neighbors, PickedNumType* num_picked_ptr, int64_t seed_index, - const std::vector& etype_id_to_num_picked_offset) { +void NumPickByEtype(bool with_seed_offsets, const std::vector &fanouts, + bool replace, const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + int64_t offset, int64_t num_neighbors, + PickedNumType *num_picked_ptr, int64_t seed_index, + const std::vector &etype_id_to_num_picked_offset) { int64_t etype_begin = offset; const int64_t end = offset + num_neighbors; PickedNumType total_count = 0; AT_DISPATCH_INTEGRAL_TYPES( type_per_edge.scalar_type(), "NumPickFnByEtype", ([&] { - const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); + const scalar_t *type_per_edge_data = type_per_edge.data_ptr(); while (etype_begin < end) { scalar_t etype = type_per_edge_data[etype_begin]; - TORCH_CHECK( - etype >= 0 && etype < (int64_t)fanouts.size(), - "Etype values exceed the number of fanouts."); - auto etype_end_it = std::upper_bound( - type_per_edge_data + etype_begin, type_per_edge_data + end, - etype); + TORCH_CHECK(etype >= 0 && etype < (int64_t)fanouts.size(), + "Etype values exceed the number of fanouts."); + auto etype_end_it = std::upper_bound(type_per_edge_data + etype_begin, + type_per_edge_data + end, etype); int64_t etype_end = etype_end_it - type_per_edge_data; // Do sampling for one etype. if (with_seed_offsets) { @@ -1148,14 +1142,12 @@ void NumPickByEtype( // each different etype. const auto offset = etype_id_to_num_picked_offset[etype] + seed_index; - NumPick( - fanouts[etype], replace, probs_or_mask, etype_begin, - etype_end - etype_begin, num_picked_ptr + offset); + NumPick(fanouts[etype], replace, probs_or_mask, etype_begin, + etype_end - etype_begin, num_picked_ptr + offset); } else { PickedNumType picked_count = 0; - NumPick( - fanouts[etype], replace, probs_or_mask, etype_begin, - etype_end - etype_begin, &picked_count); + NumPick(fanouts[etype], replace, probs_or_mask, etype_begin, + etype_end - etype_begin, &picked_count); total_count += picked_count; } etype_begin = etype_end; @@ -1166,34 +1158,34 @@ void NumPickByEtype( } } -int64_t TemporalNumPickByEtype( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, - const std::vector& fanouts, bool replace, - const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, int64_t seed_offset, - int64_t offset, int64_t num_neighbors) { +int64_t +TemporalNumPickByEtype(torch::Tensor seed_timestamp, torch::Tensor csc_indices, + const std::vector &fanouts, bool replace, + const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + int64_t seed_offset, int64_t offset, + int64_t num_neighbors) { int64_t etype_begin = offset; const int64_t end = offset + num_neighbors; int64_t total_count = 0; AT_DISPATCH_INTEGRAL_TYPES( type_per_edge.scalar_type(), "TemporalNumPickFnByEtype", ([&] { - const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); + const scalar_t *type_per_edge_data = type_per_edge.data_ptr(); while (etype_begin < end) { scalar_t etype = type_per_edge_data[etype_begin]; - TORCH_CHECK( - etype >= 0 && etype < (int64_t)fanouts.size(), - "Etype values exceed the number of fanouts."); - auto etype_end_it = std::upper_bound( - type_per_edge_data + etype_begin, type_per_edge_data + end, - etype); + TORCH_CHECK(etype >= 0 && etype < (int64_t)fanouts.size(), + "Etype values exceed the number of fanouts."); + auto etype_end_it = std::upper_bound(type_per_edge_data + etype_begin, + type_per_edge_data + end, etype); int64_t etype_end = etype_end_it - type_per_edge_data; // Do sampling for one etype. total_count += TemporalNumPick( seed_timestamp, csc_indices, fanouts[etype], replace, - probs_or_mask, node_timestamp, edge_timestamp, seed_offset, - etype_begin, etype_end - etype_begin); + probs_or_mask, node_timestamp, edge_timestamp, time_window, + seed_offset, etype_begin, etype_end - etype_begin); etype_begin = etype_end; } })); @@ -1221,9 +1213,9 @@ int64_t TemporalNumPickByEtype( * should be put. Enough memory space should be allocated in advance. */ template -inline int64_t UniformPick( - int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, - const torch::TensorOptions& options, PickedType* picked_data_ptr) { +inline int64_t +UniformPick(int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, + const torch::TensorOptions &options, PickedType *picked_data_ptr) { if ((fanout == -1) || (num_neighbors <= fanout && !replace)) { std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset); return num_neighbors; @@ -1282,13 +1274,14 @@ inline int64_t UniformPick( while (begin != end) { // Put the new random number in the last position. - *begin = RandomEngine::ThreadLocal()->RandInt( - offset, offset + num_neighbors); + *begin = RandomEngine::ThreadLocal()->RandInt(offset, + offset + num_neighbors); // Check if a new value doesn't exist in current // range(picked_data_ptr, begin). Otherwise get a new // value until we haven't unique range of elements. auto it = std::find(picked_data_ptr, begin, *begin); - if (it == begin) ++begin; + if (it == begin) + ++begin; } return fanout; } else { @@ -1320,16 +1313,19 @@ inline int64_t UniformPick( } /** @brief An operator to perform non-uniform sampling. */ -static torch::Tensor NonUniformPickOp( - torch::Tensor probs, int64_t fanout, bool replace) { +static torch::Tensor NonUniformPickOp(torch::Tensor probs, int64_t fanout, + bool replace) { auto positive_probs_indices = probs.nonzero().squeeze(1); auto num_positive_probs = positive_probs_indices.size(0); - if (num_positive_probs == 0) return torch::empty({0}, torch::kLong); + if (num_positive_probs == 0) + return torch::empty({0}, torch::kLong); if ((fanout == -1) || (num_positive_probs <= fanout && !replace)) { return positive_probs_indices; } - if (!replace) fanout = std::min(fanout, num_positive_probs); - if (fanout == 0) return torch::empty({0}, torch::kLong); + if (!replace) + fanout = std::min(fanout, num_positive_probs); + if (fanout == 0) + return torch::empty({0}, torch::kLong); auto ret_tensor = torch::empty({fanout}, torch::kLong); auto ret_ptr = ret_tensor.data_ptr(); AT_DISPATCH_FLOATING_TYPES( @@ -1372,15 +1368,15 @@ static torch::Tensor NonUniformPickOp( } if (fanout < num_positive_probs / 64) { // Use partial_sort. - std::partial_sort( - q.begin(), q.begin() + fanout, q.end(), std::greater{}); + std::partial_sort(q.begin(), q.begin() + fanout, q.end(), + std::greater{}); for (auto i = 0; i < fanout; ++i) { ret_ptr[i] = q[i].second; } } else { // Use nth_element. - std::nth_element( - q.begin(), q.begin() + fanout - 1, q.end(), std::greater{}); + std::nth_element(q.begin(), q.begin() + fanout - 1, q.end(), + std::greater{}); for (auto i = 0; i < fanout; ++i) { ret_ptr[i] = q[i].second; } @@ -1405,10 +1401,10 @@ static torch::Tensor NonUniformPickOp( double uniform_sample = RandomEngine::ThreadLocal()->Uniform(0., 1.); // Use a binary search to find the index. - int sampled_index = std::lower_bound( - prefix_sum_probs.begin(), - prefix_sum_probs.end(), uniform_sample) - - prefix_sum_probs.begin(); + int sampled_index = + std::lower_bound(prefix_sum_probs.begin(), + prefix_sum_probs.end(), uniform_sample) - + prefix_sum_probs.begin(); ret_ptr[i] = positive_probs_indices_ptr[sampled_index]; } } @@ -1452,10 +1448,11 @@ static torch::Tensor NonUniformPickOp( * should be put. Enough memory space should be allocated in advance. */ template -inline int64_t NonUniformPick( - int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, - const torch::TensorOptions& options, const torch::Tensor& probs_or_mask, - PickedType* picked_data_ptr) { +inline int64_t NonUniformPick(int64_t offset, int64_t num_neighbors, + int64_t fanout, bool replace, + const torch::TensorOptions &options, + const torch::Tensor &probs_or_mask, + PickedType *picked_data_ptr) { auto local_probs = probs_or_mask.size(0) > num_neighbors ? probs_or_mask.slice(0, offset, offset + num_neighbors) @@ -1470,37 +1467,38 @@ inline int64_t NonUniformPick( } template -int64_t Pick( - int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, - const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - SamplerArgs args, PickedType* picked_data_ptr) { - if (fanout == 0 || num_neighbors == 0) return 0; +int64_t Pick(int64_t offset, int64_t num_neighbors, int64_t fanout, + bool replace, const torch::TensorOptions &options, + const torch::optional &probs_or_mask, + SamplerArgs args, + PickedType *picked_data_ptr) { + if (fanout == 0 || num_neighbors == 0) + return 0; if (probs_or_mask.has_value()) { - return NonUniformPick( - offset, num_neighbors, fanout, replace, options, probs_or_mask.value(), - picked_data_ptr); + return NonUniformPick(offset, num_neighbors, fanout, replace, options, + probs_or_mask.value(), picked_data_ptr); } else { - return UniformPick( - offset, num_neighbors, fanout, replace, options, picked_data_ptr); + return UniformPick(offset, num_neighbors, fanout, replace, options, + picked_data_ptr); } } template -int64_t TemporalPick( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, - int64_t seed_offset, int64_t offset, int64_t num_neighbors, int64_t fanout, - bool replace, const torch::TensorOptions& options, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, SamplerArgs args, - PickedType* picked_data_ptr) { +int64_t TemporalPick(torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t seed_offset, int64_t offset, int64_t num_neighbors, + int64_t fanout, bool replace, + const torch::TensorOptions &options, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + const torch::optional &time_window, + SamplerArgs args, PickedType *picked_data_ptr) { constexpr int64_t kFastPathThreshold = 1000; if (S == SamplerType::NEIGHBOR && num_neighbors > kFastPathThreshold && !probs_or_mask.has_value()) { auto [success, sampled_edges] = FastTemporalPick( seed_timestamp, csc_indices, fanout, replace, node_timestamp, - edge_timestamp, seed_offset, offset, num_neighbors); + edge_timestamp, time_window, seed_offset, offset, num_neighbors); if (success) { for (size_t i = 0; i < sampled_edges.size(); ++i) { picked_data_ptr[i] = static_cast(sampled_edges[i]); @@ -1508,10 +1506,10 @@ int64_t TemporalPick( return sampled_edges.size(); } } - auto mask = TemporalMask( - utils::GetValueByIndex(seed_timestamp, seed_offset), csc_indices, - probs_or_mask, node_timestamp, edge_timestamp, - {offset, offset + num_neighbors}); + auto mask = + TemporalMask(utils::GetValueByIndex(seed_timestamp, seed_offset), + csc_indices, probs_or_mask, node_timestamp, edge_timestamp, + time_window, {offset, offset + num_neighbors}); torch::Tensor masked_prob; if (probs_or_mask.has_value()) { masked_prob = @@ -1529,37 +1527,34 @@ int64_t TemporalPick( return picked_indices.numel(); } if constexpr (is_labor(S)) { - return Pick( - offset, num_neighbors, fanout, replace, options, masked_prob, args, - picked_data_ptr); + return Pick(offset, num_neighbors, fanout, replace, options, masked_prob, + args, picked_data_ptr); } } template -int64_t PickByEtype( - bool with_seed_offsets, int64_t offset, int64_t num_neighbors, - const std::vector& fanouts, bool replace, - const torch::TensorOptions& options, const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr, int64_t seed_index, - PickedType* subgraph_indptr_ptr, - const std::vector& etype_id_to_num_picked_offset) { +int64_t PickByEtype(bool with_seed_offsets, int64_t offset, + int64_t num_neighbors, const std::vector &fanouts, + bool replace, const torch::TensorOptions &options, + const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + SamplerArgs args, PickedType *picked_data_ptr, + int64_t seed_index, PickedType *subgraph_indptr_ptr, + const std::vector &etype_id_to_num_picked_offset) { int64_t etype_begin = offset; int64_t etype_end = offset; int64_t picked_total_count = 0; AT_DISPATCH_INTEGRAL_TYPES( type_per_edge.scalar_type(), "PickByEtype", ([&] { - const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); + const scalar_t *type_per_edge_data = type_per_edge.data_ptr(); const auto end = offset + num_neighbors; while (etype_begin < end) { scalar_t etype = type_per_edge_data[etype_begin]; - TORCH_CHECK( - etype >= 0 && etype < (int64_t)fanouts.size(), - "Etype values exceed the number of fanouts."); + TORCH_CHECK(etype >= 0 && etype < (int64_t)fanouts.size(), + "Etype values exceed the number of fanouts."); int64_t fanout = fanouts[etype]; - auto etype_end_it = std::upper_bound( - type_per_edge_data + etype_begin, type_per_edge_data + end, - etype); + auto etype_end_it = std::upper_bound(type_per_edge_data + etype_begin, + type_per_edge_data + end, etype); etype_end = etype_end_it - type_per_edge_data; // Do sampling for one etype. The picked nodes aren't stored // continuously, but separately for each different etype. @@ -1568,22 +1563,21 @@ int64_t PickByEtype( if (with_seed_offsets) { const auto indptr_offset = etype_id_to_num_picked_offset[etype] + seed_index; - picked_count = Pick( - etype_begin, etype_end - etype_begin, fanout, replace, - options, probs_or_mask, args, - picked_data_ptr + subgraph_indptr_ptr[indptr_offset]); - TORCH_CHECK( - subgraph_indptr_ptr[indptr_offset + 1] - - subgraph_indptr_ptr[indptr_offset] == - picked_count, - "Actual picked count doesn't match the calculated " - "pick number."); + picked_count = + Pick(etype_begin, etype_end - etype_begin, fanout, replace, + options, probs_or_mask, args, + picked_data_ptr + subgraph_indptr_ptr[indptr_offset]); + TORCH_CHECK(subgraph_indptr_ptr[indptr_offset + 1] - + subgraph_indptr_ptr[indptr_offset] == + picked_count, + "Actual picked count doesn't match the calculated " + "pick number."); } else { - picked_count = Pick( - etype_begin, etype_end - etype_begin, fanout, replace, - options, probs_or_mask, args, - picked_data_ptr + subgraph_indptr_ptr[seed_index] + - picked_total_count); + picked_count = + Pick(etype_begin, etype_end - etype_begin, fanout, replace, + options, probs_or_mask, args, + picked_data_ptr + subgraph_indptr_ptr[seed_index] + + picked_total_count); } picked_total_count += picked_count; } @@ -1594,39 +1588,39 @@ int64_t PickByEtype( } template -int64_t TemporalPickByEtype( - torch::Tensor seed_timestamp, torch::Tensor csc_indices, - int64_t seed_offset, int64_t offset, int64_t num_neighbors, - const std::vector& fanouts, bool replace, - const torch::TensorOptions& options, const torch::Tensor& type_per_edge, - const torch::optional& probs_or_mask, - const torch::optional& node_timestamp, - const torch::optional& edge_timestamp, SamplerArgs args, - PickedType* picked_data_ptr) { +int64_t +TemporalPickByEtype(torch::Tensor seed_timestamp, torch::Tensor csc_indices, + int64_t seed_offset, int64_t offset, int64_t num_neighbors, + const std::vector &fanouts, bool replace, + const torch::TensorOptions &options, + const torch::Tensor &type_per_edge, + const torch::optional &probs_or_mask, + const torch::optional &node_timestamp, + const torch::optional &edge_timestamp, + torch::optional time_window, SamplerArgs args, + PickedType *picked_data_ptr) { int64_t etype_begin = offset; int64_t etype_end = offset; int64_t pick_offset = 0; AT_DISPATCH_INTEGRAL_TYPES( type_per_edge.scalar_type(), "TemporalPickByEtype", ([&] { - const scalar_t* type_per_edge_data = type_per_edge.data_ptr(); + const scalar_t *type_per_edge_data = type_per_edge.data_ptr(); const auto end = offset + num_neighbors; while (etype_begin < end) { scalar_t etype = type_per_edge_data[etype_begin]; - TORCH_CHECK( - etype >= 0 && etype < (int64_t)fanouts.size(), - "Etype values exceed the number of fanouts."); + TORCH_CHECK(etype >= 0 && etype < (int64_t)fanouts.size(), + "Etype values exceed the number of fanouts."); int64_t fanout = fanouts[etype]; - auto etype_end_it = std::upper_bound( - type_per_edge_data + etype_begin, type_per_edge_data + end, - etype); + auto etype_end_it = std::upper_bound(type_per_edge_data + etype_begin, + type_per_edge_data + end, etype); etype_end = etype_end_it - type_per_edge_data; // Do sampling for one etype. if (fanout != 0) { int64_t picked_count = TemporalPick( seed_timestamp, csc_indices, seed_offset, etype_begin, etype_end - etype_begin, fanout, replace, options, - probs_or_mask, node_timestamp, edge_timestamp, args, - picked_data_ptr + pick_offset); + probs_or_mask, node_timestamp, edge_timestamp, time_window, + args, picked_data_ptr + pick_offset); pick_offset += picked_count; } etype_begin = etype_end; @@ -1636,17 +1630,17 @@ int64_t TemporalPickByEtype( } template -std::enable_if_t Pick( - int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, - const torch::TensorOptions& options, - const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr) { - if (fanout == 0 || num_neighbors == 0) return 0; +std::enable_if_t +Pick(int64_t offset, int64_t num_neighbors, int64_t fanout, bool replace, + const torch::TensorOptions &options, + const torch::optional &probs_or_mask, SamplerArgs args, + PickedType *picked_data_ptr) { + if (fanout == 0 || num_neighbors == 0) + return 0; if (probs_or_mask.has_value()) { if (fanout < 0) { - return NonUniformPick( - offset, num_neighbors, fanout, replace, options, - probs_or_mask.value(), picked_data_ptr); + return NonUniformPick(offset, num_neighbors, fanout, replace, options, + probs_or_mask.value(), picked_data_ptr); } else { int64_t picked_count; AT_DISPATCH_FLOATING_TYPES( @@ -1664,42 +1658,40 @@ std::enable_if_t Pick( return picked_count; } } else if (fanout < 0) { - return UniformPick( - offset, num_neighbors, fanout, replace, options, picked_data_ptr); + return UniformPick(offset, num_neighbors, fanout, replace, options, + picked_data_ptr); } else if (replace) { - return LaborPick( - offset, num_neighbors, fanout, options, - /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr); - } else { // replace = false + return LaborPick(offset, num_neighbors, fanout, options, + /* probs_or_mask= */ torch::nullopt, + args, picked_data_ptr); + } else { // replace = false return LaborPick( offset, num_neighbors, fanout, options, /* probs_or_mask= */ torch::nullopt, args, picked_data_ptr); } } -template -inline void safe_divide(T& a, U b) { +template inline void safe_divide(T &a, U b) { a = b > 0 ? (T)(a / b) : std::numeric_limits::infinity(); } namespace labor { -template -inline T invcdf(T u, int64_t n, T rem) { +template inline T invcdf(T u, int64_t n, T rem) { constexpr T one = 1; return rem * (one - std::pow(one - u, one / n)); } template -inline T jth_sorted_uniform_random( - seed_t seed, int64_t t, int64_t c, int64_t j, T& rem, int64_t n) { +inline T jth_sorted_uniform_random(seed_t seed, int64_t t, int64_t c, int64_t j, + T &rem, int64_t n) { const T u = seed.uniform(t + j * c); // https://mathematica.stackexchange.com/a/256707 rem -= invcdf(u, n, rem); return 1 - rem; } -}; // namespace labor +}; // namespace labor /** * @brief Perform uniform-nonuniform sampling of elements depending on the @@ -1725,14 +1717,13 @@ inline T jth_sorted_uniform_random( * @param picked_data_ptr The destination address where the picked neighbors * should be put. Enough memory space should be allocated in advance. */ -template < - bool NonUniform, bool Replace, typename ProbsType, SamplerType S, - typename PickedType, int StackSize> -inline std::enable_if_t LaborPick( - int64_t offset, int64_t num_neighbors, int64_t fanout, - const torch::TensorOptions& options, - const torch::optional& probs_or_mask, SamplerArgs args, - PickedType* picked_data_ptr) { +template +inline std::enable_if_t +LaborPick(int64_t offset, int64_t num_neighbors, int64_t fanout, + const torch::TensorOptions &options, + const torch::optional &probs_or_mask, + SamplerArgs args, PickedType *picked_data_ptr) { fanout = Replace ? fanout : std::min(fanout, num_neighbors); if (!NonUniform && !Replace && fanout >= num_neighbors) { std::iota(picked_data_ptr, picked_data_ptr + num_neighbors, offset); @@ -1745,10 +1736,10 @@ inline std::enable_if_t LaborPick( if (fanout > StackSize) { constexpr int factor = sizeof(heap_data[0]) / sizeof(int32_t); heap_tensor = torch::empty({fanout * factor}, torch::kInt32); - heap_data = reinterpret_cast*>( + heap_data = reinterpret_cast *>( heap_tensor.data_ptr()); } - const ProbsType* local_probs_data = + const ProbsType *local_probs_data = NonUniform ? probs_or_mask.value().data_ptr() + offset : nullptr; if (NonUniform && probs_or_mask.value().size(0) <= num_neighbors) { @@ -1757,7 +1748,7 @@ inline std::enable_if_t LaborPick( AT_DISPATCH_INDEX_TYPES( args.indices.scalar_type(), "LaborPickMain", ([&] { const auto local_indices_data = - reinterpret_cast(args.indices.data_ptr()) + offset; + reinterpret_cast(args.indices.data_ptr()) + offset; if constexpr (Replace) { // [Algorithm] @mfbalin // Use a max-heap to get rid of the big random numbers and filter the @@ -1794,10 +1785,10 @@ inline std::enable_if_t LaborPick( [&](index_t t, int64_t j, uint32_t i) { auto rnd = labor::jth_sorted_uniform_random( args.random_seed, t, args.num_nodes, j, remaining_data[i], - fanout - j); // r_t + fanout - j); // r_t if constexpr (NonUniform) { safe_divide(rnd, local_probs_data[i]); - } // r_t / \pi_t + } // r_t / \pi_t if (heap_end < heap_data + fanout) { heap_end[0] = std::make_pair(rnd, i); if (++heap_end >= heap_data + fanout) { @@ -1821,10 +1812,12 @@ inline std::enable_if_t LaborPick( } } for (uint32_t i = 0; i < num_neighbors; ++i) { - if (remaining_data[i] == -1) continue; + if (remaining_data[i] == -1) + continue; const auto t = local_indices_data[i]; for (int64_t j = init_count; j < fanout; ++j) { - if (sample_neighbor_i_with_index_t_jth_time(t, j, i)) break; + if (sample_neighbor_i_with_index_t_jth_time(t, j, i)) + break; } } } else { @@ -1847,10 +1840,10 @@ inline std::enable_if_t LaborPick( // O(num_neighbors). for (uint32_t i = 0; i < fanout; ++i) { const auto t = local_indices_data[i]; - auto rnd = args.random_seed.uniform(t); // r_t + auto rnd = args.random_seed.uniform(t); // r_t if constexpr (NonUniform) { safe_divide(rnd, local_probs_data[i]); - } // r_t / \pi_t + } // r_t / \pi_t heap_data[i] = std::make_pair(rnd, i); } if (!NonUniform || fanout < num_neighbors) { @@ -1858,10 +1851,10 @@ inline std::enable_if_t LaborPick( } for (uint32_t i = fanout; i < num_neighbors; ++i) { const auto t = local_indices_data[i]; - auto rnd = args.random_seed.uniform(t); // r_t + auto rnd = args.random_seed.uniform(t); // r_t if constexpr (NonUniform) { safe_divide(rnd, local_probs_data[i]); - } // r_t / \pi_t + } // r_t / \pi_t if (rnd < heap_data[0].first) { std::pop_heap(heap_data, heap_data + fanout); heap_data[fanout - 1] = std::make_pair(rnd, i); @@ -1880,5 +1873,5 @@ inline std::enable_if_t LaborPick( return num_sampled; } -} // namespace sampling -} // namespace graphbolt +} // namespace sampling +} // namespace graphbolt diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 16f2a145287f..cd15f82d6a7e 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -966,6 +966,7 @@ def temporal_sample_neighbors( probs_name: Optional[str] = None, node_timestamp_attr_name: Optional[str] = None, edge_timestamp_attr_name: Optional[str] = None, + time_window: Optional[int] = None, ) -> torch.ScriptObject: """Temporally Sample neighboring edges of the given nodes and return the induced subgraph. @@ -1039,6 +1040,7 @@ def temporal_sample_neighbors( probs_or_mask, node_timestamp_attr_name, edge_timestamp_attr_name, + time_window, ) return self._convert_to_sampled_subgraph(C_sampled_subgraph) diff --git a/python/dgl/graphbolt/impl/temporal_neighbor_sampler.py b/python/dgl/graphbolt/impl/temporal_neighbor_sampler.py index 4560a51c631b..1be1f3beaaa6 100644 --- a/python/dgl/graphbolt/impl/temporal_neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/temporal_neighbor_sampler.py @@ -1,4 +1,5 @@ """Temporal neighbor subgraph samplers for GraphBolt.""" + import torch from torch.utils.data import functional_datapipe @@ -59,6 +60,9 @@ class TemporalNeighborSampler(SubgraphSampler): The name of an edge attribute used as the timestamps of edges. It must be a 1D integer tensor, with the number of elements equalling the total number of edges. + time_window: int, optional + A duration before a seed timestamp, within which target nodes will be + filtered. Examples ------- @@ -74,6 +78,7 @@ def __init__( prob_name=None, node_timestamp_attr_name=None, edge_timestamp_attr_name=None, + time_window=None, ): super().__init__(datapipe) self.graph = graph @@ -87,6 +92,7 @@ def __init__( self.prob_name = prob_name self.node_timestamp_attr_name = node_timestamp_attr_name self.edge_timestamp_attr_name = edge_timestamp_attr_name + self.time_window = time_window self.sampler = graph.temporal_sample_neighbors def sample_subgraphs(self, seeds, seeds_timestamp): @@ -122,6 +128,7 @@ def sample_subgraphs(self, seeds, seeds_timestamp): self.prob_name, self.node_timestamp_attr_name, self.edge_timestamp_attr_name, + self.time_window, ) ( original_row_node_ids,