aboutsummaryrefslogtreecommitdiffstats
path: root/fatcat_scholar
diff options
context:
space:
mode:
Diffstat (limited to 'fatcat_scholar')
-rwxr-xr-xfatcat_scholar/grobid2json.py38
-rw-r--r--fatcat_scholar/sandcrawler.py4
2 files changed, 23 insertions, 19 deletions
diff --git a/fatcat_scholar/grobid2json.py b/fatcat_scholar/grobid2json.py
index 4019363..979a794 100755
--- a/fatcat_scholar/grobid2json.py
+++ b/fatcat_scholar/grobid2json.py
@@ -29,12 +29,15 @@ import io
import json
import argparse
import xml.etree.ElementTree as ET
+from typing import List, Any, Dict, AnyStr, Optional
xml_ns = "http://www.w3.org/XML/1998/namespace"
ns = "http://www.tei-c.org/ns/1.0"
-def all_authors(elem):
+def all_authors(elem: Optional[ET.Element]) -> List[Dict[str, Any]]:
+ if not elem:
+ return []
names = []
for author in elem.findall(".//{%s}author" % ns):
pn = author.find("./{%s}persName" % ns)
@@ -43,16 +46,18 @@ def all_authors(elem):
given_name = pn.findtext("./{%s}forename" % ns) or None
surname = pn.findtext("./{%s}surname" % ns) or None
full_name = " ".join(pn.itertext())
- obj = dict(name=full_name)
+ obj: Dict[str, Any] = dict(name=full_name)
if given_name:
obj["given_name"] = given_name
if surname:
obj["surname"] = surname
ae = author.find("./{%s}affiliation" % ns)
if ae:
- affiliation = dict()
+ affiliation: Dict[str, Any] = dict()
for on in ae.findall("./{%s}orgName" % ns):
- affiliation[on.get("type")] = on.text
+ on_type = on.get("type")
+ if on_type:
+ affiliation[on_type] = on.text
addr_e = ae.find("./{%s}address" % ns)
if addr_e:
address = dict()
@@ -70,7 +75,7 @@ def all_authors(elem):
return names
-def journal_info(elem):
+def journal_info(elem: ET.Element) -> Dict[str, Any]:
journal = dict()
journal["name"] = elem.findtext(".//{%s}monogr/{%s}title" % (ns, ns))
journal["publisher"] = elem.findtext(
@@ -91,8 +96,8 @@ def journal_info(elem):
return journal
-def biblio_info(elem):
- ref = dict()
+def biblio_info(elem: ET.Element) -> Dict[str, Any]:
+ ref: Dict[str, Any] = dict()
ref["id"] = elem.attrib.get("{http://www.w3.org/XML/1998/namespace}id")
# Title stuff is messy in references...
ref["title"] = elem.findtext(".//{%s}analytic/{%s}title" % (ns, ns))
@@ -122,18 +127,17 @@ def biblio_info(elem):
return ref
-def teixml2json(content, encumbered=True):
+def teixml2json(content: AnyStr, encumbered: bool = True) -> Dict[str, Any]:
- if type(content) == str:
- content = io.StringIO(content)
- elif type(content) == bytes:
- content = io.BytesIO(content)
+ if isinstance(content, str):
+ tree = ET.parse(io.StringIO(content))
+ elif isinstance(content, bytes):
+ tree = ET.parse(io.BytesIO(content))
- info = dict()
+ info: Dict[str, Any] = dict()
# print(content)
# print(content.getvalue())
- tree = ET.parse(content)
tei = tree.getroot()
header = tei.find(".//{%s}teiHeader" % ns)
@@ -163,7 +167,7 @@ def teixml2json(content, encumbered=True):
text = tei.find(".//{%s}text" % (ns))
# print(text.attrib)
- if text.attrib.get("{%s}lang" % xml_ns):
+ if text and text.attrib.get("{%s}lang" % xml_ns):
info["language_code"] = text.attrib["{%s}lang" % xml_ns] # xml:lang
if encumbered:
@@ -184,7 +188,7 @@ def teixml2json(content, encumbered=True):
return info
-def main(): # pragma no cover
+def main() -> None: # pragma no cover
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="GROBID TEI XML to JSON",
@@ -200,7 +204,7 @@ def main(): # pragma no cover
args = parser.parse_args()
for filename in args.teifiles:
- content = open(filename, "r")
+ content = open(filename, "r").read()
print(
json.dumps(
teixml2json(content, encumbered=(not args.no_encumbered)),
diff --git a/fatcat_scholar/sandcrawler.py b/fatcat_scholar/sandcrawler.py
index 347364f..66d5686 100644
--- a/fatcat_scholar/sandcrawler.py
+++ b/fatcat_scholar/sandcrawler.py
@@ -41,7 +41,7 @@ class SandcrawlerMinioClient(object):
)
self.default_bucket = default_bucket
- def _blob_path(self, folder, sha1hex, extension, prefix):
+ def _blob_path(self, folder: str, sha1hex: str, extension: str, prefix: str) -> str:
if not extension:
extension = ""
if not prefix:
@@ -52,7 +52,7 @@ class SandcrawlerMinioClient(object):
)
return obj_path
- def get_blob(self, folder, sha1hex, extension="", prefix="", bucket=None):
+ def get_blob(self, folder: str, sha1hex: str, extension: str ="", prefix: str ="", bucket: Optional[str] = None) -> bytes:
"""
sha1hex is sha1 of the blob itself