diff options
Diffstat (limited to 'python/fatcat_tools/importers/common.py')
-rw-r--r-- | python/fatcat_tools/importers/common.py | 146 |
1 files changed, 99 insertions, 47 deletions
diff --git a/python/fatcat_tools/importers/common.py b/python/fatcat_tools/importers/common.py index 0b68e5fe..fd472d11 100644 --- a/python/fatcat_tools/importers/common.py +++ b/python/fatcat_tools/importers/common.py @@ -7,7 +7,7 @@ import subprocess import sys import xml.etree.ElementTree as ET from collections import Counter -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import elasticsearch import fatcat_openapi_client @@ -16,7 +16,14 @@ import fuzzycat.verify import lxml from bs4 import BeautifulSoup from confluent_kafka import Consumer, KafkaException -from fatcat_openapi_client import ReleaseEntity +from fatcat_openapi_client import ( + ApiClient, + ContainerEntity, + EntityEdit, + FileEntity, + FilesetEntity, + ReleaseEntity, +) from fatcat_openapi_client.rest import ApiException from fuzzycat.matching import match_release_fuzzy @@ -90,7 +97,7 @@ DOMAIN_REL_MAP: Dict[str, str] = { } -def make_rel_url(raw_url: str, default_link_rel: str = "web"): +def make_rel_url(raw_url: str, default_link_rel: str = "web") -> Tuple[str, str]: # this is where we map specific domains to rel types, and also filter out # bad domains, invalid URLs, etc rel = default_link_rel @@ -101,7 +108,7 @@ def make_rel_url(raw_url: str, default_link_rel: str = "web"): return (rel, raw_url) -def test_make_rel_url(): +def test_make_rel_url() -> None: assert make_rel_url("http://example.com/thing.pdf")[0] == "web" assert make_rel_url("http://example.com/thing.pdf", default_link_rel="jeans")[0] == "jeans" assert ( @@ -145,7 +152,7 @@ class EntityImporter: implementors must write insert_batch appropriately """ - def __init__(self, api, **kwargs): + def __init__(self, api: ApiClient, **kwargs) -> None: eg_extra = kwargs.get("editgroup_extra", dict()) eg_extra["git_rev"] = eg_extra.get( @@ -212,7 +219,7 @@ class EntityImporter: # implementations should fill this in raise NotImplementedError - def finish(self): + def finish(self) -> Counter: """ Gets called as cleanup at the end of imports, but can also be called at any time to "snip off" current editgroup progress. In other words, safe @@ -238,7 +245,7 @@ class EntityImporter: return self.counts - def get_editgroup_id(self, edits=1): + def get_editgroup_id(self, edits: int = 1) -> str: if self._edit_count >= self.edit_batch_size: if self.submit_mode: self.api.submit_editgroup(self._editgroup_id) @@ -257,30 +264,31 @@ class EntityImporter: self._editgroup_id = eg.editgroup_id self._edit_count += edits + assert self._editgroup_id return self._editgroup_id - def create_container(self, entity): + def create_container(self, entity: ContainerEntity) -> EntityEdit: eg_id = self.get_editgroup_id() self.counts["inserted.container"] += 1 return self.api.create_container(eg_id, entity) - def create_release(self, entity): + def create_release(self, entity: ReleaseEntity) -> EntityEdit: eg_id = self.get_editgroup_id() self.counts["inserted.release"] += 1 return self.api.create_release(eg_id, entity) - def create_file(self, entity): + def create_file(self, entity: FileEntity) -> EntityEdit: eg_id = self.get_editgroup_id() self.counts["inserted.file"] += 1 return self.api.create_file(eg_id, entity) - def updated(self): + def updated(self) -> None: """ Implementations should call this from try_update() if the update was successful """ self.counts["update"] += 1 - def push_entity(self, entity): + def push_entity(self, entity: Any) -> None: self._entity_queue.append(entity) if len(self._entity_queue) >= self.edit_batch_size: self.insert_batch(self._entity_queue) @@ -294,7 +302,7 @@ class EntityImporter: """ return True - def try_update(self, raw_record): + def try_update(self, raw_record: Any) -> Optional[bool]: """ Passed the output of parse_record(). Should try to find an existing entity and update it (PUT), decide we should do nothing (based on the @@ -307,15 +315,17 @@ class EntityImporter: """ raise NotImplementedError - def insert_batch(self, raw_records: List[Any]): + def insert_batch(self, raw_records: List[Any]) -> None: raise NotImplementedError def is_orcid(self, orcid: str) -> bool: # TODO: replace with clean_orcid() from fatcat_tools.normal return self._orcid_regex.match(orcid) is not None - def lookup_orcid(self, orcid: str): - """Caches calls to the Orcid lookup API endpoint in a local dict""" + def lookup_orcid(self, orcid: str) -> Optional[str]: + """Caches calls to the Orcid lookup API endpoint in a local dict. + + Returns a creator fatcat ident if found, else None""" if not self.is_orcid(orcid): return None if orcid in self._orcid_id_map: @@ -335,7 +345,7 @@ class EntityImporter: # TODO: replace with clean_doi() from fatcat_tools.normal return doi.startswith("10.") and doi.count("/") >= 1 - def lookup_doi(self, doi: str): + def lookup_doi(self, doi: str) -> Optional[str]: """Caches calls to the doi lookup API endpoint in a local dict For identifier lookups only (not full object fetches)""" @@ -354,7 +364,7 @@ class EntityImporter: self._doi_id_map[doi] = release_id # might be None return release_id - def lookup_pmid(self, pmid: str): + def lookup_pmid(self, pmid: str) -> Optional[str]: """Caches calls to the pmid lookup API endpoint in a local dict For identifier lookups only (not full object fetches)""" @@ -374,7 +384,7 @@ class EntityImporter: def is_issnl(self, issnl: str) -> bool: return len(issnl) == 9 and issnl[4] == "-" - def lookup_issnl(self, issnl: str): + def lookup_issnl(self, issnl: str) -> Optional[str]: """Caches calls to the ISSN-L lookup API endpoint in a local dict""" if issnl in self._issnl_id_map: return self._issnl_id_map[issnl] @@ -389,7 +399,7 @@ class EntityImporter: self._issnl_id_map[issnl] = container_id # might be None return container_id - def read_issn_map_file(self, issn_map_file): + def read_issn_map_file(self, issn_map_file: Sequence) -> None: print("Loading ISSN map file...", file=sys.stderr) self._issn_issnl_map = dict() for line in issn_map_file: @@ -407,7 +417,7 @@ class EntityImporter: return self._issn_issnl_map.get(issn) @staticmethod - def generic_file_cleanups(existing): + def generic_file_cleanups(existing: FileEntity) -> FileEntity: """ Conservative cleanup of existing file entities. @@ -453,7 +463,7 @@ class EntityImporter: return existing @staticmethod - def generic_fileset_cleanups(existing): + def generic_fileset_cleanups(existing: FilesetEntity) -> FilesetEntity: return existing def match_existing_release_fuzzy( @@ -520,10 +530,10 @@ class RecordPusher: wraps an importer and pushes records in to it. """ - def __init__(self, importer, **kwargs): + def __init__(self, importer: EntityImporter, **kwargs) -> None: self.importer = importer - def run(self): + def run(self) -> Counter: """ This will look something like: @@ -536,11 +546,11 @@ class RecordPusher: class JsonLinePusher(RecordPusher): - def __init__(self, importer, json_file, **kwargs): + def __init__(self, importer: EntityImporter, json_file: Sequence, **kwargs) -> None: self.importer = importer self.json_file = json_file - def run(self): + def run(self) -> Counter: for line in self.json_file: if not line: continue @@ -552,11 +562,11 @@ class JsonLinePusher(RecordPusher): class CsvPusher(RecordPusher): - def __init__(self, importer, csv_file, **kwargs): + def __init__(self, importer: EntityImporter, csv_file: Any, **kwargs) -> None: self.importer = importer self.reader = csv.DictReader(csv_file, delimiter=kwargs.get("delimiter", ",")) - def run(self): + def run(self) -> Counter: for line in self.reader: if not line: continue @@ -567,11 +577,11 @@ class CsvPusher(RecordPusher): class LinePusher(RecordPusher): - def __init__(self, importer, text_file, **kwargs): + def __init__(self, importer: EntityImporter, text_file: Sequence, **kwargs) -> None: self.importer = importer self.text_file = text_file - def run(self): + def run(self) -> Counter: for line in self.text_file: if not line: continue @@ -582,14 +592,21 @@ class LinePusher(RecordPusher): class SqlitePusher(RecordPusher): - def __init__(self, importer, db_file, table_name, where_clause="", **kwargs): + def __init__( + self, + importer: EntityImporter, + db_file: str, + table_name: str, + where_clause: str = "", + **kwargs + ) -> None: self.importer = importer self.db = sqlite3.connect(db_file, isolation_level="EXCLUSIVE") self.db.row_factory = sqlite3.Row self.table_name = table_name self.where_clause = where_clause - def run(self): + def run(self) -> Counter: cur = self.db.execute("SELECT * FROM {} {};".format(self.table_name, self.where_clause)) for row in cur: self.importer.push_record(row) @@ -599,12 +616,18 @@ class SqlitePusher(RecordPusher): class Bs4XmlLinesPusher(RecordPusher): - def __init__(self, importer, xml_file, prefix_filter=None, **kwargs): + def __init__( + self, + importer: EntityImporter, + xml_file: Sequence, + prefix_filter: Optional[str] = None, + **kwargs + ) -> None: self.importer = importer self.xml_file = xml_file self.prefix_filter = prefix_filter - def run(self): + def run(self) -> Counter: for line in self.xml_file: if not line: continue @@ -619,12 +642,14 @@ class Bs4XmlLinesPusher(RecordPusher): class Bs4XmlFilePusher(RecordPusher): - def __init__(self, importer, xml_file, record_tag, **kwargs): + def __init__( + self, importer: EntityImporter, xml_file: Any, record_tag: str, **kwargs + ) -> None: self.importer = importer self.xml_file = xml_file self.record_tag = record_tag - def run(self): + def run(self) -> Counter: soup = BeautifulSoup(self.xml_file, "xml") for record in soup.find_all(self.record_tag): self.importer.push_record(record) @@ -654,13 +679,20 @@ class Bs4XmlLargeFilePusher(RecordPusher): by inner container/release API lookup caches. """ - def __init__(self, importer, xml_file, record_tags, use_lxml=False, **kwargs): + def __init__( + self, + importer: EntityImporter, + xml_file: Any, + record_tags: List[str], + use_lxml: bool = False, + **kwargs + ) -> None: self.importer = importer self.xml_file = xml_file self.record_tags = record_tags self.use_lxml = use_lxml - def run(self): + def run(self) -> Counter: if self.use_lxml: elem_iter = lxml.etree.iterparse(self.xml_file, ["start", "end"], load_dtd=True) else: @@ -691,12 +723,14 @@ class Bs4XmlLargeFilePusher(RecordPusher): class Bs4XmlFileListPusher(RecordPusher): - def __init__(self, importer, list_file, record_tag, **kwargs): + def __init__( + self, importer: EntityImporter, list_file: Sequence, record_tag: str, **kwargs + ) -> None: self.importer = importer self.list_file = list_file self.record_tag = record_tag - def run(self): + def run(self) -> Counter: for xml_path in self.list_file: xml_path = xml_path.strip() if not xml_path or xml_path.startswith("#"): @@ -717,7 +751,15 @@ class KafkaBs4XmlPusher(RecordPusher): Fetch XML for an article from Kafka, parse via Bs4. """ - def __init__(self, importer, kafka_hosts, kafka_env, topic_suffix, group, **kwargs): + def __init__( + self, + importer: EntityImporter, + kafka_hosts: str, + kafka_env: str, + topic_suffix: str, + group: str, + **kwargs + ) -> None: self.importer = importer self.consumer = make_kafka_consumer( kafka_hosts, @@ -729,7 +771,7 @@ class KafkaBs4XmlPusher(RecordPusher): self.poll_interval = kwargs.get("poll_interval", 5.0) self.consume_batch_size = kwargs.get("consume_batch_size", 25) - def run(self): + def run(self) -> Counter: count = 0 last_push = datetime.datetime.now() while True: @@ -784,7 +826,15 @@ class KafkaBs4XmlPusher(RecordPusher): class KafkaJsonPusher(RecordPusher): - def __init__(self, importer, kafka_hosts, kafka_env, topic_suffix, group, **kwargs): + def __init__( + self, + importer: EntityImporter, + kafka_hosts: str, + kafka_env: str, + topic_suffix: str, + group: str, + **kwargs + ) -> None: self.importer = importer self.consumer = make_kafka_consumer( kafka_hosts, @@ -797,7 +847,7 @@ class KafkaJsonPusher(RecordPusher): self.consume_batch_size = kwargs.get("consume_batch_size", 100) self.force_flush = kwargs.get("force_flush", False) - def run(self): + def run(self) -> Counter: count = 0 last_push = datetime.datetime.now() last_force_flush = datetime.datetime.now() @@ -862,10 +912,12 @@ class KafkaJsonPusher(RecordPusher): return counts -def make_kafka_consumer(hosts, env, topic_suffix, group, kafka_namespace="fatcat"): +def make_kafka_consumer( + hosts: str, env: str, topic_suffix: str, group: str, kafka_namespace: str = "fatcat" +) -> Consumer: topic_name = "{}-{}.{}".format(kafka_namespace, env, topic_suffix) - def fail_fast(err, partitions): + def fail_fast(err: Any, partitions: List[Any]) -> None: if err is not None: print("Kafka consumer commit error: {}".format(err)) print("Bailing out...") @@ -900,7 +952,7 @@ def make_kafka_consumer(hosts, env, topic_suffix, group, kafka_namespace="fatcat }, } - def on_rebalance(consumer, partitions): + def on_rebalance(consumer: Consumer, partitions: List[Any]) -> None: for p in partitions: if p.error: raise KafkaException(p.error) |