summaryrefslogtreecommitdiffstats
path: root/python/fatcat_tools/importers/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/fatcat_tools/importers/common.py')
-rw-r--r--python/fatcat_tools/importers/common.py264
1 files changed, 90 insertions, 174 deletions
diff --git a/python/fatcat_tools/importers/common.py b/python/fatcat_tools/importers/common.py
index 604aa78b..e7fe2305 100644
--- a/python/fatcat_tools/importers/common.py
+++ b/python/fatcat_tools/importers/common.py
@@ -49,9 +49,11 @@ class EntityImporter:
self.api = api
self.bezerk_mode = kwargs.get('bezerk_mode', False)
- self._editgroup_description = kwargs.get('editgroup_description')
- self._editgroup_extra = kwargs.get('editgroup_extra')
- self.counts = Counter({'skip': 0, 'insert': 0, 'update': 0, 'exists': 0, 'sub.container': 0})
+ self.serial_mode = kwargs.get('serial_mode', False)
+ self.edit_batch_size = kwargs.get('edit_batch_size', 100)
+ self.editgroup_description = kwargs.get('editgroup_description')
+ self.editgroup_extra = kwargs.get('editgroup_extra')
+ self.counts = Counter({'skip': 0, 'insert': 0, 'update': 0, 'exists': 0})
self._edit_count = 0
self._editgroup_id = None
self._entity_queue = []
@@ -69,7 +71,9 @@ class EntityImporter:
self.counts['skip'] += 1
return
entity = self.parse_record(raw_record)
- assert entity
+ if not entity:
+ self.counts['skip'] += 1
+ return
if self.bezerk_mode:
self.push_entity(entity)
return
@@ -85,8 +89,8 @@ class EntityImporter:
if self._entity_queue:
self.insert_batch(self._entity_queue)
- self.counts['insert'] += len(_entity_queue)
- self._entity_queue = 0
+ self.counts['insert'] += len(self._entity_queue)
+ self._entity_queue = []
self.counts['total'] = 0
for key in ('skip', 'insert', 'update', 'exists'):
@@ -101,17 +105,28 @@ class EntityImporter:
if not self._editgroup_id:
eg = self.api.create_editgroup(
- description=self._editgroup_description,
- extra=self._editgroup_extra)
+ fatcat_client.Editgroup(
+ description=self.editgroup_description,
+ extra=self.editgroup_extra))
self._editgroup_id = eg.editgroup_id
self._edit_count += edits
return self._editgroup_id
def create_container(self, entity):
- eg = self._get_editgroup()
- self.counts['sub.container'] += 1
- return self.api.create_container(entity, editgroup_id=eg.editgroup_id)
+ eg_id = self._get_editgroup()
+ self.counts['inserted.container'] += 1
+ return self.api.create_container(entity, editgroup_id=eg_id)
+
+ def create_release(self, entity):
+ eg_id = self._get_editgroup()
+ self.counts['inserted.release'] += 1
+ return self.api.create_release(entity, editgroup_id=eg_id)
+
+ def create_file(self, entity):
+ eg_id = self._get_editgroup()
+ self.counts['inserted.file'] += 1
+ return self.api.create_file(entity, editgroup_id=eg_id)
def updated(self):
"""
@@ -135,10 +150,10 @@ class EntityImporter:
def parse(self, raw_record):
"""
- Returns an entity class type. expected to always succeed, or raise an
- exception (`want()` should be used to filter out bad records). May have
- side-effects (eg, create related entities), but shouldn't update/mutate
- the actual entity.
+ Returns an entity class type, or None if we should skip this one.
+
+ May have side-effects (eg, create related entities), but shouldn't
+ update/mutate the actual entity.
"""
raise NotImplementedError
@@ -148,9 +163,10 @@ class EntityImporter:
update it (PUT), decide we should do nothing (based on the existing
record), or create a new one.
- Implementations should update exists/updated counts appropriately.
+ Implementations must update the exists/updated/skip counts
+ appropriately in this method.
- Returns boolean: True if
+ Returns boolean: True if the entity should still be inserted, False otherwise
"""
raise NotImplementedError
@@ -230,7 +246,6 @@ class EntityImporter:
return self._issn_issnl_map.get(issn)
-
class RecordPusher:
"""
Base class for different importer sources. Pretty trivial interface, just
@@ -252,14 +267,14 @@ class RecordPusher:
raise NotImplementedError
-class JsonLinePusher:
+class JsonLinePusher(RecordPusher):
- def __init__(self, importer, in_file, **kwargs):
+ def __init__(self, importer, json_file, **kwargs):
self.importer = importer
- self.in_file = in_file
+ self.json_file = json_file
def run(self):
- for line in self.in_file:
+ for line in self.json_file:
if not line:
continue
record = json.loads(line)
@@ -267,11 +282,59 @@ class JsonLinePusher:
print(self.importer.finish())
-# from: https://docs.python.org/3/library/itertools.html
-def grouper(iterable, n, fillvalue=None):
- "Collect data into fixed-length chunks or blocks"
- args = [iter(iterable)] * n
- return itertools.zip_longest(*args, fillvalue=fillvalue)
+class CsvPusher(RecordPusher):
+
+ def __init__(self, importer, csv_file, **kwargs):
+ self.importer = importer
+ self.reader = csv.DictReader(csv_file, delimiter=kwargs.get('delimiter', ','))
+
+ def run(self):
+ for line in self.reader:
+ if not line:
+ continue
+ self.importer.push_record(line)
+ print(self.importer.finish())
+
+
+class LinePusher(RecordPusher):
+
+ def __init__(self, importer, text_file, **kwargs):
+ self.importer = importer
+ self.text_file = text_file
+
+ def run(self):
+ for line in self.text_file:
+ if not line:
+ continue
+ self.importer.push_record(line)
+ print(self.importer.finish())
+
+
+class KafkaJsonPusher(RecordPusher):
+
+ def __init__(self, importer, kafka_hosts, kafka_env, topic_suffix, group, **kwargs):
+ self.importer = importer
+ self.consumer = make_kafka_consumer(
+ kafka_hosts,
+ kafka_env,
+ topic_suffix,
+ group,
+ )
+
+ def run(self):
+ count = 0
+ for msg in self.consumer:
+ if not msg:
+ continue
+ record = json.loads(msg.value.decode('utf-8'))
+ self.importer.push_record(record)
+ count += 1
+ if count % 500 == 0:
+ print("Import counts: {}".format(self.importer.counts))
+ # TODO: should catch UNIX signals (HUP?) to shutdown cleanly, and/or
+ # commit the current batch if it has been lingering
+ print(self.importer.finish())
+
def make_kafka_consumer(hosts, env, topic_suffix, group):
topic_name = "fatcat-{}.{}".format(env, topic_suffix).encode('utf-8')
@@ -288,150 +351,3 @@ def make_kafka_consumer(hosts, env, topic_suffix, group):
)
return consumer
-class FatcatImporter:
- """
- Base class for fatcat importers
- """
-
- def __init__(self, api, **kwargs):
-
- eg_extra = kwargs.get('editgroup_extra', dict())
- eg_extra['git_rev'] = eg_extra.get('git_rev',
- subprocess.check_output(["git", "describe", "--always"]).strip()).decode('utf-8')
- eg_extra['agent'] = eg_extra.get('agent', 'fatcat_tools.FatcatImporter')
-
- self.api = api
- self._editgroup_description = kwargs.get('editgroup_description')
- self._editgroup_extra = kwargs.get('editgroup_extra')
- issn_map_file = kwargs.get('issn_map_file')
-
- self._issnl_id_map = dict()
- self._orcid_id_map = dict()
- self._doi_id_map = dict()
- if issn_map_file:
- self.read_issn_map_file(issn_map_file)
- self._orcid_regex = re.compile("^\\d{4}-\\d{4}-\\d{4}-\\d{3}[\\dX]$")
- self.counts = Counter({'insert': 0, 'update': 0, 'processed_lines': 0})
-
- def _editgroup(self):
- eg = fatcat_client.Editgroup(
- description=self._editgroup_description,
- extra=self._editgroup_extra,
- )
- return self.api.create_editgroup(eg)
-
- def describe_run(self):
- print("Processed {} lines, inserted {}, updated {}.".format(
- self.counts['processed_lines'], self.counts['insert'], self.counts['update']))
-
- def create_row(self, row, editgroup_id=None):
- # sub-classes expected to implement this
- raise NotImplementedError
-
- def create_batch(self, rows, editgroup_id=None):
- # sub-classes expected to implement this
- raise NotImplementedError
-
- def process_source(self, source, group_size=100):
- """Creates and auto-accepts editgroup every group_size rows"""
- eg = self._editgroup()
- i = 0
- for i, row in enumerate(source):
- self.create_row(row, editgroup_id=eg.editgroup_id)
- if i > 0 and (i % group_size) == 0:
- self.api.accept_editgroup(eg.editgroup_id)
- eg = self._editgroup()
- self.counts['processed_lines'] += 1
- if i == 0 or (i % group_size) != 0:
- self.api.accept_editgroup(eg.editgroup_id)
-
- def process_batch(self, source, size=50, decode_kafka=False):
- """Reads and processes in batches (not API-call-per-)"""
- for rows in grouper(source, size):
- if decode_kafka:
- rows = [msg.value.decode('utf-8') for msg in rows]
- self.counts['processed_lines'] += len(rows)
- #eg = self._editgroup()
- #self.create_batch(rows, editgroup_id=eg.editgroup_id)
- self.create_batch(rows)
-
- def process_csv_source(self, source, group_size=100, delimiter=','):
- reader = csv.DictReader(source, delimiter=delimiter)
- self.process_source(reader, group_size)
-
- def process_csv_batch(self, source, size=50, delimiter=','):
- reader = csv.DictReader(source, delimiter=delimiter)
- self.process_batch(reader, size)
-
- def is_issnl(self, issnl):
- return len(issnl) == 9 and issnl[4] == '-'
-
- def lookup_issnl(self, issnl):
- """Caches calls to the ISSN-L lookup API endpoint in a local dict"""
- if issnl in self._issnl_id_map:
- return self._issnl_id_map[issnl]
- container_id = None
- try:
- rv = self.api.lookup_container(issnl=issnl)
- container_id = rv.ident
- except ApiException as ae:
- # If anything other than a 404 (not found), something is wrong
- assert ae.status == 404
- self._issnl_id_map[issnl] = container_id # might be None
- return container_id
-
- def is_orcid(self, orcid):
- return self._orcid_regex.match(orcid) is not None
-
- def lookup_orcid(self, orcid):
- """Caches calls to the Orcid lookup API endpoint in a local dict"""
- if not self.is_orcid(orcid):
- return None
- if orcid in self._orcid_id_map:
- return self._orcid_id_map[orcid]
- creator_id = None
- try:
- rv = self.api.lookup_creator(orcid=orcid)
- creator_id = rv.ident
- except ApiException as ae:
- # If anything other than a 404 (not found), something is wrong
- assert ae.status == 404
- self._orcid_id_map[orcid] = creator_id # might be None
- return creator_id
-
- def is_doi(self, doi):
- return doi.startswith("10.") and doi.count("/") >= 1
-
- def lookup_doi(self, doi):
- """Caches calls to the doi lookup API endpoint in a local dict"""
- assert self.is_doi(doi)
- doi = doi.lower()
- if doi in self._doi_id_map:
- return self._doi_id_map[doi]
- release_id = None
- try:
- rv = self.api.lookup_release(doi=doi)
- release_id = rv.ident
- except ApiException as ae:
- # If anything other than a 404 (not found), something is wrong
- assert ae.status == 404
- self._doi_id_map[doi] = release_id # might be None
- return release_id
-
- def read_issn_map_file(self, issn_map_file):
- print("Loading ISSN map file...")
- self._issn_issnl_map = dict()
- for line in issn_map_file:
- if line.startswith("ISSN") or len(line) == 0:
- continue
- (issn, issnl) = line.split()[0:2]
- self._issn_issnl_map[issn] = issnl
- # double mapping makes lookups easy
- self._issn_issnl_map[issnl] = issnl
- print("Got {} ISSN-L mappings.".format(len(self._issn_issnl_map)))
-
- def issn2issnl(self, issn):
- if issn is None:
- return None
- return self._issn_issnl_map.get(issn)
-