aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Czygan <martin.czygan@gmail.com>2020-11-12 01:14:34 +0100
committerMartin Czygan <martin.czygan@gmail.com>2020-11-12 01:14:34 +0100
commit0f9a45f3c4657eb4372b96e2a6c550c4a01226e5 (patch)
treed217b6e82de697b30c8b5b7f0de438e170291f1e
parent037b077efba7ee351fe2c6b5d0217c6d426261e0 (diff)
downloadfuzzycat-0f9a45f3c4657eb4372b96e2a6c550c4a01226e5.tar.gz
fuzzycat-0f9a45f3c4657eb4372b96e2a6c550c4a01226e5.zip
move fileinput.input out of the cluster
The cluster class should work with iterable, so testing will be easier.
-rw-r--r--fuzzycat/__main__.py6
-rw-r--r--fuzzycat/cluster.py149
2 files changed, 77 insertions, 78 deletions
diff --git a/fuzzycat/__main__.py b/fuzzycat/__main__.py
index 900d5c0..3845245 100644
--- a/fuzzycat/__main__.py
+++ b/fuzzycat/__main__.py
@@ -13,8 +13,8 @@ Run, e.g. fuzzycat cluster --help for more options. Example:
import argparse
import cProfile as profile
import fileinput
-import json
import io
+import json
import logging
import pstats
import sys
@@ -33,8 +33,8 @@ def run_cluster(args):
'tnysi': release_key_title_nysiis,
'tss': release_key_title_ngram,
}
- cluster = Cluster(files=args.files,
- keyfunc=types.get(args.type),
+ cluster = Cluster(iterable=fileinput.input(args.files),
+ key=types.get(args.type),
tmpdir=args.tmpdir,
prefix=args.prefix)
stats = cluster.run()
diff --git a/fuzzycat/cluster.py b/fuzzycat/cluster.py
index dd55a24..cc74deb 100644
--- a/fuzzycat/cluster.py
+++ b/fuzzycat/cluster.py
@@ -70,7 +70,7 @@ import subprocess
import sys
import tempfile
from dataclasses import dataclass, field
-from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
+from typing import IO, Any, Callable, Dict, Generator, List, Optional, Tuple
import fuzzy
@@ -179,47 +179,6 @@ def release_key_title_ngram(doc: KeyDoc, n=3) -> Tuple[str, str]:
return (ident, key)
-def sort_by_column(filename: str,
- opts: str = "-k 2",
- fast: bool = True,
- mode: str = "w",
- prefix: str = "fuzzycat-",
- tmpdir: Optional[str] = None):
- """
- Sort tabular file with sort(1), returns the filename of the sorted file.
- TODO: use separate /fast/tmp for sort.
- """
- with tempfile.NamedTemporaryFile(delete=False, mode=mode, prefix=prefix) as tf:
- env = os.environ.copy()
- if tmpdir is not None:
- env["TMPDIR"] = tmpdir
- if fast:
- env["LC_ALL"] = "C"
- subprocess.run(["sort"] + opts.split() + [filename], stdout=tf, env=env, check=True)
-
- return tf.name
-
-
-def group_by(seq: collections.abc.Iterable,
- key: Callable[[Any], str] = None,
- value: Callable[[Any], str] = None,
- comment: str = "") -> Generator[Any, None, None]:
- """
- Iterate over lines in filename, group by key (a callable deriving the key
- from the line), then apply value callable on the same value to emit a
- minimal document, containing the key and identifiers belonging to a
- cluster.
- """
- for k, g in itertools.groupby(seq, key=key):
- doc = {
- "k": k.strip(),
- "v": [value(v) for v in g],
- }
- if comment:
- doc["c"] = comment
- yield doc
-
-
def cut(f: int = 0, sep: str = '\t', ignore_missing_column: bool = True):
"""
Return a callable, that extracts a given column from a file with a specific
@@ -238,48 +197,88 @@ def cut(f: int = 0, sep: str = '\t', ignore_missing_column: bool = True):
class Cluster:
"""
- Cluster scaffold for release entities. XXX: move IO/files out, allow any iterable.
+ Runs clustering over a potentially large number of records.
"""
def __init__(self,
- files="-",
- output=sys.stdout,
- keyfunc=lambda v: v,
- prefix='fuzzycat-',
- tmpdir=None):
+ iterable: collections.abc.Iterable,
+ key: Callable[[Any], Tuple[str, str]],
+ output: IO[str] = sys.stdout,
+ prefix: str = "fuzzycat-",
+ tmpdir: str = tempfile.gettempdir(),
+ strict: bool = False):
"""
- Files can be a list of files or "-" for stdin.
+ Setup a clusterer, using a custom key function.
"""
- self.files = files
- self.keyfunc = keyfunc
- self.output = output
- self.prefix = prefix
- self.tmpdir = tmpdir
- self.logger = logging.getLogger('fuzzycat.cluster')
+ self.iterable: collections.abc.Iterable = iterable
+ self.key: Callable[[Any], Tuple[str, str]] = key
+ self.output: IO[str] = output
+ self.prefix: str = prefix
+ self.tmpdir: str = tmpdir
+ self.counter: Dict[str, int] = collections.Counter({
+ "key_err": 0,
+ "key_ok": 0,
+ "num_clusters": 0,
+ })
def run(self):
"""
- Run clustering and write output to given stream or file.
+ First map documents to keys, then group by keys.
"""
- keyfunc = self.keyfunc # Save a lookup in loop.
- counter: Dict[str, int] = collections.Counter()
with tempfile.NamedTemporaryFile(delete=False, mode="w", prefix=self.prefix) as tf:
- for line in fileinput.input(files=self.files):
+ for line in self.iterable:
try:
- id, key = keyfunc(json.loads(line))
- print("{}\t{}".format(id, key), file=tf)
+ doc = json.loads(line)
+ id, key = self.key(doc)
+ # XXX: if the line itself contains tabs, we need to remove
+ # them here; maybe offer TSV and JSON output and extra flag
+ print("{}\t{}\t{}".format(id, key, line.replace("\t", " ")), file=tf)
except (KeyError, ValueError):
- counter["key_extraction_failed"] += 1
+ if strict:
+ raise
+ self.counter["key_err"] += 1
else:
- counter["key_ok"] += 1
- sbc = sort_by_column(tf.name, opts='-k 2', prefix=self.prefix, tmpdir=self.tmpdir)
- with open(sbc) as f:
- comment = keyfunc.__name__
- for doc in group_by(f, key=cut(f=1), value=cut(f=0), comment=comment):
- counter["groups"] += 1
- json.dump(doc, self.output)
- self.output.write("\n")
-
- os.remove(sbc)
- os.remove(tf.name)
-
- return counter
+ self.counter["key_ok"] += 1
+
+ try:
+ sf = self.sort(tf.name, opts='-k 2')
+ with open(sf) as f:
+ for doc in self.group_by(f, key=cut(f=1), value=cut(f=0)):
+ self.counter["num_clusters"] += 1
+ json.dump(doc, self.output)
+ self.output.write("\n")
+ except Exception as exc:
+ raise
+ finally:
+ os.remove(sf)
+ os.remove(tf.name)
+
+ return self.counter
+
+ def sort(self, filename: str, opts: str = "-k 2", fast: bool = True, mode: str = "w"):
+ """
+ Sort tabular file with sort(1), returns the filename of the sorted file.
+ TODO: use separate /fast/tmp for sort.
+ """
+ with tempfile.NamedTemporaryFile(delete=False, mode=mode, prefix=self.prefix) as tf:
+ env = os.environ.copy()
+ env["TMPDIR"] = self.tmpdir
+ if fast:
+ env["LC_ALL"] = "C"
+ subprocess.run(["sort"] + opts.split() + [filename], stdout=tf, env=env, check=True)
+
+ return tf.name
+
+ def group_by(self,
+ seq: collections.abc.Iterable,
+ key: Callable[[Any], str] = None,
+ value: Callable[[Any], str] = None) -> Generator[Any, None, None]:
+ """
+ Extract a key from elements of an iterable and group them. Just as
+ uniq(1), the iterable must be ordered for this to work.
+ """
+ for k, g in itertools.groupby(seq, key=key):
+ doc = {
+ "k": k.strip(),
+ "v": [value(v) for v in g],
+ }
+ yield doc