aboutsummaryrefslogtreecommitdiffstats
path: root/python/fatcat_tools/importers/common.py
blob: e39ec6c9a79af36a02882de422ccd89218fcf711 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

import re
import sys
import csv
import json
import itertools
import subprocess
from collections import Counter
import pykafka

import fatcat_client
from fatcat_client.rest import ApiException


# from: https://docs.python.org/3/library/itertools.html
def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)

def make_kafka_consumer(hosts, env, topic_suffix, group):
    topic_name = "fatcat-{}.{}".format(env, topic_suffix).encode('utf-8')
    client = pykafka.KafkaClient(hosts=hosts, broker_version="1.0.0")
    consume_topic = client.topics[topic_name]
    print("Consuming from kafka topic {}, group {}".format(topic_name, group))

    consumer = consume_topic.get_balanced_consumer(
        consumer_group=group.encode('utf-8'),
        managed=True,
        auto_commit_enable=True,
        auto_commit_interval_ms=30000, # 30 seconds
        compacted_topic=True,
    )
    return consumer

class FatcatImporter:
    """
    Base class for fatcat importers
    """

    def __init__(self, api, **kwargs):

        eg_extra = kwargs.get('editgroup_extra', dict())
        eg_extra['git_rev'] = eg_extra.get('git_rev',
            subprocess.check_output(["git", "describe", "--always"]).strip()).decode('utf-8')
        eg_extra['agent'] = eg_extra.get('agent', 'fatcat_tools.FatcatImporter')
        
        self.api = api
        self._editgroup_description = kwargs.get('editgroup_description')
        self._editgroup_extra = kwargs.get('editgroup_extra')
        issn_map_file = kwargs.get('issn_map_file')

        self._issnl_id_map = dict()
        self._orcid_id_map = dict()
        self._doi_id_map = dict()
        if issn_map_file:
            self.read_issn_map_file(issn_map_file)
        self._orcid_regex = re.compile("^\\d{4}-\\d{4}-\\d{4}-\\d{3}[\\dX]$")
        self.counts = Counter({'insert': 0, 'update': 0, 'processed_lines': 0})

    def _editgroup(self):
        eg = fatcat_client.Editgroup(
            description=self._editgroup_description,
            extra=self._editgroup_extra,
        )
        return self.api.create_editgroup(eg)

    def describe_run(self):
        print("Processed {} lines, inserted {}, updated {}.".format(
            self.counts['processed_lines'], self.counts['insert'], self.counts['update']))

    def create_row(self, row, editgroup_id=None):
        # sub-classes expected to implement this
        raise NotImplementedError

    def create_batch(self, rows, editgroup_id=None):
        # sub-classes expected to implement this
        raise NotImplementedError

    def process_source(self, source, group_size=100):
        """Creates and auto-accepts editgroup every group_size rows"""
        eg = self._editgroup()
        i = 0
        for i, row in enumerate(source):
            self.create_row(row, editgroup_id=eg.editgroup_id)
            if i > 0 and (i % group_size) == 0:
                self.api.accept_editgroup(eg.editgroup_id)
                eg = self._editgroup()
            self.counts['processed_lines'] += 1
        if i == 0 or (i % group_size) != 0:
            self.api.accept_editgroup(eg.editgroup_id)

    def process_batch(self, source, size=50, decode_kafka=False):
        """Reads and processes in batches (not API-call-per-)"""
        for rows in grouper(source, size):
            if decode_kafka:
                rows = [msg.value.decode('utf-8') for msg in rows]
            self.counts['processed_lines'] += len(rows)
            eg = self._editgroup()
            self.create_batch(rows, editgroup_id=eg.editgroup_id)

    def process_csv_source(self, source, group_size=100, delimiter=','):
        reader = csv.DictReader(source, delimiter=delimiter)
        self.process_source(reader, group_size)

    def process_csv_batch(self, source, size=50, delimiter=','):
        reader = csv.DictReader(source, delimiter=delimiter)
        self.process_batch(reader, size)

    def is_issnl(self, issnl):
        return len(issnl) == 9 and issnl[4] == '-'

    def lookup_issnl(self, issnl):
        """Caches calls to the ISSN-L lookup API endpoint in a local dict"""
        if issnl in self._issnl_id_map:
            return self._issnl_id_map[issnl]
        container_id = None
        try:
            rv = self.api.lookup_container(issnl=issnl)
            container_id = rv.ident
        except ApiException as ae:
            # If anything other than a 404 (not found), something is wrong
            assert ae.status == 404
        self._issnl_id_map[issnl] = container_id # might be None
        return container_id

    def is_orcid(self, orcid):
        return self._orcid_regex.match(orcid) is not None

    def lookup_orcid(self, orcid):
        """Caches calls to the Orcid lookup API endpoint in a local dict"""
        if not self.is_orcid(orcid):
            return None
        if orcid in self._orcid_id_map:
            return self._orcid_id_map[orcid]
        creator_id = None
        try:
            rv = self.api.lookup_creator(orcid=orcid)
            creator_id = rv.ident
        except ApiException as ae:
            # If anything other than a 404 (not found), something is wrong
            assert ae.status == 404
        self._orcid_id_map[orcid] = creator_id # might be None
        return creator_id

    def is_doi(self, doi):
        return doi.startswith("10.") and doi.count("/") >= 1

    def lookup_doi(self, doi):
        """Caches calls to the doi lookup API endpoint in a local dict"""
        assert self.is_doi(doi)
        doi = doi.lower()
        if doi in self._doi_id_map:
            return self._doi_id_map[doi]
        release_id = None
        try:
            rv = self.api.lookup_release(doi=doi)
            release_id = rv.ident
        except ApiException as ae:
            # If anything other than a 404 (not found), something is wrong
            assert ae.status == 404
        self._doi_id_map[doi] = release_id # might be None
        return release_id

    def read_issn_map_file(self, issn_map_file):
        print("Loading ISSN map file...")
        self._issn_issnl_map = dict()
        for line in issn_map_file:
            if line.startswith("ISSN") or len(line) == 0:
                continue
            (issn, issnl) = line.split()[0:2]
            self._issn_issnl_map[issn] = issnl
            # double mapping makes lookups easy
            self._issn_issnl_map[issnl] = issnl
        print("Got {} ISSN-L mappings.".format(len(self._issn_issnl_map)))

    def issn2issnl(self, issn):
        if issn is None:
            return None
        return self._issn_issnl_map.get(issn)