diff options
author | Bryan Newbold <bnewbold@archive.org> | 2021-10-26 17:10:10 -0700 |
---|---|---|
committer | Bryan Newbold <bnewbold@archive.org> | 2021-10-26 17:10:11 -0700 |
commit | b081969a52df631d9ebdc91740322f75246957ee (patch) | |
tree | 04bf4326878cb6049d120b36a846b6bc35600eb0 | |
parent | 4a46f166f8514b5620d2bcb13a5c5f3e6cee66c8 (diff) | |
download | sandcrawler-b081969a52df631d9ebdc91740322f75246957ee.tar.gz sandcrawler-b081969a52df631d9ebdc91740322f75246957ee.zip |
type annotations on SandcrawlerWorker
These annoations have a broad impact! Being conservative to start:
Any-to-Any for process(), etc.
-rw-r--r-- | python/sandcrawler/workers.py | 103 |
1 files changed, 57 insertions, 46 deletions
diff --git a/python/sandcrawler/workers.py b/python/sandcrawler/workers.py index 6b08f03..1b132ed 100644 --- a/python/sandcrawler/workers.py +++ b/python/sandcrawler/workers.py @@ -5,11 +5,13 @@ import sys import time import zipfile from collections import Counter +from typing import Any, Dict, List, Optional, Sequence import requests from confluent_kafka import Consumer, KafkaException, Producer -from .ia import PetaboxError, SandcrawlerBackoffError, WaybackContentError, WaybackError +from .ia import (PetaboxError, SandcrawlerBackoffError, WaybackClient, WaybackContentError, + WaybackError) from .misc import parse_cdx_line @@ -20,12 +22,11 @@ class SandcrawlerWorker(object): Usually these get "pushed" into by a RecordPusher. Output goes to another worker (pipeline-style), or defaults to stdout. """ - def __init__(self): - self.counts = Counter() - self.sink = None - # TODO: self.counters + def __init__(self, sink: Optional['SandcrawlerWorker'] = None): + self.counts: Counter = Counter() + self.sink: Optional[SandcrawlerWorker] = sink - def push_record(self, task, key=None): + def push_record(self, task: Any, key: Optional[str] = None) -> Any: self.counts['total'] += 1 if not self.want(task): self.counts['skip'] += 1 @@ -44,7 +45,7 @@ class SandcrawlerWorker(object): print(json.dumps(result)) return result - def timeout_response(self, task): + def timeout_response(self, task: Any) -> Any: """ This should be overridden by workers that want to return something meaningful when there is a processing timeout. Eg, JSON vs some other @@ -52,7 +53,10 @@ class SandcrawlerWorker(object): """ return None - def push_record_timeout(self, task, key=None, timeout=300): + def push_record_timeout(self, + task: Any, + key: Optional[str] = None, + timeout: int = 300) -> Any: """ A wrapper around self.push_record which sets a timeout. @@ -60,7 +64,7 @@ class SandcrawlerWorker(object): multithreading or if signal-based timeouts are used elsewhere in the same process. """ - def timeout_handler(signum, frame): + def timeout_handler(signum: int, frame: Any) -> None: raise TimeoutError("timeout processing record") signal.signal(signal.SIGALRM, timeout_handler) @@ -81,27 +85,29 @@ class SandcrawlerWorker(object): signal.alarm(0) return resp - def push_batch(self, tasks): + def push_batch(self, tasks: List[Any]) -> List[Any]: results = [] for task in tasks: results.append(self.push_record(task)) return results - def finish(self): + def finish(self) -> Counter: if self.sink: self.sink.finish() print("Worker: {}".format(self.counts), file=sys.stderr) return self.counts - def want(self, task): + def want(self, task: Any) -> bool: """ Optionally override this as a filter in implementations. """ return True - def process(self, task, key=None): + def process(self, task: Any, key: str = None) -> Any: """ Derived workers need to implement business logic here. + + TODO: should derived workers explicitly type-check the 'task' object? """ raise NotImplementedError('implementation required') @@ -111,11 +117,11 @@ class SandcrawlerFetchWorker(SandcrawlerWorker): Wrapper of SandcrawlerWorker that adds a helper method to fetch blobs (eg, PDFs) from wayback, archive.org, or other sources. """ - def __init__(self, wayback_client, **kwargs): + def __init__(self, wayback_client: WaybackClient, **kwargs): super().__init__(**kwargs) self.wayback_client = wayback_client - def fetch_blob(self, record): + def fetch_blob(self, record: Dict[str, Any]) -> Dict[str, Any]: default_key = record['sha1hex'] wayback_sec = None petabox_sec = None @@ -123,7 +129,7 @@ class SandcrawlerFetchWorker(SandcrawlerWorker): if record.get('warc_path') and record.get('warc_offset'): # it's a full CDX dict. fetch using WaybackClient if not self.wayback_client: - raise Exception("wayback client not configured for this PdfTrioWorker") + raise Exception("wayback client not configured for this SandcrawlerFetchWorker") try: start = time.time() blob = self.wayback_client.fetch_petabox_body( @@ -142,7 +148,7 @@ class SandcrawlerFetchWorker(SandcrawlerWorker): elif record.get('url') and record.get('datetime'): # it's a partial CDX dict or something? fetch using WaybackClient if not self.wayback_client: - raise Exception("wayback client not configured for this PdfTrioWorker") + raise Exception("wayback client not configured for this SandcrawlerFetchWorker") try: start = time.time() blob = self.wayback_client.fetch_replay_body( @@ -195,20 +201,23 @@ class SandcrawlerFetchWorker(SandcrawlerWorker): class MultiprocessWrapper(SandcrawlerWorker): - def __init__(self, worker, sink, jobs=None): + def __init__(self, + worker: SandcrawlerWorker, + sink: Optional[SandcrawlerWorker] = None, + jobs: Optional[int] = None): self.counts = Counter() self.worker = worker self.sink = sink self.pool = multiprocessing.pool.Pool(jobs) - def push_batch(self, tasks): + def push_batch(self, tasks: List[Any]) -> List[Any]: self.counts['total'] += len(tasks) print("... processing batch of: {}".format(len(tasks)), file=sys.stderr) results = self.pool.map(self.worker.process, tasks) for result in results: if not result: self.counts['failed'] += 1 - return + return [] elif type(result) == dict and 'status' in result and len(result['status']) < 32: self.counts[result['status']] += 1 @@ -219,7 +228,7 @@ class MultiprocessWrapper(SandcrawlerWorker): print(json.dumps(result)) return results - def finish(self): + def finish(self) -> Counter: self.pool.terminate() if self.sink: self.sink.finish() @@ -234,15 +243,15 @@ class BlackholeSink(SandcrawlerWorker): Useful for tests. """ - def push_record(self, task, key=None): + def push_record(self, task: Any, key: Optional[str] = None) -> Any: return - def push_batch(self, tasks): - return + def push_batch(self, tasks: List[Any]) -> List[Any]: + return [] class KafkaSink(SandcrawlerWorker): - def __init__(self, kafka_hosts, produce_topic, **kwargs): + def __init__(self, kafka_hosts: str, produce_topic: str, **kwargs): self.sink = None self.counts = Counter() self.produce_topic = produce_topic @@ -257,14 +266,14 @@ class KafkaSink(SandcrawlerWorker): self.producer = Producer(config) @staticmethod - def _fail_fast(err, msg): + def _fail_fast(err: Any, msg: Any) -> None: if err is not None: print("Kafka producer delivery error: {}".format(err), file=sys.stderr) print("Bailing out...", file=sys.stderr) # TODO: should it be sys.exit(-1)? raise KafkaException(err) - def producer_config(self, kafka_config): + def producer_config(self, kafka_config: dict) -> dict: config = kafka_config.copy() config.update({ 'delivery.report.only.error': True, @@ -275,7 +284,7 @@ class KafkaSink(SandcrawlerWorker): }) return config - def push_record(self, msg, key=None): + def push_record(self, msg: Any, key: Optional[str] = None) -> Any: self.counts['total'] += 1 if type(msg) == dict: if not key and 'key' in msg: @@ -291,11 +300,12 @@ class KafkaSink(SandcrawlerWorker): # check for errors etc self.producer.poll(0) - def push_batch(self, msgs): + def push_batch(self, msgs: List[Any]) -> List[Any]: for m in msgs: self.push_record(m) + return [] - def finish(self): + def finish(self) -> Counter: self.producer.flush() return self.counts @@ -304,7 +314,7 @@ class KafkaCompressSink(KafkaSink): """ Variant of KafkaSink for large documents. Used for, eg, GROBID output. """ - def producer_config(self, kafka_config): + def producer_config(self, kafka_config: Dict[str, Any]) -> Dict[str, Any]: config = kafka_config.copy() config.update({ 'compression.codec': 'gzip', @@ -325,11 +335,11 @@ class RecordPusher: Base class for different record sources to be pushed into workers. Pretty trivial interface, just wraps an importer and pushes records in to it. """ - def __init__(self, worker, **kwargs): - self.counts = Counter() - self.worker = worker + def __init__(self, worker: SandcrawlerWorker, **kwargs): + self.counts: Counter = Counter() + self.worker: SandcrawlerWorker = worker - def run(self): + def run(self) -> Counter: """ This will look something like: @@ -342,7 +352,7 @@ class RecordPusher: class JsonLinePusher(RecordPusher): - def __init__(self, worker, json_file, **kwargs): + def __init__(self, worker: SandcrawlerWorker, json_file: Sequence, **kwargs): self.counts = Counter() self.worker = worker self.json_file = json_file @@ -350,7 +360,7 @@ class JsonLinePusher(RecordPusher): if self.batch_size in (0, 1): self.batch_size = None - def run(self): + def run(self) -> Counter: batch = [] for line in self.json_file: if not line: @@ -380,7 +390,7 @@ class JsonLinePusher(RecordPusher): class CdxLinePusher(RecordPusher): - def __init__(self, worker, cdx_file, **kwargs): + def __init__(self, worker: SandcrawlerWorker, cdx_file: Sequence, **kwargs): self.counts = Counter() self.worker = worker self.cdx_file = cdx_file @@ -391,7 +401,7 @@ class CdxLinePusher(RecordPusher): if self.batch_size in (0, 1): self.batch_size = None - def run(self): + def run(self) -> Counter: batch = [] for line in self.cdx_file: if not line: @@ -427,7 +437,7 @@ class CdxLinePusher(RecordPusher): class ZipfilePusher(RecordPusher): - def __init__(self, worker, zipfile_path, **kwargs): + def __init__(self, worker: SandcrawlerWorker, zipfile_path: str, **kwargs): self.counts = Counter() self.worker = worker self.filter_suffix = ".pdf" @@ -436,7 +446,7 @@ class ZipfilePusher(RecordPusher): if self.batch_size in (0, 1): self.batch_size = None - def run(self): + def run(self) -> Counter: batch = [] with zipfile.ZipFile(self.zipfile_path, 'r') as archive: for zipinfo in archive.infolist(): @@ -466,7 +476,8 @@ class ZipfilePusher(RecordPusher): class KafkaJsonPusher(RecordPusher): - def __init__(self, worker, kafka_hosts, consume_topic, group, **kwargs): + def __init__(self, worker: SandcrawlerWorker, kafka_hosts: str, consume_topic: str, + group: str, **kwargs): self.counts = Counter() self.worker = worker self.consumer = make_kafka_consumer( @@ -483,7 +494,7 @@ class KafkaJsonPusher(RecordPusher): self.batch_worker = kwargs.get('batch_worker', False) self.process_timeout_sec = kwargs.get('process_timeout_sec', 300) - def run(self): + def run(self) -> Counter: while True: # TODO: this is batch-oriented, because underlying worker is # often batch-oriented, but this doesn't confirm that entire batch @@ -562,10 +573,10 @@ class KafkaJsonPusher(RecordPusher): return self.counts -def make_kafka_consumer(hosts, consume_topic, group): +def make_kafka_consumer(hosts: str, consume_topic: str, group: str) -> Consumer: topic_name = consume_topic - def fail_fast(err, partitions): + def fail_fast(err: Any, partitions: List[Any]) -> None: if err is not None: print("Kafka consumer commit error: {}".format(err), file=sys.stderr) print("Bailing out...", file=sys.stderr) @@ -600,7 +611,7 @@ def make_kafka_consumer(hosts, consume_topic, group): }, } - def on_rebalance(consumer, partitions): + def on_rebalance(consumer: Any, partitions: List[Any]) -> None: for p in partitions: if p.error: raise KafkaException(p.error) |