aboutsummaryrefslogtreecommitdiffstats
path: root/python/refcat/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/refcat/base.py')
-rw-r--r--python/refcat/base.py337
1 files changed, 337 insertions, 0 deletions
diff --git a/python/refcat/base.py b/python/refcat/base.py
new file mode 100644
index 0000000..240e4a0
--- /dev/null
+++ b/python/refcat/base.py
@@ -0,0 +1,337 @@
+"""
+Default task
+============
+
+A default task, that covers file system layout.
+"""
+
+import datetime
+import hashlib
+import os
+import random
+import re
+import string
+import subprocess
+import tempfile
+
+import luigi
+
+__all__ = [
+ 'BaseTask',
+ 'ClosestDateParameter',
+ 'Gzip',
+ 'TSV',
+ 'Zstd',
+ 'random_string',
+ 'shellout',
+]
+
+logger = logging.getLogger('refcat')
+
+
+class ClosestDateParameter(luigi.DateParameter):
+ """
+ A marker parameter to replace date parameter value with whatever
+ self.closest() returns. Use in conjunction with `gluish.task.BaseTask`.
+ """
+ use_closest_date = True
+
+
+def is_closest_date_parameter(task, param_name):
+ """ Return the parameter class of param_name on task. """
+ for name, obj in task.get_params():
+ if name == param_name:
+ return hasattr(obj, 'use_closest_date')
+ return False
+
+
+def delistify(x):
+ """ A basic slug version of a given parameter list. """
+ if isinstance(x, list):
+ x = [e.replace("'", "") for e in x]
+ return '-'.join(sorted(x))
+ return x
+
+
+class BaseTask(luigi.Task):
+ """
+ A base task with a `path` method. BASE should be set to the root
+ directory of all tasks. TAG is a shard for a group of related tasks.
+ """
+ BASE = tempfile.gettempdir()
+ TAG = 'default'
+
+ def closest(self):
+ """ Return the closest date for a given date.
+ Defaults to the same date. """
+ if not hasattr(self, 'date'):
+ raise AttributeError('Task has no date attribute.')
+ return self.date
+
+ def effective_task_id(self):
+ """ Replace date in task id with closest date. """
+ params = self.param_kwargs
+ if 'date' in params and is_closest_date_parameter(self, 'date'):
+ params['date'] = self.closest()
+ task_id_parts = sorted(['%s=%s' % (k, str(v)) for k, v in params.items()])
+ return '%s(%s)' % (self.task_family, ', '.join(task_id_parts))
+ else:
+ return self.task_id
+
+ def taskdir(self):
+ """ Return the directory under which all artefacts are stored. """
+ return os.path.join(self.BASE, self.TAG, self.task_family)
+
+ def path(self, filename=None, ext='tsv', digest=False, shard=False, encoding='utf-8'):
+ """
+ Return the path for this class with a certain set of parameters.
+ `ext` sets the extension of the file.
+ If `hash` is true, the filename (w/o extenstion) will be hashed.
+ If `shard` is true, the files are placed in shards, based on the first
+ two chars of the filename (hashed).
+ """
+ if self.BASE is NotImplemented:
+ raise RuntimeError('BASE directory must be set.')
+
+ params = dict(self.get_params())
+
+ if filename is None:
+ parts = []
+
+ for name, param in self.get_params():
+ if not param.significant:
+ continue
+ if name == 'date' and is_closest_date_parameter(self, 'date'):
+ parts.append('date-%s' % self.closest())
+ continue
+ if hasattr(param, 'is_list') and param.is_list:
+ es = '-'.join([str(v) for v in getattr(self, name)])
+ parts.append('%s-%s' % (name, es))
+ continue
+
+ val = getattr(self, name)
+
+ if isinstance(val, datetime.datetime):
+ val = val.strftime('%Y-%m-%dT%H%M%S')
+ elif isinstance(val, datetime.date):
+ val = val.strftime('%Y-%m-%d')
+
+ parts.append('%s-%s' % (name, val))
+
+ name = '-'.join(sorted(parts))
+ if len(name) == 0:
+ name = 'output'
+ if digest:
+ name = hashlib.sha1(name.encode(encoding)).hexdigest()
+ if not ext:
+ filename = '{fn}'.format(ext=ext, fn=name)
+ else:
+ filename = '{fn}.{ext}'.format(ext=ext, fn=name)
+ if shard:
+ prefix = hashlib.sha1(filename.encode(encoding)).hexdigest()[:2]
+ return os.path.join(self.BASE, self.TAG, self.task_family, prefix, filename)
+
+ return os.path.join(self.BASE, self.TAG, self.task_family, filename)
+
+
+def shellout(template,
+ preserve_whitespace=False,
+ executable='/bin/bash',
+ ignoremap=None,
+ encoding=None,
+ pipefail=True,
+ **kwargs):
+ """
+
+ Takes a shell command template and executes it. The template must use the
+ new (2.6+) format mini language. `kwargs` must contain any defined
+ placeholder, only `output` is optional and will be autofilled with a
+ temporary file if it used, but not specified explicitly.
+
+ If `pipefail` is `False` no subshell environment will be spawned, where a
+ failed pipe will cause an error as well. If `preserve_whitespace` is `True`,
+ no whitespace normalization is performed. A custom shell executable name can
+ be passed in `executable` and defaults to `/bin/bash`.
+
+ Raises RuntimeError on nonzero exit codes. To ignore certain errors, pass a
+ dictionary in `ignoremap`, with the error code to ignore as key and a string
+ message as value.
+
+ Simple template:
+
+ wc -l < {input} > {output}
+
+ Quoted curly braces:
+
+ ps ax|awk '{{print $1}}' > {output}
+
+ Usage with luigi:
+
+ ...
+ tmp = shellout('wc -l < {input} > {output}', input=self.input().path)
+ luigi.LocalTarget(tmp).move(self.output().path)
+ ....
+
+ """
+ if not 'output' in kwargs:
+ kwargs.update({'output': tempfile.mkstemp(prefix='refcat-')[1]})
+ if ignoremap is None:
+ ignoremap = {}
+ if encoding:
+ command = template.decode(encoding).format(**kwargs)
+ else:
+ command = template.format(**kwargs)
+ if not preserve_whitespace:
+ command = re.sub('[ \t\n]+', ' ', command)
+ if pipefail:
+ command = '(set -o pipefail && %s)' % command
+ logger.debug(command)
+ code = subprocess.call([command], shell=True, executable=executable)
+ if not code == 0:
+ if code in ignoremap:
+ logger.info("Ignoring error via ignoremap: %s" % ignoremap.get(code))
+ else:
+ logger.error('%s: %s' % (command, code))
+ error = RuntimeError('%s exitcode: %s' % (command, code))
+ error.code = code
+ raise error
+ return kwargs.get('output')
+
+
+def random_string(length=16):
+ """
+ Return a random string (upper and lowercase letters) of length `length`,
+ defaults to 16.
+ """
+ return ''.join(random.choice(string.ascii_letters) for _ in range(length))
+
+
+def which(program):
+ """
+ Search for program in PATH.
+ """
+ def is_exe(fpath):
+ return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
+
+ fpath, fname = os.path.split(program)
+ if fpath:
+ if is_exe(program):
+ return program
+ else:
+ for path in os.environ["PATH"].split(os.pathsep):
+ path = path.strip('"')
+ exe_file = os.path.join(path, program)
+ if is_exe(exe_file):
+ return exe_file
+
+ return None
+
+
+def write_tsv(output_stream, *tup, **kwargs):
+ """
+ Write argument list in `tup` out as a tab-separeated row to the stream.
+ """
+ encoding = kwargs.get('encoding') or 'utf-8'
+ value = u'\t'.join([s for s in tup]) + '\n'
+ if encoding is None:
+ if isinstance(value, str):
+ output_stream.write(value.encode('utf-8'))
+ else:
+ output_stream.write(value)
+ else:
+ output_stream.write(value.encode(encoding))
+
+
+def iter_tsv(input_stream, cols=None, encoding='utf-8'):
+ """
+ If a tuple is given in cols, use the elements as names to construct
+ a namedtuple.
+ Columns can be marked as ignored by using ``X`` or ``0`` as column name.
+ Example (ignore the first four columns of a five column TSV):
+ ::
+ def run(self):
+ with self.input().open() as handle:
+ for row in handle.iter_tsv(cols=('X', 'X', 'X', 'X', 'iln')):
+ print(row.iln)
+ """
+ if cols:
+ cols = [c if not c in ('x', 'X', 0, None) else random_string(length=5) for c in cols]
+ Record = collections.namedtuple('Record', cols)
+ for line in input_stream:
+ yield Record._make(line.decode(encoding).rstrip('\n').split('\t'))
+ else:
+ for line in input_stream:
+ yield tuple(line.decode(encoding).rstrip('\n').split('\t'))
+
+
+class TSVFormat(luigi.format.Format):
+ """
+ A basic CSV/TSV format.
+ Discussion: https://groups.google.com/forum/#!topic/luigi-user/F813st16xqw
+ """
+ def hdfs_reader(self, input_pipe):
+ raise NotImplementedError()
+
+ def hdfs_writer(self, output_pipe):
+ raise NotImplementedError()
+
+ def pipe_reader(self, input_pipe):
+ input_pipe.iter_tsv = functools.partial(iter_tsv, input_pipe)
+ return input_pipe
+
+ def pipe_writer(self, output_pipe):
+ output_pipe.write_tsv = functools.partial(write_tsv, output_pipe)
+ return output_pipe
+
+
+class GzipFormat(luigi.format.Format):
+ """
+ A gzip format, that upgrades itself to pigz, if it's installed.
+ """
+ input = 'bytes'
+ output = 'bytes'
+
+ def __init__(self, compression_level=None):
+ self.compression_level = compression_level
+ self.gzip = ["gzip"]
+ self.gunzip = ["gunzip"]
+
+ if which('pigz'):
+ self.gzip = ["pigz"]
+ self.gunzip = ["unpigz"]
+
+ def pipe_reader(self, input_pipe):
+ return luigi.format.InputPipeProcessWrapper(self.gunzip, input_pipe)
+
+ def pipe_writer(self, output_pipe):
+ args = self.gzip
+ if self.compression_level is not None:
+ args.append('-' + str(int(self.compression_level)))
+ return luigi.format.OutputPipeProcessWrapper(args, output_pipe)
+
+
+class ZstdFormat(luigi.format.Format):
+ """
+ The zstandard format.
+ """
+ input = 'bytes'
+ output = 'bytes'
+
+ def __init__(self, compression_level=None):
+ self.compression_level = compression_level
+ self.zstd = ["zstd"]
+ self.unzstd = ["unzstd"]
+
+ def pipe_reader(self, input_pipe):
+ return luigi.format.InputPipeProcessWrapper(self.unzstd, input_pipe)
+
+ def pipe_writer(self, output_pipe):
+ args = self.zstd
+ if self.compression_level is not None:
+ args.append('-' + str(int(self.compression_level)))
+ return luigi.format.OutputPipeProcessWrapper(args, output_pipe)
+
+
+TSV = TSVFormat()
+Gzip = GzipFormat()
+Zstd = ZstdFormat()