summaryrefslogtreecommitdiffstats
path: root/python/fatcat_tools/importers/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/fatcat_tools/importers/common.py')
-rw-r--r--python/fatcat_tools/importers/common.py146
1 files changed, 99 insertions, 47 deletions
diff --git a/python/fatcat_tools/importers/common.py b/python/fatcat_tools/importers/common.py
index 0b68e5fe..fd472d11 100644
--- a/python/fatcat_tools/importers/common.py
+++ b/python/fatcat_tools/importers/common.py
@@ -7,7 +7,7 @@ import subprocess
import sys
import xml.etree.ElementTree as ET
from collections import Counter
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Sequence, Tuple
import elasticsearch
import fatcat_openapi_client
@@ -16,7 +16,14 @@ import fuzzycat.verify
import lxml
from bs4 import BeautifulSoup
from confluent_kafka import Consumer, KafkaException
-from fatcat_openapi_client import ReleaseEntity
+from fatcat_openapi_client import (
+ ApiClient,
+ ContainerEntity,
+ EntityEdit,
+ FileEntity,
+ FilesetEntity,
+ ReleaseEntity,
+)
from fatcat_openapi_client.rest import ApiException
from fuzzycat.matching import match_release_fuzzy
@@ -90,7 +97,7 @@ DOMAIN_REL_MAP: Dict[str, str] = {
}
-def make_rel_url(raw_url: str, default_link_rel: str = "web"):
+def make_rel_url(raw_url: str, default_link_rel: str = "web") -> Tuple[str, str]:
# this is where we map specific domains to rel types, and also filter out
# bad domains, invalid URLs, etc
rel = default_link_rel
@@ -101,7 +108,7 @@ def make_rel_url(raw_url: str, default_link_rel: str = "web"):
return (rel, raw_url)
-def test_make_rel_url():
+def test_make_rel_url() -> None:
assert make_rel_url("http://example.com/thing.pdf")[0] == "web"
assert make_rel_url("http://example.com/thing.pdf", default_link_rel="jeans")[0] == "jeans"
assert (
@@ -145,7 +152,7 @@ class EntityImporter:
implementors must write insert_batch appropriately
"""
- def __init__(self, api, **kwargs):
+ def __init__(self, api: ApiClient, **kwargs) -> None:
eg_extra = kwargs.get("editgroup_extra", dict())
eg_extra["git_rev"] = eg_extra.get(
@@ -212,7 +219,7 @@ class EntityImporter:
# implementations should fill this in
raise NotImplementedError
- def finish(self):
+ def finish(self) -> Counter:
"""
Gets called as cleanup at the end of imports, but can also be called at
any time to "snip off" current editgroup progress. In other words, safe
@@ -238,7 +245,7 @@ class EntityImporter:
return self.counts
- def get_editgroup_id(self, edits=1):
+ def get_editgroup_id(self, edits: int = 1) -> str:
if self._edit_count >= self.edit_batch_size:
if self.submit_mode:
self.api.submit_editgroup(self._editgroup_id)
@@ -257,30 +264,31 @@ class EntityImporter:
self._editgroup_id = eg.editgroup_id
self._edit_count += edits
+ assert self._editgroup_id
return self._editgroup_id
- def create_container(self, entity):
+ def create_container(self, entity: ContainerEntity) -> EntityEdit:
eg_id = self.get_editgroup_id()
self.counts["inserted.container"] += 1
return self.api.create_container(eg_id, entity)
- def create_release(self, entity):
+ def create_release(self, entity: ReleaseEntity) -> EntityEdit:
eg_id = self.get_editgroup_id()
self.counts["inserted.release"] += 1
return self.api.create_release(eg_id, entity)
- def create_file(self, entity):
+ def create_file(self, entity: FileEntity) -> EntityEdit:
eg_id = self.get_editgroup_id()
self.counts["inserted.file"] += 1
return self.api.create_file(eg_id, entity)
- def updated(self):
+ def updated(self) -> None:
"""
Implementations should call this from try_update() if the update was successful
"""
self.counts["update"] += 1
- def push_entity(self, entity):
+ def push_entity(self, entity: Any) -> None:
self._entity_queue.append(entity)
if len(self._entity_queue) >= self.edit_batch_size:
self.insert_batch(self._entity_queue)
@@ -294,7 +302,7 @@ class EntityImporter:
"""
return True
- def try_update(self, raw_record):
+ def try_update(self, raw_record: Any) -> Optional[bool]:
"""
Passed the output of parse_record(). Should try to find an existing
entity and update it (PUT), decide we should do nothing (based on the
@@ -307,15 +315,17 @@ class EntityImporter:
"""
raise NotImplementedError
- def insert_batch(self, raw_records: List[Any]):
+ def insert_batch(self, raw_records: List[Any]) -> None:
raise NotImplementedError
def is_orcid(self, orcid: str) -> bool:
# TODO: replace with clean_orcid() from fatcat_tools.normal
return self._orcid_regex.match(orcid) is not None
- def lookup_orcid(self, orcid: str):
- """Caches calls to the Orcid lookup API endpoint in a local dict"""
+ def lookup_orcid(self, orcid: str) -> Optional[str]:
+ """Caches calls to the Orcid lookup API endpoint in a local dict.
+
+ Returns a creator fatcat ident if found, else None"""
if not self.is_orcid(orcid):
return None
if orcid in self._orcid_id_map:
@@ -335,7 +345,7 @@ class EntityImporter:
# TODO: replace with clean_doi() from fatcat_tools.normal
return doi.startswith("10.") and doi.count("/") >= 1
- def lookup_doi(self, doi: str):
+ def lookup_doi(self, doi: str) -> Optional[str]:
"""Caches calls to the doi lookup API endpoint in a local dict
For identifier lookups only (not full object fetches)"""
@@ -354,7 +364,7 @@ class EntityImporter:
self._doi_id_map[doi] = release_id # might be None
return release_id
- def lookup_pmid(self, pmid: str):
+ def lookup_pmid(self, pmid: str) -> Optional[str]:
"""Caches calls to the pmid lookup API endpoint in a local dict
For identifier lookups only (not full object fetches)"""
@@ -374,7 +384,7 @@ class EntityImporter:
def is_issnl(self, issnl: str) -> bool:
return len(issnl) == 9 and issnl[4] == "-"
- def lookup_issnl(self, issnl: str):
+ def lookup_issnl(self, issnl: str) -> Optional[str]:
"""Caches calls to the ISSN-L lookup API endpoint in a local dict"""
if issnl in self._issnl_id_map:
return self._issnl_id_map[issnl]
@@ -389,7 +399,7 @@ class EntityImporter:
self._issnl_id_map[issnl] = container_id # might be None
return container_id
- def read_issn_map_file(self, issn_map_file):
+ def read_issn_map_file(self, issn_map_file: Sequence) -> None:
print("Loading ISSN map file...", file=sys.stderr)
self._issn_issnl_map = dict()
for line in issn_map_file:
@@ -407,7 +417,7 @@ class EntityImporter:
return self._issn_issnl_map.get(issn)
@staticmethod
- def generic_file_cleanups(existing):
+ def generic_file_cleanups(existing: FileEntity) -> FileEntity:
"""
Conservative cleanup of existing file entities.
@@ -453,7 +463,7 @@ class EntityImporter:
return existing
@staticmethod
- def generic_fileset_cleanups(existing):
+ def generic_fileset_cleanups(existing: FilesetEntity) -> FilesetEntity:
return existing
def match_existing_release_fuzzy(
@@ -520,10 +530,10 @@ class RecordPusher:
wraps an importer and pushes records in to it.
"""
- def __init__(self, importer, **kwargs):
+ def __init__(self, importer: EntityImporter, **kwargs) -> None:
self.importer = importer
- def run(self):
+ def run(self) -> Counter:
"""
This will look something like:
@@ -536,11 +546,11 @@ class RecordPusher:
class JsonLinePusher(RecordPusher):
- def __init__(self, importer, json_file, **kwargs):
+ def __init__(self, importer: EntityImporter, json_file: Sequence, **kwargs) -> None:
self.importer = importer
self.json_file = json_file
- def run(self):
+ def run(self) -> Counter:
for line in self.json_file:
if not line:
continue
@@ -552,11 +562,11 @@ class JsonLinePusher(RecordPusher):
class CsvPusher(RecordPusher):
- def __init__(self, importer, csv_file, **kwargs):
+ def __init__(self, importer: EntityImporter, csv_file: Any, **kwargs) -> None:
self.importer = importer
self.reader = csv.DictReader(csv_file, delimiter=kwargs.get("delimiter", ","))
- def run(self):
+ def run(self) -> Counter:
for line in self.reader:
if not line:
continue
@@ -567,11 +577,11 @@ class CsvPusher(RecordPusher):
class LinePusher(RecordPusher):
- def __init__(self, importer, text_file, **kwargs):
+ def __init__(self, importer: EntityImporter, text_file: Sequence, **kwargs) -> None:
self.importer = importer
self.text_file = text_file
- def run(self):
+ def run(self) -> Counter:
for line in self.text_file:
if not line:
continue
@@ -582,14 +592,21 @@ class LinePusher(RecordPusher):
class SqlitePusher(RecordPusher):
- def __init__(self, importer, db_file, table_name, where_clause="", **kwargs):
+ def __init__(
+ self,
+ importer: EntityImporter,
+ db_file: str,
+ table_name: str,
+ where_clause: str = "",
+ **kwargs
+ ) -> None:
self.importer = importer
self.db = sqlite3.connect(db_file, isolation_level="EXCLUSIVE")
self.db.row_factory = sqlite3.Row
self.table_name = table_name
self.where_clause = where_clause
- def run(self):
+ def run(self) -> Counter:
cur = self.db.execute("SELECT * FROM {} {};".format(self.table_name, self.where_clause))
for row in cur:
self.importer.push_record(row)
@@ -599,12 +616,18 @@ class SqlitePusher(RecordPusher):
class Bs4XmlLinesPusher(RecordPusher):
- def __init__(self, importer, xml_file, prefix_filter=None, **kwargs):
+ def __init__(
+ self,
+ importer: EntityImporter,
+ xml_file: Sequence,
+ prefix_filter: Optional[str] = None,
+ **kwargs
+ ) -> None:
self.importer = importer
self.xml_file = xml_file
self.prefix_filter = prefix_filter
- def run(self):
+ def run(self) -> Counter:
for line in self.xml_file:
if not line:
continue
@@ -619,12 +642,14 @@ class Bs4XmlLinesPusher(RecordPusher):
class Bs4XmlFilePusher(RecordPusher):
- def __init__(self, importer, xml_file, record_tag, **kwargs):
+ def __init__(
+ self, importer: EntityImporter, xml_file: Any, record_tag: str, **kwargs
+ ) -> None:
self.importer = importer
self.xml_file = xml_file
self.record_tag = record_tag
- def run(self):
+ def run(self) -> Counter:
soup = BeautifulSoup(self.xml_file, "xml")
for record in soup.find_all(self.record_tag):
self.importer.push_record(record)
@@ -654,13 +679,20 @@ class Bs4XmlLargeFilePusher(RecordPusher):
by inner container/release API lookup caches.
"""
- def __init__(self, importer, xml_file, record_tags, use_lxml=False, **kwargs):
+ def __init__(
+ self,
+ importer: EntityImporter,
+ xml_file: Any,
+ record_tags: List[str],
+ use_lxml: bool = False,
+ **kwargs
+ ) -> None:
self.importer = importer
self.xml_file = xml_file
self.record_tags = record_tags
self.use_lxml = use_lxml
- def run(self):
+ def run(self) -> Counter:
if self.use_lxml:
elem_iter = lxml.etree.iterparse(self.xml_file, ["start", "end"], load_dtd=True)
else:
@@ -691,12 +723,14 @@ class Bs4XmlLargeFilePusher(RecordPusher):
class Bs4XmlFileListPusher(RecordPusher):
- def __init__(self, importer, list_file, record_tag, **kwargs):
+ def __init__(
+ self, importer: EntityImporter, list_file: Sequence, record_tag: str, **kwargs
+ ) -> None:
self.importer = importer
self.list_file = list_file
self.record_tag = record_tag
- def run(self):
+ def run(self) -> Counter:
for xml_path in self.list_file:
xml_path = xml_path.strip()
if not xml_path or xml_path.startswith("#"):
@@ -717,7 +751,15 @@ class KafkaBs4XmlPusher(RecordPusher):
Fetch XML for an article from Kafka, parse via Bs4.
"""
- def __init__(self, importer, kafka_hosts, kafka_env, topic_suffix, group, **kwargs):
+ def __init__(
+ self,
+ importer: EntityImporter,
+ kafka_hosts: str,
+ kafka_env: str,
+ topic_suffix: str,
+ group: str,
+ **kwargs
+ ) -> None:
self.importer = importer
self.consumer = make_kafka_consumer(
kafka_hosts,
@@ -729,7 +771,7 @@ class KafkaBs4XmlPusher(RecordPusher):
self.poll_interval = kwargs.get("poll_interval", 5.0)
self.consume_batch_size = kwargs.get("consume_batch_size", 25)
- def run(self):
+ def run(self) -> Counter:
count = 0
last_push = datetime.datetime.now()
while True:
@@ -784,7 +826,15 @@ class KafkaBs4XmlPusher(RecordPusher):
class KafkaJsonPusher(RecordPusher):
- def __init__(self, importer, kafka_hosts, kafka_env, topic_suffix, group, **kwargs):
+ def __init__(
+ self,
+ importer: EntityImporter,
+ kafka_hosts: str,
+ kafka_env: str,
+ topic_suffix: str,
+ group: str,
+ **kwargs
+ ) -> None:
self.importer = importer
self.consumer = make_kafka_consumer(
kafka_hosts,
@@ -797,7 +847,7 @@ class KafkaJsonPusher(RecordPusher):
self.consume_batch_size = kwargs.get("consume_batch_size", 100)
self.force_flush = kwargs.get("force_flush", False)
- def run(self):
+ def run(self) -> Counter:
count = 0
last_push = datetime.datetime.now()
last_force_flush = datetime.datetime.now()
@@ -862,10 +912,12 @@ class KafkaJsonPusher(RecordPusher):
return counts
-def make_kafka_consumer(hosts, env, topic_suffix, group, kafka_namespace="fatcat"):
+def make_kafka_consumer(
+ hosts: str, env: str, topic_suffix: str, group: str, kafka_namespace: str = "fatcat"
+) -> Consumer:
topic_name = "{}-{}.{}".format(kafka_namespace, env, topic_suffix)
- def fail_fast(err, partitions):
+ def fail_fast(err: Any, partitions: List[Any]) -> None:
if err is not None:
print("Kafka consumer commit error: {}".format(err))
print("Bailing out...")
@@ -900,7 +952,7 @@ def make_kafka_consumer(hosts, env, topic_suffix, group, kafka_namespace="fatcat
},
}
- def on_rebalance(consumer, partitions):
+ def on_rebalance(consumer: Consumer, partitions: List[Any]) -> None:
for p in partitions:
if p.error:
raise KafkaException(p.error)