summaryrefslogtreecommitdiffstats
path: root/fatcat_scholar/sandcrawler.py
blob: 207f240a63dbd2bf2fa256c580cdf0cb731666f3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import Any, Dict, List, Optional

import minio
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry  # pylint: disable=import-error


def requests_retry_session(
    retries: int = 2,
    backoff_factor: int = 3,
    status_forcelist: List[int] = [500, 502, 504],
) -> requests.Session:
    """
    From: https://www.peterbe.com/plog/best-practice-with-retries-with-requests
    """
    session = requests.Session()
    retry = Retry(
        total=retries,
        read=retries,
        connect=retries,
        backoff_factor=backoff_factor,
        status_forcelist=status_forcelist,
    )
    adapter = HTTPAdapter(max_retries=retry)
    session.mount("http://", adapter)
    session.mount("https://", adapter)
    return session


class SandcrawlerPostgrestClient:
    def __init__(self, api_url: str):
        self.api_url = api_url
        self.session = requests_retry_session()

    def get_grobid(self, sha1: str) -> Optional[Dict[str, Any]]:
        resp = self.session.get(
            self.api_url + "/grobid", params=dict(sha1hex="eq." + sha1)
        )
        resp.raise_for_status()
        resp_json = resp.json()
        if resp_json:
            return resp_json[0]
        else:
            return None

    def get_pdf_meta(self, sha1: str) -> Optional[Dict[str, Any]]:
        resp = self.session.get(
            self.api_url + "/pdf_meta", params=dict(sha1hex="eq." + sha1)
        )
        resp.raise_for_status()
        resp_json = resp.json()
        if resp_json:
            return resp_json[0]
        else:
            return None

    def get_html_meta(self, sha1: str) -> Optional[Dict[str, Any]]:
        resp = self.session.get(
            self.api_url + "/html_meta", params=dict(sha1hex="eq." + sha1)
        )
        resp.raise_for_status()
        resp_json = resp.json()
        if resp_json:
            return resp_json[0]
        else:
            return None

    def get_crossref_with_refs(self, doi: str) -> Optional[Dict[str, Any]]:
        resp = self.session.get(
            self.api_url + "/crossref_with_refs", params=dict(doi="eq." + doi)
        )
        resp.raise_for_status()
        resp_json = resp.json()
        if resp_json:
            return resp_json[0]
        else:
            return None


class SandcrawlerMinioClient:
    def __init__(
        self,
        host_url: str,
        access_key: Optional[str] = None,
        secret_key: Optional[str] = None,
        default_bucket: Optional[str] = "sandcrawler",
    ):
        """
        host is minio connection string (host:port)
        access and secret key are as expected
        default_bucket can be supplied so that it doesn't need to be repeated for each function call

        Example config:

            host="localhost:9000",
            access_key=os.environ['MINIO_ACCESS_KEY'],
            secret_key=os.environ['MINIO_SECRET_KEY'],
        """
        self.mc = minio.Minio(
            host_url,
            access_key=access_key,
            secret_key=secret_key,
            secure=False,
        )
        self.default_bucket = default_bucket

    def _blob_path(self, folder: str, sha1hex: str, extension: str, prefix: str) -> str:
        if not extension:
            extension = ""
        if not prefix:
            prefix = ""
        assert len(sha1hex) == 40
        obj_path = "{}{}/{}/{}/{}{}".format(
            prefix,
            folder,
            sha1hex[0:2],
            sha1hex[2:4],
            sha1hex,
            extension,
        )
        return obj_path

    def get_blob(
        self,
        folder: str,
        sha1hex: str,
        extension: str = "",
        prefix: str = "",
        bucket: Optional[str] = None,
    ) -> bytes:
        """
        sha1hex is sha1 of the blob itself

        Fetched blob from the given bucket/folder, using the sandcrawler SHA1 path convention
        """
        obj_path = self._blob_path(folder, sha1hex, extension, prefix)
        if not bucket:
            bucket = self.default_bucket
        assert bucket
        blob = self.mc.get_object(
            bucket,
            obj_path,
        )
        # TODO: optionally verify SHA-1?
        return blob.data