diff options
Diffstat (limited to 'python/sandcrawler/persist.py')
-rw-r--r-- | python/sandcrawler/persist.py | 701 |
1 files changed, 453 insertions, 248 deletions
diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py index 0fd54a4..f682572 100644 --- a/python/sandcrawler/persist.py +++ b/python/sandcrawler/persist.py @@ -1,4 +1,3 @@ - """ cdx - read raw CDX, filter @@ -20,106 +19,112 @@ grobid """ import os -from typing import Optional, AnyStr +import time import xml.etree.ElementTree +from typing import Any, Dict, List, Optional + +import psycopg2 +import requests -from sandcrawler.workers import SandcrawlerWorker from sandcrawler.db import SandcrawlerPostgresClient -from sandcrawler.minio import SandcrawlerMinioClient from sandcrawler.grobid import GrobidClient +from sandcrawler.ingest_html import HtmlMetaRow +from sandcrawler.minio import SandcrawlerMinioClient from sandcrawler.pdfextract import PdfExtractResult -from sandcrawler.html_ingest import HtmlMetaRow +from sandcrawler.workers import SandcrawlerWorker class PersistCdxWorker(SandcrawlerWorker): - - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.db = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): - self.counts['total'] += len(batch) + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) # filter to full CDX lines, no liveweb - cdx_batch = [r for r in batch if r.get('warc_path') and ("/" in r['warc_path'])] + cdx_batch = [r for r in batch if r.get("warc_path") and ("/" in r["warc_path"])] resp = self.db.insert_cdx(self.cur, cdx_batch) if len(cdx_batch) < len(batch): - self.counts['skip'] += len(batch) - len(cdx_batch) - self.counts['insert-cdx'] += resp[0] - self.counts['update-cdx'] += resp[1] + self.counts["skip"] += len(batch) - len(cdx_batch) + self.counts["insert-cdx"] += resp[0] + self.counts["update-cdx"] += resp[1] self.db.commit() return [] -class PersistIngestFileResultWorker(SandcrawlerWorker): - def __init__(self, db_url, **kwargs): +class PersistIngestFileResultWorker(SandcrawlerWorker): + def __init__(self, db_url: str, **kwargs): super().__init__() self.db = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def request_to_row(self, raw): + def request_to_row(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Converts ingest-request JSON schema (eg, from Kafka) to SQL ingest_request schema if there is a problem with conversion, return None """ # backwards compat hacks; transform request to look like current schema - if raw.get('ingest_type') == 'file': - raw['ingest_type'] = 'pdf' - if (not raw.get('link_source') - and raw.get('base_url') - and raw.get('ext_ids', {}).get('doi') - and raw['base_url'] == "https://doi.org/{}".format(raw['ext_ids']['doi'])): + if raw.get("ingest_type") == "file": + raw["ingest_type"] = "pdf" + if ( + not raw.get("link_source") + and raw.get("base_url") + and raw.get("ext_ids", {}).get("doi") + and raw["base_url"] == "https://doi.org/{}".format(raw["ext_ids"]["doi"]) + ): # set link_source(_id) for old ingest requests - raw['link_source'] = 'doi' - raw['link_source_id'] = raw['ext_ids']['doi'] - if (not raw.get('link_source') - and raw.get('ingest_request_source', '').startswith('savepapernow') - and raw.get('fatcat', {}).get('release_ident')): + raw["link_source"] = "doi" + raw["link_source_id"] = raw["ext_ids"]["doi"] + if ( + not raw.get("link_source") + and raw.get("ingest_request_source", "").startswith("savepapernow") + and raw.get("fatcat", {}).get("release_ident") + ): # set link_source(_id) for old ingest requests - raw['link_source'] = 'spn' - raw['link_source_id'] = raw['fatcat']['release_ident'] + raw["link_source"] = "spn" + raw["link_source_id"] = raw["fatcat"]["release_ident"] - for k in ('ingest_type', 'base_url', 'link_source', 'link_source_id'): - if not k in raw: - self.counts['skip-request-fields'] += 1 + for k in ("ingest_type", "base_url", "link_source", "link_source_id"): + if k not in raw: + self.counts["skip-request-fields"] += 1 return None - if raw['ingest_type'] not in ('pdf', 'xml', 'html'): - self.counts['skip-ingest-type'] += 1 + if raw["ingest_type"] not in ("pdf", "xml", "html"): + self.counts["skip-ingest-type"] += 1 + return None + # limit on base_url length + if len(raw["base_url"]) > 1500: + self.counts["skip-url-too-long"] += 1 return None request = { - 'ingest_type': raw['ingest_type'], - 'base_url': raw['base_url'], - 'link_source': raw['link_source'], - 'link_source_id': raw['link_source_id'], - 'ingest_request_source': raw.get('ingest_request_source'), - 'request': {}, + "ingest_type": raw["ingest_type"], + "base_url": raw["base_url"], + "link_source": raw["link_source"], + "link_source_id": raw["link_source_id"], + "ingest_request_source": raw.get("ingest_request_source"), + "request": {}, } # extra/optional fields - if raw.get('release_stage'): - request['release_stage'] = raw['release_stage'] - if raw.get('fatcat', {}).get('release_ident'): - request['request']['release_ident'] = raw['fatcat']['release_ident'] - for k in ('ext_ids', 'edit_extra', 'rel'): + if raw.get("release_stage"): + request["release_stage"] = raw["release_stage"] + if raw.get("fatcat", {}).get("release_ident"): + request["request"]["release_ident"] = raw["fatcat"]["release_ident"] + for k in ("ext_ids", "edit_extra", "rel"): if raw.get(k): - request['request'][k] = raw[k] + request["request"][k] = raw[k] # if this dict is empty, trim it to save DB space - if not request['request']: - request['request'] = None + if not request["request"]: + request["request"] = None return request - def file_result_to_row(self, raw: dict) -> Optional[dict]: """ @@ -127,208 +132,302 @@ class PersistIngestFileResultWorker(SandcrawlerWorker): if there is a problem with conversion, return None and set skip count """ - for k in ('request', 'hit', 'status'): - if not k in raw: - self.counts['skip-result-fields'] += 1 + for k in ("request", "hit", "status"): + if k not in raw: + self.counts["skip-result-fields"] += 1 return None - if not 'base_url' in raw['request']: - self.counts['skip-result-fields'] += 1 + if "base_url" not in raw["request"]: + self.counts["skip-result-fields"] += 1 return None - ingest_type = raw['request'].get('ingest_type') - if ingest_type == 'file': - ingest_type = 'pdf' - if ingest_type not in ('pdf', 'xml', 'html'): - self.counts['skip-ingest-type'] += 1 + ingest_type = raw["request"].get("ingest_type") + if ingest_type == "file": + ingest_type = "pdf" + if ingest_type not in ( + "pdf", + "xml", + "html", + "component", + "src", + "dataset", + "dataset-file", + ): + self.counts["skip-ingest-type"] += 1 return None - if raw['status'] in ("existing", ): - self.counts['skip-existing'] += 1 + if raw["status"] in ("existing",): + self.counts["skip-existing"] += 1 return None result = { - 'ingest_type': ingest_type, - 'base_url': raw['request']['base_url'], - 'hit': raw['hit'], - 'status': raw['status'], + "ingest_type": ingest_type, + "base_url": raw["request"]["base_url"], + "hit": raw["hit"], + "status": raw["status"], } - terminal = raw.get('terminal') + terminal = raw.get("terminal") if terminal: - result['terminal_url'] = terminal.get('terminal_url') or terminal.get('url') - result['terminal_dt'] = terminal.get('terminal_dt') - result['terminal_status_code'] = terminal.get('terminal_status_code') or terminal.get('status_code') or terminal.get('http_code') - if result['terminal_status_code']: - result['terminal_status_code'] = int(result['terminal_status_code']) - result['terminal_sha1hex'] = terminal.get('terminal_sha1hex') + result["terminal_url"] = terminal.get("terminal_url") or terminal.get("url") + result["terminal_dt"] = terminal.get("terminal_dt") + result["terminal_status_code"] = ( + terminal.get("terminal_status_code") + or terminal.get("status_code") + or terminal.get("http_code") + ) + if result["terminal_status_code"]: + result["terminal_status_code"] = int(result["terminal_status_code"]) + result["terminal_sha1hex"] = terminal.get("terminal_sha1hex") + if len(result["terminal_url"]) > 2048: + # postgresql13 doesn't like extremely large URLs in b-tree index + self.counts["skip-huge-url"] += 1 + return None return result def result_to_html_meta(self, record: dict) -> Optional[HtmlMetaRow]: - html_body = record.get('html_body') - file_meta = record.get('file_meta') + html_body = record.get("html_body") + file_meta = record.get("file_meta") if not (file_meta and html_body): return None return HtmlMetaRow( sha1hex=file_meta["sha1hex"], - status=record.get('status'), - scope=record.get('scope'), - has_teixml=bool(html_body and html_body['status'] == 'success'), + status=record.get("status"), + scope=record.get("scope"), + has_teixml=bool(html_body and html_body["status"] == "success"), has_thumbnail=False, # TODO - word_count=(html_body and html_body.get('word_count')) or None, - biblio=record.get('html_biblio'), - resources=record.get('html_resources'), + word_count=(html_body and html_body.get("word_count")) or None, + biblio=record.get("html_biblio"), + resources=record.get("html_resources"), ) - def push_batch(self, batch): - self.counts['total'] += len(batch) + def result_to_platform_row(self, raw: dict) -> Optional[dict]: + """ + Converts fileset ingest-result JSON schema (eg, from Kafka) to SQL ingest_fileset_platform schema + + if there is a problem with conversion, return None and set skip count + """ + for k in ("request", "hit", "status"): + if k not in raw: + return None + if "base_url" not in raw["request"]: + return None + ingest_type = raw["request"].get("ingest_type") + if ingest_type not in ("dataset"): + return None + if raw["status"] in ("existing",): + return None + if not raw.get("platform_name"): + return None + result = { + "ingest_type": ingest_type, + "base_url": raw["request"]["base_url"], + "hit": raw["hit"], + "status": raw["status"], + "platform_name": raw.get("platform_name"), + "platform_domain": raw.get("platform_domain"), + "platform_id": raw.get("platform_id"), + "ingest_strategy": raw.get("ingest_strategy"), + "total_size": raw.get("total_size"), + "file_count": raw.get("file_count"), + "archiveorg_item_name": raw.get("archiveorg_item_name"), + "archiveorg_item_bundle_path": None, + "web_bundle_url": None, + "web_bundle_dt": None, + "manifest": raw.get("manifest"), + } + if result.get("fileset_bundle"): + result["archiveorg_item_bundle_path"] = result["fileset_bundle"].get( + "archiveorg_item_bundle_path" + ) + result["web_bundle_url"] = ( + result["fileset_bundle"].get("terminal", {}).get("terminal_url") + ) + result["web_bundle_dt"] = ( + result["fileset_bundle"].get("terminal", {}).get("terminal_dt") + ) + return result + + def push_batch(self, batch: List[Any]) -> List[Any]: + self.counts["total"] += len(batch) if not batch: return [] - results = [self.file_result_to_row(raw) for raw in batch] - results = [r for r in results if r] + results_unfiltered = [self.file_result_to_row(raw) for raw in batch] + results = [r for r in results_unfiltered if r] - requests = [self.request_to_row(raw['request']) for raw in batch if raw.get('request')] - requests = [r for r in requests if r] + irequests_unfiltered = [ + self.request_to_row(raw["request"]) for raw in batch if raw.get("request") + ] + irequests = [ + r for r in irequests_unfiltered if r and r["ingest_type"] != "dataset-file" + ] - if requests: - resp = self.db.insert_ingest_request(self.cur, requests) - self.counts['insert-requests'] += resp[0] - self.counts['update-requests'] += resp[1] + if irequests: + resp = self.db.insert_ingest_request(self.cur, irequests) + self.counts["insert-requests"] += resp[0] + self.counts["update-requests"] += resp[1] if results: resp = self.db.insert_ingest_file_result(self.cur, results, on_conflict="update") - self.counts['insert-results'] += resp[0] - self.counts['update-results'] += resp[1] + self.counts["insert-results"] += resp[0] + self.counts["update-results"] += resp[1] # these schemas match, so can just pass through - cdx_batch = [r['cdx'] for r in batch if r.get('hit') and r.get('cdx')] - revisit_cdx_batch = [r['revisit_cdx'] for r in batch if r.get('hit') and r.get('revisit_cdx')] + cdx_batch = [r["cdx"] for r in batch if r.get("hit") and r.get("cdx")] + revisit_cdx_batch = [ + r["revisit_cdx"] for r in batch if r.get("hit") and r.get("revisit_cdx") + ] cdx_batch.extend(revisit_cdx_batch) # filter to full CDX lines, with full warc_paths (not liveweb) - cdx_batch = [r for r in cdx_batch if r.get('warc_path') and ("/" in r['warc_path'])] + cdx_batch = [r for r in cdx_batch if r.get("warc_path") and ("/" in r["warc_path"])] if cdx_batch: resp = self.db.insert_cdx(self.cur, cdx_batch) - self.counts['insert-cdx'] += resp[0] - self.counts['update-cdx'] += resp[1] + self.counts["insert-cdx"] += resp[0] + self.counts["update-cdx"] += resp[1] - file_meta_batch = [r['file_meta'] for r in batch if r.get('hit') and r.get('file_meta')] + file_meta_batch = [r["file_meta"] for r in batch if r.get("hit") and r.get("file_meta")] if file_meta_batch: resp = self.db.insert_file_meta(self.cur, file_meta_batch, on_conflict="nothing") - self.counts['insert-file_meta'] += resp[0] - self.counts['update-file_meta'] += resp[1] + self.counts["insert-file_meta"] += resp[0] + self.counts["update-file_meta"] += resp[1] - html_meta_batch = [self.result_to_html_meta(r) for r in batch if r.get('hit') and r.get('html_body')] + html_meta_batch = [ + self.result_to_html_meta(r) for r in batch if r.get("hit") and r.get("html_body") + ] if html_meta_batch: - resp = self.db.insert_html_meta(self.cur, html_meta_batch, on_conflict="update") - self.counts['insert-html_meta'] += resp[0] - self.counts['update-html_meta'] += resp[1] + rows = [d.to_sql_tuple() for d in html_meta_batch if d] + resp = self.db.insert_html_meta(self.cur, rows, on_conflict="update") + self.counts["insert-html_meta"] += resp[0] + self.counts["update-html_meta"] += resp[1] + + fileset_platform_batch_all = [ + self.result_to_platform_row(raw) + for raw in batch + if raw.get("request", {}).get("ingest_type") == "dataset" + and raw.get("platform_name") + ] + fileset_platform_batch: List[Dict] = [p for p in fileset_platform_batch_all if p] + if fileset_platform_batch: + resp = self.db.insert_ingest_fileset_platform( + self.cur, fileset_platform_batch, on_conflict="update" + ) + self.counts["insert-fileset_platform"] += resp[0] + self.counts["update-fileset_platform"] += resp[1] self.db.commit() return [] -class PersistIngestRequestWorker(PersistIngestFileResultWorker): - def __init__(self, db_url, **kwargs): +class PersistIngestFilesetWorker(SandcrawlerWorker): + def __init__(self, db_url: str, **kwargs): + super().__init__() + self.db = SandcrawlerPostgresClient(db_url) + self.cur = self.db.conn.cursor() + + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" + raise NotImplementedError + + +class PersistIngestRequestWorker(PersistIngestFileResultWorker): + def __init__(self, db_url: str, **kwargs): super().__init__(db_url=db_url) - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): - self.counts['total'] += len(batch) + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) if not batch: return [] - requests = [self.request_to_row(raw) for raw in batch] - requests = [r for r in requests if r] + irequests_all = [self.request_to_row(raw) for raw in batch] + irequests: List[Dict] = [r for r in irequests_all if r] - if requests: - resp = self.db.insert_ingest_request(self.cur, requests) - self.counts['insert-requests'] += resp[0] - self.counts['update-requests'] += resp[1] + if irequests: + resp = self.db.insert_ingest_request(self.cur, irequests) + self.counts["insert-requests"] += resp[0] + self.counts["update-requests"] += resp[1] self.db.commit() return [] -class PersistGrobidWorker(SandcrawlerWorker): - def __init__(self, db_url, **kwargs): +class PersistGrobidWorker(SandcrawlerWorker): + def __init__(self, db_url: str, **kwargs): super().__init__() self.grobid = GrobidClient() self.s3 = SandcrawlerMinioClient( - host_url=kwargs.get('s3_url', 'localhost:9000'), - access_key=kwargs['s3_access_key'], - secret_key=kwargs['s3_secret_key'], - default_bucket=kwargs['s3_bucket'], + host_url=kwargs.get("s3_url", "localhost:9000"), + access_key=kwargs["s3_access_key"], + secret_key=kwargs["s3_secret_key"], + default_bucket=kwargs["s3_bucket"], ) - self.s3_only = kwargs.get('s3_only', False) - self.db_only = kwargs.get('db_only', False) + self.s3_only = kwargs.get("s3_only", False) + self.db_only = kwargs.get("db_only", False) assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed" if not self.s3_only: - self.db = SandcrawlerPostgresClient(db_url) - self.cur = self.db.conn.cursor() + self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url) + self.cur: Optional[psycopg2.extensions.cursor] = self.db.conn.cursor() else: self.db = None self.cur = None - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): - self.counts['total'] += len(batch) + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) # filter out bad "missing status_code" timeout rows - missing = [r for r in batch if not r.get('status_code')] + missing = [r for r in batch if not r.get("status_code")] if missing: - self.counts['skip-missing-status'] += len(missing) - batch = [r for r in batch if r.get('status_code')] + self.counts["skip-missing-status"] += len(missing) + batch = [r for r in batch if r.get("status_code")] for r in batch: - if r['status_code'] != 200 or not r.get('tei_xml'): - self.counts['s3-skip-status'] += 1 - if r.get('error_msg'): - r['metadata'] = {'error_msg': r['error_msg'][:500]} + if r["status_code"] != 200 or not r.get("tei_xml"): + self.counts["s3-skip-status"] += 1 + if r.get("error_msg"): + r["metadata"] = {"error_msg": r["error_msg"][:500]} continue - assert len(r['key']) == 40 + assert len(r["key"]) == 40 if not self.db_only: - resp = self.s3.put_blob( + self.s3.put_blob( folder="grobid", - blob=r['tei_xml'], - sha1hex=r['key'], + blob=r["tei_xml"], + sha1hex=r["key"], extension=".tei.xml", ) - self.counts['s3-put'] += 1 + self.counts["s3-put"] += 1 - # enhance with teixml2json metadata, if available + # enhance with GROBID TEI-XML metadata, if available try: metadata = self.grobid.metadata(r) except xml.etree.ElementTree.ParseError as xml_e: - r['status'] = 'bad-grobid-xml' - r['metadata'] = {'error_msg': str(xml_e)[:1024]} + r["status"] = "bad-grobid-xml" + r["metadata"] = {"error_msg": str(xml_e)[:1024]} continue if not metadata: continue - for k in ('fatcat_release', 'grobid_version'): + for k in ("fatcat_release", "grobid_version"): r[k] = metadata.pop(k, None) - if r.get('fatcat_release'): - r['fatcat_release'] = r['fatcat_release'].replace('release_', '') - if metadata.get('grobid_timestamp'): - r['updated'] = metadata['grobid_timestamp'] - r['metadata'] = metadata + if r.get("fatcat_release"): + r["fatcat_release"] = r["fatcat_release"].replace("release_", "") + if metadata.get("grobid_timestamp"): + r["updated"] = metadata["grobid_timestamp"] + r["metadata"] = metadata if not self.s3_only: + assert self.db and self.cur resp = self.db.insert_grobid(self.cur, batch, on_conflict="update") - self.counts['insert-grobid'] += resp[0] - self.counts['update-grobid'] += resp[1] + self.counts["insert-grobid"] += resp[0] + self.counts["update-grobid"] += resp[1] - file_meta_batch = [r['file_meta'] for r in batch if r.get('file_meta')] + file_meta_batch = [r["file_meta"] for r in batch if r.get("file_meta")] resp = self.db.insert_file_meta(self.cur, file_meta_batch, on_conflict="update") - self.counts['insert-file-meta'] += resp[0] - self.counts['update-file-meta'] += resp[1] + self.counts["insert-file-meta"] += resp[0] + self.counts["update-file-meta"] += resp[1] self.db.commit() @@ -342,11 +441,11 @@ class PersistGrobidDiskWorker(SandcrawlerWorker): This could be refactored into a "Sink" type with an even thinner wrapper. """ - def __init__(self, output_dir): + def __init__(self, output_dir: str): super().__init__() self.output_dir = output_dir - def _blob_path(self, sha1hex, extension=".tei.xml"): + def _blob_path(self, sha1hex: str, extension: str = ".tei.xml") -> str: obj_path = "{}/{}/{}{}".format( sha1hex[0:2], sha1hex[2:4], @@ -355,48 +454,49 @@ class PersistGrobidDiskWorker(SandcrawlerWorker): ) return obj_path - def process(self, record, key=None): + def process(self, record: Any, key: Optional[str] = None) -> Any: - if record.get('status_code') != 200 or not record.get('tei_xml'): + if record.get("status_code") != 200 or not record.get("tei_xml"): return False - assert(len(record['key'])) == 40 - p = "{}/{}".format(self.output_dir, self._blob_path(record['key'])) + assert (len(record["key"])) == 40 + p = "{}/{}".format(self.output_dir, self._blob_path(record["key"])) os.makedirs(os.path.dirname(p), exist_ok=True) - with open(p, 'w') as f: - f.write(record.pop('tei_xml')) - self.counts['written'] += 1 + with open(p, "w") as f: + f.write(record.pop("tei_xml")) + self.counts["written"] += 1 return record class PersistPdfTrioWorker(SandcrawlerWorker): - - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.db = SandcrawlerPostgresClient(db_url) self.cur = self.db.conn.cursor() - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): - self.counts['total'] += len(batch) + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) - batch = [r for r in batch if 'pdf_trio' in r and r['pdf_trio'].get('status_code')] + batch = [r for r in batch if "pdf_trio" in r and r["pdf_trio"].get("status_code")] for r in batch: # copy key (sha1hex) into sub-object - r['pdf_trio']['key'] = r['key'] - pdftrio_batch = [r['pdf_trio'] for r in batch] + r["pdf_trio"]["key"] = r["key"] + pdftrio_batch = [r["pdf_trio"] for r in batch] resp = self.db.insert_pdftrio(self.cur, pdftrio_batch, on_conflict="update") - self.counts['insert-pdftrio'] += resp[0] - self.counts['update-pdftrio'] += resp[1] - - file_meta_batch = [r['file_meta'] for r in batch if r['pdf_trio']['status'] == "success" and r.get('file_meta')] + self.counts["insert-pdftrio"] += resp[0] + self.counts["update-pdftrio"] += resp[1] + + file_meta_batch = [ + r["file_meta"] + for r in batch + if r["pdf_trio"]["status"] == "success" and r.get("file_meta") + ] resp = self.db.insert_file_meta(self.cur, file_meta_batch) - self.counts['insert-file-meta'] += resp[0] - self.counts['update-file-meta'] += resp[1] + self.counts["insert-file-meta"] += resp[0] + self.counts["update-file-meta"] += resp[1] self.db.commit() return [] @@ -409,63 +509,63 @@ class PersistPdfTextWorker(SandcrawlerWorker): Should keep batch sizes small. """ - def __init__(self, db_url, **kwargs): + def __init__(self, db_url: str, **kwargs): super().__init__() self.s3 = SandcrawlerMinioClient( - host_url=kwargs.get('s3_url', 'localhost:9000'), - access_key=kwargs['s3_access_key'], - secret_key=kwargs['s3_secret_key'], - default_bucket=kwargs['s3_bucket'], + host_url=kwargs.get("s3_url", "localhost:9000"), + access_key=kwargs["s3_access_key"], + secret_key=kwargs["s3_secret_key"], + default_bucket=kwargs["s3_bucket"], ) - self.s3_only = kwargs.get('s3_only', False) - self.db_only = kwargs.get('db_only', False) + self.s3_only = kwargs.get("s3_only", False) + self.db_only = kwargs.get("db_only", False) assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed" if not self.s3_only: - self.db = SandcrawlerPostgresClient(db_url) - self.cur = self.db.conn.cursor() + self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url) + self.cur: Optional[psycopg2.extensions.cursor] = self.db.conn.cursor() else: self.db = None self.cur = None - def process(self, record, key=None): - """ - Only do batches (as transactions) - """ + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" raise NotImplementedError - def push_batch(self, batch): - self.counts['total'] += len(batch) + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) parsed_batch = [] for r in batch: parsed_batch.append(PdfExtractResult.from_pdftext_dict(r)) for r in parsed_batch: - if r.status != 'success' or not r.text: - self.counts['s3-skip-status'] += 1 + if r.status != "success" or not r.text: + self.counts["s3-skip-status"] += 1 if r.error_msg: - r.metadata = {'error_msg': r.error_msg[:500]} + r.metadata = {"error_msg": r.error_msg[:500]} continue assert len(r.sha1hex) == 40 if not self.db_only: - resp = self.s3.put_blob( + self.s3.put_blob( folder="text", blob=r.text, sha1hex=r.sha1hex, extension=".txt", ) - self.counts['s3-put'] += 1 + self.counts["s3-put"] += 1 if not self.s3_only: - resp = self.db.insert_pdf_meta(self.cur, parsed_batch, on_conflict="update") - self.counts['insert-pdf-meta'] += resp[0] - self.counts['update-pdf-meta'] += resp[1] + assert self.db and self.cur + rows = [r.to_sql_tuple() for r in parsed_batch] + resp = self.db.insert_pdf_meta(self.cur, rows, on_conflict="update") + self.counts["insert-pdf-meta"] += resp[0] + self.counts["update-pdf-meta"] += resp[1] file_meta_batch = [r.file_meta for r in parsed_batch if r.file_meta] resp = self.db.insert_file_meta(self.cur, file_meta_batch, on_conflict="update") - self.counts['insert-file-meta'] += resp[0] - self.counts['update-file-meta'] += resp[1] + self.counts["insert-file-meta"] += resp[0] + self.counts["update-file-meta"] += resp[1] self.db.commit() @@ -484,32 +584,33 @@ class PersistThumbnailWorker(SandcrawlerWorker): def __init__(self, **kwargs): super().__init__() self.s3 = SandcrawlerMinioClient( - host_url=kwargs.get('s3_url', 'localhost:9000'), - access_key=kwargs['s3_access_key'], - secret_key=kwargs['s3_secret_key'], - default_bucket=kwargs['s3_bucket'], + host_url=kwargs.get("s3_url", "localhost:9000"), + access_key=kwargs["s3_access_key"], + secret_key=kwargs["s3_secret_key"], + default_bucket=kwargs["s3_bucket"], ) - self.s3_extension = kwargs.get('s3_extension', ".jpg") - self.s3_folder = kwargs.get('s3_folder', "pdf") + self.s3_extension = kwargs.get("s3_extension", ".jpg") + self.s3_folder = kwargs.get("s3_folder", "pdf") - def process(self, blob: bytes, key: Optional[str] = None): + def process(self, record: Any, key: Optional[str] = None) -> Any: """ Processing raw messages, not decoded JSON objects """ + assert isinstance(record, bytes) + blob: bytes = record if isinstance(key, bytes): - key = key.decode('utf-8') + key = key.decode("utf-8") assert key is not None and len(key) == 40 and isinstance(key, str) - assert isinstance(blob, bytes) assert len(blob) >= 50 - resp = self.s3.put_blob( + self.s3.put_blob( folder=self.s3_folder, blob=blob, sha1hex=key, extension=self.s3_extension, ) - self.counts['s3-put'] += 1 + self.counts["s3-put"] += 1 class GenericPersistDocWorker(SandcrawlerWorker): @@ -522,36 +623,36 @@ class GenericPersistDocWorker(SandcrawlerWorker): def __init__(self, **kwargs): super().__init__() self.s3 = SandcrawlerMinioClient( - host_url=kwargs.get('s3_url', 'localhost:9000'), - access_key=kwargs['s3_access_key'], - secret_key=kwargs['s3_secret_key'], - default_bucket=kwargs['s3_bucket'], + host_url=kwargs.get("s3_url", "localhost:9000"), + access_key=kwargs["s3_access_key"], + secret_key=kwargs["s3_secret_key"], + default_bucket=kwargs["s3_bucket"], ) - self.s3_extension = kwargs.get('s3_extension', ".unknown") - self.s3_folder = kwargs.get('s3_folder', "unknown") + self.s3_extension = kwargs.get("s3_extension", ".unknown") + self.s3_folder = kwargs.get("s3_folder", "unknown") self.doc_key = "unknown" - def process(self, record: dict, key: Optional[AnyStr] = None) -> None: + def process(self, record: Any, key: Optional[str] = None) -> Any: - if record.get('status') != 'success' or not record.get(self.doc_key): + if record.get("status") != "success" or not record.get(self.doc_key): return assert key is not None if isinstance(key, bytes): - key_str = key.decode('utf-8') + key_str = key.decode("utf-8") elif isinstance(key, str): key_str = key assert len(key_str) == 40 - if 'sha1hex' in record: - assert key_str == record['sha1hex'] + if "sha1hex" in record: + assert key_str == record["sha1hex"] - resp = self.s3.put_blob( + self.s3.put_blob( folder=self.s3_folder, - blob=record[self.doc_key].encode('utf-8'), + blob=record[self.doc_key].encode("utf-8"), sha1hex=key_str, extension=self.s3_extension, ) - self.counts['s3-put'] += 1 + self.counts["s3-put"] += 1 class PersistXmlDocWorker(GenericPersistDocWorker): @@ -562,8 +663,8 @@ class PersistXmlDocWorker(GenericPersistDocWorker): def __init__(self, **kwargs): super().__init__(**kwargs) - self.s3_extension = kwargs.get('s3_extension', ".jats.xml") - self.s3_folder = kwargs.get('s3_folder', "xml_doc") + self.s3_extension = kwargs.get("s3_extension", ".jats.xml") + self.s3_folder = kwargs.get("s3_folder", "xml_doc") self.doc_key = "jats_xml" @@ -575,6 +676,110 @@ class PersistHtmlTeiXmlWorker(GenericPersistDocWorker): def __init__(self, **kwargs): super().__init__(**kwargs) - self.s3_extension = kwargs.get('s3_extension', ".tei.xml") - self.s3_folder = kwargs.get('s3_folder', "html_body") + self.s3_extension = kwargs.get("s3_extension", ".tei.xml") + self.s3_folder = kwargs.get("s3_folder", "html_body") self.doc_key = "tei_xml" + + +class PersistCrossrefWorker(SandcrawlerWorker): + """ + Pushes Crossref API JSON records into postgresql. Can also talk to GROBID, + parsed 'unstructured' references, and push the results in to postgresql at + the same time. + """ + + def __init__( + self, + db_url: str, + grobid_client: Optional[GrobidClient], + parse_refs: bool = True, + **kwargs + ): + super().__init__(**kwargs) + self.db = SandcrawlerPostgresClient(db_url) + self.cur = self.db.conn.cursor() + if grobid_client: + self.grobid_client = grobid_client + else: + self.grobid_client = GrobidClient() + self.parse_refs = parse_refs + + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" + raise NotImplementedError + + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) + + crossref_batch = [] + refs_batch = [] + for record in batch: + crossref_batch.append( + dict( + doi=record["DOI"].lower().strip(), + indexed=record["indexed"]["date-time"], + record=record, + ) + ) + if self.parse_refs: + try: + parsed_refs = self.grobid_client.crossref_refs(record) + refs_batch.append(parsed_refs) + except ( + xml.etree.ElementTree.ParseError, + requests.exceptions.HTTPError, + requests.exceptions.ReadTimeout, + ): + print("GROBID crossref refs parsing error, skipping with a sleep") + time.sleep(3) + pass + + resp = self.db.insert_crossref(self.cur, crossref_batch) + if len(crossref_batch) < len(batch): + self.counts["skip"] += len(batch) - len(crossref_batch) + self.counts["insert-crossref"] += resp[0] + self.counts["update-crossref"] += resp[1] + + if refs_batch: + resp = self.db.insert_grobid_refs(self.cur, refs_batch) + if len(refs_batch) < len(batch): + self.counts["skip"] += len(batch) - len(refs_batch) + self.counts["insert-grobid_refs"] += resp[0] + self.counts["update-grobid_refs"] += resp[1] + + self.db.commit() + return [] + + +class PersistGrobidRefsWorker(SandcrawlerWorker): + """ + Simple persist worker to backfill GROBID references in to postgresql + locally. Consumes the JSON output from GROBID CrossrefRefsWorker. + """ + + def __init__(self, db_url: str, **kwargs): + super().__init__(**kwargs) + self.db = SandcrawlerPostgresClient(db_url) + self.cur = self.db.conn.cursor() + + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" + raise NotImplementedError + + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) + + refs_batch = [] + for record in batch: + assert record["source"] + assert record["source_id"] + refs_batch.append(record) + + resp = self.db.insert_grobid_refs(self.cur, refs_batch) + if len(refs_batch) < len(batch): + self.counts["skip"] += len(batch) - len(refs_batch) + self.counts["insert-grobid_refs"] += resp[0] + self.counts["update-grobid_refs"] += resp[1] + + self.db.commit() + return [] |