aboutsummaryrefslogtreecommitdiffstats
path: root/python/sandcrawler/persist.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/sandcrawler/persist.py')
-rw-r--r--python/sandcrawler/persist.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py
index f50b9d1..4c9d9d7 100644
--- a/python/sandcrawler/persist.py
+++ b/python/sandcrawler/persist.py
@@ -673,3 +673,48 @@ class PersistHtmlTeiXmlWorker(GenericPersistDocWorker):
self.s3_extension = kwargs.get("s3_extension", ".tei.xml")
self.s3_folder = kwargs.get("s3_folder", "html_body")
self.doc_key = "tei_xml"
+
+
+class PersistCrossrefWorker(SandcrawlerWorker):
+ def __init__(self, db_url: str, grobid_client: Optional[GrobidClient], **kwargs):
+ super().__init__(**kwargs)
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+ if grobid_client:
+ self.grobid_client = grobid_client
+ else:
+ self.grobid_client = GrobidClient()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ crossref_batch = []
+ refs_batch = []
+ for record in batch:
+ crossref_batch.append(
+ dict(
+ doi=record["DOI"].lower().strip(),
+ indexed=record["indexed"]["date-time"],
+ record=record,
+ )
+ )
+ refs_batch.append(self.grobid_client.crossref_refs(record))
+
+ resp = self.db.insert_crossref(self.cur, crossref_batch)
+ if len(crossref_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(crossref_batch)
+ self.counts["insert-crossref"] += resp[0]
+ self.counts["update-crossref"] += resp[1]
+
+ resp = self.db.insert_grobid_refs(self.cur, refs_batch)
+ if len(refs_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(refs_batch)
+ self.counts["insert-grobid_refs"] += resp[0]
+ self.counts["update-grobid_refs"] += resp[1]
+
+ self.db.commit()
+ return []