aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBryan Newbold <bnewbold@archive.org>2021-10-26 17:59:10 -0700
committerBryan Newbold <bnewbold@archive.org>2021-10-26 17:59:12 -0700
commit1088205653eb22f285dd34f2e058216ecfd2abed (patch)
tree3c0fc815d4a8a4885a463c3f07fdaceb490d9568
parent85a264afa2b0cb844d56ccdccc0b8f6bf16f7621 (diff)
downloadsandcrawler-1088205653eb22f285dd34f2e058216ecfd2abed.tar.gz
sandcrawler-1088205653eb22f285dd34f2e058216ecfd2abed.zip
type annotations for persist workers; required some work
Had to re-structure and filter things a bit, Should be better behavior, but might be some small changes.
-rw-r--r--python/sandcrawler/persist.py125
1 files changed, 59 insertions, 66 deletions
diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py
index bb76e54..d47a8cb 100644
--- a/python/sandcrawler/persist.py
+++ b/python/sandcrawler/persist.py
@@ -20,7 +20,7 @@ grobid
import os
import xml.etree.ElementTree
-from typing import AnyStr, Optional
+from typing import Any, Dict, List, Optional
from sandcrawler.db import SandcrawlerPostgresClient
from sandcrawler.grobid import GrobidClient
@@ -31,18 +31,16 @@ from sandcrawler.workers import SandcrawlerWorker
class PersistCdxWorker(SandcrawlerWorker):
- def __init__(self, db_url, **kwargs):
+ def __init__(self, db_url: str, **kwargs):
super().__init__()
self.db = SandcrawlerPostgresClient(db_url)
self.cur = self.db.conn.cursor()
- def process(self, record, key=None):
- """
- Only do batches (as transactions)
- """
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
raise NotImplementedError
- def push_batch(self, batch):
+ def push_batch(self, batch: list) -> list:
self.counts['total'] += len(batch)
# filter to full CDX lines, no liveweb
cdx_batch = [r for r in batch if r.get('warc_path') and ("/" in r['warc_path'])]
@@ -56,18 +54,16 @@ class PersistCdxWorker(SandcrawlerWorker):
class PersistIngestFileResultWorker(SandcrawlerWorker):
- def __init__(self, db_url, **kwargs):
+ def __init__(self, db_url: str, **kwargs):
super().__init__()
self.db = SandcrawlerPostgresClient(db_url)
self.cur = self.db.conn.cursor()
- def process(self, record, key=None):
- """
- Only do batches (as transactions)
- """
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
raise NotImplementedError
- def request_to_row(self, raw):
+ def request_to_row(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Converts ingest-request JSON schema (eg, from Kafka) to SQL ingest_request schema
@@ -222,20 +218,24 @@ class PersistIngestFileResultWorker(SandcrawlerWorker):
{}).get('terminal_dt')
return result
- def push_batch(self, batch):
+ def push_batch(self, batch: List[Any]) -> List[Any]:
self.counts['total'] += len(batch)
if not batch:
return []
- results = [self.file_result_to_row(raw) for raw in batch]
- results = [r for r in results if r]
+ results_unfiltered = [self.file_result_to_row(raw) for raw in batch]
+ results = [r for r in results_unfiltered if r]
- requests = [self.request_to_row(raw['request']) for raw in batch if raw.get('request')]
- requests = [r for r in requests if r and r['ingest_type'] != 'dataset-file']
+ irequests_unfiltered = [
+ self.request_to_row(raw['request']) for raw in batch if raw.get('request')
+ ]
+ irequests = [
+ r for r in irequests_unfiltered if r and r['ingest_type'] != 'dataset-file'
+ ]
- if requests:
- resp = self.db.insert_ingest_request(self.cur, requests)
+ if irequests:
+ resp = self.db.insert_ingest_request(self.cur, irequests)
self.counts['insert-requests'] += resp[0]
self.counts['update-requests'] += resp[1]
if results:
@@ -266,16 +266,16 @@ class PersistIngestFileResultWorker(SandcrawlerWorker):
self.result_to_html_meta(r) for r in batch if r.get('hit') and r.get('html_body')
]
if html_meta_batch:
- rows = [d.to_sql_tuple() for d in html_meta_batch]
+ rows = [d.to_sql_tuple() for d in html_meta_batch if d]
resp = self.db.insert_html_meta(self.cur, rows, on_conflict="update")
self.counts['insert-html_meta'] += resp[0]
self.counts['update-html_meta'] += resp[1]
- fileset_platform_batch = [
+ fileset_platform_batch_all = [
self.result_to_platform_row(raw) for raw in batch if
raw.get('request', {}).get('ingest_type') == 'dataset' and raw.get('platform_name')
]
- fileset_platform_batch = [p for p in fileset_platform_batch if p]
+ fileset_platform_batch: List[Dict] = [p for p in fileset_platform_batch_all if p]
if fileset_platform_batch:
resp = self.db.insert_ingest_fileset_platform(self.cur,
fileset_platform_batch,
@@ -288,39 +288,35 @@ class PersistIngestFileResultWorker(SandcrawlerWorker):
class PersistIngestFilesetWorker(SandcrawlerWorker):
- def __init__(self, db_url, **kwargs):
+ def __init__(self, db_url: str, **kwargs):
super().__init__()
self.db = SandcrawlerPostgresClient(db_url)
self.cur = self.db.conn.cursor()
- def process(self, record, key=None):
- """
- Only do batches (as transactions)
- """
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
raise NotImplementedError
class PersistIngestRequestWorker(PersistIngestFileResultWorker):
- def __init__(self, db_url, **kwargs):
+ def __init__(self, db_url: str, **kwargs):
super().__init__(db_url=db_url)
- def process(self, record, key=None):
- """
- Only do batches (as transactions)
- """
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
raise NotImplementedError
- def push_batch(self, batch):
+ def push_batch(self, batch: list) -> list:
self.counts['total'] += len(batch)
if not batch:
return []
- requests = [self.request_to_row(raw) for raw in batch]
- requests = [r for r in requests if r]
+ irequests_all = [self.request_to_row(raw) for raw in batch]
+ irequests: List[Dict] = [r for r in irequests_all if r]
- if requests:
- resp = self.db.insert_ingest_request(self.cur, requests)
+ if irequests:
+ resp = self.db.insert_ingest_request(self.cur, irequests)
self.counts['insert-requests'] += resp[0]
self.counts['update-requests'] += resp[1]
@@ -329,7 +325,7 @@ class PersistIngestRequestWorker(PersistIngestFileResultWorker):
class PersistGrobidWorker(SandcrawlerWorker):
- def __init__(self, db_url, **kwargs):
+ def __init__(self, db_url: str, **kwargs):
super().__init__()
self.grobid = GrobidClient()
self.s3 = SandcrawlerMinioClient(
@@ -342,19 +338,17 @@ class PersistGrobidWorker(SandcrawlerWorker):
self.db_only = kwargs.get('db_only', False)
assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed"
if not self.s3_only:
- self.db = SandcrawlerPostgresClient(db_url)
+ self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url)
self.cur = self.db.conn.cursor()
else:
self.db = None
self.cur = None
- def process(self, record, key=None):
- """
- Only do batches (as transactions)
- """
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
raise NotImplementedError
- def push_batch(self, batch):
+ def push_batch(self, batch: list) -> list:
self.counts['total'] += len(batch)
# filter out bad "missing status_code" timeout rows
@@ -372,7 +366,7 @@ class PersistGrobidWorker(SandcrawlerWorker):
assert len(r['key']) == 40
if not self.db_only:
- resp = self.s3.put_blob(
+ self.s3.put_blob(
folder="grobid",
blob=r['tei_xml'],
sha1hex=r['key'],
@@ -398,6 +392,7 @@ class PersistGrobidWorker(SandcrawlerWorker):
r['metadata'] = metadata
if not self.s3_only:
+ assert self.db and self.cur
resp = self.db.insert_grobid(self.cur, batch, on_conflict="update")
self.counts['insert-grobid'] += resp[0]
self.counts['update-grobid'] += resp[1]
@@ -418,11 +413,11 @@ class PersistGrobidDiskWorker(SandcrawlerWorker):
This could be refactored into a "Sink" type with an even thinner wrapper.
"""
- def __init__(self, output_dir):
+ def __init__(self, output_dir: str):
super().__init__()
self.output_dir = output_dir
- def _blob_path(self, sha1hex, extension=".tei.xml"):
+ def _blob_path(self, sha1hex: str, extension: str = ".tei.xml") -> str:
obj_path = "{}/{}/{}{}".format(
sha1hex[0:2],
sha1hex[2:4],
@@ -431,7 +426,7 @@ class PersistGrobidDiskWorker(SandcrawlerWorker):
)
return obj_path
- def process(self, record, key=None):
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
if record.get('status_code') != 200 or not record.get('tei_xml'):
return False
@@ -445,18 +440,16 @@ class PersistGrobidDiskWorker(SandcrawlerWorker):
class PersistPdfTrioWorker(SandcrawlerWorker):
- def __init__(self, db_url, **kwargs):
+ def __init__(self, db_url: str, **kwargs):
super().__init__()
self.db = SandcrawlerPostgresClient(db_url)
self.cur = self.db.conn.cursor()
- def process(self, record, key=None):
- """
- Only do batches (as transactions)
- """
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
raise NotImplementedError
- def push_batch(self, batch):
+ def push_batch(self, batch: list) -> list:
self.counts['total'] += len(batch)
batch = [r for r in batch if 'pdf_trio' in r and r['pdf_trio'].get('status_code')]
@@ -486,7 +479,7 @@ class PersistPdfTextWorker(SandcrawlerWorker):
Should keep batch sizes small.
"""
- def __init__(self, db_url, **kwargs):
+ def __init__(self, db_url: str, **kwargs):
super().__init__()
self.s3 = SandcrawlerMinioClient(
host_url=kwargs.get('s3_url', 'localhost:9000'),
@@ -498,19 +491,17 @@ class PersistPdfTextWorker(SandcrawlerWorker):
self.db_only = kwargs.get('db_only', False)
assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed"
if not self.s3_only:
- self.db = SandcrawlerPostgresClient(db_url)
+ self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url)
self.cur = self.db.conn.cursor()
else:
self.db = None
self.cur = None
- def process(self, record, key=None):
- """
- Only do batches (as transactions)
- """
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
raise NotImplementedError
- def push_batch(self, batch):
+ def push_batch(self, batch: list) -> list:
self.counts['total'] += len(batch)
parsed_batch = []
@@ -526,7 +517,7 @@ class PersistPdfTextWorker(SandcrawlerWorker):
assert len(r.sha1hex) == 40
if not self.db_only:
- resp = self.s3.put_blob(
+ self.s3.put_blob(
folder="text",
blob=r.text,
sha1hex=r.sha1hex,
@@ -535,6 +526,7 @@ class PersistPdfTextWorker(SandcrawlerWorker):
self.counts['s3-put'] += 1
if not self.s3_only:
+ assert self.db and self.cur
rows = [r.to_sql_tuple() for r in parsed_batch]
resp = self.db.insert_pdf_meta(self.cur, rows, on_conflict="update")
self.counts['insert-pdf-meta'] += resp[0]
@@ -569,15 +561,16 @@ class PersistThumbnailWorker(SandcrawlerWorker):
self.s3_extension = kwargs.get('s3_extension', ".jpg")
self.s3_folder = kwargs.get('s3_folder', "pdf")
- def process(self, blob: bytes, key: Optional[str] = None):
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
"""
Processing raw messages, not decoded JSON objects
"""
+ assert isinstance(record, bytes)
+ blob: bytes = record
if isinstance(key, bytes):
key = key.decode('utf-8')
assert key is not None and len(key) == 40 and isinstance(key, str)
- assert isinstance(blob, bytes)
assert len(blob) >= 50
self.s3.put_blob(
@@ -607,7 +600,7 @@ class GenericPersistDocWorker(SandcrawlerWorker):
self.s3_folder = kwargs.get('s3_folder', "unknown")
self.doc_key = "unknown"
- def process(self, record: dict, key: Optional[AnyStr] = None) -> None:
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
if record.get('status') != 'success' or not record.get(self.doc_key):
return