aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xpython/persist_tool.py22
-rw-r--r--python/sandcrawler/persist.py40
2 files changed, 62 insertions, 0 deletions
diff --git a/python/persist_tool.py b/python/persist_tool.py
index 069bef7..e08d66c 100755
--- a/python/persist_tool.py
+++ b/python/persist_tool.py
@@ -139,6 +139,18 @@ def run_crossref(args):
pusher.run()
+def run_grobid_refs(args):
+ worker = PersistGrobidRefsWorker(
+ db_url=args.db_url,
+ )
+ pusher = JsonLinePusher(
+ worker,
+ args.json_file,
+ batch_size=100,
+ )
+ pusher.run()
+
+
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
@@ -277,6 +289,16 @@ def main():
help="use GROBID to parse any unstructured references (default is to not)",
)
+ sub_grobid_refs = subparsers.add_parser(
+ "grobid-refs", help="backfill a grobid_refs JSON dump into postgresql"
+ )
+ sub_grobid_refs.set_defaults(func=run_grobid_refs)
+ sub_grobid_refs.add_argument(
+ "json_file",
+ help="grobid_refs to import from (or '-' for stdin)",
+ type=argparse.FileType("r"),
+ )
+
args = parser.parse_args()
if not args.__dict__.get("func"):
print("Tell me what to do!", file=sys.stderr)
diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py
index 6847e2e..d753380 100644
--- a/python/sandcrawler/persist.py
+++ b/python/sandcrawler/persist.py
@@ -678,6 +678,12 @@ class PersistHtmlTeiXmlWorker(GenericPersistDocWorker):
class PersistCrossrefWorker(SandcrawlerWorker):
+ """
+ Pushes Crossref API JSON records into postgresql. Can also talk to GROBID,
+ parsed 'unstructured' references, and push the results in to postgresql at
+ the same time.
+ """
+
def __init__(
self,
db_url: str,
@@ -739,3 +745,37 @@ class PersistCrossrefWorker(SandcrawlerWorker):
self.db.commit()
return []
+
+
+class PersistGrobidRefsWorker(SandcrawlerWorker):
+ """
+ Simple persist worker to backfill GROBID references in to postgresql
+ locally. Consumes the JSON output from GROBID CrossrefRefsWorker.
+ """
+
+ def __init__(self, db_url: str, **kwargs):
+ super().__init__(**kwargs)
+ self.db = SandcrawlerPostgresClient(db_url)
+ self.cur = self.db.conn.cursor()
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ """Only do batches (as transactions)"""
+ raise NotImplementedError
+
+ def push_batch(self, batch: list) -> list:
+ self.counts["total"] += len(batch)
+
+ refs_batch = []
+ for record in batch:
+ assert record["source"]
+ assert record["source_id"]
+ refs_batch.append(record)
+
+ resp = self.db.insert_grobid_refs(self.cur, refs_batch)
+ if len(refs_batch) < len(batch):
+ self.counts["skip"] += len(batch) - len(refs_batch)
+ self.counts["insert-grobid_refs"] += resp[0]
+ self.counts["update-grobid_refs"] += resp[1]
+
+ self.db.commit()
+ return []