import bz2
import json
from logging import getLogger
import luigi
from luigi import task_register
logger = getLogger(__name__)
[docs]class TaskInstanceParameter(luigi.Parameter):
@staticmethod
def _recursive(param_dict):
params = param_dict['params']
task_cls = task_register.Register.get_task_cls(param_dict['type'])
for key, value in task_cls.get_params():
if key in params:
params[key] = value.parse(params[key])
return task_cls(**params)
@staticmethod
def _recursive_decompress(s):
s = dict(luigi.DictParameter().parse(s))
if 'params' in s:
s['params'] = TaskInstanceParameter._recursive_decompress(bz2.decompress(bytes.fromhex(s['params'])).decode())
return s
[docs] def parse(self, s):
if isinstance(s, str):
s = self._recursive_decompress(s)
return self._recursive(s)
[docs] def serialize(self, x):
params = bz2.compress(json.dumps(x.to_str_params(only_significant=True)).encode()).hex()
values = dict(type=x.get_task_family(), params=params)
return luigi.DictParameter().serialize(values)
class _TaskInstanceEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, luigi.Task):
return TaskInstanceParameter().serialize(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
[docs]class ListTaskInstanceParameter(luigi.Parameter):
[docs] def parse(self, s):
return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))]
[docs] def serialize(self, x):
return json.dumps(x, cls=_TaskInstanceEncoder)
[docs]class ExplicitBoolParameter(luigi.BoolParameter):
def __init__(self, *args, **kwargs):
luigi.Parameter.__init__(self, *args, **kwargs)
def _parser_kwargs(self, *args, **kwargs):
return luigi.Parameter._parser_kwargs(*args, *kwargs)