diff --git a/kafka/client_async.py b/kafka/client_async.py index 19508b242..30258b7bd 100644 --- a/kafka/client_async.py +++ b/kafka/client_async.py @@ -27,7 +27,7 @@ from kafka.metrics.stats.rate import TimeUnit from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.protocol.metadata import MetadataRequest -from kafka.util import Dict, WeakMethod, ensure_valid_topic_name +from kafka.util import Dict, WeakMethod, ensure_valid_topic_name, timeout_ms_fn # Although this looks unused, it actually monkey-patches socket.socketpair() # and should be left in as long as we're using socket.socketpair() in this file from kafka.vendor import socketpair # noqa: F401 @@ -400,6 +400,11 @@ def maybe_connect(self, node_id, wakeup=True): return True return False + def connection_failed(self, node_id): + if node_id not in self._conns: + return False + return self._conns[node_id].connect_failed() + def _should_recycle_connection(self, conn): # Never recycle unless disconnected if not conn.disconnected(): @@ -1157,6 +1162,39 @@ def bootstrap_connected(self): else: return False + def await_ready(self, node_id, timeout_ms=30000): + """ + Invokes `poll` to discard pending disconnects, followed by `client.ready` and 0 or more `client.poll` + invocations until the connection to `node` is ready, the timeoutMs expires or the connection fails. + + It returns `true` if the call completes normally or `false` if the timeoutMs expires. If the connection fails, + an `IOException` is thrown instead. Note that if the `NetworkClient` has been configured with a positive + connection timeoutMs, it is possible for this method to raise an `IOException` for a previous connection which + has recently disconnected. + + This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with + care. + """ + inner_timeout_ms = timeout_ms_fn(timeout_ms, None) + self.poll(timeout_ms=0) + if self.is_ready(node_id): + return True + + while not self.is_ready(node_id) and inner_timeout_ms() > 0: + if self.connection_failed(node_id): + raise Errors.KafkaConnectionError("Connection to %s failed." % (node_id,)) + self.maybe_connect(node_id) + self.poll(timeout_ms=inner_timeout_ms()) + return self.is_ready(node_id) + + def send_and_receive(self, node_id, request): + future = self.send(node_id, request) + self.poll(future=future) + assert future.is_done + if future.failed(): + raise future.exception + return future.value + # OrderedDict requires python2.7+ try: diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py index f0eb37a8f..320a1657f 100644 --- a/kafka/producer/kafka.py +++ b/kafka/producer/kafka.py @@ -19,6 +19,7 @@ from kafka.producer.future import FutureRecordMetadata, FutureProduceResult from kafka.producer.record_accumulator import AtomicInteger, RecordAccumulator from kafka.producer.sender import Sender +from kafka.producer.transaction_state import TransactionState from kafka.record.default_records import DefaultRecordBatchBuilder from kafka.record.legacy_records import LegacyRecordBatchBuilder from kafka.serializer import Serializer @@ -93,6 +94,19 @@ class KafkaProducer(object): value_serializer (callable): used to convert user-supplied message values to bytes. If not None, called as f(value), should return bytes. Default: None. + enable_idempotence (bool): When set to True, the producer will ensure + that exactly one copy of each message is written in the stream. + If False, producer retries due to broker failures, etc., may write + duplicates of the retried message in the stream. Default: False. + + Note that enabling idempotence requires + `max_in_flight_requests_per_connection` to be set to 1 and `retries` + cannot be zero. Additionally, `acks` must be set to 'all'. If these + values are left at their defaults, the producer will override the + defaults to be suitable. If the values are set to something + incompatible with the idempotent producer, a KafkaConfigurationError + will be raised. + acks (0, 1, 'all'): The number of acknowledgments the producer requires the leader to have received before considering a request complete. This controls the durability of records that are sent. The @@ -303,6 +317,7 @@ class KafkaProducer(object): 'client_id': None, 'key_serializer': None, 'value_serializer': None, + 'enable_idempotence': False, 'acks': 1, 'bootstrap_topics_filter': set(), 'compression_type': None, @@ -365,6 +380,7 @@ class KafkaProducer(object): def __init__(self, **configs): log.debug("Starting the Kafka producer") # trace self.config = copy.copy(self.DEFAULT_CONFIG) + user_provided_configs = set(configs.keys()) for key in self.config: if key in configs: self.config[key] = configs.pop(key) @@ -428,13 +444,41 @@ def __init__(self, **configs): assert checker(), "Libraries for {} compression codec not found".format(ct) self.config['compression_attrs'] = compression_attrs - message_version = self._max_usable_produce_magic() - self._accumulator = RecordAccumulator(message_version=message_version, **self.config) + self._transaction_state = None + if self.config['enable_idempotence']: + self._transaction_state = TransactionState() + if 'retries' not in user_provided_configs: + log.info("Overriding the default 'retries' config to 3 since the idempotent producer is enabled.") + self.config['retries'] = 3 + elif self.config['retries'] == 0: + raise Errors.KafkaConfigurationError("Must set 'retries' to non-zero when using the idempotent producer.") + + if 'max_in_flight_requests_per_connection' not in user_provided_configs: + log.info("Overriding the default 'max_in_flight_requests_per_connection' to 1 since idempontence is enabled.") + self.config['max_in_flight_requests_per_connection'] = 1 + elif self.config['max_in_flight_requests_per_connection'] != 1: + raise Errors.KafkaConfigurationError("Must set 'max_in_flight_requests_per_connection' to 1 in order" + " to use the idempotent producer." + " Otherwise we cannot guarantee idempotence.") + + if 'acks' not in user_provided_configs: + log.info("Overriding the default 'acks' config to 'all' since idempotence is enabled") + self.config['acks'] = -1 + elif self.config['acks'] != -1: + raise Errors.KafkaConfigurationError("Must set 'acks' config to 'all' in order to use the idempotent" + " producer. Otherwise we cannot guarantee idempotence") + + message_version = self.max_usable_produce_magic(self.config['api_version']) + self._accumulator = RecordAccumulator( + transaction_state=self._transaction_state, + message_version=message_version, + **self.config) self._metadata = client.cluster guarantee_message_order = bool(self.config['max_in_flight_requests_per_connection'] == 1) self._sender = Sender(client, self._metadata, self._accumulator, metrics=self._metrics, + transaction_state=self._transaction_state, guarantee_message_order=guarantee_message_order, **self.config) self._sender.daemon = True @@ -548,16 +592,17 @@ def partitions_for(self, topic): max_wait = self.config['max_block_ms'] / 1000 return self._wait_on_metadata(topic, max_wait) - def _max_usable_produce_magic(self): - if self.config['api_version'] >= (0, 11): + @classmethod + def max_usable_produce_magic(cls, api_version): + if api_version >= (0, 11): return 2 - elif self.config['api_version'] >= (0, 10, 0): + elif api_version >= (0, 10, 0): return 1 else: return 0 def _estimate_size_in_bytes(self, key, value, headers=[]): - magic = self._max_usable_produce_magic() + magic = self.max_usable_produce_magic(self.config['api_version']) if magic == 2: return DefaultRecordBatchBuilder.estimate_size_in_bytes( key, value, headers) diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py index ba823500d..60fa0a323 100644 --- a/kafka/producer/record_accumulator.py +++ b/kafka/producer/record_accumulator.py @@ -35,9 +35,9 @@ def get(self): class ProducerBatch(object): - def __init__(self, tp, records): + def __init__(self, tp, records, now=None): self.max_record_size = 0 - now = time.time() + now = time.time() if now is None else now self.created = now self.drained = None self.attempts = 0 @@ -52,13 +52,18 @@ def __init__(self, tp, records): def record_count(self): return self.records.next_offset() - def try_append(self, timestamp_ms, key, value, headers): + @property + def producer_id(self): + return self.records.producer_id if self.records else None + + def try_append(self, timestamp_ms, key, value, headers, now=None): metadata = self.records.append(timestamp_ms, key, value, headers) if metadata is None: return None + now = time.time() if now is None else now self.max_record_size = max(self.max_record_size, metadata.size) - self.last_append = time.time() + self.last_append = now future = FutureRecordMetadata(self.produce_future, metadata.offset, metadata.timestamp, metadata.crc, len(key) if key is not None else -1, @@ -81,7 +86,7 @@ def done(self, base_offset=None, timestamp_ms=None, exception=None, log_start_of log_start_offset, exception) # trace self.produce_future.failure(exception) - def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full): + def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full, now=None): """Expire batches if metadata is not available A batch whose metadata is not available should be expired if one @@ -93,7 +98,7 @@ def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full) * the batch is in retry AND request timeout has elapsed after the backoff period ended. """ - now = time.time() + now = time.time() if now is None else now since_append = now - self.last_append since_ready = now - (self.created + linger_ms / 1000.0) since_backoff = now - (self.last_attempt + retry_backoff_ms / 1000.0) @@ -121,6 +126,10 @@ def in_retry(self): def set_retry(self): self._retry = True + @property + def is_done(self): + return self.produce_future.is_done + def __str__(self): return 'ProducerBatch(topic_partition=%s, record_count=%d)' % ( self.topic_partition, self.records.next_offset()) @@ -161,6 +170,7 @@ class RecordAccumulator(object): 'compression_attrs': 0, 'linger_ms': 0, 'retry_backoff_ms': 100, + 'transaction_state': None, 'message_version': 0, } @@ -171,6 +181,7 @@ def __init__(self, **configs): self.config[key] = configs.pop(key) self._closed = False + self._transaction_state = self.config['transaction_state'] self._flushes_in_progress = AtomicInteger() self._appends_in_progress = AtomicInteger() self._batches = collections.defaultdict(collections.deque) # TopicPartition: [ProducerBatch] @@ -233,6 +244,10 @@ def append(self, tp, timestamp_ms, key, value, headers): batch_is_full = len(dq) > 1 or last.records.is_full() return future, batch_is_full, False + if self._transaction_state and self.config['message_version'] < 2: + raise Errors.UnsupportedVersionError("Attempting to use idempotence with a broker which" + " does not support the required message format (v2)." + " The broker must be version 0.11 or later.") records = MemoryRecordsBuilder( self.config['message_version'], self.config['compression_attrs'], @@ -310,9 +325,9 @@ def abort_expired_batches(self, request_timeout_ms, cluster): return expired_batches - def reenqueue(self, batch): + def reenqueue(self, batch, now=None): """Re-enqueue the given record batch in the accumulator to retry.""" - now = time.time() + now = time.time() if now is None else now batch.attempts += 1 batch.last_attempt = now batch.last_append = now @@ -323,7 +338,7 @@ def reenqueue(self, batch): with self._tp_locks[batch.topic_partition]: dq.appendleft(batch) - def ready(self, cluster): + def ready(self, cluster, now=None): """ Get a list of nodes whose partitions are ready to be sent, and the earliest time at which any non-sendable partition will be ready; @@ -357,7 +372,7 @@ def ready(self, cluster): ready_nodes = set() next_ready_check = 9999999.99 unknown_leaders_exist = False - now = time.time() + now = time.time() if now is None else now # several threads are accessing self._batches -- to simplify # concurrent access, we iterate over a snapshot of partitions @@ -412,7 +427,7 @@ def has_unsent(self): return True return False - def drain(self, cluster, nodes, max_size): + def drain(self, cluster, nodes, max_size, now=None): """ Drain all the data for the given nodes and collate them into a list of batches that will fit within the specified size on a per-node basis. @@ -430,7 +445,7 @@ def drain(self, cluster, nodes, max_size): if not nodes: return {} - now = time.time() + now = time.time() if now is None else now batches = {} for node_id in nodes: size = 0 @@ -463,7 +478,26 @@ def drain(self, cluster, nodes, max_size): # single request break else: + producer_id_and_epoch = None + if self._transaction_state: + producer_id_and_epoch = self._transaction_state.producer_id_and_epoch + if not producer_id_and_epoch.is_valid: + # we cannot send the batch until we have refreshed the PID + log.debug("Waiting to send ready batches because transaction producer id is not valid") + break + batch = dq.popleft() + if producer_id_and_epoch and not batch.in_retry(): + # If the batch is in retry, then we should not change the pid and + # sequence number, since this may introduce duplicates. In particular, + # the previous attempt may actually have been accepted, and if we change + # the pid and sequence here, this attempt will also be accepted, causing + # a duplicate. + sequence_number = self._transaction_state.sequence_number(batch.topic_partition) + log.debug("Dest: %s: %s producer_id=%s epoch=%s sequence=%s", + node_id, batch.topic_partition, producer_id_and_epoch.producer_id, producer_id_and_epoch.epoch, + sequence_number) + batch.records.set_producer_state(producer_id_and_epoch.producer_id, producer_id_and_epoch.epoch, sequence_number) batch.records.close() size += batch.records.size_in_bytes() ready.append(batch) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py index 20af28d07..24b84a9b1 100644 --- a/kafka/producer/sender.py +++ b/kafka/producer/sender.py @@ -11,6 +11,7 @@ from kafka import errors as Errors from kafka.metrics.measurable import AnonMeasurable from kafka.metrics.stats import Avg, Max, Rate +from kafka.protocol.init_producer_id import InitProducerIdRequest from kafka.protocol.produce import ProduceRequest from kafka.structs import TopicPartition from kafka.version import __version__ @@ -29,8 +30,12 @@ class Sender(threading.Thread): 'acks': 1, 'retries': 0, 'request_timeout_ms': 30000, + 'retry_backoff_ms': 100, 'metrics': None, 'guarantee_message_order': False, + 'transaction_state': None, + 'transactional_id': None, + 'transaction_timeout_ms': 60000, 'client_id': 'kafka-python-' + __version__, } @@ -52,6 +57,7 @@ def __init__(self, client, metadata, accumulator, **configs): self._sensors = SenderMetrics(self.config['metrics'], self._client, self._metadata) else: self._sensors = None + self._transaction_state = self.config['transaction_state'] def run(self): """The main run loop for the sender thread.""" @@ -95,6 +101,8 @@ def run_once(self): while self._topics_to_add: self._client.add_topic(self._topics_to_add.pop()) + self._maybe_wait_for_producer_id() + # get the list of partitions with data ready to send result = self._accumulator.ready(self._metadata) ready_nodes, next_ready_check_delay, unknown_leaders_exist = result @@ -128,6 +136,13 @@ def run_once(self): expired_batches = self._accumulator.abort_expired_batches( self.config['request_timeout_ms'], self._metadata) + # Reset the producer_id if an expired batch has previously been sent to the broker. + # See the documentation of `TransactionState.reset_producer_id` to understand why + # we need to reset the producer id here. + if self._transaction_state and any([batch.in_retry() for batch in expired_batches]): + self._transaction_state.reset_producer_id() + return + if self._sensors: for expired_batch in expired_batches: self._sensors.record_errors(expired_batch.topic_partition.topic, expired_batch.record_count) @@ -185,6 +200,41 @@ def add_topic(self, topic): self._topics_to_add.add(topic) self.wakeup() + def _maybe_wait_for_producer_id(self): + if not self._transaction_state: + return + + while not self._transaction_state.has_pid(): + try: + node_id = self._client.least_loaded_node() + if node_id is None or not self._client.await_ready(node_id): + log.debug("Could not find an available broker to send InitProducerIdRequest to." + + " Will back off and try again.") + time.sleep(self._client.least_loaded_node_refresh_ms() / 1000) + continue + version = self._client.api_version(InitProducerIdRequest, max_version=1) + request = InitProducerIdRequest[version]( + transactional_id=self.config['transactional_id'], + transaction_timeout_ms=self.config['transaction_timeout_ms'], + ) + response = self._client.send_and_receive(node_id, request) + error_type = Errors.for_code(response.error_code) + if error_type is Errors.NoError: + self._transaction_state.set_producer_id_and_epoch(response.producer_id, response.producer_epoch) + return + elif getattr(error_type, 'retriable', False): + log.debug("Retriable error from InitProducerId response: %s", error_type.__name__) + if getattr(error_type, 'invalid_metadata', False): + self._metadata.request_update() + else: + log.error("Received a non-retriable error from InitProducerId response: %s", error_type.__name__) + break + except Errors.KafkaConnectionError: + log.debug("Broker %s disconnected while awaiting InitProducerId response", node_id) + except Errors.RequestTimedOutError: + log.debug("InitProducerId request to node %s timed out", node_id) + time.sleep(self.config['retry_backoff_ms'] / 1000) + def _failed_produce(self, batches, node_id, error): log.error("Error sending produce request to node %d: %s", node_id, error) # trace for batch in batches: @@ -221,6 +271,17 @@ def _handle_produce_response(self, node_id, send_time, batches, response): for batch in batches: self._complete_batch(batch, None, -1) + def _fail_batch(batch, *args, **kwargs): + if self._transaction_state and self._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id: + # Reset the transaction state since we have hit an irrecoverable exception and cannot make any guarantees + # about the previously committed message. Note that this will discard the producer id and sequence + # numbers for all existing partitions. + self._transaction_state.reset_producer_id() + batch.done(*args, **kwargs) + self._accumulator.deallocate(batch) + if self._sensors: + self._sensors.record_errors(batch.topic_partition.topic, batch.record_count) + def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_start_offset=None): """Complete or retry the given batch of records. @@ -235,28 +296,55 @@ def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_star if error is Errors.NoError: error = None - if error is not None and self._can_retry(batch, error): - # retry - log.warning("Got error produce response on topic-partition %s," - " retrying (%d attempts left). Error: %s", - batch.topic_partition, - self.config['retries'] - batch.attempts - 1, - error) - self._accumulator.reenqueue(batch) - if self._sensors: - self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) - else: - if error is Errors.TopicAuthorizationFailedError: - error = error(batch.topic_partition.topic) + if error is not None: + if self._can_retry(batch, error): + # retry + log.warning("Got error produce response on topic-partition %s," + " retrying (%d attempts left). Error: %s", + batch.topic_partition, + self.config['retries'] - batch.attempts - 1, + error) + + # If idempotence is enabled only retry the request if the current PID is the same as the pid of the batch. + if not self._transaction_state or self._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id: + log.debug("Retrying batch to topic-partition %s. Sequence number: %s", + batch.topic_partition, + self._transaction_state.sequence_number(batch.topic_partition) if self._transaction_state else None) + self._accumulator.reenqueue(batch) + if self._sensors: + self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) + else: + log.warning("Attempted to retry sending a batch but the producer id changed from %s to %s. This batch will be dropped" % ( + batch.producer_id, self._transaction_state.producer_id_and_epoch.producer_id)) + self._fail_batch(batch, base_offset=base_offset, timestamp_ms=timestamp_ms, exception=error, log_start_offset=log_start_offset) + else: + if error is Errors.OutOfOrderSequenceNumberError and batch.producer_id == self._transaction_state.producer_id_and_epoch.producer_id: + log.error("The broker received an out of order sequence number error for produer_id %s, topic-partition %s" + " at offset %s. This indicates data loss on the broker, and should be investigated.", + batch.producer_id, batch.topic_partition, base_offset) + + if error is Errors.TopicAuthorizationFailedError: + error = error(batch.topic_partition.topic) + + # tell the user the result of their request + self._fail_batch(batch, base_offset=base_offset, timestamp_ms=timestamp_ms, exception=error, log_start_offset=log_start_offset) + + if error is Errors.UnknownTopicOrPartitionError: + log.warning("Received unknown topic or partition error in produce request on partition %s." + " The topic/partition may not exist or the user may not have Describe access to it", + batch.topic_partition) + + if getattr(error, 'invalid_metadata', False): + self._metadata.request_update() - # tell the user the result of their request - batch.done(base_offset, timestamp_ms, error, log_start_offset) + else: + batch.done(base_offset=base_offset, timestamp_ms=timestamp_ms, log_start_offset=log_start_offset) self._accumulator.deallocate(batch) - if error is not None and self._sensors: - self._sensors.record_errors(batch.topic_partition.topic, batch.record_count) - if getattr(error, 'invalid_metadata', False): - self._metadata.request_update() + if self._transaction_state and self._transaction_state.producer_id_and_epoch.producer_id == batch.producer_id: + self._transaction_state.increment_sequence_number(batch.topic_partition, batch.record_count) + log.debug("Incremented sequence number for topic-partition %s to %s", batch.topic_partition, + self._transaction_state.sequence_number(batch.topic_partition)) # Unmute the completed partition. if self.config['guarantee_message_order']: diff --git a/kafka/producer/transaction_state.py b/kafka/producer/transaction_state.py new file mode 100644 index 000000000..05cdc5766 --- /dev/null +++ b/kafka/producer/transaction_state.py @@ -0,0 +1,96 @@ +from __future__ import absolute_import, division + +import collections +import threading +import time + +from kafka.errors import IllegalStateError + + +NO_PRODUCER_ID = -1 +NO_PRODUCER_EPOCH = -1 + + +class ProducerIdAndEpoch(object): + __slots__ = ('producer_id', 'epoch') + + def __init__(self, producer_id, epoch): + self.producer_id = producer_id + self.epoch = epoch + + @property + def is_valid(self): + return NO_PRODUCER_ID < self.producer_id + + def __str__(self): + return "ProducerIdAndEpoch(producer_id={}, epoch={})".format(self.producer_id, self.epoch) + +class TransactionState(object): + __slots__ = ('producer_id_and_epoch', '_sequence_numbers', '_lock') + + def __init__(self): + self.producer_id_and_epoch = ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH) + self._sequence_numbers = collections.defaultdict(lambda: 0) + self._lock = threading.Condition() + + def has_pid(self): + return self.producer_id_and_epoch.is_valid + + + def await_producer_id_and_epoch(self, max_wait_time_ms): + """ + A blocking call to get the pid and epoch for the producer. If the PID and epoch has not been set, this method + will block for at most maxWaitTimeMs. It is expected that this method be called from application thread + contexts (ie. through Producer.send). The PID it self will be retrieved in the background thread. + + Arguments: + max_wait_time_ms (numeric): The maximum time to block. + + Returns: + ProducerIdAndEpoch object. Callers must check the 'is_valid' property of the returned object to ensure that a + valid pid and epoch is actually returned. + """ + with self._lock: + start = time.time() + elapsed = 0 + while not self.has_pid() and elapsed < max_wait_time_ms: + self._lock.wait(max_wait_time_ms / 1000) + elapsed = time.time() - start + return self.producer_id_and_epoch + + def set_producer_id_and_epoch(self, producer_id, epoch): + """ + Set the pid and epoch atomically. This method will signal any callers blocked on the `pidAndEpoch` method + once the pid is set. This method will be called on the background thread when the broker responds with the pid. + """ + with self._lock: + self.producer_id_and_epoch = ProducerIdAndEpoch(producer_id, epoch) + if self.producer_id_and_epoch.is_valid: + self._lock.notify_all() + + def reset_producer_id(self): + """ + This method is used when the producer needs to reset it's internal state because of an irrecoverable exception + from the broker. + + We need to reset the producer id and associated state when we have sent a batch to the broker, but we either get + a non-retriable exception or we run out of retries, or the batch expired in the producer queue after it was already + sent to the broker. + + In all of these cases, we don't know whether batch was actually committed on the broker, and hence whether the + sequence number was actually updated. If we don't reset the producer state, we risk the chance that all future + messages will return an OutOfOrderSequenceException. + """ + with self._lock: + self.producer_id_and_epoch = ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH) + self._sequence_numbers.clear() + + def sequence_number(self, tp): + with self._lock: + return self._sequence_numbers[tp] + + def increment_sequence_number(self, tp, increment): + with self._lock: + if tp not in self._sequence_numbers: + raise IllegalStateError("Attempt to increment sequence number for a partition with no current sequence.") + self._sequence_numbers[tp] += increment diff --git a/kafka/protocol/init_producer_id.py b/kafka/protocol/init_producer_id.py new file mode 100644 index 000000000..8426fe00b --- /dev/null +++ b/kafka/protocol/init_producer_id.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.types import Int16, Int32, Int64, Schema, String + + +class InitProducerIdResponse_v0(Response): + API_KEY = 22 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('producer_id', Int64), + ('producer_epoch', Int16), + ) + + +class InitProducerIdResponse_v1(Response): + API_KEY = 22 + API_VERSION = 1 + SCHEMA = InitProducerIdResponse_v0.SCHEMA + + +class InitProducerIdRequest_v0(Request): + API_KEY = 22 + API_VERSION = 0 + RESPONSE_TYPE = InitProducerIdResponse_v0 + SCHEMA = Schema( + ('transactional_id', String('utf-8')), + ('transaction_timeout_ms', Int32), + ) + + +class InitProducerIdRequest_v1(Request): + API_KEY = 22 + API_VERSION = 1 + RESPONSE_TYPE = InitProducerIdResponse_v1 + SCHEMA = InitProducerIdRequest_v0.SCHEMA + + +InitProducerIdRequest = [ + InitProducerIdRequest_v0, InitProducerIdRequest_v1, +] +InitProducerIdResponse = [ + InitProducerIdResponse_v0, InitProducerIdResponse_v1, +] diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py index 0d69d72a2..855306bbd 100644 --- a/kafka/record/default_records.py +++ b/kafka/record/default_records.py @@ -448,6 +448,15 @@ def __init__( self._buffer = bytearray(self.HEADER_STRUCT.size) + def set_producer_state(self, producer_id, producer_epoch, base_sequence): + self._producer_id = producer_id + self._producer_epoch = producer_epoch + self._base_sequence = base_sequence + + @property + def producer_id(self): + return self._producer_id + def _get_attributes(self, include_compression_type=True): attrs = 0 if include_compression_type: diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py index 72baea547..a803047ea 100644 --- a/kafka/record/memory_records.py +++ b/kafka/record/memory_records.py @@ -22,7 +22,7 @@ import struct -from kafka.errors import CorruptRecordException +from kafka.errors import CorruptRecordException, IllegalStateError, UnsupportedVersionError from kafka.record.abc import ABCRecords from kafka.record.legacy_records import LegacyRecordBatch, LegacyRecordBatchBuilder from kafka.record.default_records import DefaultRecordBatch, DefaultRecordBatchBuilder @@ -113,7 +113,7 @@ def next_batch(self, _min_slice=MIN_SLICE, class MemoryRecordsBuilder(object): __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", - "_bytes_written") + "_magic", "_bytes_written", "_producer_id") def __init__(self, magic, compression_type, batch_size, offset=0): assert magic in [0, 1, 2], "Not supported magic" @@ -123,15 +123,18 @@ def __init__(self, magic, compression_type, batch_size, offset=0): magic=magic, compression_type=compression_type, is_transactional=False, producer_id=-1, producer_epoch=-1, base_sequence=-1, batch_size=batch_size) + self._producer_id = -1 else: self._builder = LegacyRecordBatchBuilder( magic=magic, compression_type=compression_type, batch_size=batch_size) + self._producer_id = None self._batch_size = batch_size self._buffer = None self._next_offset = offset self._closed = False + self._magic = magic self._bytes_written = 0 def skip(self, offsets_to_skip): @@ -155,6 +158,24 @@ def append(self, timestamp, key, value, headers=[]): self._next_offset += 1 return metadata + def set_producer_state(self, producer_id, producer_epoch, base_sequence): + if self._magic < 2: + raise UnsupportedVersionError('Producer State requires Message format v2+') + elif self._closed: + # Sequence numbers are assigned when the batch is closed while the accumulator is being drained. + # If the resulting ProduceRequest to the partition leader failed for a retriable error, the batch will + # be re queued. In this case, we should not attempt to set the state again, since changing the pid and sequence + # once a batch has been sent to the broker risks introducing duplicates. + raise IllegalStateError("Trying to set producer state of an already closed batch. This indicates a bug on the client.") + self._builder.set_producer_state(producer_id, producer_epoch, base_sequence) + self._producer_id = producer_id + + @property + def producer_id(self): + if self._magic < 2: + raise UnsupportedVersionError('Producer State requires Message format v2+') + return self._producer_id + def close(self): # This method may be called multiple times on the same batch # i.e., on retries @@ -164,6 +185,8 @@ def close(self): if not self._closed: self._bytes_written = self._builder.size() self._buffer = bytes(self._builder.build()) + if self._magic == 2: + self._producer_id = self._builder.producer_id self._builder = None self._closed = True diff --git a/test/test_producer.py b/test/test_producer.py index 069362f26..303832b9f 100644 --- a/test/test_producer.py +++ b/test/test_producer.py @@ -100,7 +100,7 @@ def test_kafka_producer_proper_record_metadata(kafka_broker, compression): retries=5, max_block_ms=30000, compression_type=compression) as producer: - magic = producer._max_usable_produce_magic() + magic = producer.max_usable_produce_magic(producer.config['api_version']) # record headers are supported in 0.11.0 if env_kafka_version() < (0, 11, 0): diff --git a/test/test_record_accumulator.py b/test/test_record_accumulator.py new file mode 100644 index 000000000..babff5617 --- /dev/null +++ b/test/test_record_accumulator.py @@ -0,0 +1,75 @@ +# pylint: skip-file +from __future__ import absolute_import + +import pytest +import io + +from kafka.errors import KafkaTimeoutError +from kafka.producer.future import FutureRecordMetadata, RecordMetadata +from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch +from kafka.record.memory_records import MemoryRecordsBuilder +from kafka.structs import TopicPartition + + +def test_producer_batch_producer_id(): + tp = TopicPartition('foo', 0) + records = MemoryRecordsBuilder( + magic=2, compression_type=0, batch_size=100000) + batch = ProducerBatch(tp, records) + assert batch.producer_id == -1 + batch.records.set_producer_state(123, 456, 789) + assert batch.producer_id == 123 + records.close() + assert batch.producer_id == 123 + +@pytest.mark.parametrize("magic", [0, 1, 2]) +def test_producer_batch_try_append(magic): + tp = TopicPartition('foo', 0) + records = MemoryRecordsBuilder( + magic=magic, compression_type=0, batch_size=100000) + batch = ProducerBatch(tp, records) + assert batch.record_count == 0 + future = batch.try_append(0, b'key', b'value', []) + assert isinstance(future, FutureRecordMetadata) + assert not future.is_done + batch.done(base_offset=123, timestamp_ms=456, log_start_offset=0) + assert future.is_done + # record-level checksum only provided in v0/v1 formats; payload includes magic-byte + if magic == 0: + checksum = 592888119 + elif magic == 1: + checksum = 213653215 + else: + checksum = None + + expected_metadata = RecordMetadata( + topic=tp[0], partition=tp[1], topic_partition=tp, + offset=123, timestamp=456, log_start_offset=0, + checksum=checksum, + serialized_key_size=3, serialized_value_size=5, serialized_header_size=-1) + assert future.value == expected_metadata + +def test_producer_batch_retry(): + tp = TopicPartition('foo', 0) + records = MemoryRecordsBuilder( + magic=2, compression_type=0, batch_size=100000) + batch = ProducerBatch(tp, records) + assert not batch.in_retry() + batch.set_retry() + assert batch.in_retry() + +def test_producer_batch_maybe_expire(): + tp = TopicPartition('foo', 0) + records = MemoryRecordsBuilder( + magic=2, compression_type=0, batch_size=100000) + batch = ProducerBatch(tp, records, now=1) + future = batch.try_append(0, b'key', b'value', [], now=2) + request_timeout_ms = 5000 + retry_backoff_ms = 200 + linger_ms = 1000 + is_full = True + batch.maybe_expire(request_timeout_ms, retry_backoff_ms, linger_ms, is_full, now=20) + assert batch.is_done + assert future.is_done + assert future.failed() + assert isinstance(future.exception, KafkaTimeoutError) diff --git a/test/test_sender.py b/test/test_sender.py index b037d2b48..eedc43d25 100644 --- a/test/test_sender.py +++ b/test/test_sender.py @@ -6,6 +6,7 @@ from kafka.client_async import KafkaClient from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS +from kafka.producer.kafka import KafkaProducer from kafka.protocol.produce import ProduceRequest from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch from kafka.producer.sender import Sender @@ -24,6 +25,7 @@ def sender(client, accumulator, metrics, mocker): @pytest.mark.parametrize(("api_version", "produce_version"), [ + ((2, 1), 7), ((0, 10, 0), 2), ((0, 9), 1), ((0, 8, 0), 0) @@ -31,6 +33,7 @@ def sender(client, accumulator, metrics, mocker): def test_produce_request(sender, mocker, api_version, produce_version): sender._client._api_versions = BROKER_API_VERSIONS[api_version] tp = TopicPartition('foo', 0) + magic = KafkaProducer.max_usable_produce_magic(api_version) records = MemoryRecordsBuilder( magic=1, compression_type=0, batch_size=100000) batch = ProducerBatch(tp, records)