Skip to content

Commit

Permalink
Merge branch 'master' into clickhouse_datasource
Browse files Browse the repository at this point in the history
  • Loading branch information
jecsand838 authored Dec 3, 2024
2 parents 5be7b33 + a00a788 commit 29b20e3
Show file tree
Hide file tree
Showing 19 changed files with 404 additions and 204 deletions.
91 changes: 5 additions & 86 deletions python/ray/data/_internal/datasource/sql_datasource.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import math
from contextlib import contextmanager
from typing import Any, Callable, Iterable, Iterator, List, Optional

from ray.data.block import Block, BlockAccessor, BlockMetadata
from ray.data.block import Block, BlockMetadata
from ray.data.datasource.datasource import Datasource, ReadTask

Connection = Any # A Python DB API2-compliant `Connection` object.
Expand Down Expand Up @@ -72,99 +71,19 @@ def _connect(connection_factory: Callable[[], Connection]) -> Iterator[Cursor]:


class SQLDatasource(Datasource):

NUM_SAMPLE_ROWS = 100
MIN_ROWS_PER_READ_TASK = 50

def __init__(self, sql: str, connection_factory: Callable[[], Connection]):
self.sql = sql
self.connection_factory = connection_factory

def estimate_inmemory_data_size(self) -> Optional[int]:
pass
return None

def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
def fallback_read_fn() -> Iterable[Block]:
with _connect(self.connection_factory) as cursor:
cursor.execute(self.sql)
block = _cursor_to_block(cursor)
return [block]

# If `parallelism` is 1, directly fetch all rows. This avoids unnecessary
# queries to fetch a sample block and compute the total number of rows.
if parallelism == 1:
metadata = BlockMetadata(None, None, None, None, None)
return [ReadTask(fallback_read_fn, metadata)]

# Databases like DB2, Oracle, and MS SQL Server don't support `LIMIT`.
try:
with _connect(self.connection_factory) as cursor:
cursor.execute(f"SELECT * FROM ({self.sql}) as T LIMIT 1 OFFSET 0")
is_limit_supported = True
except Exception:
is_limit_supported = False

if not is_limit_supported:
metadata = BlockMetadata(None, None, None, None, None)
return [ReadTask(fallback_read_fn, metadata)]

num_rows_total = self._get_num_rows()

if num_rows_total == 0:
return []

parallelism = min(
parallelism, math.ceil(num_rows_total / self.MIN_ROWS_PER_READ_TASK)
)
num_rows_per_block = num_rows_total // parallelism
num_blocks_with_extra_row = num_rows_total % parallelism

sample_block_accessor = BlockAccessor.for_block(self._get_sample_block())
estimated_size_bytes_per_row = math.ceil(
sample_block_accessor.size_bytes() / sample_block_accessor.num_rows()
)
sample_block_schema = sample_block_accessor.schema()

tasks = []
offset = 0
for i in range(parallelism):
num_rows = num_rows_per_block
if i < num_blocks_with_extra_row:
num_rows += 1

read_fn = self._create_read_fn(num_rows, offset)
metadata = BlockMetadata(
num_rows,
estimated_size_bytes_per_row * num_rows,
sample_block_schema,
None,
None,
)
tasks.append(ReadTask(read_fn, metadata))

offset += num_rows

return tasks

def _get_num_rows(self) -> int:
with _connect(self.connection_factory) as cursor:
cursor.execute(f"SELECT COUNT(*) FROM ({self.sql}) as T")
return cursor.fetchone()[0]

def _get_sample_block(self) -> Block:
with _connect(self.connection_factory) as cursor:
cursor.execute(
f"SELECT * FROM ({self.sql}) as T LIMIT {self.NUM_SAMPLE_ROWS}"
)
return _cursor_to_block(cursor)

def _create_read_fn(self, num_rows: int, offset: int):
def read_fn() -> Iterable[Block]:
with _connect(self.connection_factory) as cursor:
cursor.execute(
f"SELECT * FROM ({self.sql}) as T LIMIT {num_rows} OFFSET {offset}"
)
cursor.execute(self.sql)
block = _cursor_to_block(cursor)
return [block]

