aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBryan Newbold <bnewbold@archive.org>2021-10-26 17:10:10 -0700
committerBryan Newbold <bnewbold@archive.org>2021-10-26 17:10:11 -0700
commitb081969a52df631d9ebdc91740322f75246957ee (patch)
tree04bf4326878cb6049d120b36a846b6bc35600eb0
parent4a46f166f8514b5620d2bcb13a5c5f3e6cee66c8 (diff)
downloadsandcrawler-b081969a52df631d9ebdc91740322f75246957ee.tar.gz
sandcrawler-b081969a52df631d9ebdc91740322f75246957ee.zip
type annotations on SandcrawlerWorker
These annoations have a broad impact! Being conservative to start: Any-to-Any for process(), etc.
-rw-r--r--python/sandcrawler/workers.py103
1 files changed, 57 insertions, 46 deletions
diff --git a/python/sandcrawler/workers.py b/python/sandcrawler/workers.py
index 6b08f03..1b132ed 100644
--- a/python/sandcrawler/workers.py
+++ b/python/sandcrawler/workers.py
@@ -5,11 +5,13 @@ import sys
import time
import zipfile
from collections import Counter
+from typing import Any, Dict, List, Optional, Sequence
import requests
from confluent_kafka import Consumer, KafkaException, Producer
-from .ia import PetaboxError, SandcrawlerBackoffError, WaybackContentError, WaybackError
+from .ia import (PetaboxError, SandcrawlerBackoffError, WaybackClient, WaybackContentError,
+ WaybackError)
from .misc import parse_cdx_line
@@ -20,12 +22,11 @@ class SandcrawlerWorker(object):
Usually these get "pushed" into by a RecordPusher. Output goes to another
worker (pipeline-style), or defaults to stdout.
"""
- def __init__(self):
- self.counts = Counter()
- self.sink = None
- # TODO: self.counters
+ def __init__(self, sink: Optional['SandcrawlerWorker'] = None):
+ self.counts: Counter = Counter()
+ self.sink: Optional[SandcrawlerWorker] = sink
- def push_record(self, task, key=None):
+ def push_record(self, task: Any, key: Optional[str] = None) -> Any:
self.counts['total'] += 1
if not self.want(task):
self.counts['skip'] += 1
@@ -44,7 +45,7 @@ class SandcrawlerWorker(object):
print(json.dumps(result))
return result
- def timeout_response(self, task):
+ def timeout_response(self, task: Any) -> Any:
"""
This should be overridden by workers that want to return something
meaningful when there is a processing timeout. Eg, JSON vs some other
@@ -52,7 +53,10 @@ class SandcrawlerWorker(object):
"""
return None
- def push_record_timeout(self, task, key=None, timeout=300):
+ def push_record_timeout(self,
+ task: Any,
+ key: Optional[str] = None,
+ timeout: int = 300) -> Any:
"""
A wrapper around self.push_record which sets a timeout.
@@ -60,7 +64,7 @@ class SandcrawlerWorker(object):
multithreading or if signal-based timeouts are used elsewhere in the
same process.
"""
- def timeout_handler(signum, frame):
+ def timeout_handler(signum: int, frame: Any) -> None:
raise TimeoutError("timeout processing record")
signal.signal(signal.SIGALRM, timeout_handler)
@@ -81,27 +85,29 @@ class SandcrawlerWorker(object):
signal.alarm(0)
return resp
- def push_batch(self, tasks):
+ def push_batch(self, tasks: List[Any]) -> List[Any]:
results = []
for task in tasks:
results.append(self.push_record(task))
return results
- def finish(self):
+ def finish(self) -> Counter:
if self.sink:
self.sink.finish()
print("Worker: {}".format(self.counts), file=sys.stderr)
return self.counts
- def want(self, task):
+ def want(self, task: Any) -> bool:
"""
Optionally override this as a filter in implementations.
"""
return True
- def process(self, task, key=None):
+ def process(self, task: Any, key: str = None) -> Any:
"""
Derived workers need to implement business logic here.
+
+ TODO: should derived workers explicitly type-check the 'task' object?
"""
raise NotImplementedError('implementation required')
@@ -111,11 +117,11 @@ class SandcrawlerFetchWorker(SandcrawlerWorker):
Wrapper of SandcrawlerWorker that adds a helper method to fetch blobs (eg,
PDFs) from wayback, archive.org, or other sources.
"""
- def __init__(self, wayback_client, **kwargs):
+ def __init__(self, wayback_client: WaybackClient, **kwargs):
super().__init__(**kwargs)
self.wayback_client = wayback_client
- def fetch_blob(self, record):
+ def fetch_blob(self, record: Dict[str, Any]) -> Dict[str, Any]:
default_key = record['sha1hex']
wayback_sec = None
petabox_sec = None
@@ -123,7 +129,7 @@ class SandcrawlerFetchWorker(SandcrawlerWorker):
if record.get('warc_path') and record.get('warc_offset'):
# it's a full CDX dict. fetch using WaybackClient
if not self.wayback_client:
- raise Exception("wayback client not configured for this PdfTrioWorker")
+ raise Exception("wayback client not configured for this SandcrawlerFetchWorker")
try:
start = time.time()
blob = self.wayback_client.fetch_petabox_body(
@@ -142,7 +148,7 @@ class SandcrawlerFetchWorker(SandcrawlerWorker):
elif record.get('url') and record.get('datetime'):
# it's a partial CDX dict or something? fetch using WaybackClient
if not self.wayback_client:
- raise Exception("wayback client not configured for this PdfTrioWorker")
+ raise Exception("wayback client not configured for this SandcrawlerFetchWorker")
try:
start = time.time()
blob = self.wayback_client.fetch_replay_body(
@@ -195,20 +201,23 @@ class SandcrawlerFetchWorker(SandcrawlerWorker):
class MultiprocessWrapper(SandcrawlerWorker):
- def __init__(self, worker, sink, jobs=None):
+ def __init__(self,
+ worker: SandcrawlerWorker,
+ sink: Optional[SandcrawlerWorker] = None,
+ jobs: Optional[int] = None):
self.counts = Counter()
self.worker = worker
self.sink = sink
self.pool = multiprocessing.pool.Pool(jobs)
- def push_batch(self, tasks):
+ def push_batch(self, tasks: List[Any]) -> List[Any]:
self.counts['total'] += len(tasks)
print("... processing batch of: {}".format(len(tasks)), file=sys.stderr)
results = self.pool.map(self.worker.process, tasks)
for result in results:
if not result:
self.counts['failed'] += 1
- return
+ return []
elif type(result) == dict and 'status' in result and len(result['status']) < 32:
self.counts[result['status']] += 1
@@ -219,7 +228,7 @@ class MultiprocessWrapper(SandcrawlerWorker):
print(json.dumps(result))
return results
- def finish(self):
+ def finish(self) -> Counter:
self.pool.terminate()
if self.sink:
self.sink.finish()
@@ -234,15 +243,15 @@ class BlackholeSink(SandcrawlerWorker):
Useful for tests.
"""
- def push_record(self, task, key=None):
+ def push_record(self, task: Any, key: Optional[str] = None) -> Any:
return
- def push_batch(self, tasks):
- return
+ def push_batch(self, tasks: List[Any]) -> List[Any]:
+ return []
class KafkaSink(SandcrawlerWorker):
- def __init__(self, kafka_hosts, produce_topic, **kwargs):
+ def __init__(self, kafka_hosts: str, produce_topic: str, **kwargs):
self.sink = None
self.counts = Counter()
self.produce_topic = produce_topic
@@ -257,14 +266,14 @@ class KafkaSink(SandcrawlerWorker):
self.producer = Producer(config)
@staticmethod
- def _fail_fast(err, msg):
+ def _fail_fast(err: Any, msg: Any) -> None:
if err is not None:
print("Kafka producer delivery error: {}".format(err), file=sys.stderr)
print("Bailing out...", file=sys.stderr)
# TODO: should it be sys.exit(-1)?
raise KafkaException(err)
- def producer_config(self, kafka_config):
+ def producer_config(self, kafka_config: dict) -> dict:
config = kafka_config.copy()
config.update({
'delivery.report.only.error': True,
@@ -275,7 +284,7 @@ class KafkaSink(SandcrawlerWorker):
})
return config
- def push_record(self, msg, key=None):
+ def push_record(self, msg: Any, key: Optional[str] = None) -> Any:
self.counts['total'] += 1
if type(msg) == dict:
if not key and 'key' in msg:
@@ -291,11 +300,12 @@ class KafkaSink(SandcrawlerWorker):
# check for errors etc
self.producer.poll(0)
- def push_batch(self, msgs):
+ def push_batch(self, msgs: List[Any]) -> List[Any]:
for m in msgs:
self.push_record(m)
+ return []
- def finish(self):
+ def finish(self) -> Counter:
self.producer.flush()
return self.counts
@@ -304,7 +314,7 @@ class KafkaCompressSink(KafkaSink):
"""
Variant of KafkaSink for large documents. Used for, eg, GROBID output.
"""
- def producer_config(self, kafka_config):
+ def producer_config(self, kafka_config: Dict[str, Any]) -> Dict[str, Any]:
config = kafka_config.copy()
config.update({
'compression.codec': 'gzip',
@@ -325,11 +335,11 @@ class RecordPusher:
Base class for different record sources to be pushed into workers. Pretty
trivial interface, just wraps an importer and pushes records in to it.
"""
- def __init__(self, worker, **kwargs):
- self.counts = Counter()
- self.worker = worker
+ def __init__(self, worker: SandcrawlerWorker, **kwargs):
+ self.counts: Counter = Counter()
+ self.worker: SandcrawlerWorker = worker
- def run(self):
+ def run(self) -> Counter:
"""
This will look something like:
@@ -342,7 +352,7 @@ class RecordPusher:
class JsonLinePusher(RecordPusher):
- def __init__(self, worker, json_file, **kwargs):
+ def __init__(self, worker: SandcrawlerWorker, json_file: Sequence, **kwargs):
self.counts = Counter()
self.worker = worker
self.json_file = json_file
@@ -350,7 +360,7 @@ class JsonLinePusher(RecordPusher):
if self.batch_size in (0, 1):
self.batch_size = None
- def run(self):
+ def run(self) -> Counter:
batch = []
for line in self.json_file:
if not line:
@@ -380,7 +390,7 @@ class JsonLinePusher(RecordPusher):
class CdxLinePusher(RecordPusher):
- def __init__(self, worker, cdx_file, **kwargs):
+ def __init__(self, worker: SandcrawlerWorker, cdx_file: Sequence, **kwargs):
self.counts = Counter()
self.worker = worker
self.cdx_file = cdx_file
@@ -391,7 +401,7 @@ class CdxLinePusher(RecordPusher):
if self.batch_size in (0, 1):
self.batch_size = None
- def run(self):
+ def run(self) -> Counter:
batch = []
for line in self.cdx_file:
if not line:
@@ -427,7 +437,7 @@ class CdxLinePusher(RecordPusher):
class ZipfilePusher(RecordPusher):
- def __init__(self, worker, zipfile_path, **kwargs):
+ def __init__(self, worker: SandcrawlerWorker, zipfile_path: str, **kwargs):
self.counts = Counter()
self.worker = worker
self.filter_suffix = ".pdf"
@@ -436,7 +446,7 @@ class ZipfilePusher(RecordPusher):
if self.batch_size in (0, 1):
self.batch_size = None
- def run(self):
+ def run(self) -> Counter:
batch = []
with zipfile.ZipFile(self.zipfile_path, 'r') as archive:
for zipinfo in archive.infolist():
@@ -466,7 +476,8 @@ class ZipfilePusher(RecordPusher):
class KafkaJsonPusher(RecordPusher):
- def __init__(self, worker, kafka_hosts, consume_topic, group, **kwargs):
+ def __init__(self, worker: SandcrawlerWorker, kafka_hosts: str, consume_topic: str,
+ group: str, **kwargs):
self.counts = Counter()
self.worker = worker
self.consumer = make_kafka_consumer(
@@ -483,7 +494,7 @@ class KafkaJsonPusher(RecordPusher):
self.batch_worker = kwargs.get('batch_worker', False)
self.process_timeout_sec = kwargs.get('process_timeout_sec', 300)
- def run(self):
+ def run(self) -> Counter:
while True:
# TODO: this is batch-oriented, because underlying worker is
# often batch-oriented, but this doesn't confirm that entire batch
@@ -562,10 +573,10 @@ class KafkaJsonPusher(RecordPusher):
return self.counts
-def make_kafka_consumer(hosts, consume_topic, group):
+def make_kafka_consumer(hosts: str, consume_topic: str, group: str) -> Consumer:
topic_name = consume_topic
- def fail_fast(err, partitions):
+ def fail_fast(err: Any, partitions: List[Any]) -> None:
if err is not None:
print("Kafka consumer commit error: {}".format(err), file=sys.stderr)
print("Bailing out...", file=sys.stderr)
@@ -600,7 +611,7 @@ def make_kafka_consumer(hosts, consume_topic, group):
},
}
- def on_rebalance(consumer, partitions):
+ def on_rebalance(consumer: Any, partitions: List[Any]) -> None:
for p in partitions:
if p.error:
raise KafkaException(p.error)