aboutsummaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rwxr-xr-xpython/persist_tool.py30
-rw-r--r--python/sandcrawler/db.py109
-rw-r--r--python/sandcrawler/persist.py45
-rwxr-xr-xpython/sandcrawler_worker.py33
4 files changed, 212 insertions, 5 deletions
diff --git a/python/persist_tool.py b/python/persist_tool.py
index b124ddc..a4f9812 100755
--- a/python/persist_tool.py
+++ b/python/persist_tool.py
@@ -119,6 +119,22 @@ def run_ingest_request(args):
pusher.run()
+def run_crossref(args):
+ grobid_client = GrobidClient(
+ host_url=args.grobid_host,
+ )
+ worker = PersistCrossrefWorker(
+ db_url=args.db_url,
+ grobid_client=grobid_client,
+ )
+ pusher = JsonLinePusher(
+ worker,
+ args.json_file,
+ batch_size=10,
+ )
+ pusher.run()
+
+
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
@@ -238,6 +254,20 @@ def main():
type=argparse.FileType("r"),
)
+ sub_crossref = subparsers.add_parser(
+ "crossref",
+ help="backfill a crossref JSON dump into postgresql, and extract references at the same time",
+ )
+ sub_crossref.set_defaults(func=run_crossref)
+ sub_crossref.add_argument(
+ "json_file",
+ help="crossref file to import from (or '-' for stdin)",
+ type=argparse.FileType("r"),
+ )
+ sub_crossref.add_argument(
+ "--grobid-host", default="https://grobid.qa.fatcat.wiki", help="GROBID API host/port"
+ )
+
args = parser.parse_args()
if not args.__dict__.get("func"):
print("Tell me what to do!", file=sys.stderr)
diff --git a/python/sandcrawler/db.py b/python/sandcrawler/db.py
index 69d2116..101d419 100644
--- a/python/sandcrawler/db.py
+++ b/python/sandcrawler/db.py
@@ -89,7 +89,28 @@ class SandcrawlerPostgrestClient:
return None
def get_crossref(self, doi: str) -> Optional[dict]:
- resp = requests.get(self.api_url + "/crossref", params=dict(doi="eq." + doi))
+ resp = requests.get(self.api_url + "/crossref", params=dict(doi=f"eq.{doi}"))
+ resp.raise_for_status()
+ resp_json = resp.json()
+ if resp_json:
+ return resp_json[0]
+ else:
+ return None
+
+ def get_crossref_with_refs(self, doi: str) -> Optional[dict]:
+ resp = requests.get(self.api_url + "/crossref_with_refs", params=dict(doi=f"eq.{doi}"))
+ resp.raise_for_status()
+ resp_json = resp.json()
+ if resp_json:
+ return resp_json[0]
+ else:
+ return None
+
+ def get_grobid_refs(self, source: str, source_id: str) -> Optional[dict]:
+ resp = requests.get(
+ self.api_url + "/grobid_refs",
+ params=dict(source=f"eq.{source}", source_id=f"eq.{source_id}"),
+ )
resp.raise_for_status()
resp_json = resp.json()
if resp_json:
@@ -230,6 +251,7 @@ class SandcrawlerPostgresClient:
r[k] = r["metadata"].get(k)
r["metadata"].pop(k, None)
r["metadata"] = json.dumps(r["metadata"], sort_keys=True)
+ now = datetime.datetime.now()
rows = [
(
d["key"],
@@ -237,7 +259,7 @@ class SandcrawlerPostgresClient:
d["status_code"],
d["status"],
d.get("fatcat_release") or None,
- d.get("updated") or datetime.datetime.now(),
+ d.get("updated") or now,
d.get("metadata") or None,
)
for d in batch
@@ -356,10 +378,11 @@ class SandcrawlerPostgresClient:
else:
raise NotImplementedError("on_conflict: {}".format(on_conflict))
sql += " RETURNING xmax;"
+ now = datetime.datetime.now()
rows = [
(
d["key"],
- d.get("updated") or datetime.datetime.now(),
+ d.get("updated") or now,
d["status_code"],
d["status"],
d.get("versions", {}).get("pdftrio_version") or None,
@@ -533,3 +556,83 @@ class SandcrawlerPostgresClient:
rows = list(row_dict.values())
resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True)
return self._inserts_and_updates(resp, on_conflict)
+
+ def insert_crossref(
+ self,
+ cur: psycopg2.extensions.cursor,
+ batch: List[Dict[str, Any]],
+ on_conflict: str = "update",
+ ) -> Tuple[int, int]:
+ sql = """
+ INSERT INTO
+ crossref (doi, indexed, record)
+ VALUES %s
+ ON CONFLICT (doi) DO
+ """
+ if on_conflict.lower() == "nothing":
+ sql += " NOTHING"
+ elif on_conflict.lower() == "update":
+ sql += """ UPDATE SET
+ indexed=EXCLUDED.indexed,
+ record=EXCLUDED.record
+ """
+ else:
+ raise NotImplementedError("on_conflict: {}".format(on_conflict))
+ sql += " RETURNING xmax;"
+ rows = [
+ (
+ d["doi"],
+ d.get("indexed") or None,
+ json.dumps(d["record"], sort_keys=True),
+ )
+ for d in batch
+ ]
+ # filter out duplicate rows by key (sha1hex)
+ row_dict = dict()
+ for b in rows:
+ row_dict[b[0]] = b
+ rows = list(row_dict.values())
+ resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True)
+ return self._inserts_and_updates(resp, on_conflict)
+
+ def insert_grobid_refs(
+ self,
+ cur: psycopg2.extensions.cursor,
+ batch: List[Dict[str, Any]],
+ on_conflict: str = "update",
+ ) -> Tuple[int, int]:
+ sql = """
+ INSERT INTO
+ grobid_refs (source, source_id, source_ts, updated, refs_json)
+ VALUES %s
+ ON CONFLICT (source, source_id) DO
+ """
+ if on_conflict.lower() == "nothing":
+ sql += " NOTHING"
+ elif on_conflict.lower() == "update":
+ sql += """ UPDATE SET
+ source_ts=EXCLUDED.source_ts,
+ updated=EXCLUDED.updated,
+ refs_json=EXCLUDED.refs_json
+ """
+ else:
+ raise NotImplementedError("on_conflict: {}".format(on_conflict))
+ sql += " RETURNING xmax;"
+ now = datetime.datetime.now()
+ rows = [
+ (
+ d["source"],
+ d["source_id"],
+ d.get("source_ts") or None,
+ d.get("updated") or now,
+ json.dumps(d["refs_json"], sort_keys=True),
+ )
+ for d in batch
+ ]
+ # filter out duplicate rows by key (sha1hex)
+ row_dict = dict()
+ for b in rows:
+ row_dict[(b[0], b[1])] = b
+ rows = list(row_dict.values())
+ resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True)
+ return self._inserts_and_updates(resp, on_conflict)
diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py
index f50b9d1..4c9d9d7 100644
--- a/python/sandcrawler/persist.py
+++ b/python/sandcrawler/persist.py
@@ -673,3 +673,48 @@ class PersistHtmlTeiXmlWorker(GenericPersistDocWorker):
self.s3_extension = kwargs.get("s3_extension", ".tei.xml")
self.s3_folder = kwargs.get("s3_folder", "html_body")
self.doc_key = "tei_xml"
+
+
+class PersistCrossrefWorker(SandcrawlerWorker):
+ def __init__(self, db_url: str, grobid_client: Optional[GrobidClient], **kwargs):
+ super().__init__(**kwargs)
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+ if grobid_client:
+ self.grobid_client = grobid_client
+ else:
+ self.grobid_client = GrobidClient()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ crossref_batch = []
+ refs_batch = []
+ for record in batch:
+ crossref_batch.append(
+ dict(
+ doi=record["DOI"].lower().strip(),
+ indexed=record["indexed"]["date-time"],
+ record=record,
+ )
+ )
+ refs_batch.append(self.grobid_client.crossref_refs(record))
+
+ resp = self.db.insert_crossref(self.cur, crossref_batch)
+ if len(crossref_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(crossref_batch)
+ self.counts["insert-crossref"] += resp[0]
+ self.counts["update-crossref"] += resp[1]
+
+ resp = self.db.insert_grobid_refs(self.cur, refs_batch)
+ if len(refs_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(refs_batch)
+ self.counts["insert-grobid_refs"] += resp[0]
+ self.counts["update-grobid_refs"] += resp[1]
+
+ self.db.commit()
+ return []
diff --git a/python/sandcrawler_worker.py b/python/sandcrawler_worker.py
index d42cd8c..73bd444 100755
--- a/python/sandcrawler_worker.py
+++ b/python/sandcrawler_worker.py
@@ -12,7 +12,11 @@ import sys
import raven
from sandcrawler import *
-from sandcrawler.persist import PersistHtmlTeiXmlWorker, PersistXmlDocWorker
+from sandcrawler.persist import (
+ PersistCrossrefWorker,
+ PersistHtmlTeiXmlWorker,
+ PersistXmlDocWorker,
+)
# Yep, a global. Gets DSN from `SENTRY_DSN` environment variable
try:
@@ -291,6 +295,22 @@ def run_persist_ingest_file(args):
pusher.run()
+def run_persist_crossref(args):
+ grobid_client = GrobidClient(host_url=args.grobid_host)
+ consume_topic = "fatcat-{}.api-crossref".format(args.env)
+ worker = PersistCrossrefWorker(db_url=args.db_url, grobid_client=grobid_client)
+ pusher = KafkaJsonPusher(
+ worker=worker,
+ kafka_hosts=args.kafka_hosts,
+ consume_topic=consume_topic,
+ group="persist-ingest",
+ push_batches=True,
+ # small batch size because doing GROBID processing
+ batch_size=20,
+ )
+ pusher.run()
+
+
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
@@ -302,7 +322,7 @@ def main():
"--env", default="dev", help="Kafka topic namespace to use (eg, prod, qa, dev)"
)
parser.add_argument(
- "--grobid-host", default="http://grobid.qa.fatcat.wiki", help="GROBID API host/port"
+ "--grobid-host", default="https://grobid.qa.fatcat.wiki", help="GROBID API host/port"
)
parser.add_argument(
"--db-url",
@@ -417,6 +437,15 @@ def main():
)
sub_persist_ingest_file.set_defaults(func=run_persist_ingest_file)
+ sub_persist_crossref = subparsers.add_parser(
+ "persist-crossref",
+ help="daemon that persists crossref to postgres; also does GROBID ref transform",
+ )
+ sub_persist_crossref.add_argument(
+ "--grobid-host", default="https://grobid.qa.fatcat.wiki", help="GROBID API host/port"
+ )
+ sub_persist_crossref.set_defaults(func=run_persist_crossref)
+
args = parser.parse_args()
if not args.__dict__.get("func"):
parser.print_help(file=sys.stderr)