return read_fn
metadata = BlockMetadata(None, None, None, None, None)
return [ReadTask(read_fn, metadata)]
6 changes: 6 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2159,6 +2159,12 @@ def create_connection():
Returns:
A :class:`Dataset` containing the queried data.
"""
if parallelism != -1 and parallelism != 1:
raise ValueError(
"To ensure correctness, 'read_sql' always launches one task. The "
"'parallelism' argument you specified can't be used."
)

datasource = SQLDatasource(sql=sql, connection_factory=connection_factory)
return read_datasource(
datasource,
Expand Down
11 changes: 8 additions & 3 deletions python/ray/data/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@ def temp_database_fixture() -> Generator[str, None, None]:
yield file.name


@pytest.mark.parametrize("parallelism", [-1, 1])
def test_read_sql(temp_database: str, parallelism: int):
def test_read_sql_with_parallelism_warns(temp_database):
with pytest.raises(ValueError):
ray.data.read_sql(
"SELECT * FROM movie", lambda: sqlite3.connect(temp_database), parallelism=2
)


def test_read_sql(temp_database: str):
connection = sqlite3.connect(temp_database)
connection.execute("CREATE TABLE movie(title, year, score)")
expected_values = [
Expand All @@ -37,7 +43,6 @@ def test_read_sql(temp_database: str, parallelism: int):
dataset = ray.data.read_sql(
"SELECT * FROM movie",
lambda: sqlite3.connect(temp_database),
override_num_blocks=parallelism,
)
actual_values = [tuple(record.values()) for record in dataset.take_all()]

Expand Down
5 changes: 5 additions & 0 deletions src/mock/ray/raylet_client/raylet_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ class MockRayletClientInterface : public RayletClientInterface {
(const TaskID &task_id,
const rpc::ClientCallback<rpc::GetTaskFailureCauseReply> &callback),
(override));
MOCK_METHOD(void,
PrestartWorkers,
(const rpc::PrestartWorkersRequest &request,
const rpc::ClientCallback<ray::rpc::PrestartWorkersReply> &callback),
(override));
MOCK_METHOD(void,
ReleaseUnusedActorWorkers,
(const std::vector<WorkerID> &workers_in_use,
Expand Down
20 changes: 20 additions & 0 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#include "ray/core_worker/core_worker.h"

#include <future>

#ifndef _WIN32
#include <unistd.h>
#endif
Expand Down Expand Up @@ -2293,6 +2295,24 @@ void CoreWorker::BuildCommonTaskSpec(
}
}

void CoreWorker::PrestartWorkers(const std::string &serialized_runtime_env_info,
uint64_t keep_alive_duration_secs,
size_t num_workers) {
rpc::PrestartWorkersRequest request;
request.set_language(GetLanguage());
request.set_job_id(GetCurrentJobId().Binary());
*request.mutable_runtime_env_info() =
*OverrideTaskOrActorRuntimeEnvInfo(serialized_runtime_env_info);
request.set_keep_alive_duration_secs(keep_alive_duration_secs);
request.set_num_workers(num_workers);
local_raylet_client_->PrestartWorkers(
request, [](const Status &status, const rpc::PrestartWorkersReply &reply) {
if (!status.ok()) {
RAY_LOG(INFO) << "Failed to prestart workers: " << status.ToString();
}
});
}

std::vector<rpc::ObjectReference> CoreWorker::SubmitTask(
const RayFunction &function,
const std::vector<std::unique_ptr<TaskArg>> &args,
Expand Down
13 changes: 12 additions & 1 deletion src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
#include "ray/pubsub/subscriber.h"
#include "ray/raylet_client/raylet_client.h"
#include "ray/rpc/node_manager/node_manager_client.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "ray/rpc/worker/core_worker_server.h"
#include "ray/util/process.h"
#include "src/ray/protobuf/pubsub.pb.h"
Expand Down Expand Up @@ -825,6 +824,18 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
const std::string &error_message,
double timestamp);

// Prestart workers. The workers:
// - uses current language.
// - uses current JobID.
// - does NOT support root_detached_actor_id.
// - uses provided runtime_env_info applied to the job runtime env, as if it's a task
// request.
//
// This API is async. It provides no guarantee that the workers are actually started.
void PrestartWorkers(const std::string &serialized_runtime_env_info,
uint64_t keep_alive_duration_secs,
size_t num_workers);

/// Submit a normal task.
///
/// \param[in] function The remote function to execute.
Expand Down
5 changes: 5 additions & 0 deletions src/ray/core_worker/test/normal_task_submitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ class MockRayletClient : public WorkerLeaseInterface {
}
callbacks.push_back(callback);
}
void PrestartWorkers(
const rpc::PrestartWorkersRequest &request,
const rpc::ClientCallback<ray::rpc::PrestartWorkersReply> &callback) override {
RAY_LOG(FATAL) << "Not implemented";
}

void ReleaseUnusedActorWorkers(
const std::vector<WorkerID> &workers_in_use,
Expand Down
6 changes: 6 additions & 0 deletions src/ray/gcs/gcs_server/test/gcs_server_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ struct GcsServerMocker {
callbacks.push_back(callback);
}

void PrestartWorkers(
const rpc::PrestartWorkersRequest &request,
const rpc::ClientCallback<ray::rpc::PrestartWorkersReply> &callback) override {
RAY_LOG(FATAL) << "Not implemented";
}

/// WorkerLeaseInterface
void ReleaseUnusedActorWorkers(
const std::vector<WorkerID> &workers_in_use,
Expand Down
3 changes: 2 additions & 1 deletion src/ray/protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ proto_library(
":autoscaler_proto",
":common_proto",
":gcs_proto",
":runtime_env_common_proto",
],
)

Expand Down Expand Up @@ -249,7 +250,7 @@ proto_library(
name = "export_event_proto",
srcs = ["export_api/export_event.proto"],
deps = [
":export_task_event_proto",
":export_task_event_proto",
":export_node_event_proto",
":export_actor_event_proto",
":export_driver_job_event_proto",
Expand Down
17 changes: 17 additions & 0 deletions src/ray/protobuf/node_manager.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package ray.rpc;
import "src/ray/protobuf/common.proto";
import "src/ray/protobuf/gcs.proto";
import "src/ray/protobuf/autoscaler.proto";
import "src/ray/protobuf/runtime_env_common.proto";

message WorkerBacklogReport {
// TaskSpec indicating the scheduling class.
Expand Down Expand Up @@ -94,6 +95,20 @@ message RequestWorkerLeaseReply {
string scheduling_failure_message = 10;
}

// Request to prestart workers. At this time we don't yet know the resource, or task type.
message PrestartWorkersRequest {
Language language = 1;
// Job ID for the workers. Note: root_detached_actor_id is not supported.
optional bytes job_id = 2;
RuntimeEnvInfo runtime_env_info = 3;
// Started idle workers will be kept alive for this duration. Reset on task assignment.
uint64 keep_alive_duration_secs = 4;
// Raylet will try to start `num_workers` workers.
uint64 num_workers = 5;
}

message PrestartWorkersReply {}

message PrepareBundleResourcesRequest {
// Bundles that containing the requested resources.
repeated Bundle bundle_specs = 1;
Expand Down Expand Up @@ -385,6 +400,8 @@ service NodeManagerService {
rpc GetResourceLoad(GetResourceLoadRequest) returns (GetResourceLoadReply);
// Request a worker from the raylet.
rpc RequestWorkerLease(RequestWorkerLeaseRequest) returns (RequestWorkerLeaseReply);
// Request to prestart workers.
rpc PrestartWorkers(PrestartWorkersRequest) returns (PrestartWorkersReply);
// Report task backlog information from a worker to the raylet
rpc ReportWorkerBacklog(ReportWorkerBacklogRequest) returns (ReportWorkerBacklogReply);
// Release a worker back to its raylet.
Expand Down
37 changes: 37 additions & 0 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
#include <csignal>
#include <fstream>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/functional/bind_front.h"
#include "absl/time/clock.h"
Expand All @@ -40,6 +43,7 @@
#include "ray/util/event.h"
#include "ray/util/event_label.h"
#include "ray/util/util.h"
#include "src/ray/raylet/worker_pool.h"

namespace {

Expand Down Expand Up @@ -1867,6 +1871,39 @@ void NodeManager::HandleRequestWorkerLease(rpc::RequestWorkerLeaseRequest reques
send_reply_callback_wrapper);
}

void NodeManager::HandlePrestartWorkers(rpc::PrestartWorkersRequest request,
rpc::PrestartWorkersReply *reply,
rpc::SendReplyCallback send_reply_callback) {
auto pop_worker_request = std::make_shared<PopWorkerRequest>(
request.language(),
rpc::WorkerType::WORKER,
request.has_job_id() ? JobID::FromBinary(request.job_id()) : JobID::Nil(),
/*root_detached_actor_id=*/ActorID::Nil(),
/*gpu=*/std::nullopt,
/*actor_worker=*/std::nullopt,
request.runtime_env_info(),
/*runtime_env_hash=*/
CalculateRuntimeEnvHash(request.runtime_env_info().serialized_runtime_env()),
/*options=*/std::vector<std::string>{},
absl::Seconds(request.keep_alive_duration_secs()),
/*callback=*/
[request](const std::shared_ptr<WorkerInterface> &worker,
PopWorkerStatus status,
const std::string &runtime_env_setup_error_message) {
// This callback does not use the worker.
RAY_LOG(DEBUG).WithField(worker->WorkerId())
<< "Prestart worker started! token " << worker->GetStartupToken()
<< ", status " << status << ", runtime_env_setup_error_message "
<< runtime_env_setup_error_message;
return false;
});

for (uint64_t i = 0; i < request.num_workers(); i++) {
worker_pool_.StartNewWorker(pop_worker_request);
}
send_reply_callback(Status::OK(), nullptr, nullptr);
}

void NodeManager::HandlePrepareBundleResources(
rpc::PrepareBundleResourcesRequest request,
rpc::PrepareBundleResourcesReply *reply,
Expand Down
4 changes: 4 additions & 0 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

void HandlePrestartWorkers(rpc::PrestartWorkersRequest request,
rpc::PrestartWorkersReply *reply,
rpc::SendReplyCallback send_reply_callback) override;

/// Handle a `ReportWorkerBacklog` request.
void HandleReportWorkerBacklog(rpc::ReportWorkerBacklogRequest request,
rpc::ReportWorkerBacklogReply *reply,
Expand Down
Loading

0 comments on commit 29b20e3

Please sign in to comment.