Source code for gokart.task

from __future__ import annotations

import functools
import hashlib
import inspect
import os
import random
import types
from collections.abc import Callable, Generator, Iterable
from importlib import import_module
from logging import getLogger
from typing import Any, Generic, TypeVar, cast, overload

import luigi
import pandas as pd
from luigi.parameter import ParameterVisibility

import gokart
import gokart.target
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper
from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items

logger = getLogger(__name__)


T = TypeVar('T')
K = TypeVar('K')


# NOTE: inherited from AssertionError for backward compatibility (Formerly, Gokart raises that exception when a task dumps an empty DataFrame).
[docs] class EmptyDumpError(AssertionError): """Raised when the task attempts to dump an empty DataFrame even though it is prohibited (``fail_on_empty_dump`` is set to True)"""
[docs] class TaskOnKart(luigi.Task, Generic[T]): """ This is a wrapper class of luigi.Task. The key methods of a TaskOnKart are: * :py:meth:`make_target` - this makes output target with a relative file path. * :py:meth:`make_model_target` - this makes output target for models which generate multiple files to save. * :py:meth:`load` - this loads input files of this task. * :py:meth:`dump` - this save a object as output of this task. """ workspace_directory: luigi.Parameter[str] = luigi.Parameter( default='./resources/', description='A directory to set outputs on. Please use a path starts with s3:// when you use s3.', significant=False ) local_temporary_directory: luigi.Parameter[str] = luigi.Parameter( default='./resources/tmp/', description='A directory to save temporary files.', significant=False ) rerun: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, this task will run even if all output files exist.', significant=False ) strict_check: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, this task will not run only if all input and output files exist.', significant=False ) modification_time_check: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, this task will not run only if all input and output files exist,' ' and all input files are modified before output file are modified.', significant=False, ) serialized_task_definition_check: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, even if all outputs are present,this task will be executed if any changes have been made to the code.', significant=False, ) delete_unnecessary_output_files: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, delete unnecessary output files.', significant=False ) significant: luigi.BoolParameter = luigi.BoolParameter( default=True, description='If this is false, this task is not treated as a part of dependent tasks for the unique id.', significant=False ) fix_random_seed_methods: luigi.Parameter[tuple[str, ...]] = luigi.ListParameter( default=('random.seed', 'numpy.random.seed'), description='Fix random seed method list.', significant=False ) FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER = -42497368 fix_random_seed_value: luigi.Parameter[int] = luigi.IntParameter( default=FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER, description='Fix random seed method value.', significant=False ) # FIXME: should fix with OptionalIntParameter after newer luigi (https://github.com/spotify/luigi/pull/3079) will be released redis_host: luigi.Parameter[str | None] = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False) redis_port: luigi.OptionalIntParameter = luigi.OptionalIntParameter( default=None, description='Task lock check is deactivated, when None.', significant=False, ) redis_timeout: luigi.IntParameter = luigi.IntParameter( default=180, description='Redis lock will be released after `redis_timeout` seconds', significant=False ) fail_on_empty_dump: luigi.Parameter[bool] = ExplicitBoolParameter(default=False, description='Fail when task dumps empty DF', significant=False) store_index_in_feather: luigi.Parameter[bool] = ExplicitBoolParameter( default=True, description='Wether to store index when using feather as a output object.', significant=False ) cache_unique_id: luigi.Parameter[bool] = ExplicitBoolParameter(default=True, description='Cache unique id during runtime', significant=False) should_dump_supplementary_log_files: luigi.Parameter[bool] = ExplicitBoolParameter( default=True, description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \ Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.', significant=False, ) complete_check_at_run: luigi.Parameter[bool] = ExplicitBoolParameter( default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False ) should_lock_run: luigi.Parameter[bool] = ExplicitBoolParameter( default=False, significant=False, description='Whether to use redis lock or not at task run.' ) @property def priority(self): return random.Random().random() # seed is fixed, so we need to use random.Random().random() instead f random.random() def __init__(self, *args, **kwargs): self._add_configuration(kwargs, 'TaskOnKart') # 'This parameter is dumped into "workspace_directory/log/task_log/" when this task finishes with success.' self.task_log = dict() self.task_unique_id = None super().__init__(*args, **kwargs) self._rerun_state = self.rerun self._lock_at_dump = True # Cache to_str_params to avoid slow task creation in a deep task tree. # For example, gokart.build(RecursiveTask(dep=RecursiveTask(dep=RecursiveTask(dep=HelloWorldTask())))) results in O(n^2) calls to to_str_params. # However, @lru_cache cannot be used as a decorator because luigi.Task employs metaclass tricks. self.to_str_params = functools.lru_cache(maxsize=None)(self.to_str_params) # type: ignore[method-assign] if self.complete_check_at_run: self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore if self.should_lock_run: self._lock_at_dump = False assert self.redis_host is not None, 'redis_host must be set when should_lock_run is True.' assert self.redis_port is not None, 'redis_port must be set when should_lock_run is True.' task_lock_params = make_task_lock_params_for_run(task_self=self) self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore
[docs] def input(self) -> FlattenableItems[TargetOnKart]: return cast(FlattenableItems[TargetOnKart], super().input())
[docs] def output(self) -> FlattenableItems[TargetOnKart]: return self.make_target()
[docs] def requires(self) -> FlattenableItems[TaskOnKart[Any]]: tasks = self.make_task_instance_dictionary() if tasks: return cast(FlattenableItems[TaskOnKart[Any]], tasks) return [] # when tasks is empty dict, then this returns empty list.
[docs] def make_task_instance_dictionary(self) -> dict[str, TaskOnKart[Any]]: return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)}
[docs] @staticmethod def is_task_on_kart(value): return isinstance(value, TaskOnKart) or (isinstance(value, tuple) and bool(value) and all([isinstance(v, TaskOnKart) for v in value]))
@classmethod def _add_configuration(cls, kwargs, section): config = luigi.configuration.get_config() class_variables = dict(TaskOnKart.__dict__) class_variables.update(dict(cls.__dict__)) if section not in config: return for key, value in dict(config[section]).items(): if key not in kwargs and key in class_variables: kwargs[key] = class_variables[key].parse(value)
[docs] def complete(self) -> bool: if self._rerun_state: for target in flatten(self.output()): target.remove() self._rerun_state = False return False is_completed = all([t.exists() for t in flatten(self.output())]) if self.strict_check or self.modification_time_check: requirements = flatten(self.requires()) inputs = flatten(self.input()) is_completed = is_completed and all([task.complete() for task in requirements]) and all([i.exists() for i in inputs]) if not self.modification_time_check or not is_completed or not self.input(): return is_completed return self._check_modification_time()
def _check_modification_time(self) -> bool: common_path = set(t.path() for t in flatten(self.input())) & set(t.path() for t in flatten(self.output())) input_tasks = [t for t in flatten(self.input()) if t.path() not in common_path] output_tasks = [t for t in flatten(self.output()) if t.path() not in common_path] input_modification_time = max([target.last_modification_time() for target in input_tasks]) if input_tasks else None output_modification_time = min([target.last_modification_time() for target in output_tasks]) if output_tasks else None if input_modification_time is None or output_modification_time is None: return True # "=" must be required in the following statements, because some tasks use input targets as output targets. return input_modification_time <= output_modification_time
[docs] def clone(self, cls=None, **kwargs): _SPECIAL_PARAMS = {'rerun', 'strict_check', 'modification_time_check'} if cls is None: cls = self.__class__ new_k = {} for param_name, _ in cls.get_params(): if param_name in kwargs: new_k[param_name] = kwargs[param_name] elif hasattr(self, param_name) and (param_name not in _SPECIAL_PARAMS): new_k[param_name] = getattr(self, param_name) return cls(**new_k)
[docs] def make_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, processor: FileProcessor | None = None) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None # Auto-select processor based on type parameter if not provided if processor is None and relative_file_path is not None: processor = self._create_processor_for_dataframe_type(file_path) task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, raise_task_lock_exception_on_collision=False, ) return gokart.target.make_target( file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather )
def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor: df_type = get_dataframe_type_from_task(self) return make_file_processor(file_path, dataframe_type=df_type, store_index_in_feather=self.store_index_in_feather)
[docs] def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte: int = int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') ) file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, raise_task_lock_exception_on_collision=False, ) return gokart.target.make_model_target( file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save, load_function=gokart.target.LargeDataFrameProcessor.load, task_lock_params=task_lock_params, )
[docs] def make_model_target( self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True ) -> TargetOnKart: """ Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on. :param relative_file_path: A file path to save. :param save_function: A function to save a model. This takes a model object and a file path. :param load_function: A function to load a model. This takes a file path and returns a model object. :param use_unique_id: If this is true, add an unique id to a file base name. """ file_path = os.path.join(self.workspace_directory, relative_file_path) assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.' unique_id = self.make_unique_id() if use_unique_id else None task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, raise_task_lock_exception_on_collision=False, ) return gokart.target.make_model_target( file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=save_function, load_function=load_function, task_lock_params=task_lock_params, )
@overload def load(self, target: None | str | TargetOnKart = None) -> Any: ... @overload def load(self, target: TaskOnKart[K]) -> K: ... @overload def load(self, target: list[TaskOnKart[K]]) -> list[K]: ...
[docs] def load(self, target: None | str | TargetOnKart | TaskOnKart[K] | list[TaskOnKart[K]] = None) -> Any: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): return [_load(t) for t in targets] if isinstance(targets, dict): return {k: _load(t) for k, t in targets.items()} return targets.load() return _load(self._get_input_targets(target))
@overload def load_generator(self, target: None | str | TargetOnKart = None) -> Generator[Any, None, None]: ... @overload def load_generator(self, target: list[TaskOnKart[K]]) -> Generator[K, None, None]: ...
[docs] def load_generator(self, target: None | str | TargetOnKart | list[TaskOnKart[K]] = None) -> Generator[Any, None, None]: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): for t in targets: yield from _load(t) elif isinstance(targets, dict): for k, t in targets.items(): yield from {k: _load(t)} else: yield targets.load() return cast(Generator[Any, None, None], _load(self._get_input_targets(target)))
@overload def dump(self, obj: T, target: None = None, custom_labels: dict[Any, Any] | None = None) -> None: ... @overload def dump(self, obj: Any, target: str | TargetOnKart, custom_labels: dict[Any, Any] | None = None) -> None: ...
[docs] def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels: dict[str, Any] | None = None) -> None: PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace) if self.fail_on_empty_dump: if isinstance(obj, pd.DataFrame) and obj.empty: raise EmptyDumpError() required_task_outputs = cast( FlattenableItems[RequiredTaskOutput], map_flattenable_items( lambda task: map_flattenable_items( lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output() ), self.requires(), ), ) self._get_output_target(target).dump( obj, lock_at_dump=self._lock_at_dump, task_params=super().to_str_params(only_significant=True, only_public=True), custom_labels=custom_labels, required_task_outputs=required_task_outputs, )
[docs] @staticmethod def get_code(target_class: Any) -> set[str]: def has_sourcecode(obj): return inspect.ismethod(obj) or inspect.isfunction(obj) or inspect.isframe(obj) or inspect.iscode(obj) return {inspect.getsource(t) for _, t in inspect.getmembers(target_class, has_sourcecode)}
[docs] def get_own_code(self): gokart_codes = self.get_code(TaskOnKart) own_codes = self.get_code(self) return ''.join(sorted(list(own_codes - gokart_codes)))
[docs] def make_unique_id(self) -> str: unique_id = self.task_unique_id or self._make_hash_id() if self.cache_unique_id: self.task_unique_id = unique_id return unique_id
def _make_hash_id(self) -> str: def _to_str_params(task): if isinstance(task, TaskOnKart): return str(task.make_unique_id()) if task.significant else None if not isinstance(task, luigi.Task): raise ValueError(f'Task.requires method returns {type(task)}. You should return luigi.Task.') return task.to_str_params(only_significant=True) dependencies: list[Any] = [_to_str_params(task) for task in flatten(self.requires())] dependencies = [d for d in dependencies if d is not None] dependencies.append(self.to_str_params(only_significant=True)) dependencies.append(self.__class__.__name__) if self.serialized_task_definition_check: dependencies.append(self.get_own_code()) return hashlib.md5(str(dependencies).encode()).hexdigest() def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart[Any] | list[TaskOnKart[Any]]) -> FlattenableItems[TargetOnKart]: if target is None: return self.input() if isinstance(target, str): input = self.input() assert isinstance(input, dict), f'input must be dict[str, TargetOnKart], but {type(input)} is passed.' result: FlattenableItems[TargetOnKart] = input[target] return result if isinstance(target, Iterable): return [self._get_input_targets(t) for t in target] if isinstance(target, TaskOnKart): requires_unique_ids = [task.make_unique_id() for task in flatten(self.requires())] assert target.make_unique_id() in requires_unique_ids, f'{target} should be in requires method' return target.output() return target def _get_output_target(self, target: None | str | TargetOnKart) -> TargetOnKart: if target is None: output = self.output() assert isinstance(output, TargetOnKart), f'output must be TargetOnKart, but {type(output)} is passed.' return output if isinstance(target, str): output = self.output() assert isinstance(output, dict), f'output must be dict[str, TargetOnKart], but {type(output)} is passed.' result = output[target] assert isinstance(result, TargetOnKart), f'output must be dict[str, TargetOnKart], but {type(output)} is passed.' return result return target
[docs] def get_info(self, only_significant=False): params_str = {} params = dict(self.get_params()) for param_name, param_value in self.param_kwargs.items(): if (not only_significant) or params[param_name].significant: if isinstance(params[param_name], gokart.TaskInstanceParameter): params_str[param_name] = type(param_value).__name__ + '-' + cast(TaskOnKart[Any], param_value).make_unique_id() else: params_str[param_name] = params[param_name].serialize(param_value) return params_str
def _get_task_log_target(self): return self.make_target(f'log/task_log/{type(self).__name__}.pkl')
[docs] def get_task_log(self) -> dict[str, Any]: target = self._get_task_log_target() if self.task_log: return self.task_log if target.exists(): return cast(dict[Any, Any], self.load(target)) return dict()
@luigi.Task.event_handler(luigi.Event.SUCCESS) def _dump_task_log(self): self.task_log['file_path'] = [target.path() for target in flatten(self.output())] if self.should_dump_supplementary_log_files: self.dump(self.task_log, self._get_task_log_target()) def _get_task_params_target(self): return self.make_target(f'log/task_params/{type(self).__name__}.pkl')
[docs] def get_task_params(self) -> dict[str, Any]: target = self._get_task_log_target() if target.exists(): return cast(dict[Any, Any], self.load(target)) return dict()
@luigi.Task.event_handler(luigi.Event.START) def _set_random_seed(self): if self.should_dump_supplementary_log_files: random_seed = self._get_random_seed() seed_methods = self.try_set_seed(list(self.fix_random_seed_methods), random_seed) self.dump({'seed': random_seed, 'seed_methods': seed_methods}, self._get_random_seeds_target()) def _get_random_seeds_target(self): return self.make_target(f'log/random_seed/{type(self).__name__}.pkl')
[docs] @staticmethod def try_set_seed(methods: list[str], random_seed: int) -> list[str]: success_methods: list[str] = [] for method_name in methods: try: parts = method_name.split('.') m: Any = import_module(parts[0]) for x in parts[1:]: m = getattr(m, x) m(random_seed) success_methods.append(method_name) except ModuleNotFoundError: pass except AttributeError: pass return success_methods
def _get_random_seed(self): if self.fix_random_seed_value and (not self.fix_random_seed_value == self.FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER): return self.fix_random_seed_value return int(self.make_unique_id(), 16) % (2**32 - 1) # maximum numpy.random.seed @luigi.Task.event_handler(luigi.Event.START) def _dump_task_params(self): if self.should_dump_supplementary_log_files: self.dump(self.to_str_params(only_significant=True), self._get_task_params_target()) def _get_processing_time_target(self): return self.make_target(f'log/processing_time/{type(self).__name__}.pkl')
[docs] def get_processing_time(self) -> str: target = self._get_processing_time_target() if target.exists(): return cast(str, self.load(target)) return 'unknown'
@luigi.Task.event_handler(luigi.Event.PROCESSING_TIME) def _dump_processing_time(self, processing_time): if self.should_dump_supplementary_log_files: self.dump(processing_time, self._get_processing_time_target())
[docs] @classmethod def restore(cls, unique_id): params = TaskOnKart().make_target(f'log/task_params/{cls.__name__}_{unique_id}.pkl', use_unique_id=False).load() return cls.from_str_params(params)
@luigi.Task.event_handler(luigi.Event.FAILURE) def _log_unique_id(self, exception): logger.info(f'FAILURE:\n task name={type(self).__name__}\n unique id={self.make_unique_id()}') @luigi.Task.event_handler(luigi.Event.START) def _dump_module_versions(self): if self.should_dump_supplementary_log_files: self.dump(self._get_module_versions(), self._get_module_versions_target()) def _get_module_versions_target(self): return self.make_target(f'log/module_versions/{type(self).__name__}.txt') def _get_module_versions(self) -> str: module_versions = [] for x in set([x.split('.')[0] for x in globals().keys() if isinstance(x, types.ModuleType) and '_' not in x]): module = import_module(x) if '__version__' in dir(module): if isinstance(module.__version__, str): version = module.__version__.split(' ')[0] else: version = '.'.join([str(v) for v in module.__version__]) module_versions.append(f'{x}=={version}') return '\n'.join(module_versions) def __repr__(self): """ Build a task representation like `MyTask[aca2f28555dadd0f1e3dee3d4b973651](param1=1.5, param2='5', data_task=DataTask(c1f5d06aa580c5761c55bd83b18b0b4e))` """ return self._get_task_string() def __str__(self): """ Build a human-readable task representation like `MyTask[aca2f28555dadd0f1e3dee3d4b973651](param1=1.5, param2='5', data_task=DataTask(c1f5d06aa580c5761c55bd83b18b0b4e))` This includes only public parameters """ return self._get_task_string(only_public=True) def _get_task_string(self, only_public=False): """ Convert a task representation like `MyTask(param1=1.5, param2='5', data_task=DataTask(id=35tyi))` """ params = self.get_params() param_values = self.get_param_values(params, [], self.param_kwargs) # Build up task id repr_parts = [] param_objs = dict(params) for param_name, param_value in param_values: param_obj = param_objs[param_name] if param_obj.significant and ((not only_public) or param_obj.visibility == ParameterVisibility.PUBLIC): repr_parts.append(f'{param_name}={self._make_representation(param_obj, param_value)}') task_str = f'{self.get_task_family()}[{self.make_unique_id()}]({", ".join(repr_parts)})' return task_str def _make_representation(self, param_obj: luigi.Parameter, param_value: Any) -> str: if isinstance(param_obj, TaskInstanceParameter): return f'{param_value.get_task_family()}({param_value.make_unique_id()})' if isinstance(param_obj, ListTaskInstanceParameter): return f'[{", ".join(f"{v.get_task_family()}({v.make_unique_id()})" for v in param_value)}]' return str(param_obj.serialize(param_value))