summaryrefslogtreecommitdiffstats
path: root/python/fatcat_tools/importers/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/fatcat_tools/importers/common.py')
-rw-r--r--python/fatcat_tools/importers/common.py390
1 files changed, 303 insertions, 87 deletions
diff --git a/python/fatcat_tools/importers/common.py b/python/fatcat_tools/importers/common.py
index 06897bee..89203a4f 100644
--- a/python/fatcat_tools/importers/common.py
+++ b/python/fatcat_tools/importers/common.py
@@ -3,6 +3,7 @@ import re
import sys
import csv
import json
+import ftfy
import itertools
import subprocess
from collections import Counter
@@ -12,30 +13,66 @@ import fatcat_client
from fatcat_client.rest import ApiException
-# from: https://docs.python.org/3/library/itertools.html
-def grouper(iterable, n, fillvalue=None):
- "Collect data into fixed-length chunks or blocks"
- args = [iter(iterable)] * n
- return itertools.zip_longest(*args, fillvalue=fillvalue)
+def clean(thing, force_xml=False):
+ """
+ This function is appropriate to be called on any random, non-markup string,
+ such as author names, titles, etc.
-def make_kafka_consumer(hosts, env, topic_suffix, group):
- topic_name = "fatcat-{}.{}".format(env, topic_suffix).encode('utf-8')
- client = pykafka.KafkaClient(hosts=hosts, broker_version="1.0.0")
- consume_topic = client.topics[topic_name]
- print("Consuming from kafka topic {}, group {}".format(topic_name, group))
+ It will try to clean up commong unicode mangles, HTML characters, etc.
- consumer = consume_topic.get_balanced_consumer(
- consumer_group=group.encode('utf-8'),
- managed=True,
- auto_commit_enable=True,
- auto_commit_interval_ms=30000, # 30 seconds
- compacted_topic=True,
- )
- return consumer
+ This will detect XML/HTML and "do the right thing" (aka, not remove
+ entities like '&amp' if there are tags in the string), unless you pass the
+ 'force_xml' parameter, which might be appropriate for, eg, names and
+ titles, which generally should be projected down to plain text.
+
+ Also strips extra whitespace.
+ """
+ if not thing:
+ return thing
+ fix_entities = 'auto'
+ if force_xml:
+ fix_entities = True
+ fixed = ftfy.fix_text(thing, fix_entities=fix_entities).strip()
+ if not fixed:
+ # wasn't zero-length before, but is now; return None
+ return None
+ return fixed
+
+def test_clean():
-class FatcatImporter:
+ assert clean(None) == None
+ assert clean('') == ''
+ assert clean('123') == '123'
+ assert clean('a&b') == 'a&b'
+ assert clean('<b>a&amp;b</b>') == '<b>a&amp;b</b>'
+ assert clean('<b>a&amp;b</b>', force_xml=True) == '<b>a&b</b>'
+
+class EntityImporter:
"""
- Base class for fatcat importers
+ Base class for fatcat entity importers.
+
+ The API exposed to record iterator is:
+
+ push_record(raw_record)
+ finish()
+
+ The API that implementations are expected to fill in are:
+
+ want(raw_record) -> boolean
+ parse(raw_record) -> entity
+ try_update(entity) -> boolean
+ insert_batch([entity]) -> None
+
+ This class exposes helpers for implementations:
+
+ self.api
+ self.create_<entity>(entity) -> EntityEdit
+ for related entity types
+ self.push_entity(entity)
+ self.counts['exists'] += 1
+ if didn't update or insert because of existing)
+ self.counts['update'] += 1
+ if updated an entity
"""
def __init__(self, api, **kwargs):
@@ -43,87 +80,135 @@ class FatcatImporter:
eg_extra = kwargs.get('editgroup_extra', dict())
eg_extra['git_rev'] = eg_extra.get('git_rev',
subprocess.check_output(["git", "describe", "--always"]).strip()).decode('utf-8')
- eg_extra['agent'] = eg_extra.get('agent', 'fatcat_tools.FatcatImporter')
+ eg_extra['agent'] = eg_extra.get('agent', 'fatcat_tools.EntityImporter')
self.api = api
- self._editgroup_description = kwargs.get('editgroup_description')
- self._editgroup_extra = kwargs.get('editgroup_extra')
- issn_map_file = kwargs.get('issn_map_file')
+ self.bezerk_mode = kwargs.get('bezerk_mode', False)
+ self.edit_batch_size = kwargs.get('edit_batch_size', 100)
+ self.editgroup_description = kwargs.get('editgroup_description')
+ self.editgroup_extra = kwargs.get('editgroup_extra')
+ self.reset()
self._issnl_id_map = dict()
self._orcid_id_map = dict()
- self._doi_id_map = dict()
- if issn_map_file:
- self.read_issn_map_file(issn_map_file)
self._orcid_regex = re.compile("^\\d{4}-\\d{4}-\\d{4}-\\d{3}[\\dX]$")
- self.counts = Counter({'insert': 0, 'update': 0, 'processed_lines': 0})
+ self._doi_id_map = dict()
- def _editgroup(self):
- eg = fatcat_client.Editgroup(
- description=self._editgroup_description,
- extra=self._editgroup_extra,
- )
- return self.api.create_editgroup(eg)
+ def reset(self):
+ self.counts = Counter({'skip': 0, 'insert': 0, 'update': 0, 'exists': 0})
+ self._edit_count = 0
+ self._editgroup_id = None
+ self._entity_queue = []
- def describe_run(self):
- print("Processed {} lines, inserted {}, updated {}.".format(
- self.counts['processed_lines'], self.counts['insert'], self.counts['update']))
+ def push_record(self, raw_record):
+ """
+ Returns nothing.
+ """
+ if (not raw_record) or (not self.want(raw_record)):
+ self.counts['skip'] += 1
+ return
+ entity = self.parse_record(raw_record)
+ if not entity:
+ self.counts['skip'] += 1
+ return
+ if self.bezerk_mode:
+ self.push_entity(entity)
+ return
+ if self.try_update(entity):
+ self.push_entity(entity)
+ return
- def create_row(self, row, editgroup_id=None):
- # sub-classes expected to implement this
- raise NotImplementedError
+ def finish(self):
+ if self._edit_count > 0:
+ self.api.accept_editgroup(self._editgroup_id)
+ self._editgroup_id = None
+ self._edit_count = 0
+
+ if self._entity_queue:
+ self.insert_batch(self._entity_queue)
+ self.counts['insert'] += len(self._entity_queue)
+ self._entity_queue = []
+
+ self.counts['total'] = 0
+ for key in ('skip', 'insert', 'update', 'exists'):
+ self.counts['total'] += self.counts[key]
+ return self.counts
+
+ def _get_editgroup(self, edits=1):
+ if self._edit_count >= self.edit_batch_size:
+ self.api.accept_editgroup(self._editgroup_id)
+ self._editgroup_id = None
+ self._edit_count = 0
- def create_batch(self, rows, editgroup_id=None):
- # sub-classes expected to implement this
+ if not self._editgroup_id:
+ eg = self.api.create_editgroup(
+ fatcat_client.Editgroup(
+ description=self.editgroup_description,
+ extra=self.editgroup_extra))
+ self._editgroup_id = eg.editgroup_id
+
+ self._edit_count += edits
+ return self._editgroup_id
+
+ def create_container(self, entity):
+ eg_id = self._get_editgroup()
+ self.counts['inserted.container'] += 1
+ return self.api.create_container(entity, editgroup_id=eg_id)
+
+ def create_release(self, entity):
+ eg_id = self._get_editgroup()
+ self.counts['inserted.release'] += 1
+ return self.api.create_release(entity, editgroup_id=eg_id)
+
+ def create_file(self, entity):
+ eg_id = self._get_editgroup()
+ self.counts['inserted.file'] += 1
+ return self.api.create_file(entity, editgroup_id=eg_id)
+
+ def updated(self):
+ """
+ Implementations should call this from try_update() if the update was successful
+ """
+ self.counts['update'] += 1
+
+ def push_entity(self, entity):
+ self._entity_queue.append(entity)
+ if len(self._entity_queue) >= self.edit_batch_size:
+ self.insert_batch(self._entity_queue)
+ self.counts['insert'] += len(_entity_queue)
+ self._entity_queue = 0
+
+ def want(self, raw_record):
+ """
+ Implementations can override for optional fast-path to drop a record.
+ Must have no side-effects; returns bool.
+ """
+ return True
+
+ def parse(self, raw_record):
+ """
+ Returns an entity class type, or None if we should skip this one.
+
+ May have side-effects (eg, create related entities), but shouldn't
+ update/mutate the actual entity.
+ """
raise NotImplementedError
- def process_source(self, source, group_size=100):
- """Creates and auto-accepts editgroup every group_size rows"""
- eg = self._editgroup()
- i = 0
- for i, row in enumerate(source):
- self.create_row(row, editgroup_id=eg.editgroup_id)
- if i > 0 and (i % group_size) == 0:
- self.api.accept_editgroup(eg.editgroup_id)
- eg = self._editgroup()
- self.counts['processed_lines'] += 1
- if i == 0 or (i % group_size) != 0:
- self.api.accept_editgroup(eg.editgroup_id)
-
- def process_batch(self, source, size=50, decode_kafka=False):
- """Reads and processes in batches (not API-call-per-)"""
- for rows in grouper(source, size):
- if decode_kafka:
- rows = [msg.value.decode('utf-8') for msg in rows]
- self.counts['processed_lines'] += len(rows)
- #eg = self._editgroup()
- #self.create_batch(rows, editgroup_id=eg.editgroup_id)
- self.create_batch(rows)
-
- def process_csv_source(self, source, group_size=100, delimiter=','):
- reader = csv.DictReader(source, delimiter=delimiter)
- self.process_source(reader, group_size)
-
- def process_csv_batch(self, source, size=50, delimiter=','):
- reader = csv.DictReader(source, delimiter=delimiter)
- self.process_batch(reader, size)
+ def try_update(self, raw_record):
+ """
+ Passed the output of parse(). Should try to find an existing entity and
+ update it (PUT), decide we should do nothing (based on the existing
+ record), or create a new one.
- def is_issnl(self, issnl):
- return len(issnl) == 9 and issnl[4] == '-'
+ Implementations must update the exists/updated/skip counts
+ appropriately in this method.
- def lookup_issnl(self, issnl):
- """Caches calls to the ISSN-L lookup API endpoint in a local dict"""
- if issnl in self._issnl_id_map:
- return self._issnl_id_map[issnl]
- container_id = None
- try:
- rv = self.api.lookup_container(issnl=issnl)
- container_id = rv.ident
- except ApiException as ae:
- # If anything other than a 404 (not found), something is wrong
- assert ae.status == 404
- self._issnl_id_map[issnl] = container_id # might be None
- return container_id
+ Returns boolean: True if the entity should still be inserted, False otherwise
+ """
+ raise NotImplementedError
+
+ def insert_batch(self, raw_record):
+ raise NotImplementedError
def is_orcid(self, orcid):
return self._orcid_regex.match(orcid) is not None
@@ -163,6 +248,23 @@ class FatcatImporter:
self._doi_id_map[doi] = release_id # might be None
return release_id
+ def is_issnl(self, issnl):
+ return len(issnl) == 9 and issnl[4] == '-'
+
+ def lookup_issnl(self, issnl):
+ """Caches calls to the ISSN-L lookup API endpoint in a local dict"""
+ if issnl in self._issnl_id_map:
+ return self._issnl_id_map[issnl]
+ container_id = None
+ try:
+ rv = self.api.lookup_container(issnl=issnl)
+ container_id = rv.ident
+ except ApiException as ae:
+ # If anything other than a 404 (not found), something is wrong
+ assert ae.status == 404
+ self._issnl_id_map[issnl] = container_id # might be None
+ return container_id
+
def read_issn_map_file(self, issn_map_file):
print("Loading ISSN map file...")
self._issn_issnl_map = dict()
@@ -179,3 +281,117 @@ class FatcatImporter:
if issn is None:
return None
return self._issn_issnl_map.get(issn)
+
+
+class RecordPusher:
+ """
+ Base class for different importer sources. Pretty trivial interface, just
+ wraps an importer and pushes records in to it.
+ """
+
+ def __init__(self, importer, **kwargs):
+ self.importer = importer
+
+ def run(self):
+ """
+ This will look something like:
+
+ for line in sys.stdin:
+ record = json.loads(line)
+ self.importer.push_record(record)
+ print(self.importer.finish())
+ """
+ raise NotImplementedError
+
+
+class JsonLinePusher(RecordPusher):
+
+ def __init__(self, importer, json_file, **kwargs):
+ self.importer = importer
+ self.json_file = json_file
+
+ def run(self):
+ for line in self.json_file:
+ if not line:
+ continue
+ record = json.loads(line)
+ self.importer.push_record(record)
+ counts = self.importer.finish()
+ print(counts)
+ return counts
+
+
+class CsvPusher(RecordPusher):
+
+ def __init__(self, importer, csv_file, **kwargs):
+ self.importer = importer
+ self.reader = csv.DictReader(csv_file, delimiter=kwargs.get('delimiter', ','))
+
+ def run(self):
+ for line in self.reader:
+ if not line:
+ continue
+ self.importer.push_record(line)
+ counts = self.importer.finish()
+ print(counts)
+ return counts
+
+
+class LinePusher(RecordPusher):
+
+ def __init__(self, importer, text_file, **kwargs):
+ self.importer = importer
+ self.text_file = text_file
+
+ def run(self):
+ for line in self.text_file:
+ if not line:
+ continue
+ self.importer.push_record(line)
+ counts = self.importer.finish()
+ print(counts)
+ return counts
+
+
+class KafkaJsonPusher(RecordPusher):
+
+ def __init__(self, importer, kafka_hosts, kafka_env, topic_suffix, group, **kwargs):
+ self.importer = importer
+ self.consumer = make_kafka_consumer(
+ kafka_hosts,
+ kafka_env,
+ topic_suffix,
+ group,
+ )
+
+ def run(self):
+ count = 0
+ for msg in self.consumer:
+ if not msg:
+ continue
+ record = json.loads(msg.value.decode('utf-8'))
+ self.importer.push_record(record)
+ count += 1
+ if count % 500 == 0:
+ print("Import counts: {}".format(self.importer.counts))
+ # TODO: should catch UNIX signals (HUP?) to shutdown cleanly, and/or
+ # commit the current batch if it has been lingering
+ counts = self.importer.finish()
+ print(counts)
+ return counts
+
+
+def make_kafka_consumer(hosts, env, topic_suffix, group):
+ topic_name = "fatcat-{}.{}".format(env, topic_suffix).encode('utf-8')
+ client = pykafka.KafkaClient(hosts=hosts, broker_version="1.0.0")
+ consume_topic = client.topics[topic_name]
+ print("Consuming from kafka topic {}, group {}".format(topic_name, group))
+
+ consumer = consume_topic.get_balanced_consumer(
+ consumer_group=group.encode('utf-8'),
+ managed=True,
+ auto_commit_enable=True,
+ auto_commit_interval_ms=30000, # 30 seconds
+ compacted_topic=True,
+ )
+ return consumer