From 1088205653eb22f285dd34f2e058216ecfd2abed Mon Sep 17 00:00:00 2001
From: Bryan Newbold <bnewbold@archive.org>
Date: Tue, 26 Oct 2021 17:59:10 -0700
Subject: type annotations for persist workers; required some work

Had to re-structure and filter things a bit, Should be better behavior,
but might be some small changes.
---
 python/sandcrawler/persist.py | 125 ++++++++++++++++++++----------------------
 1 file changed, 59 insertions(+), 66 deletions(-)

(limited to 'python')

diff --git a/python/sandcrawler/persist.py b/python/sandcrawler/persist.py
index bb76e54..d47a8cb 100644
--- a/python/sandcrawler/persist.py
+++ b/python/sandcrawler/persist.py
@@ -20,7 +20,7 @@ grobid
 
 import os
 import xml.etree.ElementTree
-from typing import AnyStr, Optional
+from typing import Any, Dict, List, Optional
 
 from sandcrawler.db import SandcrawlerPostgresClient
 from sandcrawler.grobid import GrobidClient
@@ -31,18 +31,16 @@ from sandcrawler.workers import SandcrawlerWorker
 
 
 class PersistCdxWorker(SandcrawlerWorker):
-    def __init__(self, db_url, **kwargs):
+    def __init__(self, db_url: str, **kwargs):
         super().__init__()
         self.db = SandcrawlerPostgresClient(db_url)
         self.cur = self.db.conn.cursor()
 
-    def process(self, record, key=None):
-        """
-        Only do batches (as transactions)
-        """
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
+        """Only do batches (as transactions)"""
         raise NotImplementedError
 
-    def push_batch(self, batch):
+    def push_batch(self, batch: list) -> list:
         self.counts['total'] += len(batch)
         # filter to full CDX lines, no liveweb
         cdx_batch = [r for r in batch if r.get('warc_path') and ("/" in r['warc_path'])]
@@ -56,18 +54,16 @@ class PersistCdxWorker(SandcrawlerWorker):
 
 
 class PersistIngestFileResultWorker(SandcrawlerWorker):
-    def __init__(self, db_url, **kwargs):
+    def __init__(self, db_url: str, **kwargs):
         super().__init__()
         self.db = SandcrawlerPostgresClient(db_url)
         self.cur = self.db.conn.cursor()
 
-    def process(self, record, key=None):
-        """
-        Only do batches (as transactions)
-        """
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
+        """Only do batches (as transactions)"""
         raise NotImplementedError
 
