aboutsummaryrefslogtreecommitdiffstats
path: root/python/sandcrawler/persist.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/sandcrawler/persist.py')
-rw-r--r--python/sandcrawler/persist.py785
1 files changed, 785 insertions, 0 deletions
diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py
new file mode 100644
index 0000000..f682572
--- /dev/null
+++ b/python/sandcrawler/persist.py
@@ -0,0 +1,785 @@
+"""
+cdx
+- read raw CDX, filter
+- push to SQL table
+
+ingest-file-result
+- read JSON format (batch)
+- cdx SQL push batch (on conflict skip)
+- file_meta SQL push batch (on conflict update)
+- ingest request push batch (on conflict skip)
+- ingest result push batch (on conflict update)
+
+grobid
+- reads JSON format (batch)
+- grobid2json
+- minio push (one-by-one)
+- grobid SQL push batch (on conflict update)
+- file_meta SQL push batch (on conflict update)
+"""
+
+import os
+import time
+import xml.etree.ElementTree
+from typing import Any, Dict, List, Optional
+
+import psycopg2
+import requests
+
+from sandcrawler.db import SandcrawlerPostgresClient
+from sandcrawler.grobid import GrobidClient
+from sandcrawler.ingest_html import HtmlMetaRow
+from sandcrawler.minio import SandcrawlerMinioClient
+from sandcrawler.pdfextract import PdfExtractResult
+from sandcrawler.workers import SandcrawlerWorker
+
+
+class PersistCdxWorker(SandcrawlerWorker):
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__()
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+ # filter to full CDX lines, no liveweb
+ cdx_batch = [r for r in batch if r.get("warc_path") and ("/" in r["warc_path"])]
+ resp = self.db.insert_cdx(self.cur, cdx_batch)
+ if len(cdx_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(cdx_batch)
+ self.counts["insert-cdx"] += resp[0]
+ self.counts["update-cdx"] += resp[1]
+ self.db.commit()
+ return []
+
+
+class PersistIngestFileResultWorker(SandcrawlerWorker):
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__()
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def request_to_row(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]:
+ """
+ Converts ingest-request JSON schema (eg, from Kafka) to SQL ingest_request schema
+
+ if there is a problem with conversion, return None
+ """
+ # backwards compat hacks; transform request to look like current schema
+ if raw.get("ingest_type") == "file":
+ raw["ingest_type"] = "pdf"
+ if (
+ not raw.get("link_source")
+ and raw.get("base_url")
+ and raw.get("ext_ids", {}).get("doi")
+ and raw["base_url"] == "https://doi.org/{}".format(raw["ext_ids"]["doi"])
+ ):
+ # set link_source(_id) for old ingest requests
+ raw["link_source"] = "doi"
+ raw["link_source_id"] = raw["ext_ids"]["doi"]
+ if (
+ not raw.get("link_source")
+ and raw.get("ingest_request_source", "").startswith("savepapernow")
+ and raw.get("fatcat", {}).get("release_ident")
+ ):
+ # set link_source(_id) for old ingest requests
+ raw["link_source"] = "spn"
+ raw["link_source_id"] = raw["fatcat"]["release_ident"]
+
+ for k in ("ingest_type", "base_url", "link_source", "link_source_id"):
+ if k not in raw:
+ self.counts["skip-request-fields"] += 1
+ return None
+ if raw["ingest_type"] not in ("pdf", "xml", "html"):
+ self.counts["skip-ingest-type"] += 1
+ return None
+ # limit on base_url length
+ if len(raw["base_url"]) > 1500:
+ self.counts["skip-url-too-long"] += 1
+ return None
+ request = {
+ "ingest_type": raw["ingest_type"],
+ "base_url": raw["base_url"],
+ "link_source": raw["link_source"],
+ "link_source_id": raw["link_source_id"],
+ "ingest_request_source": raw.get("ingest_request_source"),
+ "request": {},
+ }
+ # extra/optional fields
+ if raw.get("release_stage"):
+ request["release_stage"] = raw["release_stage"]
+ if raw.get("fatcat", {}).get("release_ident"):
+ request["request"]["release_ident"] = raw["fatcat"]["release_ident"]
+ for k in ("ext_ids", "edit_extra", "rel"):
+ if raw.get(k):
+ request["request"][k] = raw[k]
+ # if this dict is empty, trim it to save DB space
+ if not request["request"]:
+ request["request"] = None
+ return request
+
+ def file_result_to_row(self, raw: dict) -> Optional[dict]:
+ """
+ Converts ingest-result JSON schema (eg, from Kafka) to SQL ingest_file_result schema
+
+ if there is a problem with conversion, return None and set skip count
+ """
+ for k in ("request", "hit", "status"):
+ if k not in raw:
+ self.counts["skip-result-fields"] += 1
+ return None
+ if "base_url" not in raw["request"]:
+ self.counts["skip-result-fields"] += 1
+ return None
+ ingest_type = raw["request"].get("ingest_type")
+ if ingest_type == "file":
+ ingest_type = "pdf"
+ if ingest_type not in (
+ "pdf",
+ "xml",
+ "html",
+ "component",
+ "src",
+ "dataset",
+ "dataset-file",
+ ):
+ self.counts["skip-ingest-type"] += 1
+ return None
+ if raw["status"] in ("existing",):
+ self.counts["skip-existing"] += 1
+ return None
+ result = {
+ "ingest_type": ingest_type,
+ "base_url": raw["request"]["base_url"],
+ "hit": raw["hit"],
+ "status": raw["status"],
+ }
+ terminal = raw.get("terminal")
+ if terminal:
+ result["terminal_url"] = terminal.get("terminal_url") or terminal.get("url")
+ result["terminal_dt"] = terminal.get("terminal_dt")
+ result["terminal_status_code"] = (
+ terminal.get("terminal_status_code")
+ or terminal.get("status_code")
+ or terminal.get("http_code")
+ )
+ if result["terminal_status_code"]:
+ result["terminal_status_code"] = int(result["terminal_status_code"])
+ result["terminal_sha1hex"] = terminal.get("terminal_sha1hex")
+ if len(result["terminal_url"]) > 2048:
+ # postgresql13 doesn't like extremely large URLs in b-tree index
+ self.counts["skip-huge-url"] += 1
+ return None
+ return result
+
+ def result_to_html_meta(self, record: dict) -> Optional[HtmlMetaRow]:
+ html_body = record.get("html_body")
+ file_meta = record.get("file_meta")
+ if not (file_meta and html_body):
+ return None
+ return HtmlMetaRow(
+ sha1hex=file_meta["sha1hex"],
+ status=record.get("status"),
+ scope=record.get("scope"),
+ has_teixml=bool(html_body and html_body["status"] == "success"),
+ has_thumbnail=False, # TODO
+ word_count=(html_body and html_body.get("word_count")) or None,
+ biblio=record.get("html_biblio"),
+ resources=record.get("html_resources"),
+ )
+
+ def result_to_platform_row(self, raw: dict) -> Optional[dict]:
+ """
+ Converts fileset ingest-result JSON schema (eg, from Kafka) to SQL ingest_fileset_platform schema
+
+ if there is a problem with conversion, return None and set skip count
+ """
+ for k in ("request", "hit", "status"):
+ if k not in raw:
+ return None
+ if "base_url" not in raw["request"]:
+ return None
+ ingest_type = raw["request"].get("ingest_type")
+ if ingest_type not in ("dataset"):
+ return None
+ if raw["status"] in ("existing",):
+ return None
+ if not raw.get("platform_name"):
+ return None
+ result = {
+ "ingest_type": ingest_type,
+ "base_url": raw["request"]["base_url"],
+ "hit": raw["hit"],
+ "status": raw["status"],
+ "platform_name": raw.get("platform_name"),
+ "platform_domain": raw.get("platform_domain"),
+ "platform_id": raw.get("platform_id"),
+ "ingest_strategy": raw.get("ingest_strategy"),
+ "total_size": raw.get("total_size"),
+ "file_count": raw.get("file_count"),
+ "archiveorg_item_name": raw.get("archiveorg_item_name"),
+ "archiveorg_item_bundle_path": None,
+ "web_bundle_url": None,
+ "web_bundle_dt": None,
+ "manifest": raw.get("manifest"),
+ }
+ if result.get("fileset_bundle"):
+ result["archiveorg_item_bundle_path"] = result["fileset_bundle"].get(
+ "archiveorg_item_bundle_path"
+ )
+ result["web_bundle_url"] = (
+ result["fileset_bundle"].get("terminal", {}).get("terminal_url")
+ )
+ result["web_bundle_dt"] = (
+ result["fileset_bundle"].get("terminal", {}).get("terminal_dt")
+ )
+ return result
+
+ def push_batch(self, batch: List[Any]) -> List[Any]:
+ self.counts["total"] += len(batch)
+
+ if not batch:
+ return []
+
+ results_unfiltered = [self.file_result_to_row(raw) for raw in batch]
+ results = [r for r in results_unfiltered if r]
+
+ irequests_unfiltered = [
+ self.request_to_row(raw["request"]) for raw in batch if raw.get("request")
+ ]
+ irequests = [
+ r for r in irequests_unfiltered if r and r["ingest_type"] != "dataset-file"
+ ]
+
+ if irequests:
+ resp = self.db.insert_ingest_request(self.cur, irequests)
+ self.counts["insert-requests"] += resp[0]
+ self.counts["update-requests"] += resp[1]
+ if results:
+ resp = self.db.insert_ingest_file_result(self.cur, results, on_conflict="update")
+ self.counts["insert-results"] += resp[0]
+ self.counts["update-results"] += resp[1]
+
+ # these schemas match, so can just pass through
+ cdx_batch = [r["cdx"] for r in batch if r.get("hit") and r.get("cdx")]
+ revisit_cdx_batch = [
+ r["revisit_cdx"] for r in batch if r.get("hit") and r.get("revisit_cdx")
+ ]
+ cdx_batch.extend(revisit_cdx_batch)
+ # filter to full CDX lines, with full warc_paths (not liveweb)
+ cdx_batch = [r for r in cdx_batch if r.get("warc_path") and ("/" in r["warc_path"])]
+ if cdx_batch:
+ resp = self.db.insert_cdx(self.cur, cdx_batch)
+ self.counts["insert-cdx"] += resp[0]
+ self.counts["update-cdx"] += resp[1]
+
+ file_meta_batch = [r["file_meta"] for r in batch if r.get("hit") and r.get("file_meta")]
+ if file_meta_batch:
+ resp = self.db.insert_file_meta(self.cur, file_meta_batch, on_conflict="nothing")
+ self.counts["insert-file_meta"] += resp[0]
+ self.counts["update-file_meta"] += resp[1]
+
+ html_meta_batch = [
+ self.result_to_html_meta(r) for r in batch if r.get("hit") and r.get("html_body")
+ ]
+ if html_meta_batch:
+ rows = [d.to_sql_tuple() for d in html_meta_batch if d]
+ resp = self.db.insert_html_meta(self.cur, rows, on_conflict="update")
+ self.counts["insert-html_meta"] += resp[0]
+ self.counts["update-html_meta"] += resp[1]
+
+ fileset_platform_batch_all = [
+ self.result_to_platform_row(raw)
+ for raw in batch
+ if raw.get("request", {}).get("ingest_type") == "dataset"
+ and raw.get("platform_name")
+ ]
+ fileset_platform_batch: List[Dict] = [p for p in fileset_platform_batch_all if p]
+ if fileset_platform_batch:
+ resp = self.db.insert_ingest_fileset_platform(
+ self.cur, fileset_platform_batch, on_conflict="update"
+ )
+ self.counts["insert-fileset_platform"] += resp[0]
+ self.counts["update-fileset_platform"] += resp[1]
+
+ self.db.commit()
+ return []
+
+
+class PersistIngestFilesetWorker(SandcrawlerWorker):
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__()
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+
+class PersistIngestRequestWorker(PersistIngestFileResultWorker):
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__(db_url=db_url)
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ if not batch:
+ return []
+
+ irequests_all = [self.request_to_row(raw) for raw in batch]
+ irequests: List[Dict] = [r for r in irequests_all if r]
+
+ if irequests:
+ resp = self.db.insert_ingest_request(self.cur, irequests)
+ self.counts["insert-requests"] += resp[0]
+ self.counts["update-requests"] += resp[1]
+
+ self.db.commit()
+ return []
+
+
+class PersistGrobidWorker(SandcrawlerWorker):
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__()
+ self.grobid = GrobidClient()
+ self.s3 = SandcrawlerMinioClient(
+ host_url=kwargs.get("s3_url", "localhost:9000"),
+ access_key=kwargs["s3_access_key"],
+ secret_key=kwargs["s3_secret_key"],
+ default_bucket=kwargs["s3_bucket"],
+ )
+ self.s3_only = kwargs.get("s3_only", False)
+ self.db_only = kwargs.get("db_only", False)
+ assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed"
+ if not self.s3_only:
+ self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url)
+ self.cur: Optional[psycopg2.extensions.cursor] = self.db.conn.cursor()
+ else:
+ self.db = None
+ self.cur = None
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ # filter out bad "missing status_code" timeout rows
+ missing = [r for r in batch if not r.get("status_code")]
+ if missing:
+ self.counts["skip-missing-status"] += len(missing)
+ batch = [r for r in batch if r.get("status_code")]
+
+ for r in batch:
+ if r["status_code"] != 200 or not r.get("tei_xml"):
+ self.counts["s3-skip-status"] += 1
+ if r.get("error_msg"):
+ r["metadata"] = {"error_msg": r["error_msg"][:500]}
+ continue
+
+ assert len(r["key"]) == 40
+ if not self.db_only:
+ self.s3.put_blob(
+ folder="grobid",
+ blob=r["tei_xml"],
+ sha1hex=r["key"],
+ extension=".tei.xml",
+ )
+ self.counts["s3-put"] += 1
+
+ # enhance with GROBID TEI-XML metadata, if available
+ try:
+ metadata = self.grobid.metadata(r)
+ except xml.etree.ElementTree.ParseError as xml_e:
+ r["status"] = "bad-grobid-xml"
+ r["metadata"] = {"error_msg": str(xml_e)[:1024]}
+ continue
+ if not metadata:
+ continue
+ for k in ("fatcat_release", "grobid_version"):
+ r[k] = metadata.pop(k, None)
+ if r.get("fatcat_release"):
+ r["fatcat_release"] = r["fatcat_release"].replace("release_", "")
+ if metadata.get("grobid_timestamp"):
+ r["updated"] = metadata["grobid_timestamp"]
+ r["metadata"] = metadata
+
+ if not self.s3_only:
+ assert self.db and self.cur
+ resp = self.db.insert_grobid(self.cur, batch, on_conflict="update")
+ self.counts["insert-grobid"] += resp[0]
+ self.counts["update-grobid"] += resp[1]
+
+ file_meta_batch = [r["file_meta"] for r in batch if r.get("file_meta")]
+ resp = self.db.insert_file_meta(self.cur, file_meta_batch, on_conflict="update")
+ self.counts["insert-file-meta"] += resp[0]
+ self.counts["update-file-meta"] += resp[1]
+
+ self.db.commit()
+
+ return []
+
+
+class PersistGrobidDiskWorker(SandcrawlerWorker):
+ """
+ Writes blobs out to disk.
+
+ This could be refactored into a "Sink" type with an even thinner wrapper.
+ """
+
+ def __init__(self, output_dir: str):
+ super().__init__()
+ self.output_dir = output_dir
+
+ def _blob_path(self, sha1hex: str, extension: str = ".tei.xml") -> str:
+ obj_path = "{}/{}/{}{}".format(
+ sha1hex[0:2],
+ sha1hex[2:4],
+ sha1hex,
+ extension,
+ )
+ return obj_path
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+
+ if record.get("status_code") != 200 or not record.get("tei_xml"):
+ return False
+ assert (len(record["key"])) == 40
+ p = "{}/{}".format(self.output_dir, self._blob_path(record["key"]))
+ os.makedirs(os.path.dirname(p), exist_ok=True)
+ with open(p, "w") as f:
+ f.write(record.pop("tei_xml"))
+ self.counts["written"] += 1
+ return record
+
+
+class PersistPdfTrioWorker(SandcrawlerWorker):
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__()
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ batch = [r for r in batch if "pdf_trio" in r and r["pdf_trio"].get("status_code")]
+ for r in batch:
+ # copy key (sha1hex) into sub-object
+ r["pdf_trio"]["key"] = r["key"]
+ pdftrio_batch = [r["pdf_trio"] for r in batch]
+ resp = self.db.insert_pdftrio(self.cur, pdftrio_batch, on_conflict="update")
+ self.counts["insert-pdftrio"] += resp[0]
+ self.counts["update-pdftrio"] += resp[1]
+
+ file_meta_batch = [
+ r["file_meta"]
+ for r in batch
+ if r["pdf_trio"]["status"] == "success" and r.get("file_meta")
+ ]
+ resp = self.db.insert_file_meta(self.cur, file_meta_batch)
+ self.counts["insert-file-meta"] += resp[0]
+ self.counts["update-file-meta"] += resp[1]
+
+ self.db.commit()
+ return []
+
+
+class PersistPdfTextWorker(SandcrawlerWorker):
+ """
+ Pushes text file to blob store (S3/seaweed/minio) and PDF metadata to SQL table.
+
+ Should keep batch sizes small.
+ """
+
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__()
+ self.s3 = SandcrawlerMinioClient(
+ host_url=kwargs.get("s3_url", "localhost:9000"),
+ access_key=kwargs["s3_access_key"],
+ secret_key=kwargs["s3_secret_key"],
+ default_bucket=kwargs["s3_bucket"],
+ )
+ self.s3_only = kwargs.get("s3_only", False)
+ self.db_only = kwargs.get("db_only", False)
+ assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed"
+ if not self.s3_only:
+ self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url)
+ self.cur: Optional[psycopg2.extensions.cursor] = self.db.conn.cursor()
+ else:
+ self.db = None
+ self.cur = None
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ parsed_batch = []
+ for r in batch:
+ parsed_batch.append(PdfExtractResult.from_pdftext_dict(r))
+
+ for r in parsed_batch:
+ if r.status != "success" or not r.text:
+ self.counts["s3-skip-status"] += 1
+ if r.error_msg:
+ r.metadata = {"error_msg": r.error_msg[:500]}
+ continue
+
+ assert len(r.sha1hex) == 40
+ if not self.db_only:
+ self.s3.put_blob(
+ folder="text",
+ blob=r.text,
+ sha1hex=r.sha1hex,
+ extension=".txt",
+ )
+ self.counts["s3-put"] += 1
+
+ if not self.s3_only:
+ assert self.db and self.cur
+ rows = [r.to_sql_tuple() for r in parsed_batch]
+ resp = self.db.insert_pdf_meta(self.cur, rows, on_conflict="update")
+ self.counts["insert-pdf-meta"] += resp[0]
+ self.counts["update-pdf-meta"] += resp[1]
+
+ file_meta_batch = [r.file_meta for r in parsed_batch if r.file_meta]
+ resp = self.db.insert_file_meta(self.cur, file_meta_batch, on_conflict="update")
+ self.counts["insert-file-meta"] += resp[0]
+ self.counts["update-file-meta"] += resp[1]
+
+ self.db.commit()
+
+ return []
+
+
+class PersistThumbnailWorker(SandcrawlerWorker):
+ """
+ Pushes text file to blob store (S3/seaweed/minio) and PDF metadata to SQL
+ table.
+
+ This worker *must* be used with raw kakfa mode; thumbnails are *not*
+ wrapped in JSON like most sandcrawler kafka messages.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.s3 = SandcrawlerMinioClient(
+ host_url=kwargs.get("s3_url", "localhost:9000"),
+ access_key=kwargs["s3_access_key"],
+ secret_key=kwargs["s3_secret_key"],
+ default_bucket=kwargs["s3_bucket"],
+ )
+ self.s3_extension = kwargs.get("s3_extension", ".jpg")
+ self.s3_folder = kwargs.get("s3_folder", "pdf")
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """
+ Processing raw messages, not decoded JSON objects
+ """
+
+ assert isinstance(record, bytes)
+ blob: bytes = record
+ if isinstance(key, bytes):
+ key = key.decode("utf-8")
+ assert key is not None and len(key) == 40 and isinstance(key, str)
+ assert len(blob) >= 50
+
+ self.s3.put_blob(
+ folder=self.s3_folder,
+ blob=blob,
+ sha1hex=key,
+ extension=self.s3_extension,
+ )
+ self.counts["s3-put"] += 1
+
+
+class GenericPersistDocWorker(SandcrawlerWorker):
+ """
+ Pushes blobs from Kafka to S3.
+
+ Objects are assumed to be JSON-wrapped strings.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.s3 = SandcrawlerMinioClient(
+ host_url=kwargs.get("s3_url", "localhost:9000"),
+ access_key=kwargs["s3_access_key"],
+ secret_key=kwargs["s3_secret_key"],
+ default_bucket=kwargs["s3_bucket"],
+ )
+ self.s3_extension = kwargs.get("s3_extension", ".unknown")
+ self.s3_folder = kwargs.get("s3_folder", "unknown")
+ self.doc_key = "unknown"
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+
+ if record.get("status") != "success" or not record.get(self.doc_key):
+ return
+
+ assert key is not None
+ if isinstance(key, bytes):
+ key_str = key.decode("utf-8")
+ elif isinstance(key, str):
+ key_str = key
+ assert len(key_str) == 40
+ if "sha1hex" in record:
+ assert key_str == record["sha1hex"]
+
+ self.s3.put_blob(
+ folder=self.s3_folder,
+ blob=record[self.doc_key].encode("utf-8"),
+ sha1hex=key_str,
+ extension=self.s3_extension,
+ )
+ self.counts["s3-put"] += 1
+
+
+class PersistXmlDocWorker(GenericPersistDocWorker):
+ """
+ Pushes TEI-XML file to blob store (S3/seaweed/minio). Does not talk to
+ sandcrawler database (SQL).
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.s3_extension = kwargs.get("s3_extension", ".jats.xml")
+ self.s3_folder = kwargs.get("s3_folder", "xml_doc")
+ self.doc_key = "jats_xml"
+
+
+class PersistHtmlTeiXmlWorker(GenericPersistDocWorker):
+ """
+ Pushes TEI-XML file to blob store (S3/seaweed/minio). Does not talk to
+ sandcrawler database (SQL).
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.s3_extension = kwargs.get("s3_extension", ".tei.xml")
+ self.s3_folder = kwargs.get("s3_folder", "html_body")
+ self.doc_key = "tei_xml"
+
+
+class PersistCrossrefWorker(SandcrawlerWorker):
+ """
+ Pushes Crossref API JSON records into postgresql. Can also talk to GROBID,
+ parsed 'unstructured' references, and push the results in to postgresql at
+ the same time.
+ """
+
+ def __init__(
+ self,
+ db_url: str,
+ grobid_client: Optional[GrobidClient],
+ parse_refs: bool = True,
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+ if grobid_client:
+ self.grobid_client = grobid_client
+ else:
+ self.grobid_client = GrobidClient()
+ self.parse_refs = parse_refs
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ crossref_batch = []
+ refs_batch = []
+ for record in batch:
+ crossref_batch.append(
+ dict(
+ doi=record["DOI"].lower().strip(),
+ indexed=record["indexed"]["date-time"],
+ record=record,
+ )
+ )
+ if self.parse_refs:
+ try:
+ parsed_refs = self.grobid_client.crossref_refs(record)
+ refs_batch.append(parsed_refs)
+ except (
+ xml.etree.ElementTree.ParseError,
+ requests.exceptions.HTTPError,
+ requests.exceptions.ReadTimeout,
+ ):
+ print("GROBID crossref refs parsing error, skipping with a sleep")
+ time.sleep(3)
+ pass
+
+ resp = self.db.insert_crossref(self.cur, crossref_batch)
+ if len(crossref_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(crossref_batch)
+ self.counts["insert-crossref"] += resp[0]
+ self.counts["update-crossref"] += resp[1]
+
+ if refs_batch:
+ resp = self.db.insert_grobid_refs(self.cur, refs_batch)
+ if len(refs_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(refs_batch)
+ self.counts["insert-grobid_refs"] += resp[0]
+ self.counts["update-grobid_refs"] += resp[1]
+
+ self.db.commit()
+ return []
+
+
+class PersistGrobidRefsWorker(SandcrawlerWorker):
+ """
+ Simple persist worker to backfill GROBID references in to postgresql
+ locally. Consumes the JSON output from GROBID CrossrefRefsWorker.
+ """
+
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__(**kwargs)
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ refs_batch = []
+ for record in batch:
+ assert record["source"]
+ assert record["source_id"]
+ refs_batch.append(record)
+
+ resp = self.db.insert_grobid_refs(self.cur, refs_batch)
+ if len(refs_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(refs_batch)
+ self.counts["insert-grobid_refs"] += resp[0]
+ self.counts["update-grobid_refs"] += resp[1]
+
+ self.db.commit()
+ return []