Source code for gokart.task

import hashlib
import inspect
import os
import types
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Set, Union

import luigi
import pandas as pd

import gokart
from gokart.file_processor import FileProcessor
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.redis_lock import make_redis_params
from gokart.target import TargetOnKart

logger = getLogger(__name__)


[docs]class TaskOnKart(luigi.Task): """ This is a wrapper class of luigi.Task. The key methods of a TaskOnKart are: * :py:meth:`make_target` - this makes output target with a relative file path. * :py:meth:`make_model_target` - this makes output target for models which generate multiple files to save. * :py:meth:`load` - this loads input files of this task. * :py:meth:`dump` - this save a object as output of this task. """ workspace_directory = luigi.Parameter(default='./resources/', description='A directory to set outputs on. Please use a path starts with s3:// when you use s3.', significant=False) # type: str local_temporary_directory = luigi.Parameter(default='./resources/tmp/', description='A directory to save temporary files.', significant=False) # type: str rerun = luigi.BoolParameter(default=False, description='If this is true, this task will run even if all output files exist.', significant=False) strict_check = luigi.BoolParameter(default=False, description='If this is true, this task will not run only if all input and output files exist.', significant=False) modification_time_check = luigi.BoolParameter(default=False, description='If this is true, this task will not run only if all input and output files exist,' ' and all input files are modified before output file are modified.', significant=False) serialized_task_definition_check = luigi.BoolParameter(default=False, description='If this is true, even if all outputs are present,' 'this task will be executed if any changes have been made to the code.', significant=False) delete_unnecessary_output_files = luigi.BoolParameter(default=False, description='If this is true, delete unnecessary output files.', significant=False) significant = luigi.BoolParameter(default=True, description='If this is false, this task is not treated as a part of dependent tasks for the unique id.', significant=False) fix_random_seed_methods = luigi.ListParameter(default=['random.seed', 'numpy.random.seed'], description='Fix random seed method list.', significant=False) FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER = -42497368 fix_random_seed_value = luigi.IntParameter( default=FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER, description='Fix random seed method value.', significant=False) # FIXME: should fix with OptionalIntParameter after newer luigi (https://github.com/spotify/luigi/pull/3079) will be released redis_host = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False) redis_port = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False) redis_timeout = luigi.IntParameter(default=180, description='Redis lock will be released after `redis_timeout` seconds', significant=False) redis_fail_on_collision: bool = luigi.BoolParameter( default=False, description='True for failing the task immediately when the cache is locked, instead of waiting for the lock to be released', significant=False) fail_on_empty_dump: bool = ExplicitBoolParameter(default=False, description='Fail when task dumps empty DF', significant=False) store_index_in_feather: bool = ExplicitBoolParameter(default=True, description='Wether to store index when using feather as a output object.', significant=False) cache_unique_id: bool = ExplicitBoolParameter(default=True, description='Cache unique id during runtime', significant=False) should_dump_supplementary_log_files: bool = ExplicitBoolParameter( default=True, description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \ Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.', significant=False) def __init__(self, *args, **kwargs): self._add_configuration(kwargs, 'TaskOnKart') # 'This parameter is dumped into "workspace_directory/log/task_log/" when this task finishes with success.' self.task_log = dict() self.task_unique_id = None super(TaskOnKart, self).__init__(*args, **kwargs) self._rerun_state = self.rerun self._lock_at_dump = True
[docs] def output(self): return self.make_target()
[docs] def requires(self): tasks = self.make_task_instance_dictionary() return tasks or [] # when tasks is empty dict, then this returns empty list.
[docs] def make_task_instance_dictionary(self) -> Dict[str, 'TaskOnKart']: return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)}
[docs] @staticmethod def is_task_on_kart(value): return isinstance(value, TaskOnKart) or (isinstance(value, tuple) and bool(value) and all([isinstance(v, TaskOnKart) for v in value]))
@classmethod def _add_configuration(cls, kwargs, section): config = luigi.configuration.get_config() class_variables = dict(TaskOnKart.__dict__) class_variables.update(dict(cls.__dict__)) if section not in config: return for key, value in dict(config[section]).items(): if key not in kwargs and key in class_variables: kwargs[key] = class_variables[key].parse(value)
[docs] def complete(self) -> bool: if self._rerun_state: for target in luigi.task.flatten(self.output()): target.remove() self._rerun_state = False return False is_completed = all([t.exists() for t in luigi.task.flatten(self.output())]) if self.strict_check or self.modification_time_check: requirements = luigi.task.flatten(self.requires()) inputs = luigi.task.flatten(self.input()) is_completed = is_completed and all([task.complete() for task in requirements]) and all([i.exists() for i in inputs]) if not self.modification_time_check or not is_completed or not self.input(): return is_completed return self._check_modification_time()
def _check_modification_time(self): common_path = set(t.path() for t in luigi.task.flatten(self.input())) & set(t.path() for t in luigi.task.flatten(self.output())) input_tasks = [t for t in luigi.task.flatten(self.input()) if t.path() not in common_path] output_tasks = [t for t in luigi.task.flatten(self.output()) if t.path() not in common_path] input_modification_time = max([target.last_modification_time() for target in input_tasks]) if input_tasks else None output_modification_time = min([target.last_modification_time() for target in output_tasks]) if output_tasks else None if input_modification_time is None or output_modification_time is None: return True # "=" must be required in the following statements, because some tasks use input targets as output targets. return input_modification_time <= output_modification_time
[docs] def clone(self, cls=None, **kwargs): _SPECIAL_PARAMS = {'rerun', 'strict_check', 'modification_time_check'} if cls is None: cls = self.__class__ new_k = {} for param_name, param_class in cls.get_params(): if param_name in kwargs: new_k[param_name] = kwargs[param_name] elif hasattr(self, param_name) and (param_name not in _SPECIAL_PARAMS): new_k[param_name] = getattr(self, param_name) return cls(**new_k)
[docs] def make_target(self, relative_file_path: str = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: formatted_relative_file_path = relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace(".", "/"), f"{type(self).__name__}.pkl") file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_target(file_path=file_path, unique_id=unique_id, processor=processor, redis_params=redis_params, store_index_in_feather=self.store_index_in_feather)
[docs] def make_large_data_frame_target(self, relative_file_path: str = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace(".", "/"), f"{type(self).__name__}.zip") file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_model_target(file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save, load_function=gokart.target.LargeDataFrameProcessor.load, redis_params=redis_params)
[docs] def make_model_target(self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True): """ Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on. :param relative_file_path: A file path to save. :param save_function: A function to save a model. This takes a model object and a file path. :param load_function: A function to load a model. This takes a file path and returns a model object. :param use_unique_id: If this is true, add an unique id to a file base name. """ file_path = os.path.join(self.workspace_directory, relative_file_path) assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.' unique_id = self.make_unique_id() if use_unique_id else None redis_params = make_redis_params(file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, redis_fail_on_collision=self.redis_fail_on_collision) return gokart.target.make_model_target(file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=save_function, load_function=load_function, redis_params=redis_params)
[docs] def load(self, target: Union[None, str, TargetOnKart] = None) -> Any: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): return [_load(t) for t in targets] if isinstance(targets, dict): return {k: _load(t) for k, t in targets.items()} return targets.load() data = _load(self._get_input_targets(target)) if target is None and isinstance(data, dict) and len(data) == 1: return list(data.values())[0] return data
[docs] def load_generator(self, target: Union[None, str, TargetOnKart] = None) -> Any: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): for t in targets: yield from _load(t) elif isinstance(targets, dict): for k, t in targets.items(): yield from {k: _load(t)} else: yield targets.load() return _load(self._get_input_targets(target))
[docs] def load_data_frame(self, target: Union[None, str, TargetOnKart] = None, required_columns: Optional[Set[str]] = None, drop_columns: bool = False) -> pd.DataFrame: def _flatten_recursively(dfs): if isinstance(dfs, list): return pd.concat([_flatten_recursively(df) for df in dfs]) else: return dfs dfs = self.load(target=target) if isinstance(dfs, dict) and len(dfs) == 1: dfs = list(dfs.values())[0] data = _flatten_recursively(dfs) required_columns = required_columns or set() if data.empty and len(data.index) == 0 and len(required_columns - set(data.columns)) > 0: return pd.DataFrame(columns=required_columns) assert required_columns.issubset(set(data.columns)), f'data must have columns {required_columns}, but actually have only {data.columns}.' if drop_columns: data = data[list(required_columns)] return data
[docs] def dump(self, obj, target: Union[None, str, TargetOnKart] = None) -> None: PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace) if self.fail_on_empty_dump and isinstance(obj, pd.DataFrame): assert not obj.empty self._get_output_target(target).dump(obj, lock_at_dump=self._lock_at_dump)
[docs] @staticmethod def get_code(target_class) -> Set[str]: def has_sourcecode(obj): return inspect.ismethod(obj) or inspect.isfunction(obj) or inspect.isframe(obj) or inspect.iscode(obj) return {inspect.getsource(t) for _, t in inspect.getmembers(target_class, has_sourcecode)}
[docs] def get_own_code(self): gokart_codes = self.get_code(TaskOnKart) own_codes = self.get_code(self) return ''.join(sorted(list(own_codes - gokart_codes)))
[docs] def make_unique_id(self): unique_id = self.task_unique_id or self._make_hash_id() if self.cache_unique_id: self.task_unique_id = unique_id return unique_id
def _make_hash_id(self): def _to_str_params(task): if isinstance(task, TaskOnKart): return str(task.make_unique_id()) if task.significant else None if not isinstance(task, luigi.Task): raise ValueError(f"Task.requires method returns {type(task)}. You should return luigi.Task.") return task.to_str_params(only_significant=True) dependencies = [_to_str_params(task) for task in luigi.task.flatten(self.requires())] dependencies = [d for d in dependencies if d is not None] dependencies.append(self.to_str_params(only_significant=True)) dependencies.append(self.__class__.__name__) if self.serialized_task_definition_check: dependencies.append(self.get_own_code()) return hashlib.md5(str(dependencies).encode()).hexdigest() def _get_input_targets(self, target: Union[None, str, TargetOnKart]) -> Union[TargetOnKart, List[TargetOnKart]]: if target is None: return self.input() if isinstance(target, str): return self.input()[target] return target def _get_output_target(self, target: Union[None, str, TargetOnKart]) -> TargetOnKart: if target is None: return self.output() if isinstance(target, str): return self.output()[target] return target
[docs] def get_info(self, only_significant=False): params_str = {} params = dict(self.get_params()) for param_name, param_value in self.param_kwargs.items(): if (not only_significant) or params[param_name].significant: if type(params[param_name]) == gokart.TaskInstanceParameter: params_str[param_name] = type(param_value).__name__ + '-' + param_value.make_unique_id() else: params_str[param_name] = params[param_name].serialize(param_value) return params_str
def _get_task_log_target(self): return self.make_target(f'log/task_log/{type(self).__name__}.pkl')
[docs] def get_task_log(self) -> Dict: target = self._get_task_log_target() if self.task_log: return self.task_log if target.exists(): return self.load(target) return dict()
@luigi.Task.event_handler(luigi.Event.SUCCESS) def _dump_task_log(self): self.task_log['file_path'] = [target.path() for target in luigi.task.flatten(self.output())] if self.should_dump_supplementary_log_files: self.dump(self.task_log, self._get_task_log_target()) def _get_task_params_target(self): return self.make_target(f'log/task_params/{type(self).__name__}.pkl')
[docs] def get_task_params(self) -> Dict: target = self._get_task_log_target() if target.exists(): return self.load(target) return dict()
@luigi.Task.event_handler(luigi.Event.START) def _set_random_seed(self): if self.should_dump_supplementary_log_files: random_seed = self._get_random_seed() seed_methods = self.try_set_seed(self.fix_random_seed_methods, random_seed) self.dump({'seed': random_seed, 'seed_methods': seed_methods}, self._get_random_seeds_target()) def _get_random_seeds_target(self): return self.make_target(f'log/random_seed/{type(self).__name__}.pkl')
[docs] @staticmethod def try_set_seed(methods: List[str], random_seed: int) -> List[str]: success_methods = [] for method_name in methods: try: for i, x in enumerate(method_name.split('.')): if i == 0: m = import_module(x) else: m = getattr(m, x) m(random_seed) success_methods.append(method_name) except ModuleNotFoundError: pass except AttributeError: pass return success_methods
def _get_random_seed(self): if self.fix_random_seed_value and (not self.fix_random_seed_value == self.FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER): return self.fix_random_seed_value return int(self.make_unique_id(), 16) % (2**32 - 1) # maximum numpy.random.seed @luigi.Task.event_handler(luigi.Event.START) def _dump_task_params(self): if self.should_dump_supplementary_log_files: self.dump(self.to_str_params(only_significant=True), self._get_task_params_target()) def _get_processing_time_target(self): return self.make_target(f'log/processing_time/{type(self).__name__}.pkl')
[docs] def get_processing_time(self) -> str: target = self._get_processing_time_target() if target.exists(): return self.load(target) return 'unknown'
@luigi.Task.event_handler(luigi.Event.PROCESSING_TIME) def _dump_processing_time(self, processing_time): if self.should_dump_supplementary_log_files: self.dump(processing_time, self._get_processing_time_target())
[docs] @classmethod def restore(cls, unique_id): params = TaskOnKart().make_target(f'log/task_params/{cls.__name__}_{unique_id}.pkl', use_unique_id=False).load() return cls.from_str_params(params)
@luigi.Task.event_handler(luigi.Event.FAILURE) def _log_unique_id(self, exception): logger.info(f'FAILURE:\n task name={type(self).__name__}\n unique id={self.make_unique_id()}') @luigi.Task.event_handler(luigi.Event.START) def _dump_module_versions(self): if self.should_dump_supplementary_log_files: self.dump(self._get_module_versions(), self._get_module_versions_target()) def _get_module_versions_target(self): return self.make_target(f'log/module_versions/{type(self).__name__}.txt') def _get_module_versions(self) -> str: module_versions = [] for x in set([x.split('.')[0] for x in globals().keys() if isinstance(x, types.ModuleType) and '_' not in x]): module = import_module(x) if '__version__' in dir(module): if type(module.__version__) == str: version = module.__version__.split(" ")[0] else: version = '.'.join([str(v) for v in module.__version__]) module_versions.append(f'{x}=={version}') return '\n'.join(module_versions) def __repr__(self): """ Build a task representation like `MyTask(param1=1.5, param2='5', data_task=DataTask(id=35tyi))` """ params = self.get_params() param_values = self.get_param_values(params, [], self.param_kwargs) # Build up task id repr_parts = [] param_objs = dict(params) for param_name, param_value in param_values: param_obj = param_objs[param_name] if param_obj.significant: repr_parts.append(f'{param_name}={self._make_representation(param_obj, param_value)}') task_str = f'{self.get_task_family()}({", ".join(repr_parts)})' return task_str def _make_representation(self, param_obj: luigi.Parameter, param_value): if isinstance(param_obj, TaskInstanceParameter): return f'{param_value.get_task_family()}({param_value.make_unique_id()})' if isinstance(param_obj, ListTaskInstanceParameter): return f"[{', '.join(f'{v.get_task_family()}({v.make_unique_id()})' for v in param_value)}]" return param_obj.serialize(param_value)