-    def request_to_row(self, raw):
+    def request_to_row(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]:
         """
         Converts ingest-request JSON schema (eg, from Kafka) to SQL ingest_request schema
 
@@ -222,20 +218,24 @@ class PersistIngestFileResultWorker(SandcrawlerWorker):
                                                                    {}).get('terminal_dt')
         return result
 
-    def push_batch(self, batch):
+    def push_batch(self, batch: List[Any]) -> List[Any]:
         self.counts['total'] += len(batch)
 
         if not batch:
             return []
 
-        results = [self.file_result_to_row(raw) for raw in batch]
-        results = [r for r in results if r]
+        results_unfiltered = [self.file_result_to_row(raw) for raw in batch]
+        results = [r for r in results_unfiltered if r]
 
-        requests = [self.request_to_row(raw['request']) for raw in batch if raw.get('request')]
-        requests = [r for r in requests if r and r['ingest_type'] != 'dataset-file']
+        irequests_unfiltered = [
+            self.request_to_row(raw['request']) for raw in batch if raw.get('request')
+        ]
+        irequests = [
+            r for r in irequests_unfiltered if r and r['ingest_type'] != 'dataset-file'
+        ]
 
-        if requests:
-            resp = self.db.insert_ingest_request(self.cur, requests)
+        if irequests:
+            resp = self.db.insert_ingest_request(self.cur, irequests)
             self.counts['insert-requests'] += resp[0]
             self.counts['update-requests'] += resp[1]
         if results:
@@ -266,16 +266,16 @@ class PersistIngestFileResultWorker(SandcrawlerWorker):
             self.result_to_html_meta(r) for r in batch if r.get('hit') and r.get('html_body')
         ]
         if html_meta_batch:
-            rows = [d.to_sql_tuple() for d in html_meta_batch]
+            rows = [d.to_sql_tuple() for d in html_meta_batch if d]
             resp = self.db.insert_html_meta(self.cur, rows, on_conflict="update")
             self.counts['insert-html_meta'] += resp[0]
             self.counts['update-html_meta'] += resp[1]
 
-        fileset_platform_batch = [
+        fileset_platform_batch_all = [
             self.result_to_platform_row(raw) for raw in batch if
             raw.get('request', {}).get('ingest_type') == 'dataset' and raw.get('platform_name')
         ]
-        fileset_platform_batch = [p for p in fileset_platform_batch if p]
+        fileset_platform_batch: List[Dict] = [p for p in fileset_platform_batch_all if p]
         if fileset_platform_batch:
             resp = self.db.insert_ingest_fileset_platform(self.cur,
                                                           fileset_platform_batch,
@@ -288,39 +288,35 @@ class PersistIngestFileResultWorker(SandcrawlerWorker):
 
 
 class PersistIngestFilesetWorker(SandcrawlerWorker):
-    def __init__(self, db_url, **kwargs):
+    def __init__(self, db_url: str, **kwargs):
         super().__init__()
         self.db = SandcrawlerPostgresClient(db_url)
         self.cur = self.db.conn.cursor()
 
-    def process(self, record, key=None):
-        """
-        Only do batches (as transactions)
-        """
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
+        """Only do batches (as transactions)"""
         raise NotImplementedError
 
 
 class PersistIngestRequestWorker(PersistIngestFileResultWorker):
-    def __init__(self, db_url, **kwargs):
+    def __init__(self, db_url: str, **kwargs):
         super().__init__(db_url=db_url)
 
-    def process(self, record, key=None):
-        """
-        Only do batches (as transactions)
-        """
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
+        """Only do batches (as transactions)"""
         raise NotImplementedError
 
-    def push_batch(self, batch):
+    def push_batch(self, batch: list) -> list:
         self.counts['total'] += len(batch)
 
         if not batch:
             return []
 
-        requests = [self.request_to_row(raw) for raw in batch]
-        requests = [r for r in requests if r]
+        irequests_all = [self.request_to_row(raw) for raw in batch]
+        irequests: List[Dict] = [r for r in irequests_all if r]
 
-        if requests:
-            resp = self.db.insert_ingest_request(self.cur, requests)
+        if irequests:
+            resp = self.db.insert_ingest_request(self.cur, irequests)
             self.counts['insert-requests'] += resp[0]
             self.counts['update-requests'] += resp[1]
 
@@ -329,7 +325,7 @@ class PersistIngestRequestWorker(PersistIngestFileResultWorker):
 
 
 class PersistGrobidWorker(SandcrawlerWorker):
-    def __init__(self, db_url, **kwargs):
+    def __init__(self, db_url: str, **kwargs):
         super().__init__()
         self.grobid = GrobidClient()
         self.s3 = SandcrawlerMinioClient(
@@ -342,19 +338,17 @@ class PersistGrobidWorker(SandcrawlerWorker):
         self.db_only = kwargs.get('db_only', False)
         assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed"
         if not self.s3_only:
-            self.db = SandcrawlerPostgresClient(db_url)
+            self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url)
             self.cur = self.db.conn.cursor()
         else:
             self.db = None
             self.cur = None
 
-    def process(self, record, key=None):
-        """
-        Only do batches (as transactions)
-        """
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
+        """Only do batches (as transactions)"""
         raise NotImplementedError
 
-    def push_batch(self, batch):
+    def push_batch(self, batch: list) -> list:
         self.counts['total'] += len(batch)
 
         # filter out bad "missing status_code" timeout rows
@@ -372,7 +366,7 @@ class PersistGrobidWorker(SandcrawlerWorker):
 
             assert len(r['key']) == 40
             if not self.db_only:
-                resp = self.s3.put_blob(
+                self.s3.put_blob(
                     folder="grobid",
                     blob=r['tei_xml'],
                     sha1hex=r['key'],
@@ -398,6 +392,7 @@ class PersistGrobidWorker(SandcrawlerWorker):
             r['metadata'] = metadata
 
         if not self.s3_only:
+            assert self.db and self.cur
             resp = self.db.insert_grobid(self.cur, batch, on_conflict="update")
             self.counts['insert-grobid'] += resp[0]
             self.counts['update-grobid'] += resp[1]
@@ -418,11 +413,11 @@ class PersistGrobidDiskWorker(SandcrawlerWorker):
 
     This could be refactored into a "Sink" type with an even thinner wrapper.
     """
-    def __init__(self, output_dir):
+    def __init__(self, output_dir: str):
         super().__init__()
         self.output_dir = output_dir
 
