aboutsummaryrefslogtreecommitdiffstats
path: root/python/sandcrawler/db.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/sandcrawler/db.py')
-rw-r--r--python/sandcrawler/db.py109
1 files changed, 106 insertions, 3 deletions
diff --git a/python/sandcrawler/db.py b/python/sandcrawler/db.py
index 69d2116..101d419 100644
--- a/python/sandcrawler/db.py
+++ b/python/sandcrawler/db.py
@@ -89,7 +89,28 @@ class SandcrawlerPostgrestClient:
return None
def get_crossref(self, doi: str) -> Optional[dict]:
- resp = requests.get(self.api_url + "/crossref", params=dict(doi="eq." + doi))
+ resp = requests.get(self.api_url + "/crossref", params=dict(doi=f"eq.{doi}"))
+ resp.raise_for_status()
+ resp_json = resp.json()
+ if resp_json:
+ return resp_json[0]
+ else:
+ return None
+
+ def get_crossref_with_refs(self, doi: str) -> Optional[dict]:
+ resp = requests.get(self.api_url + "/crossref_with_refs", params=dict(doi=f"eq.{doi}"))
+ resp.raise_for_status()
+ resp_json = resp.json()
+ if resp_json:
+ return resp_json[0]
+ else:
+ return None
+
+ def get_grobid_refs(self, source: str, source_id: str) -> Optional[dict]:
+ resp = requests.get(
+ self.api_url + "/grobid_refs",
+ params=dict(source=f"eq.{source}", source_id=f"eq.{source_id}"),
+ )
resp.raise_for_status()
resp_json = resp.json()
if resp_json:
@@ -230,6 +251,7 @@ class SandcrawlerPostgresClient:
r[k] = r["metadata"].get(k)
r["metadata"].pop(k, None)
r["metadata"] = json.dumps(r["metadata"], sort_keys=True)
+ now = datetime.datetime.now()
rows = [
(
d["key"],
@@ -237,7 +259,7 @@ class SandcrawlerPostgresClient:
d["status_code"],
d["status"],
d.get("fatcat_release") or None,
- d.get("updated") or datetime.datetime.now(),
+ d.get("updated") or now,
d.get("metadata") or None,
)
for d in batch
@@ -356,10 +378,11 @@ class SandcrawlerPostgresClient:
else:
raise NotImplementedError("on_conflict: {}".format(on_conflict))
sql += " RETURNING xmax;"
+ now = datetime.datetime.now()
rows = [
(
d["key"],
- d.get("updated") or datetime.datetime.now(),
+ d.get("updated") or now,
d["status_code"],
d["status"],
d.get("versions", {}).get("pdftrio_version") or None,
@@ -533,3 +556,83 @@ class SandcrawlerPostgresClient:
rows = list(row_dict.values())
resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True)
return self._inserts_and_updates(resp, on_conflict)
+
+ def insert_crossref(
+ self,
+ cur: psycopg2.extensions.cursor,
+ batch: List[Dict[str, Any]],
+ on_conflict: str = "update",
+ ) -> Tuple[int, int]:
+ sql = """
+ INSERT INTO
+ crossref (doi, indexed, record)
+ VALUES %s
+ ON CONFLICT (doi) DO
+ """
+ if on_conflict.lower() == "nothing":
+ sql += " NOTHING"
+ elif on_conflict.lower() == "update":
+ sql += """ UPDATE SET
+ indexed=EXCLUDED.indexed,
+ record=EXCLUDED.record
+ """
+ else:
+ raise NotImplementedError("on_conflict: {}".format(on_conflict))
+ sql += " RETURNING xmax;"
+ rows = [
+ (
+ d["doi"],
+ d.get("indexed") or None,
+ json.dumps(d["record"], sort_keys=True),
+ )
+ for d in batch
+ ]
+ # filter out duplicate rows by key (sha1hex)
+ row_dict = dict()
+ for b in rows:
+ row_dict[b[0]] = b
+ rows = list(row_dict.values())
+ resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True)
+ return self._inserts_and_updates(resp, on_conflict)
+
+ def insert_grobid_refs(
+ self,
+ cur: psycopg2.extensions.cursor,
+ batch: List[Dict[str, Any]],
+ on_conflict: str = "update",
+ ) -> Tuple[int, int]:
+ sql = """
+ INSERT INTO
+ grobid_refs (source, source_id, source_ts, updated, refs_json)
+ VALUES %s
+ ON CONFLICT (source, source_id) DO
+ """
+ if on_conflict.lower() == "nothing":
+ sql += " NOTHING"
+ elif on_conflict.lower() == "update":
+ sql += """ UPDATE SET
+ source_ts=EXCLUDED.source_ts,
+ updated=EXCLUDED.updated,
+ refs_json=EXCLUDED.refs_json
+ """
+ else:
+ raise NotImplementedError("on_conflict: {}".format(on_conflict))
+ sql += " RETURNING xmax;"
+ now = datetime.datetime.now()
+ rows = [
+ (
+ d["source"],
+ d["source_id"],
+ d.get("source_ts") or None,
+ d.get("updated") or now,
+ json.dumps(d["refs_json"], sort_keys=True),
+ )
+ for d in batch
+ ]
+ # filter out duplicate rows by key (sha1hex)
+ row_dict = dict()
+ for b in rows:
+ row_dict[(b[0], b[1])] = b
+ rows = list(row_dict.values())
+ resp = psycopg2.extras.execute_values(cur, sql, rows, page_size=250, fetch=True)
+ return self._inserts_and_updates(resp, on_conflict)