From a859fddb227872ce52f06af1dd9fb80987f348c4 Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Fri, 29 Oct 2021 18:36:53 -0700 Subject: glue, utils, and worker code for crossref and grobid_refs --- python/persist_tool.py | 30 ++++++++++++ python/sandcrawler/db.py | 109 ++++++++++++++++++++++++++++++++++++++++-- python/sandcrawler/persist.py | 45 +++++++++++++++++ python/sandcrawler_worker.py | 33 ++++++++++++- 4 files changed, 212 insertions(+), 5 deletions(-) diff --git a/python/persist_tool.py b/python/persist_tool.py index b124ddc..a4f9812 100755 --- a/python/persist_tool.py +++ b/python/persist_tool.py @@ -119,6 +119,22 @@ def run_ingest_request(args): pusher.run() +def run_crossref(args): + grobid_client = GrobidClient( + host_url=args.grobid_host, + ) + worker = PersistCrossrefWorker( + db_url=args.db_url, + grobid_client=grobid_client, + ) + pusher = JsonLinePusher( + worker, + args.json_file, + batch_size=10, + ) + pusher.run() + + def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( @@ -238,6 +254,20 @@ def main(): type=argparse.FileType("r"), ) + sub_crossref = subparsers.add_parser( + "crossref", + help="backfill a crossref JSON dump into postgresql, and extract references at the same time", + ) + sub_crossref.set_defaults(func=run_crossref) + sub_crossref.add_argument( + "json_file", + help="crossref file to import from (or '-' for stdin)", + type=argparse.FileType("r"), + ) + sub_crossref.add_argument( + "--grobid-host", default="https://grobid.qa.fatcat.wiki", help="GROBID API host/port" + ) + args = parser.parse_args() if not args.__dict__.get("func"): print("Tell me what to do!", file=sys.stderr) diff --git a/python/sandcrawler/db.py b/python/sandcrawler/db.py index 69d2116..101d419 100644 --- a/python/sandcrawler/db.py +++ b/python/sandcrawler/db.py @@ -89,7 +89,28 @@ class SandcrawlerPostgrestClient: return None def get_crossref(self, doi: str) -> Optional[dict]: - resp = requests.get(self.api_url + "/crossref", params=dict(doi="eq." + doi)) + resp = requests.get(self.api_url + "/crossref", params=dict(doi=f"eq.{doi}")) + resp.raise_for_status() + resp_json = resp.json() + if resp_json: + return resp_json[0] + else: + return None + + def get_crossref_with_refs(self, doi: str) -> Optional[dict]: + resp = requests.get(self.api_url + "/crossref_with_refs", params=dict(doi=f"eq.{doi}")) + resp.raise_for_status() + resp_json = resp.json() + if resp_json: + return resp_json[0] + else: + return None + + def get_grobid_refs(self, source: str, source_id: str) -> Optional[dict]: + resp = requests.get( + self.api_url + "/grobid_refs", + params=dict(source=f"eq.{source}", source_id=f"eq.{source_id}"), + ) resp.raise_for_status() resp_json = resp.json() if resp_json: @@ -230,6 +251,7 @@ class SandcrawlerPostgresClient: r[k] = r["metadata"].get(k) r["metadata"].pop(k, None) r["metadata"] = json.dumps(r["metadata"], sort_keys=True) + now = datetime.datetime.now() rows = [ ( d["key"], @@ -237,7 +259,7 @@ class SandcrawlerPostgresClient: d["status_code"], d["status"], d.get("fatcat_release") or None, - d.get("updated") or datetime.datetime.now(), + d.get("updated") or now, d.get("metadata") or None, ) for d in batch @@ -356,10 +378,11 @@ class SandcrawlerPostgresClient: else: raise NotImplementedError("on_conflict: {}".format(on_conflict)) sql += " RETURNING xmax;" + now = datetime.datetime.now() rows = [ ( d["key"], - d.get("updated") or datetime.datetime.now(), + d.get("updated") or now, d["status_code"], d["status"], d.get("versions", {}).get("pdftrio_version") or None, @@ -533,3 +556,83 @@ class SandcrawlerPostgresClient: rows = list(row_dict.values()) resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True) return self._inserts_and_updates(resp, on_conflict) + + def insert_crossref( + self, + cur: psycopg2.extensions.cursor, + batch: List[Dict[str, Any]], + on_conflict: str = "update", + ) -> Tuple[int, int]: + sql = """ + INSERT INTO + crossref (doi, indexed, record) + VALUES %s + ON CONFLICT (doi) DO + """ + if on_conflict.lower() == "nothing": + sql += " NOTHING" + elif on_conflict.lower() == "update": + sql += """ UPDATE SET + indexed=EXCLUDED.indexed, + record=EXCLUDED.record + """ + else: + raise NotImplementedError("on_conflict: {}".format(on_conflict)) + sql += " RETURNING xmax;" + rows = [ + ( + d["doi"], + d.get("indexed") or None, + json.dumps(d["record"], sort_keys=True), + ) + for d in batch + ] + # filter out duplicate rows by key (sha1hex) + row_dict = dict() + for b in rows: + row_dict[b[0]] = b + rows = list(row_dict.values()) + resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True) + return self._inserts_and_updates(resp, on_conflict) + + def insert_grobid_refs( + self, + cur: psycopg2.extensions.cursor, + batch: List[Dict[str, Any]], + on_conflict: str = "update", + ) -> Tuple[int, int]: + sql = """ + INSERT INTO + grobid_refs (source, source_id, source_ts, updated, refs_json) + VALUES %s + ON CONFLICT (source, source_id) DO + """ + if on_conflict.lower() == "nothing": + sql += " NOTHING" + elif on_conflict.lower() == "update": + sql += """ UPDATE SET + source_ts=EXCLUDED.source_ts, + updated=EXCLUDED.updated, + refs_json=EXCLUDED.refs_json + """ + else: + raise NotImplementedError("on_conflict: {}".format(on_conflict)) + sql += " RETURNING xmax;" + now = datetime.datetime.now() + rows = [ + ( + d["source"], + d["source_id"], + d.get("source_ts") or None, + d.get("updated") or now, + json.dumps(d["refs_json"], sort_keys=True), + ) + for d in batch + ] + # filter out duplicate rows by key (sha1hex) + row_dict = dict() + for b in rows: + row_dict[(b[0], b[1])] = b + rows = list(row_dict.values()) + resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True) + return self._inserts_and_updates(resp, on_conflict) diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py index f50b9d1..4c9d9d7 100644 --- a/python/sandcrawler/persist.py +++ b/python/sandcrawler/persist.py @@ -673,3 +673,48 @@ class PersistHtmlTeiXmlWorker(GenericPersistDocWorker): self.s3_extension = kwargs.get("s3_extension", ".tei.xml") self.s3_folder = kwargs.get("s3_folder", "html_body") self.doc_key = "tei_xml" + + +class PersistCrossrefWorker(SandcrawlerWorker): + def __init__(self, db_url: str, grobid_client: Optional[GrobidClient], **kwargs): + super().__init__(**kwargs) + self.db = SandcrawlerPostgresClient(db_url) + self.cur = self.db.conn.cursor() + if grobid_client: + self.grobid_client = grobid_client + else: + self.grobid_client = GrobidClient() + + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" + raise NotImplementedError + + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) + + crossref_batch = [] + refs_batch = [] + for record in batch: + crossref_batch.append( + dict( + doi=record["DOI"].lower().strip(), + indexed=record["indexed"]["date-time"], + record=record, + ) + ) + refs_batch.append(self.grobid_client.crossref_refs(record)) + + resp = self.db.insert_crossref(self.cur, crossref_batch) + if len(crossref_batch) < len(batch): + self.counts["skip"] += len(batch) - len(crossref_batch) + self.counts["insert-crossref"] += resp[0] + self.counts["update-crossref"] += resp[1] + + resp = self.db.insert_grobid_refs(self.cur, refs_batch) + if len(refs_batch) < len(batch): + self.counts["skip"] += len(batch) - len(refs_batch) + self.counts["insert-grobid_refs"] += resp[0] + self.counts["update-grobid_refs"] += resp[1] + + self.db.commit() + return [] diff --git a/python/sandcrawler_worker.py b/python/sandcrawler_worker.py index d42cd8c..73bd444 100755 --- a/python/sandcrawler_worker.py +++ b/python/sandcrawler_worker.py @@ -12,7 +12,11 @@ import sys import raven from sandcrawler import * -from sandcrawler.persist import PersistHtmlTeiXmlWorker, PersistXmlDocWorker +from sandcrawler.persist import ( + PersistCrossrefWorker, + PersistHtmlTeiXmlWorker, + PersistXmlDocWorker, +) # Yep, a global. Gets DSN from `SENTRY_DSN` environment variable try: @@ -291,6 +295,22 @@ def run_persist_ingest_file(args): pusher.run() +def run_persist_crossref(args): + grobid_client = GrobidClient(host_url=args.grobid_host) + consume_topic = "fatcat-{}.api-crossref".format(args.env) + worker = PersistCrossrefWorker(db_url=args.db_url, grobid_client=grobid_client) + pusher = KafkaJsonPusher( + worker=worker, + kafka_hosts=args.kafka_hosts, + consume_topic=consume_topic, + group="persist-ingest", + push_batches=True, + # small batch size because doing GROBID processing + batch_size=20, + ) + pusher.run() + + def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( @@ -302,7 +322,7 @@ def main(): "--env", default="dev", help="Kafka topic namespace to use (eg, prod, qa, dev)" ) parser.add_argument( - "--grobid-host", default="http://grobid.qa.fatcat.wiki", help="GROBID API host/port" + "--grobid-host", default="https://grobid.qa.fatcat.wiki", help="GROBID API host/port" ) parser.add_argument( "--db-url", @@ -417,6 +437,15 @@ def main(): ) sub_persist_ingest_file.set_defaults(func=run_persist_ingest_file) + sub_persist_crossref = subparsers.add_parser( + "persist-crossref", + help="daemon that persists crossref to postgres; also does GROBID ref transform", + ) + sub_persist_crossref.add_argument( + "--grobid-host", default="https://grobid.qa.fatcat.wiki", help="GROBID API host/port" + ) + sub_persist_crossref.set_defaults(func=run_persist_crossref) + args = parser.parse_args() if not args.__dict__.get("func"): parser.print_help(file=sys.stderr) -- cgit v1.2.3