import re import sys import csv import json import ftfy import sqlite3 import itertools import subprocess from collections import Counter import pykafka import fatcat_client from fatcat_client.rest import ApiException def clean(thing, force_xml=False): """ This function is appropriate to be called on any random, non-markup string, such as author names, titles, etc. It will try to clean up commong unicode mangles, HTML characters, etc. This will detect XML/HTML and "do the right thing" (aka, not remove entities like '&' if there are tags in the string), unless you pass the 'force_xml' parameter, which might be appropriate for, eg, names and titles, which generally should be projected down to plain text. Also strips extra whitespace. """ if not thing: return None fix_entities = 'auto' if force_xml: fix_entities = True fixed = ftfy.fix_text(thing, fix_entities=fix_entities).strip() if not fixed or len(fixed) <= 1: # wasn't zero-length before, but is now; return None return None return fixed def test_clean(): assert clean(None) == None assert clean('') == None assert clean('1') == None assert clean('123') == '123' assert clean('a&b') == 'a&b' assert clean('a&b') == 'a&b' assert clean('a&b', force_xml=True) == 'a&b' DOMAIN_REL_MAP = { "archive.org": "repository", "arxiv.org": "repository", "babel.hathitrust.org": "repository", "cds.cern.ch": "repository", "citeseerx.ist.psu.edu": "repository", "deepblue.lib.umich.edu": "repository", "europepmc.org": "repository", "hal.inria.fr": "repository", "scielo.isciii.es": "repository", "www.dtic.mil": "repository", "www.jstage.jst.go.jp": "repository", "www.jstor.org": "repository", "www.ncbi.nlm.nih.gov": "repository", "www.scielo.br": "repository", "www.scielo.cl": "repository", "www.scielo.org.mx": "repository", "zenodo.org": "repository", "academic.oup.com": "publisher", "cdn.elifesciences.org": "publisher", "cell.com": "publisher", "dl.acm.org": "publisher", "downloads.hindawi.com": "publisher", "elifesciences.org": "publisher", "iopscience.iop.org": "publisher", "journals.plos.org": "publisher", "link.springer.com": "publisher", "onlinelibrary.wiley.com": "publisher", "works.bepress.com": "publisher", "www.biomedcentral.com": "publisher", "www.cell.com": "publisher", "www.nature.com": "publisher", "www.pnas.org": "publisher", "www.tandfonline.com": "publisher", "www.researchgate.net": "social", "academia.edu": "social", "wayback.archive-it.org": "webarchive", "web.archive.org": "webarchive", "archive.is": "webarchive", } def make_rel_url(raw_url, default_link_rel="web"): # this is where we map specific domains to rel types, and also filter out # bad domains, invalid URLs, etc rel = default_link_rel for domain, domain_rel in DOMAIN_REL_MAP.items(): if "//{}/".format(domain) in raw_url: rel = domain_rel break return (rel, raw_url) def test_make_rel_url(): assert make_rel_url("http://example.com/thing.pdf")[0] == "web" assert make_rel_url("http://example.com/thing.pdf", default_link_rel="jeans")[0] == "jeans" assert make_rel_url("https://web.archive.org/web/*/http://example.com/thing.pdf")[0] == "webarchive" assert make_rel_url("http://cell.com/thing.pdf")[0] == "publisher" class EntityImporter: """ Base class for fatcat entity importers. The API exposed to record iterator is: push_record(raw_record) finish() The API that implementations are expected to fill in are: want(raw_record) -> boolean parse(raw_record) -> entity try_update(entity) -> boolean insert_batch([entity]) -> None This class exposes helpers for implementations: self.api self.get_editgroup_id() self.create_(entity) -> EntityEdit for related entity types self.push_entity(entity) self.counts['exists'] += 1 if didn't update or insert because of existing) self.counts['update'] += 1 if updated an entity """ def __init__(self, api, **kwargs): eg_extra = kwargs.get('editgroup_extra', dict()) eg_extra['git_rev'] = eg_extra.get('git_rev', subprocess.check_output(["git", "describe", "--always"]).strip()).decode('utf-8') eg_extra['agent'] = eg_extra.get('agent', 'fatcat_tools.EntityImporter') self.api = api self.bezerk_mode = kwargs.get('bezerk_mode', False) self.edit_batch_size = kwargs.get('edit_batch_size', 100) self.editgroup_description = kwargs.get('editgroup_description') self.editgroup_extra = kwargs.get('editgroup_extra') self.reset() self._issnl_id_map = dict() self._orcid_id_map = dict() self._orcid_regex = re.compile("^\\d{4}-\\d{4}-\\d{4}-\\d{3}[\\dX]$") self._doi_id_map = dict() def reset(self): self.counts = Counter({'skip': 0, 'insert': 0, 'update': 0, 'exists': 0}) self._edit_count = 0 self._editgroup_id = None self._entity_queue = [] def push_record(self, raw_record): """ Returns nothing. """ if (not raw_record) or (not self.want(raw_record)): self.counts['skip'] += 1 return entity = self.parse_record(raw_record) if not entity: self.counts['skip'] += 1 return if self.bezerk_mode: self.push_entity(entity) return if self.try_update(entity): self.push_entity(entity) return def parse_record(self, raw_record): # implementations should fill this in raise NotImplementedError def finish(self): if self._edit_count > 0: self.api.accept_editgroup(self._editgroup_id) self._editgroup_id = None self._edit_count = 0 if self._entity_queue: self.insert_batch(self._entity_queue) self.counts['insert'] += len(self._entity_queue) self._entity_queue = [] self.counts['total'] = 0 for key in ('skip', 'insert', 'update', 'exists'): self.counts['total'] += self.counts[key] return self.counts def get_editgroup_id(self, edits=1): if self._edit_count >= self.edit_batch_size: self.api.accept_editgroup(self._editgroup_id) self._editgroup_id = None self._edit_count = 0 if not self._editgroup_id: eg = self.api.create_editgroup( fatcat_client.Editgroup( description=self.editgroup_description, extra=self.editgroup_extra)) self._editgroup_id = eg.editgroup_id self._edit_count += edits return self._editgroup_id def create_container(self, entity): eg_id = self.get_editgroup_id() self.counts['inserted.container'] += 1 return self.api.create_container(entity, editgroup_id=eg_id) def create_release(self, entity): eg_id = self.get_editgroup_id() self.counts['inserted.release'] += 1 return self.api.create_release(entity, editgroup_id=eg_id) def create_file(self, entity): eg_id = self.get_editgroup_id() self.counts['inserted.file'] += 1 return self.api.create_file(entity, editgroup_id=eg_id) def updated(self): """ Implementations should call this from try_update() if the update was successful """ self.counts['update'] += 1 def push_entity(self, entity): self._entity_queue.append(entity) if len(self._entity_queue) >= self.edit_batch_size: self.insert_batch(self._entity_queue) self.counts['insert'] += len(self._entity_queue) self._entity_queue = [] def want(self, raw_record): """ Implementations can override for optional fast-path to drop a record. Must have no side-effects; returns bool. """ return True def parse(self, raw_record): """ Returns an entity class type, or None if we should skip this one. May have side-effects (eg, create related entities), but shouldn't update/mutate the actual entity. """ raise NotImplementedError def try_update(self, raw_record): """ Passed the output of parse(). Should try to find an existing entity and update it (PUT), decide we should do nothing (based on the existing record), or create a new one. Implementations must update the exists/updated/skip counts appropriately in this method. Returns boolean: True if the entity should still be inserted, False otherwise """ raise NotImplementedError def insert_batch(self, raw_record): raise NotImplementedError def is_orcid(self, orcid): return self._orcid_regex.match(orcid) is not None def lookup_orcid(self, orcid): """Caches calls to the Orcid lookup API endpoint in a local dict""" if not self.is_orcid(orcid): return None if orcid in self._orcid_id_map: return self._orcid_id_map[orcid] creator_id = None try: rv = self.api.lookup_creator(orcid=orcid) creator_id = rv.ident except ApiException as ae: # If anything other than a 404 (not found), something is wrong assert ae.status == 404 self._orcid_id_map[orcid] = creator_id # might be None return creator_id def is_doi(self, doi): return doi.startswith("10.") and doi.count("/") >= 1 def lookup_doi(self, doi): """Caches calls to the doi lookup API endpoint in a local dict""" assert self.is_doi(doi) doi = doi.lower() if doi in self._doi_id_map: return self._doi_id_map[doi] release_id = None try: rv = self.api.lookup_release(doi=doi) release_id = rv.ident except ApiException as ae: # If anything other than a 404 (not found), something is wrong assert ae.status == 404 self._doi_id_map[doi] = release_id # might be None return release_id def is_issnl(self, issnl): return len(issnl) == 9 and issnl[4] == '-' def lookup_issnl(self, issnl): """Caches calls to the ISSN-L lookup API endpoint in a local dict""" if issnl in self._issnl_id_map: return self._issnl_id_map[issnl] container_id = None try: rv = self.api.lookup_container(issnl=issnl) container_id = rv.ident except ApiException as ae: # If anything other than a 404 (not found), something is wrong assert ae.status == 404 self._issnl_id_map[issnl] = container_id # might be None return container_id def read_issn_map_file(self, issn_map_file): print("Loading ISSN map file...") self._issn_issnl_map = dict() for line in issn_map_file: if line.startswith("ISSN") or len(line) == 0: continue (issn, issnl) = line.split()[0:2] self._issn_issnl_map[issn] = issnl # double mapping makes lookups easy self._issn_issnl_map[issnl] = issnl print("Got {} ISSN-L mappings.".format(len(self._issn_issnl_map))) def issn2issnl(self, issn): if issn is None: return None return self._issn_issnl_map.get(issn) class RecordPusher: """ Base class for different importer sources. Pretty trivial interface, just wraps an importer and pushes records in to it. """ def __init__(self, importer, **kwargs): self.importer = importer def run(self): """ This will look something like: for line in sys.stdin: record = json.loads(line) self.importer.push_record(record) print(self.importer.finish()) """ raise NotImplementedError class JsonLinePusher(RecordPusher): def __init__(self, importer, json_file, **kwargs): self.importer = importer self.json_file = json_file def run(self): for line in self.json_file: if not line: continue record = json.loads(line) self.importer.push_record(record) counts = self.importer.finish() print(counts) return counts class CsvPusher(RecordPusher): def __init__(self, importer, csv_file, **kwargs): self.importer = importer self.reader = csv.DictReader(csv_file, delimiter=kwargs.get('delimiter', ',')) def run(self): for line in self.reader: if not line: continue self.importer.push_record(line) counts = self.importer.finish() print(counts) return counts class LinePusher(RecordPusher): def __init__(self, importer, text_file, **kwargs): self.importer = importer self.text_file = text_file def run(self): for line in self.text_file: if not line: continue self.importer.push_record(line) counts = self.importer.finish() print(counts) return counts class SqlitePusher(RecordPusher): def __init__(self, importer, db_file, table_name, where_clause="", **kwargs): self.importer = importer self.db = sqlite3.connect(db_file, isolation_level='EXCLUSIVE') self.db.row_factory = sqlite3.Row self.table_name = table_name self.where_clause = where_clause def run(self): cur = self.db.execute("SELECT * FROM {} {};".format( self.table_name, self.where_clause)) for row in cur: self.importer.push_record(row) counts = self.importer.finish() print(counts) return counts class KafkaJsonPusher(RecordPusher): def __init__(self, importer, kafka_hosts, kafka_env, topic_suffix, group, **kwargs): self.importer = importer self.consumer = make_kafka_consumer( kafka_hosts, kafka_env, topic_suffix, group, ) def run(self): count = 0 for msg in self.consumer: if not msg: continue record = json.loads(msg.value.decode('utf-8')) self.importer.push_record(record) count += 1 if count % 500 == 0: print("Import counts: {}".format(self.importer.counts)) # TODO: should catch UNIX signals (HUP?) to shutdown cleanly, and/or # commit the current batch if it has been lingering counts = self.importer.finish() print(counts) return counts def make_kafka_consumer(hosts, env, topic_suffix, group): topic_name = "fatcat-{}.{}".format(env, topic_suffix).encode('utf-8') client = pykafka.KafkaClient(hosts=hosts, broker_version="1.0.0") consume_topic = client.topics[topic_name] print("Consuming from kafka topic {}, group {}".format(topic_name, group)) consumer = consume_topic.get_balanced_consumer( consumer_group=group.encode('utf-8'), managed=True, auto_commit_enable=True, auto_commit_interval_ms=30000, # 30 seconds compacted_topic=True, ) return consumer