aboutsummaryrefslogtreecommitdiffstats
path: root/python/sandcrawler/pdftrio.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/sandcrawler/pdftrio.py')
-rw-r--r--python/sandcrawler/pdftrio.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/python/sandcrawler/pdftrio.py b/python/sandcrawler/pdftrio.py
new file mode 100644
index 0000000..112df6a
--- /dev/null
+++ b/python/sandcrawler/pdftrio.py
@@ -0,0 +1,142 @@
+import time
+from typing import Any, Dict, Optional
+
+import requests
+
+from .ia import WaybackClient
+from .misc import gen_file_metadata, requests_retry_session
+from .workers import SandcrawlerFetchWorker, SandcrawlerWorker
+
+
+class PdfTrioClient(object):
+ def __init__(self, host_url: str = "http://pdftrio.qa.fatcat.wiki", **kwargs):
+ self.host_url = host_url
+ self.http_session = requests_retry_session(retries=3, backoff_factor=3)
+
+ def classify_pdf(self, blob: bytes, mode: str = "auto") -> Dict[str, Any]:
+ """
+ Returns a dict with at least:
+
+ - status_code (int, always set)
+ - status (success, or error-*)
+
+ On success, the other remote API JSON response keys are also included.
+
+ On HTTP-level failures, the status_code and status field are set
+ appropriately; an optional `error_msg` may also be set. For some other
+ errors, like connection failure, an exception is raised.
+ """
+ assert blob and type(blob) == bytes
+
+ try:
+ pdftrio_response = self.http_session.post(
+ self.host_url + "/classify/research-pub/" + mode,
+ files={
+ "pdf_content": blob,
+ },
+ timeout=60.0,
+ )
+ except requests.Timeout:
+ return {
+ "status": "error-timeout",
+ "status_code": -4, # heritrix3 "HTTP timeout" code
+ "error_msg": "pdftrio request (HTTP POST) timeout",
+ }
+ except requests.exceptions.ConnectionError:
+ # crude back-off
+ time.sleep(2.0)
+ return {
+ "status": "error-connect",
+ "status_code": -2, # heritrix3 "HTTP connect" code
+ "error_msg": "pdftrio request connection timeout",
+ }
+
+ info: Dict[str, Any] = dict(status_code=pdftrio_response.status_code)
+ if pdftrio_response.status_code == 200:
+ resp_json = pdftrio_response.json()
+ assert "ensemble_score" in resp_json
+ assert "status" in resp_json
+ assert "versions" in resp_json
+ info.update(resp_json)
+ else:
+ info["status"] = "error"
+ # TODO: might return JSON with some info?
+
+ info["_total_sec"] = pdftrio_response.elapsed.total_seconds()
+ return info
+
+
+class PdfTrioWorker(SandcrawlerFetchWorker):
+ """
+ This class is basically copied directly from GrobidWorker
+ """
+
+ def __init__(
+ self,
+ pdftrio_client: PdfTrioClient,
+ wayback_client: Optional[WaybackClient] = None,
+ sink: Optional[SandcrawlerWorker] = None,
+ **kwargs
+ ):
+ super().__init__(wayback_client=wayback_client, **kwargs)
+ self.pdftrio_client = pdftrio_client
+ self.sink = sink
+
+ def process(self, record: Any, key: Optional[str] = None) -> Any:
+ start_process = time.time()
+ fetch_sec = None
+
+ start = time.time()
+ fetch_result = self.fetch_blob(record)
+ fetch_sec = time.time() - start
+ if fetch_result["status"] != "success":
+ return fetch_result
+ blob: bytes = fetch_result["blob"]
+ assert blob and isinstance(blob, bytes)
+
+ result = dict()
+ result["file_meta"] = gen_file_metadata(blob)
+ result["key"] = result["file_meta"]["sha1hex"]
+ result["pdf_trio"] = self.pdftrio_client.classify_pdf(blob)
+ result["source"] = record
+ result["timing"] = dict(
+ pdftrio_sec=result["pdf_trio"].pop("_total_sec", None),
+ total_sec=time.time() - start_process,
+ )
+ if fetch_sec:
+ result["timing"]["fetch_sec"] = fetch_sec
+ return result
+
+
+class PdfTrioBlobWorker(SandcrawlerWorker):
+ """
+ This is sort of like PdfTrioWorker, except it receives blobs directly,
+ instead of fetching blobs from some remote store.
+ """
+
+ def __init__(
+ self,
+ pdftrio_client: PdfTrioClient,
+ sink: Optional[SandcrawlerWorker] = None,
+ mode: str = "auto",
+ **kwargs
+ ):
+ super().__init__(**kwargs)
+ self.pdftrio_client = pdftrio_client
+ self.sink = sink
+ self.mode = mode
+
+ def process(self, blob: Any, key: Optional[str] = None) -> Any:
+ start_process = time.time()
+ if not blob:
+ return None
+ assert isinstance(blob, bytes)
+ result = dict()
+ result["file_meta"] = gen_file_metadata(blob)
+ result["key"] = result["file_meta"]["sha1hex"]
+ result["pdf_trio"] = self.pdftrio_client.classify_pdf(blob, mode=self.mode)
+ result["timing"] = dict(
+ pdftrio_sec=result["pdf_trio"].pop("_total_sec", None),
+ total_sec=time.time() - start_process,
+ )
+ return result