From 8cf887240bff6c0ccbffa9f9003f90bfa2c94b4f Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Tue, 26 Oct 2021 17:29:18 -0700 Subject: grobid: type annotations --- python/sandcrawler/grobid.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) (limited to 'python/sandcrawler') diff --git a/python/sandcrawler/grobid.py b/python/sandcrawler/grobid.py index d0b7f7e..a44861a 100644 --- a/python/sandcrawler/grobid.py +++ b/python/sandcrawler/grobid.py @@ -1,17 +1,20 @@ +from typing import Any, Dict, Optional + import requests from grobid2json import teixml2json +from .ia import WaybackClient from .misc import gen_file_metadata from .workers import SandcrawlerFetchWorker, SandcrawlerWorker class GrobidClient(object): - def __init__(self, host_url="http://grobid.qa.fatcat.wiki", **kwargs): + def __init__(self, host_url: str = "http://grobid.qa.fatcat.wiki", **kwargs): self.host_url = host_url self.consolidate_mode = int(kwargs.get('consolidate_mode', 0)) - def process_fulltext(self, blob, consolidate_mode=None): + def process_fulltext(self, blob: bytes, consolidate_mode: Optional[int] = None) -> Dict[str, Any]: """ Returns dict with keys: - status_code @@ -44,7 +47,7 @@ class GrobidClient(object): 'error_msg': 'GROBID request (HTTP POST) timeout', } - info = dict(status_code=grobid_response.status_code, ) + info: Dict[str, Any] = dict(status_code=grobid_response.status_code) if grobid_response.status_code == 200: info['status'] = 'success' info['tei_xml'] = grobid_response.text @@ -61,7 +64,7 @@ class GrobidClient(object): info['error_msg'] = grobid_response.text[:10000] return info - def metadata(self, result): + def metadata(self, result: Dict[str, Any]) -> Optional[Dict[str, Any]]: if result['status'] != 'success': return None tei_json = teixml2json(result['tei_xml'], encumbered=False) @@ -84,13 +87,17 @@ class GrobidClient(object): class GrobidWorker(SandcrawlerFetchWorker): - def __init__(self, grobid_client, wayback_client=None, sink=None, **kwargs): + def __init__(self, + grobid_client: GrobidClient, + wayback_client: Optional[WaybackClient] = None, + sink: Optional[SandcrawlerWorker] = None, + **kwargs): super().__init__(wayback_client=wayback_client) self.grobid_client = grobid_client self.sink = sink self.consolidate_mode = 0 - def timeout_response(self, task): + def timeout_response(self, task: Any) -> Any: default_key = task['sha1hex'] return dict( status="error-timeout", @@ -99,7 +106,7 @@ class GrobidWorker(SandcrawlerFetchWorker): key=default_key, ) - def process(self, record, key=None): + def process(self, record: Any, key: Optional[str] = None) -> Any: fetch_result = self.fetch_blob(record) if fetch_result['status'] != 'success': return fetch_result @@ -118,13 +125,16 @@ class GrobidBlobWorker(SandcrawlerWorker): This is sort of like GrobidWorker, except it receives blobs directly, instead of fetching blobs from some remote store. """ - def __init__(self, grobid_client, sink=None, **kwargs): + def __init__(self, + grobid_client: GrobidClient, + sink: Optional[SandcrawlerWorker] = None, + **kwargs): super().__init__() self.grobid_client = grobid_client self.sink = sink self.consolidate_mode = 0 - def process(self, blob, key=None): + def process(self, blob: Any, key: Optional[str] = None) -> Any: if not blob: return None result = self.grobid_client.process_fulltext(blob, -- cgit v1.2.3