From 403fa63a7b2ab015b86e35e577218d045d1ea6ca Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Thu, 4 Jun 2020 00:09:56 -0700 Subject: more type annotations and fixes --- fatcat_scholar/grobid2json.py | 38 +++++++++++++++++++++----------------- fatcat_scholar/sandcrawler.py | 4 ++-- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/fatcat_scholar/grobid2json.py b/fatcat_scholar/grobid2json.py index 4019363..979a794 100755 --- a/fatcat_scholar/grobid2json.py +++ b/fatcat_scholar/grobid2json.py @@ -29,12 +29,15 @@ import io import json import argparse import xml.etree.ElementTree as ET +from typing import List, Any, Dict, AnyStr, Optional xml_ns = "http://www.w3.org/XML/1998/namespace" ns = "http://www.tei-c.org/ns/1.0" -def all_authors(elem): +def all_authors(elem: Optional[ET.Element]) -> List[Dict[str, Any]]: + if not elem: + return [] names = [] for author in elem.findall(".//{%s}author" % ns): pn = author.find("./{%s}persName" % ns) @@ -43,16 +46,18 @@ def all_authors(elem): given_name = pn.findtext("./{%s}forename" % ns) or None surname = pn.findtext("./{%s}surname" % ns) or None full_name = " ".join(pn.itertext()) - obj = dict(name=full_name) + obj: Dict[str, Any] = dict(name=full_name) if given_name: obj["given_name"] = given_name if surname: obj["surname"] = surname ae = author.find("./{%s}affiliation" % ns) if ae: - affiliation = dict() + affiliation: Dict[str, Any] = dict() for on in ae.findall("./{%s}orgName" % ns): - affiliation[on.get("type")] = on.text + on_type = on.get("type") + if on_type: + affiliation[on_type] = on.text addr_e = ae.find("./{%s}address" % ns) if addr_e: address = dict() @@ -70,7 +75,7 @@ def all_authors(elem): return names -def journal_info(elem): +def journal_info(elem: ET.Element) -> Dict[str, Any]: journal = dict() journal["name"] = elem.findtext(".//{%s}monogr/{%s}title" % (ns, ns)) journal["publisher"] = elem.findtext( @@ -91,8 +96,8 @@ def journal_info(elem): return journal -def biblio_info(elem): - ref = dict() +def biblio_info(elem: ET.Element) -> Dict[str, Any]: + ref: Dict[str, Any] = dict() ref["id"] = elem.attrib.get("{http://www.w3.org/XML/1998/namespace}id") # Title stuff is messy in references... ref["title"] = elem.findtext(".//{%s}analytic/{%s}title" % (ns, ns)) @@ -122,18 +127,17 @@ def biblio_info(elem): return ref -def teixml2json(content, encumbered=True): +def teixml2json(content: AnyStr, encumbered: bool = True) -> Dict[str, Any]: - if type(content) == str: - content = io.StringIO(content) - elif type(content) == bytes: - content = io.BytesIO(content) + if isinstance(content, str): + tree = ET.parse(io.StringIO(content)) + elif isinstance(content, bytes): + tree = ET.parse(io.BytesIO(content)) - info = dict() + info: Dict[str, Any] = dict() # print(content) # print(content.getvalue()) - tree = ET.parse(content) tei = tree.getroot() header = tei.find(".//{%s}teiHeader" % ns) @@ -163,7 +167,7 @@ def teixml2json(content, encumbered=True): text = tei.find(".//{%s}text" % (ns)) # print(text.attrib) - if text.attrib.get("{%s}lang" % xml_ns): + if text and text.attrib.get("{%s}lang" % xml_ns): info["language_code"] = text.attrib["{%s}lang" % xml_ns] # xml:lang if encumbered: @@ -184,7 +188,7 @@ def teixml2json(content, encumbered=True): return info -def main(): # pragma no cover +def main() -> None: # pragma no cover parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="GROBID TEI XML to JSON", @@ -200,7 +204,7 @@ def main(): # pragma no cover args = parser.parse_args() for filename in args.teifiles: - content = open(filename, "r") + content = open(filename, "r").read() print( json.dumps( teixml2json(content, encumbered=(not args.no_encumbered)), diff --git a/fatcat_scholar/sandcrawler.py b/fatcat_scholar/sandcrawler.py index 347364f..66d5686 100644 --- a/fatcat_scholar/sandcrawler.py +++ b/fatcat_scholar/sandcrawler.py @@ -41,7 +41,7 @@ class SandcrawlerMinioClient(object): ) self.default_bucket = default_bucket - def _blob_path(self, folder, sha1hex, extension, prefix): + def _blob_path(self, folder: str, sha1hex: str, extension: str, prefix: str) -> str: if not extension: extension = "" if not prefix: @@ -52,7 +52,7 @@ class SandcrawlerMinioClient(object): ) return obj_path - def get_blob(self, folder, sha1hex, extension="", prefix="", bucket=None): + def get_blob(self, folder: str, sha1hex: str, extension: str ="", prefix: str ="", bucket: Optional[str] = None) -> bytes: """ sha1hex is sha1 of the blob itself -- cgit v1.2.3