diff options
Diffstat (limited to 'python')
-rwxr-xr-x | python/persist_tool.py | 22 | ||||
-rw-r--r-- | python/sandcrawler/persist.py | 40 |
2 files changed, 62 insertions, 0 deletions
diff --git a/python/persist_tool.py b/python/persist_tool.py index 069bef7..e08d66c 100755 --- a/python/persist_tool.py +++ b/python/persist_tool.py @@ -139,6 +139,18 @@ def run_crossref(args): pusher.run() +def run_grobid_refs(args): + worker = PersistGrobidRefsWorker( + db_url=args.db_url, + ) + pusher = JsonLinePusher( + worker, + args.json_file, + batch_size=100, + ) + pusher.run() + + def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( @@ -277,6 +289,16 @@ def main(): help="use GROBID to parse any unstructured references (default is to not)", ) + sub_grobid_refs = subparsers.add_parser( + "grobid-refs", help="backfill a grobid_refs JSON dump into postgresql" + ) + sub_grobid_refs.set_defaults(func=run_grobid_refs) + sub_grobid_refs.add_argument( + "json_file", + help="grobid_refs to import from (or '-' for stdin)", + type=argparse.FileType("r"), + ) + args = parser.parse_args() if not args.__dict__.get("func"): print("Tell me what to do!", file=sys.stderr) diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py index 6847e2e..d753380 100644 --- a/python/sandcrawler/persist.py +++ b/python/sandcrawler/persist.py @@ -678,6 +678,12 @@ class PersistHtmlTeiXmlWorker(GenericPersistDocWorker): class PersistCrossrefWorker(SandcrawlerWorker): + """ + Pushes Crossref API JSON records into postgresql. Can also talk to GROBID, + parsed 'unstructured' references, and push the results in to postgresql at + the same time. + """ + def __init__( self, db_url: str, @@ -739,3 +745,37 @@ class PersistCrossrefWorker(SandcrawlerWorker): self.db.commit() return [] + + +class PersistGrobidRefsWorker(SandcrawlerWorker): + """ + Simple persist worker to backfill GROBID references in to postgresql + locally. Consumes the JSON output from GROBID CrossrefRefsWorker. + """ + + def __init__(self, db_url: str, **kwargs): + super().__init__(**kwargs) + self.db = SandcrawlerPostgresClient(db_url) + self.cur = self.db.conn.cursor() + + def process(self, record: Any, key: Optional[str] = None) -> Any: + """Only do batches (as transactions)""" + raise NotImplementedError + + def push_batch(self, batch: list) -> list: + self.counts["total"] += len(batch) + + refs_batch = [] + for record in batch: + assert record["source"] + assert record["source_id"] + refs_batch.append(record) + + resp = self.db.insert_grobid_refs(self.cur, refs_batch) + if len(refs_batch) < len(batch): + self.counts["skip"] += len(batch) - len(refs_batch) + self.counts["insert-grobid_refs"] += resp[0] + self.counts["update-grobid_refs"] += resp[1] + + self.db.commit() + return [] |