aboutsummaryrefslogtreecommitdiffstats
path: root/fatcat_scholar/kafka.py
diff options
context:
space:
mode:
authorBryan Newbold <bnewbold@archive.org>2020-10-16 19:59:14 -0700
committerBryan Newbold <bnewbold@archive.org>2020-10-16 19:59:14 -0700
commitc284c0f3aad801dbd8b477fb62b50ab63c990890 (patch)
tree440ef06a88f1468da7931af1dd21cbae9fb22619 /fatcat_scholar/kafka.py
parentda121979c8481a5e1f6cf103e2d77363b31018c9 (diff)
downloadfatcat-scholar-c284c0f3aad801dbd8b477fb62b50ab63c990890.tar.gz
fatcat-scholar-c284c0f3aad801dbd8b477fb62b50ab63c990890.zip
first implementation of Kafka worker base class
Diffstat (limited to 'fatcat_scholar/kafka.py')
-rw-r--r--fatcat_scholar/kafka.py219
1 files changed, 219 insertions, 0 deletions
diff --git a/fatcat_scholar/kafka.py b/fatcat_scholar/kafka.py
new file mode 100644
index 0000000..5faff9a
--- /dev/null
+++ b/fatcat_scholar/kafka.py
@@ -0,0 +1,219 @@
+import sys
+import json
+import signal
+from collections import Counter
+from typing import List, Any
+
+from confluent_kafka import Consumer, Producer, KafkaException
+
+
+class KafkaWorker(object):
+ """
+ Base class for Scholar workers which consume from Kafka topics.
+
+ Configuration (passed to __init__):
+
+ kafka_brokers (List[str]): brokers to connect to
+
+ consume_topics (List[str]): topics to consume from
+
+ consumer_group (str): kafka consumer group
+
+ batch_size (int): number of records to consume and process at a time
+
+ batch_timeout_sec (int): max seconds for each batch to process. set to 0 to disable
+
+ API:
+
+ __init__()
+
+ run()
+ starts consuming, calling process_batch() for each message batch
+
+ process_batch(batch: List[dict]) -> None
+ implemented by sub-class
+
+ process_msg(msg: dict) -> None
+ implemented by sub-class
+
+ Example of producing (in a worker):
+
+ producer = self.create_kafka_producer(...)
+
+ producer.produce(
+ topic,
+ some_obj.json(exclude_none=True).encode('UTF-8'),
+ key=key,
+ on_delivery=self._fail_fast_produce)
+
+ # check for errors etc
+ producer.poll(0)
+ """
+
+ def __init__(
+ self,
+ kafka_brokers: List[str],
+ consume_topics: List[str],
+ consumer_group: str,
+ **kwargs: Any,
+ ):
+ self.counts: Counter = Counter()
+ self.kafka_brokers = kafka_brokers
+ self.batch_size = kwargs.get("batch_size", 1)
+ self.batch_timeout_sec = kwargs.get("batch_timeout_sec", 30)
+ self.poll_interval_sec = kwargs.get("poll_interval_sec", 5.0)
+ self.consumer = self.create_kafka_consumer(
+ kafka_brokers, consume_topics, consumer_group
+ )
+
+ @staticmethod
+ def _fail_fast_produce(err: Any, msg: Any) -> None:
+ if err is not None:
+ print("Kafka producer delivery error: {}".format(err), file=sys.stderr)
+ raise KafkaException(err)
+
+ @staticmethod
+ def _timeout_handler(signum: Any, frame: Any) -> None:
+ raise TimeoutError("timeout processing record")
+
+ @staticmethod
+ def create_kafka_consumer(
+ kafka_brokers: List[str], consume_topics: List[str], consumer_group: str
+ ) -> Consumer:
+ """
+ NOTE: it is important that consume_topics be str, *not* bytes
+ """
+
+ def _on_rebalance(consumer: Any, partitions: Any) -> None:
+
+ for p in partitions:
+ if p.error:
+ raise KafkaException(p.error)
+
+ print(
+ f"Kafka partitions rebalanced: {consumer} / {partitions}",
+ file=sys.stderr,
+ )
+
+ def _fail_fast_consume(err: Any, partitions: Any) -> None:
+ if err is not None:
+ print("Kafka consumer commit error: {}".format(err), file=sys.stderr)
+ raise KafkaException(err)
+ for p in partitions:
+ # check for partition-specific commit errors
+ if p.error:
+ print(
+ "Kafka consumer commit error: {}".format(p.error),
+ file=sys.stderr,
+ )
+ raise KafkaException(p.error)
+
+ config = {
+ "bootstrap.servers": ",".join(kafka_brokers),
+ "group.id": consumer_group,
+ "on_commit": _fail_fast_consume,
+ # messages don't have offset marked as stored until processed,
+ # but we do auto-commit stored offsets to broker
+ "enable.auto.offset.store": False,
+ "enable.auto.commit": True,
+ # user code timeout; if no poll after this long, assume user code
+ # hung and rebalance (default: 6min)
+ "max.poll.interval.ms": 360000,
+ "default.topic.config": {"auto.offset.reset": "latest",},
+ }
+
+ consumer = Consumer(config)
+ consumer.subscribe(
+ consume_topics, on_assign=_on_rebalance, on_revoke=_on_rebalance,
+ )
+ print(
+ f"Consuming from kafka topics {consume_topics}, group {consumer_group}",
+ file=sys.stderr,
+ )
+ return consumer
+
+ @staticmethod
+ def create_kafka_producer(kafka_brokers: List[str]) -> Producer:
+ """
+ This configuration is for large compressed messages.
+ """
+
+ config = {
+ "bootstrap.servers": ",".join(kafka_brokers),
+ "message.max.bytes": 30000000, # ~30 MBytes; broker is ~50 MBytes
+ "api.version.request": True,
+ "api.version.fallback.ms": 0,
+ "compression.codec": "gzip",
+ "retry.backoff.ms": 250,
+ "linger.ms": 1000,
+ "batch.num.messages": 50,
+ "delivery.report.only.error": True,
+ "default.topic.config": {
+ "message.timeout.ms": 30000,
+ "request.required.acks": -1, # all brokers must confirm
+ },
+ }
+ return Producer(config)
+
+ def run(self) -> Counter:
+
+ if self.batch_timeout_sec:
+ signal.signal(signal.SIGALRM, self._timeout_handler)
+
+ while True:
+ batch = self.consumer.consume(
+ num_messages=self.batch_size, timeout=self.poll_interval_sec,
+ )
+
+ print(
+ f"... got {len(batch)} kafka messages ({self.poll_interval_sec}sec poll interval). stats: {self.counts}",
+ file=sys.stderr,
+ )
+
+ if not batch:
+ continue
+
+ # first check errors on entire batch...
+ for msg in batch:
+ if msg.error():
+ raise KafkaException(msg.error())
+
+ # ... then process, with optional timeout
+ self.counts["total"] += len(batch)
+ records = [json.loads(msg.value().decode("utf-8")) for msg in batch]
+
+ if self.batch_timeout_sec:
+ signal.alarm(int(self.batch_timeout_sec))
+ try:
+ self.process_batch(records)
+ except TimeoutError as te:
+ raise te
+ finally:
+ signal.alarm(0)
+ else:
+ self.process_batch(records)
+
+ self.counts["processed"] += len(batch)
+
+ # ... then record progress
+ for msg in batch:
+ # will be auto-commited by librdkafka from this "stored" value
+ self.consumer.store_offsets(message=msg)
+
+ # Note: never actually get here, but including as documentation on how to clean up
+ self.consumer.close()
+ return self.counts
+
+ def process_batch(self, batch: List[dict]) -> None:
+ """
+ Workers can override this method for batch processing. By default it
+ calls process_msg() for each message in the batch.
+ """
+ for msg in batch:
+ self.process_msg(msg)
+
+ def process_msg(self, msg: dict) -> None:
+ """
+ Workers can override this method for individual record processing.
+ """
+ raise NotImplementedError("implementation required")