diff options
author | Martin Czygan <martin.czygan@gmail.com> | 2020-11-12 01:14:34 +0100 |
---|---|---|
committer | Martin Czygan <martin.czygan@gmail.com> | 2020-11-12 01:14:34 +0100 |
commit | 0f9a45f3c4657eb4372b96e2a6c550c4a01226e5 (patch) | |
tree | d217b6e82de697b30c8b5b7f0de438e170291f1e | |
parent | 037b077efba7ee351fe2c6b5d0217c6d426261e0 (diff) | |
download | fuzzycat-0f9a45f3c4657eb4372b96e2a6c550c4a01226e5.tar.gz fuzzycat-0f9a45f3c4657eb4372b96e2a6c550c4a01226e5.zip |
move fileinput.input out of the cluster
The cluster class should work with iterable, so testing will be easier.
-rw-r--r-- | fuzzycat/__main__.py | 6 | ||||
-rw-r--r-- | fuzzycat/cluster.py | 149 |
2 files changed, 77 insertions, 78 deletions
diff --git a/fuzzycat/__main__.py b/fuzzycat/__main__.py index 900d5c0..3845245 100644 --- a/fuzzycat/__main__.py +++ b/fuzzycat/__main__.py @@ -13,8 +13,8 @@ Run, e.g. fuzzycat cluster --help for more options. Example: import argparse import cProfile as profile import fileinput -import json import io +import json import logging import pstats import sys @@ -33,8 +33,8 @@ def run_cluster(args): 'tnysi': release_key_title_nysiis, 'tss': release_key_title_ngram, } - cluster = Cluster(files=args.files, - keyfunc=types.get(args.type), + cluster = Cluster(iterable=fileinput.input(args.files), + key=types.get(args.type), tmpdir=args.tmpdir, prefix=args.prefix) stats = cluster.run() diff --git a/fuzzycat/cluster.py b/fuzzycat/cluster.py index dd55a24..cc74deb 100644 --- a/fuzzycat/cluster.py +++ b/fuzzycat/cluster.py @@ -70,7 +70,7 @@ import subprocess import sys import tempfile from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple +from typing import IO, Any, Callable, Dict, Generator, List, Optional, Tuple import fuzzy @@ -179,47 +179,6 @@ def release_key_title_ngram(doc: KeyDoc, n=3) -> Tuple[str, str]: return (ident, key) -def sort_by_column(filename: str, - opts: str = "-k 2", - fast: bool = True, - mode: str = "w", - prefix: str = "fuzzycat-", - tmpdir: Optional[str] = None): - """ - Sort tabular file with sort(1), returns the filename of the sorted file. - TODO: use separate /fast/tmp for sort. - """ - with tempfile.NamedTemporaryFile(delete=False, mode=mode, prefix=prefix) as tf: - env = os.environ.copy() - if tmpdir is not None: - env["TMPDIR"] = tmpdir - if fast: - env["LC_ALL"] = "C" - subprocess.run(["sort"] + opts.split() + [filename], stdout=tf, env=env, check=True) - - return tf.name - - -def group_by(seq: collections.abc.Iterable, - key: Callable[[Any], str] = None, - value: Callable[[Any], str] = None, - comment: str = "") -> Generator[Any, None, None]: - """ - Iterate over lines in filename, group by key (a callable deriving the key - from the line), then apply value callable on the same value to emit a - minimal document, containing the key and identifiers belonging to a - cluster. - """ - for k, g in itertools.groupby(seq, key=key): - doc = { - "k": k.strip(), - "v": [value(v) for v in g], - } - if comment: - doc["c"] = comment - yield doc - - def cut(f: int = 0, sep: str = '\t', ignore_missing_column: bool = True): """ Return a callable, that extracts a given column from a file with a specific @@ -238,48 +197,88 @@ def cut(f: int = 0, sep: str = '\t', ignore_missing_column: bool = True): class Cluster: """ - Cluster scaffold for release entities. XXX: move IO/files out, allow any iterable. + Runs clustering over a potentially large number of records. """ def __init__(self, - files="-", - output=sys.stdout, - keyfunc=lambda v: v, - prefix='fuzzycat-', - tmpdir=None): + iterable: collections.abc.Iterable, + key: Callable[[Any], Tuple[str, str]], + output: IO[str] = sys.stdout, + prefix: str = "fuzzycat-", + tmpdir: str = tempfile.gettempdir(), + strict: bool = False): """ - Files can be a list of files or "-" for stdin. + Setup a clusterer, using a custom key function. """ - self.files = files - self.keyfunc = keyfunc - self.output = output - self.prefix = prefix - self.tmpdir = tmpdir - self.logger = logging.getLogger('fuzzycat.cluster') + self.iterable: collections.abc.Iterable = iterable + self.key: Callable[[Any], Tuple[str, str]] = key + self.output: IO[str] = output + self.prefix: str = prefix + self.tmpdir: str = tmpdir + self.counter: Dict[str, int] = collections.Counter({ + "key_err": 0, + "key_ok": 0, + "num_clusters": 0, + }) def run(self): """ - Run clustering and write output to given stream or file. + First map documents to keys, then group by keys. """ - keyfunc = self.keyfunc # Save a lookup in loop. - counter: Dict[str, int] = collections.Counter() with tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=self.prefix) as tf: - for line in fileinput.input(files=self.files): + for line in self.iterable: try: - id, key = keyfunc(json.loads(line)) - print("{}\t{}".format(id, key), file=tf) + doc = json.loads(line) + id, key = self.key(doc) + # XXX: if the line itself contains tabs, we need to remove + # them here; maybe offer TSV and JSON output and extra flag + print("{}\t{}\t{}".format(id, key, line.replace("\t", " ")), file=tf) except (KeyError, ValueError): - counter["key_extraction_failed"] += 1 + if strict: + raise + self.counter["key_err"] += 1 else: - counter["key_ok"] += 1 - sbc = sort_by_column(tf.name, opts='-k 2', prefix=self.prefix, tmpdir=self.tmpdir) - with open(sbc) as f: - comment = keyfunc.__name__ - for doc in group_by(f, key=cut(f=1), value=cut(f=0), comment=comment): - counter["groups"] += 1 - json.dump(doc, self.output) - self.output.write("\n") - - os.remove(sbc) - os.remove(tf.name) - - return counter + self.counter["key_ok"] += 1 + + try: + sf = self.sort(tf.name, opts='-k 2') + with open(sf) as f: + for doc in self.group_by(f, key=cut(f=1), value=cut(f=0)): + self.counter["num_clusters"] += 1 + json.dump(doc, self.output) + self.output.write("\n") + except Exception as exc: + raise + finally: + os.remove(sf) + os.remove(tf.name) + + return self.counter + + def sort(self, filename: str, opts: str = "-k 2", fast: bool = True, mode: str = "w"): + """ + Sort tabular file with sort(1), returns the filename of the sorted file. + TODO: use separate /fast/tmp for sort. + """ + with tempfile.NamedTemporaryFile(delete=False, mode=mode, prefix=self.prefix) as tf: + env = os.environ.copy() + env["TMPDIR"] = self.tmpdir + if fast: + env["LC_ALL"] = "C" + subprocess.run(["sort"] + opts.split() + [filename], stdout=tf, env=env, check=True) + + return tf.name + + def group_by(self, + seq: collections.abc.Iterable, + key: Callable[[Any], str] = None, + value: Callable[[Any], str] = None) -> Generator[Any, None, None]: + """ + Extract a key from elements of an iterable and group them. Just as + uniq(1), the iterable must be ordered for this to work. + """ + for k, g in itertools.groupby(seq, key=key): + doc = { + "k": k.strip(), + "v": [value(v) for v in g], + } + yield doc |