From 8e2fd41d10725b65c787ae56cb9320fbcc182288 Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Thu, 16 Apr 2020 18:34:45 -0700 Subject: try some type annotations --- python/fatcat_tools/importers/common.py | 67 ++++++++++++++++--------------- python/fatcat_tools/importers/crossref.py | 51 +++++++++++++---------- 2 files changed, 63 insertions(+), 55 deletions(-) (limited to 'python/fatcat_tools/importers') diff --git a/python/fatcat_tools/importers/common.py b/python/fatcat_tools/importers/common.py index 680b4f9c..2c4dd496 100644 --- a/python/fatcat_tools/importers/common.py +++ b/python/fatcat_tools/importers/common.py @@ -7,7 +7,7 @@ import sqlite3 import datetime import subprocess from collections import Counter -from typing import Optional, Tuple +from typing import Dict, Any, List, Optional import lxml import xml.etree.ElementTree as ET @@ -26,11 +26,12 @@ import fuzzycat.verify from fatcat_tools.normal import (clean_str as clean, is_cjk, b32_hex, LANG_MAP_MARC) # noqa: F401 from fatcat_tools.transforms import entity_to_dict -DATE_FMT = "%Y-%m-%d" -SANE_MAX_RELEASES = 200 -SANE_MAX_URLS = 100 -DOMAIN_REL_MAP = { +DATE_FMT: str = "%Y-%m-%d" +SANE_MAX_RELEASES: int = 200 +SANE_MAX_URLS: int = 100 + +DOMAIN_REL_MAP: Dict[str, str] = { "archive.org": "archive", # LOCKSS, Portico, DuraSpace, etc would also be "archive" @@ -94,7 +95,7 @@ DOMAIN_REL_MAP = { "archive.is": "webarchive", } -def make_rel_url(raw_url, default_link_rel="web"): +def make_rel_url(raw_url: str, default_link_rel: str = "web"): # this is where we map specific domains to rel types, and also filter out # bad domains, invalid URLs, etc rel = default_link_rel @@ -153,33 +154,33 @@ class EntityImporter: self.api = api self.do_updates = bool(kwargs.get('do_updates', True)) - self.do_fuzzy_match = kwargs.get('do_fuzzy_match', True) - self.bezerk_mode = kwargs.get('bezerk_mode', False) - self.submit_mode = kwargs.get('submit_mode', False) - self.edit_batch_size = kwargs.get('edit_batch_size', 100) - self.editgroup_description = kwargs.get('editgroup_description') - self.editgroup_extra = eg_extra + self.do_fuzzy_match: bool = kwargs.get('do_fuzzy_match', True) + self.bezerk_mode: bool = kwargs.get('bezerk_mode', False) + self.submit_mode: bool = kwargs.get('submit_mode', False) + self.edit_batch_size: int = kwargs.get('edit_batch_size', 100) + self.editgroup_description: Optional[str] = kwargs.get('editgroup_description') + self.editgroup_extra: Optional[Any] = eg_extra self.es_client = kwargs.get('es_client') if not self.es_client: self.es_client = elasticsearch.Elasticsearch("https://search.fatcat.wiki", timeout=120) - self._issnl_id_map = dict() - self._orcid_id_map = dict() + self._issnl_id_map: Dict[str, Any] = dict() + self._orcid_id_map: Dict[str, Any] = dict() self._orcid_regex = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[\dX]$") - self._doi_id_map = dict() - self._pmid_id_map = dict() + self._doi_id_map: Dict[str, Any] = dict() + self._pmid_id_map: Dict[str, Any] = dict() self.reset() - def reset(self): + def reset(self) -> None: self.counts = Counter({'total': 0, 'skip': 0, 'insert': 0, 'update': 0, 'exists': 0}) - self._edit_count = 0 - self._editgroup_id = None - self._entity_queue = [] - self._edits_inflight = [] + self._edit_count: int = 0 + self._editgroup_id: Optional[str] = None + self._entity_queue: List[Any] = [] + self._edits_inflight: List[Any] = [] - def push_record(self, raw_record): + def push_record(self, raw_record: Any) -> None: """ Returns nothing. """ @@ -198,7 +199,7 @@ class EntityImporter: self.push_entity(entity) return - def parse_record(self, raw_record): + def parse_record(self, raw_record: Any) -> Optional[Any]: """ Returns an entity class type, or None if we should skip this one. @@ -282,7 +283,7 @@ class EntityImporter: self.counts['insert'] += len(self._entity_queue) self._entity_queue = [] - def want(self, raw_record): + def want(self, raw_record: Any) -> bool: """ Implementations can override for optional fast-path to drop a record. Must have no side-effects; returns bool. @@ -302,14 +303,14 @@ class EntityImporter: """ raise NotImplementedError - def insert_batch(self, raw_record): + def insert_batch(self, raw_records: List[Any]): raise NotImplementedError - def is_orcid(self, orcid): + def is_orcid(self, orcid: str) -> bool: # TODO: replace with clean_orcid() from fatcat_tools.normal return self._orcid_regex.match(orcid) is not None - def lookup_orcid(self, orcid): + def lookup_orcid(self, orcid: str): """Caches calls to the Orcid lookup API endpoint in a local dict""" if not self.is_orcid(orcid): return None @@ -326,11 +327,11 @@ class EntityImporter: self._orcid_id_map[orcid] = creator_id # might be None return creator_id - def is_doi(self, doi): + def is_doi(self, doi: str) -> bool: # TODO: replace with clean_doi() from fatcat_tools.normal return doi.startswith("10.") and doi.count("/") >= 1 - def lookup_doi(self, doi): + def lookup_doi(self, doi: str): """Caches calls to the doi lookup API endpoint in a local dict For identifier lookups only (not full object fetches)""" @@ -349,7 +350,7 @@ class EntityImporter: self._doi_id_map[doi] = release_id # might be None return release_id - def lookup_pmid(self, pmid): + def lookup_pmid(self, pmid: str): """Caches calls to the pmid lookup API endpoint in a local dict For identifier lookups only (not full object fetches)""" @@ -366,10 +367,10 @@ class EntityImporter: self._pmid_id_map[pmid] = release_id # might be None return release_id - def is_issnl(self, issnl): + def is_issnl(self, issnl: str) -> bool: return len(issnl) == 9 and issnl[4] == '-' - def lookup_issnl(self, issnl): + def lookup_issnl(self, issnl: str): """Caches calls to the ISSN-L lookup API endpoint in a local dict""" if issnl in self._issnl_id_map: return self._issnl_id_map[issnl] @@ -396,7 +397,7 @@ class EntityImporter: self._issn_issnl_map[issnl] = issnl print("Got {} ISSN-L mappings.".format(len(self._issn_issnl_map)), file=sys.stderr) - def issn2issnl(self, issn): + def issn2issnl(self, issn: str) -> Optional[str]: if issn is None: return None return self._issn_issnl_map.get(issn) diff --git a/python/fatcat_tools/importers/crossref.py b/python/fatcat_tools/importers/crossref.py index e77fa65e..d4b4a4c7 100644 --- a/python/fatcat_tools/importers/crossref.py +++ b/python/fatcat_tools/importers/crossref.py @@ -9,7 +9,7 @@ from .common import EntityImporter, clean # first # Can get a list of Crossref types (with counts) via API: # https://api.crossref.org/works?rows=0&facet=type-name:* -CROSSREF_TYPE_MAP = { +CROSSREF_TYPE_MAP: Dict[str, Optional[str]] = { 'book': 'book', 'book-chapter': 'chapter', 'book-part': 'chapter', @@ -30,7 +30,7 @@ CROSSREF_TYPE_MAP = { 'standard': 'standard', } -CONTAINER_TYPE_MAP = { +CONTAINER_TYPE_MAP: Dict[str, str] = { 'article-journal': 'journal', 'paper-conference': 'conference', 'book': 'book-series', @@ -41,7 +41,7 @@ CONTAINER_TYPE_MAP = { # popular are here; many were variants of the CC URLs. Would be useful to # normalize CC licenses better. # The current norm is to only add license slugs that are at least partially OA. -LICENSE_SLUG_MAP = { +LICENSE_SLUG_MAP: Dict[str, str] = { "//creativecommons.org/publicdomain/mark/1.0": "CC-0", "//creativecommons.org/publicdomain/mark/1.0/": "CC-0", "//creativecommons.org/publicdomain/mark/1.0/deed.de": "CC-0", @@ -87,7 +87,7 @@ LICENSE_SLUG_MAP = { "//arxiv.org/licenses/nonexclusive-distrib/1.0/": "ARXIV-1.0", } -def lookup_license_slug(raw): +def lookup_license_slug(raw: str) -> Optional[str]: if not raw: return None raw = raw.strip().replace('http://', '//').replace('https://', '//') @@ -121,9 +121,9 @@ class CrossrefImporter(EntityImporter): def __init__(self, api, issn_map_file, **kwargs): - eg_desc = kwargs.get('editgroup_description', + eg_desc: Optional[str] = kwargs.get('editgroup_description', "Automated import of Crossref DOI metadata, harvested from REST API") - eg_extra = kwargs.get('editgroup_extra', dict()) + eg_extra: Optional[dict] = kwargs.get('editgroup_extra', dict()) eg_extra['agent'] = eg_extra.get('agent', 'fatcat_tools.CrossrefImporter') super().__init__(api, issn_map_file=issn_map_file, @@ -131,9 +131,9 @@ class CrossrefImporter(EntityImporter): editgroup_extra=eg_extra, **kwargs) - self.create_containers = kwargs.get('create_containers', True) + self.create_containers: bool = kwargs.get('create_containers', True) extid_map_file = kwargs.get('extid_map_file') - self.extid_map_db = None + self.extid_map_db: Optional[Any] = None if extid_map_file: db_uri = "file:{}?mode=ro".format(extid_map_file) print("Using external ID map: {}".format(db_uri)) @@ -143,7 +143,7 @@ class CrossrefImporter(EntityImporter): self.read_issn_map_file(issn_map_file) - def lookup_ext_ids(self, doi): + def lookup_ext_ids(self, doi: str) -> Optional[Any]: if self.extid_map_db is None: return dict(core_id=None, pmid=None, pmcid=None, wikidata_qid=None, arxiv_id=None, jstor_id=None) row = self.extid_map_db.execute("SELECT core, pmid, pmcid, wikidata FROM ids WHERE doi=? LIMIT 1", @@ -161,20 +161,23 @@ class CrossrefImporter(EntityImporter): jstor_id=None, ) - def map_release_type(self, crossref_type): + def map_release_type(self, crossref_type: str) -> Optional[str]: return CROSSREF_TYPE_MAP.get(crossref_type) - def map_container_type(self, crossref_type): + def map_container_type(self, crossref_type: Optional[str]) -> Optional[str]: + if not crossref_type: + return None return CONTAINER_TYPE_MAP.get(crossref_type) - def want(self, obj): + def want(self, obj: Dict[str, Any]) -> bool: if not obj.get('title'): self.counts['skip-blank-title'] += 1 return False # these are pre-registered DOIs before the actual record is ready # title is a list of titles - if obj.get('title')[0].strip().lower() in [ + titles = obj.get('title') + if titles is not None and titles[0].strip().lower() in [ "OUP accepted manuscript".lower(), ]: self.counts['skip-stub-title'] += 1 @@ -183,7 +186,7 @@ class CrossrefImporter(EntityImporter): # do most of these checks in-line below return True - def parse_record(self, obj): + def parse_record(self, obj: Dict[str, Any]) -> Optional[ReleaseEntity]: """ obj is a python dict (parsed from json). returns a ReleaseEntity @@ -292,14 +295,15 @@ class CrossrefImporter(EntityImporter): refs = [] for i, rm in enumerate(obj.get('reference', [])): try: - year = int(rm.get('year')) + year: Optional[int] = int(rm.get('year')) # TODO: will need to update/config in the future! # NOTE: are there crossref works with year < 100? - if year > 2025 or year < 100: - year = None + if year is not None: + if year > 2025 or year < 100: + year = None except (TypeError, ValueError): year = None - ref_extra = dict() + ref_extra: Dict[str, Any] = dict() key = rm.get('key') if key and key.startswith(obj['DOI'].upper()): key = key.replace(obj['DOI'].upper() + "-", '') @@ -394,7 +398,7 @@ class CrossrefImporter(EntityImporter): release_stage = None # external identifiers - extids = self.lookup_ext_ids(doi=obj['DOI'].lower()) + extids: Dict[str, Any] = self.lookup_ext_ids(doi=obj['DOI'].lower()) # filter out unreasonably huge releases if len(abstracts) > 100: @@ -421,11 +425,14 @@ class CrossrefImporter(EntityImporter): release_year = raw_date[0] release_date = None - original_title = None + + original_title: Optional[str] = None if obj.get('original-title'): - original_title = clean(obj.get('original-title')[0], force_xml=True) + ot = obj.get('original-title') + if ot is not None: + original_title = clean(ot[0], force_xml=True) - title = None + title: Optional[str] = None if obj.get('title'): title = clean(obj.get('title')[0], force_xml=True) if not title or len(title) <= 1: -- cgit v1.2.3