from __future__ import annotations
import enum
import io
import logging
from dataclasses import dataclass
from functools import partial
from logging import getLogger
from typing import Any, Literal, Protocol, TypeVar, cast, overload
import backoff
import luigi
from luigi import LuigiStatusCode, rpc, scheduler, task_register
from luigi.execution_summary import LuigiRunResult
import gokart
import gokart.tree.task_info
from gokart import worker
from gokart.conflict_prevention_lock.task_lock import TaskLockException
from gokart.target import TargetOnKart
from gokart.task import TaskOnKart
T = TypeVar('T')
logger: logging.Logger = logging.getLogger(__name__)
class LoggerConfig:
def __init__(self, level: int):
self.logger = getLogger(__name__)
self.default_level = self.logger.level
self.level = level
def __enter__(self):
logging.disable(self.level - 10) # subtract 10 to disable below self.level
self.logger.setLevel(self.level)
return self
def __exit__(self, exception_type, exception_value, traceback):
logging.disable(self.default_level - 10) # subtract 10 to disable below self.level
self.logger.setLevel(self.default_level)
class GokartBuildError(Exception):
"""Raised when ``gokart.build`` failed. This exception contains raised exceptions in the task execution."""
def __init__(self, message: str, raised_exceptions: dict[str, list[Exception]]) -> None:
super().__init__(message)
self.raised_exceptions = raised_exceptions
class HasLockedTaskException(Exception):
"""Raised when the task failed to acquire the lock in the task execution."""
class TaskLockExceptionRaisedFlag:
def __init__(self):
self.flag: bool = False
class WorkerProtocol(Protocol):
"""Protocol for Worker.
This protocol is determined by luigi.worker.Worker.
"""
def add(self, task: TaskOnKart[Any]) -> bool: ...
def run(self) -> bool: ...
def __enter__(self) -> WorkerProtocol: ...
def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: ...
[docs]
class WorkerSchedulerFactory:
[docs]
def create_local_scheduler(self) -> scheduler.Scheduler:
return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)
[docs]
def create_remote_scheduler(self, url: str) -> rpc.RemoteScheduler:
return rpc.RemoteScheduler(url)
[docs]
def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant: bool = False) -> WorkerProtocol:
return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant)
def _get_output(task: TaskOnKart[T]) -> T:
output = task.output()
# FIXME: currently, nested output is not supported
if isinstance(output, list) or isinstance(output, tuple):
return cast(T, [t.load() for t in output if isinstance(t, TargetOnKart)])
if isinstance(output, dict):
return cast(T, {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)})
if isinstance(output, TargetOnKart):
return cast(T, output.load())
raise ValueError(f'output type is not supported: {type(output)}')
def _reset_register(keep={'gokart', 'luigi'}):
"""reset task_register.Register._reg everytime gokart.build called to avoid TaskClassAmbigiousException"""
task_register.Register._reg = [
x
for x in task_register.Register._reg
if (
(x.__module__.split('.')[0] in keep) # keep luigi and gokart
or (issubclass(x, gokart.PandasTypeConfig))
) # PandasTypeConfig should be kept
]
class TaskDumpMode(enum.Enum):
TREE = 'tree'
TABLE = 'table'
NONE = 'none'
class TaskDumpOutputType(enum.Enum):
PRINT = 'print'
DUMP = 'dump'
NONE = 'none'
@dataclass
class TaskDumpConfig:
mode: TaskDumpMode = TaskDumpMode.NONE
output_type: TaskDumpOutputType = TaskDumpOutputType.NONE
def process_task_info(task: TaskOnKart[Any], task_dump_config: TaskDumpConfig = TaskDumpConfig()) -> None:
match task_dump_config:
case TaskDumpConfig(mode=TaskDumpMode.NONE, output_type=TaskDumpOutputType.NONE):
pass
case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.PRINT):
tree = gokart.make_tree_info(task)
logger.info(tree)
case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.PRINT):
table = gokart.tree.task_info.make_task_info_as_table(task)
output = io.StringIO()
table.to_csv(output, index=False, sep='\t')
output.seek(0)
logger.info(output.read())
case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.DUMP):
tree = gokart.make_tree_info(task)
gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.txt').dump(tree)
case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.DUMP):
table = gokart.tree.task_info.make_task_info_as_table(task)
gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.pkl').dump(table)
case _:
raise ValueError(f'Unsupported TaskDumpConfig: {task_dump_config}')
@overload
def build(
task: TaskOnKart[T],
return_value: Literal[True] = True,
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
**env_params: Any,
) -> T: ...
@overload
def build(
task: TaskOnKart[T],
return_value: Literal[False],
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
**env_params: Any,
) -> None: ...
[docs]
def build(
task: TaskOnKart[T],
return_value: bool = True,
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
task_dump_config: TaskDumpConfig = TaskDumpConfig(),
**env_params: Any,
) -> T | None:
"""
Run gokart task for local interpreter.
Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build)
"""
if reset_register:
_reset_register()
with LoggerConfig(level=log_level):
log_handler_before_run = logging.StreamHandler()
logger.addHandler(log_handler_before_run)
process_task_info(task, task_dump_config)
logger.removeHandler(log_handler_before_run)
log_handler_before_run.close()
task_lock_exception_raised = TaskLockExceptionRaisedFlag()
raised_exceptions: dict[str, list[Exception]] = dict()
@TaskOnKart.event_handler(luigi.Event.FAILURE)
def when_failure(task, exception):
if isinstance(exception, TaskLockException):
task_lock_exception_raised.flag = True
else:
raised_exceptions.setdefault(task.make_unique_id(), []).append(exception)
@backoff.on_exception(
partial(backoff.expo, max_value=task_lock_exception_max_wait_seconds), HasLockedTaskException, max_tries=task_lock_exception_max_tries
)
def _build_task():
task_lock_exception_raised.flag = False
result = luigi.build(
[task],
worker_scheduler_factory=WorkerSchedulerFactory(),
local_scheduler=True,
detailed_summary=True,
log_level=logging.getLevelName(log_level),
**env_params,
)
if task_lock_exception_raised.flag:
raise HasLockedTaskException()
result = cast(LuigiRunResult, result)
if result.status in (LuigiStatusCode.FAILED, LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED, LuigiStatusCode.SCHEDULING_FAILED):
raise GokartBuildError(result.summary_text, raised_exceptions=raised_exceptions)
return _get_output(task) if return_value else None
return cast(T | None, _build_task())