From 1088205653eb22f285dd34f2e058216ecfd2abed Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Tue, 26 Oct 2021 17:59:10 -0700 Subject: type annotations for persist workers; required some work Had to re-structure and filter things a bit, Should be better behavior, but might be some small changes. --- python/sandcrawler/persist.py | 125 ++++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 66 deletions(-) (limited to 'python/sandcrawler/persist.py') diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py index bb76e54..d47a8cb 100644 --- a/python/sandcrawler/persist.py +++ b/python/sandcrawler/persist.py @@ -20,7 +20,7 @@ grobid import os import xml.etree.ElementTree -from typing import AnyStr, Optional +from typing import Any, Dict, List, Optional from sandcrawler.db import SandcrawlerPostgresClient from sandcrawler.grobid import GrobidClient @@ -31,18 +31,16 @@ from sandcrawler.workers import SandcrawlerWorker class PersistCdxWorker(SandcrawlerWorker): - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.db = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): + def push_batch(self, batch: list) -> list: self.counts['total'] += len(batch) # filter to full CDX lines, no liveweb cdx_batch = [r for r in batch if r.get('warc_path') and ("/" in r['warc_path'])] @@ -56,18 +54,16 @@ class PersistCdxWorker(SandcrawlerWorker): class PersistIngestFileResultWorker(SandcrawlerWorker): - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.db = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def request_to_row(self, raw): + def request_to_row(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Converts ingest-request JSON schema (eg, from Kafka) to SQL ingest_request schema @@ -222,20 +218,24 @@ class PersistIngestFileResultWorker(SandcrawlerWorker): {}).get('terminal_dt') return result - def push_batch(self, batch): + def push_batch(self, batch: List[Any]) -> List[Any]: self.counts['total'] += len(batch) if not batch: return [] - results = [self.file_result_to_row(raw) for raw in batch] - results = [r for r in results if r] + results_unfiltered = [self.file_result_to_row(raw) for raw in batch] + results = [r for r in results_unfiltered if r] - requests = [self.request_to_row(raw['request']) for raw in batch if raw.get('request')] - requests = [r for r in requests if r and r['ingest_type'] != 'dataset-file'] + irequests_unfiltered = [ + self.request_to_row(raw['request']) for raw in batch if raw.get('request') + ] + irequests = [ + r for r in irequests_unfiltered if r and r['ingest_type'] != 'dataset-file' + ] - if requests: - resp = self.db.insert_ingest_request(self.cur, requests) + if irequests: + resp = self.db.insert_ingest_request(self.cur, irequests) self.counts['insert-requests'] += resp[0] self.counts['update-requests'] += resp[1] if results: @@ -266,16 +266,16 @@ class PersistIngestFileResultWorker(SandcrawlerWorker): self.result_to_html_meta(r) for r in batch if r.get('hit') and r.get('html_body') ] if html_meta_batch: - rows = [d.to_sql_tuple() for d in html_meta_batch] + rows = [d.to_sql_tuple() for d in html_meta_batch if d] resp = self.db.insert_html_meta(self.cur, rows, on_conflict="update") self.counts['insert-html_meta'] += resp[0] self.counts['update-html_meta'] += resp[1] - fileset_platform_batch = [ + fileset_platform_batch_all = [ self.result_to_platform_row(raw) for raw in batch if raw.get('request', {}).get('ingest_type') == 'dataset' and raw.get('platform_name') ] - fileset_platform_batch = [p for p in fileset_platform_batch if p] + fileset_platform_batch: List[Dict] = [p for p in fileset_platform_batch_all if p] if fileset_platform_batch: resp = self.db.insert_ingest_fileset_platform(self.cur, fileset_platform_batch, @@ -288,39 +288,35 @@ class PersistIngestFileResultWorker(SandcrawlerWorker): class PersistIngestFilesetWorker(SandcrawlerWorker): - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.db = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError class PersistIngestRequestWorker(PersistIngestFileResultWorker): - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__(db_url=db_url) - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): + def push_batch(self, batch: list) -> list: self.counts['total'] += len(batch) if not batch: return [] - requests = [self.request_to_row(raw) for raw in batch] - requests = [r for r in requests if r] + irequests_all = [self.request_to_row(raw) for raw in batch] + irequests: List[Dict] = [r for r in irequests_all if r] - if requests: - resp = self.db.insert_ingest_request(self.cur, requests) + if irequests: + resp = self.db.insert_ingest_request(self.cur, irequests) self.counts['insert-requests'] += resp[0] self.counts['update-requests'] += resp[1] @@ -329,7 +325,7 @@ class PersistIngestRequestWorker(PersistIngestFileResultWorker): class PersistGrobidWorker(SandcrawlerWorker): - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.grobid = GrobidClient() self.s3 = SandcrawlerMinioClient( @@ -342,19 +338,17 @@ class PersistGrobidWorker(SandcrawlerWorker): self.db_only = kwargs.get('db_only', False) assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed" if not self.s3_only: - self.db = SandcrawlerPostgresClient(db_url) + self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() else: self.db = None self.cur = None - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): + def push_batch(self, batch: list) -> list: self.counts['total'] += len(batch) # filter out bad "missing status_code" timeout rows @@ -372,7 +366,7 @@ class PersistGrobidWorker(SandcrawlerWorker): assert len(r['key']) == 40 if not self.db_only: - resp = self.s3.put_blob( + self.s3.put_blob( folder="grobid", blob=r['tei_xml'], sha1hex=r['key'], @@ -398,6 +392,7 @@ class PersistGrobidWorker(SandcrawlerWorker): r['metadata'] = metadata if not self.s3_only: + assert self.db and self.cur resp = self.db.insert_grobid(self.cur, batch, on_conflict="update") self.counts['insert-grobid'] += resp[0] self.counts['update-grobid'] += resp[1] @@ -418,11 +413,11 @@ class PersistGrobidDiskWorker(SandcrawlerWorker): This could be refactored into a "Sink" type with an even thinner wrapper. """ - def __init__(self, output_dir): + def __init__(self, output_dir: str): super().__init__() self.output_dir = output_dir - def _blob_path(self, sha1hex, extension=".tei.xml"): + def _blob_path(self, sha1hex: str, extension: str = ".tei.xml") -> str: obj_path = "{}/{}/{}{}".format( sha1hex[0:2], sha1hex[2:4], @@ -431,7 +426,7 @@ class PersistGrobidDiskWorker(SandcrawlerWorker): ) return obj_path - def process(self, record, key=None): + def process(self, record: Any, key: Optional[str] = None) -> Any: if record.get('status_code') != 200 or not record.get('tei_xml'): return False @@ -445,18 +440,16 @@ class PersistGrobidDiskWorker(SandcrawlerWorker): class PersistPdfTrioWorker(SandcrawlerWorker): - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.db = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): + def push_batch(self, batch: list) -> list: self.counts['total'] += len(batch) batch = [r for r in batch if 'pdf_trio' in r and r['pdf_trio'].get('status_code')] @@ -486,7 +479,7 @@ class PersistPdfTextWorker(SandcrawlerWorker): Should keep batch sizes small. """ - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.s3 = SandcrawlerMinioClient( host_url=kwargs.get('s3_url', 'localhost:9000'), @@ -498,19 +491,17 @@ class PersistPdfTextWorker(SandcrawlerWorker): self.db_only = kwargs.get('db_only', False) assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed" if not self.s3_only: - self.db = SandcrawlerPostgresClient(db_url) + self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() else: self.db = None self.cur = None - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): + def push_batch(self, batch: list) -> list: self.counts['total'] += len(batch) parsed_batch = [] @@ -526,7 +517,7 @@ class PersistPdfTextWorker(SandcrawlerWorker): assert len(r.sha1hex) == 40 if not self.db_only: - resp = self.s3.put_blob( + self.s3.put_blob( folder="text", blob=r.text, sha1hex=r.sha1hex, @@ -535,6 +526,7 @@ class PersistPdfTextWorker(SandcrawlerWorker): self.counts['s3-put'] += 1 if not self.s3_only: + assert self.db and self.cur rows = [r.to_sql_tuple() for r in parsed_batch] resp = self.db.insert_pdf_meta(self.cur, rows, on_conflict="update") self.counts['insert-pdf-meta'] += resp[0] @@ -569,15 +561,16 @@ class PersistThumbnailWorker(SandcrawlerWorker): self.s3_extension = kwargs.get('s3_extension', ".jpg") self.s3_folder = kwargs.get('s3_folder', "pdf") - def process(self, blob: bytes, key: Optional[str] = None): + def process(self, record: Any, key: Optional[str] = None) -> Any: """ Processing raw messages, not decoded JSON objects """ + assert isinstance(record, bytes) + blob: bytes = record if isinstance(key, bytes): key = key.decode('utf-8') assert key is not None and len(key) == 40 and isinstance(key, str) - assert isinstance(blob, bytes) assert len(blob) >= 50 self.s3.put_blob( @@ -607,7 +600,7 @@ class GenericPersistDocWorker(SandcrawlerWorker): self.s3_folder = kwargs.get('s3_folder', "unknown") self.doc_key = "unknown" - def process(self, record: dict, key: Optional[AnyStr] = None) -> None: + def process(self, record: Any, key: Optional[str] = None) -> Any: if record.get('status') != 'success' or not record.get(self.doc_key): return -- cgit v1.2.3