-    def _blob_path(self, sha1hex, extension=".tei.xml"):
+    def _blob_path(self, sha1hex: str, extension: str = ".tei.xml") -> str:
         obj_path = "{}/{}/{}{}".format(
             sha1hex[0:2],
             sha1hex[2:4],
@@ -431,7 +426,7 @@ class PersistGrobidDiskWorker(SandcrawlerWorker):
         )
         return obj_path
 
-    def process(self, record, key=None):
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
 
         if record.get('status_code') != 200 or not record.get('tei_xml'):
             return False
@@ -445,18 +440,16 @@ class PersistGrobidDiskWorker(SandcrawlerWorker):
 
 
 class PersistPdfTrioWorker(SandcrawlerWorker):
-    def __init__(self, db_url, **kwargs):
+    def __init__(self, db_url: str, **kwargs):
         super().__init__()
         self.db = SandcrawlerPostgresClient(db_url)
         self.cur = self.db.conn.cursor()
 
-    def process(self, record, key=None):
-        """
-        Only do batches (as transactions)
-        """
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
+        """Only do batches (as transactions)"""
         raise NotImplementedError
 
-    def push_batch(self, batch):
+    def push_batch(self, batch: list) -> list:
         self.counts['total'] += len(batch)
 
         batch = [r for r in batch if 'pdf_trio' in r and r['pdf_trio'].get('status_code')]
@@ -486,7 +479,7 @@ class PersistPdfTextWorker(SandcrawlerWorker):
 
     Should keep batch sizes small.
     """
-    def __init__(self, db_url, **kwargs):
+    def __init__(self, db_url: str, **kwargs):
         super().__init__()
         self.s3 = SandcrawlerMinioClient(
             host_url=kwargs.get('s3_url', 'localhost:9000'),
@@ -498,19 +491,17 @@ class PersistPdfTextWorker(SandcrawlerWorker):
         self.db_only = kwargs.get('db_only', False)
         assert not (self.s3_only and self.db_only), "Only one of s3_only and db_only allowed"
         if not self.s3_only:
-            self.db = SandcrawlerPostgresClient(db_url)
+            self.db: Optional[SandcrawlerPostgresClient] = SandcrawlerPostgresClient(db_url)
             self.cur = self.db.conn.cursor()
         else:
             self.db = None
             self.cur = None
 
-    def process(self, record, key=None):
-        """
-        Only do batches (as transactions)
-        """
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
+        """Only do batches (as transactions)"""
         raise NotImplementedError
 
-    def push_batch(self, batch):
+    def push_batch(self, batch: list) -> list:
         self.counts['total'] += len(batch)
 
         parsed_batch = []
@@ -526,7 +517,7 @@ class PersistPdfTextWorker(SandcrawlerWorker):
 
             assert len(r.sha1hex) == 40
             if not self.db_only:
-                resp = self.s3.put_blob(
+                self.s3.put_blob(
                     folder="text",
                     blob=r.text,
                     sha1hex=r.sha1hex,
@@ -535,6 +526,7 @@ class PersistPdfTextWorker(SandcrawlerWorker):
                 self.counts['s3-put'] += 1
 
         if not self.s3_only:
+            assert self.db and self.cur
             rows = [r.to_sql_tuple() for r in parsed_batch]
             resp = self.db.insert_pdf_meta(self.cur, rows, on_conflict="update")
             self.counts['insert-pdf-meta'] += resp[0]
@@ -569,15 +561,16 @@ class PersistThumbnailWorker(SandcrawlerWorker):
         self.s3_extension = kwargs.get('s3_extension', ".jpg")
         self.s3_folder = kwargs.get('s3_folder', "pdf")
 
-    def process(self, blob: bytes, key: Optional[str] = None):
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
         """
         Processing raw messages, not decoded JSON objects
         """
 
+        assert isinstance(record, bytes)
+        blob: bytes = record
         if isinstance(key, bytes):
             key = key.decode('utf-8')
         assert key is not None and len(key) == 40 and isinstance(key, str)
-        assert isinstance(blob, bytes)
         assert len(blob) >= 50
 
         self.s3.put_blob(
@@ -607,7 +600,7 @@ class GenericPersistDocWorker(SandcrawlerWorker):
         self.s3_folder = kwargs.get('s3_folder', "unknown")
         self.doc_key = "unknown"
 
-    def process(self, record: dict, key: Optional[AnyStr] = None) -> None:
+    def process(self, record: Any, key: Optional[str] = None) -> Any:
 
         if record.get('status') != 'success' or not record.get(self.doc_key):
             return
-- 
cgit v1.2.3