Source code for gokart.parameter

from __future__ import annotations

import bz2
import datetime
import json
import sys
from collections.abc import Callable, Iterable
from logging import getLogger
from typing import Any, Generic, Protocol, TypedDict, TypeVar

if sys.version_info >= (3, 11):
    from typing import Unpack
else:
    from typing_extensions import Unpack
from warnings import warn

import luigi
from luigi import task_register
from luigi.parameter import ConfigPath, ParameterVisibility, _no_value, _NoValueType

import gokart


[docs] class ParameterKwargs(TypedDict, total=False): significant: bool description: str | None config_path: ConfigPath | None positional: bool always_in_help: bool batch_method: Callable[[Iterable[Any]], Any] | None visibility: ParameterVisibility
logger = getLogger(__name__) TASK_ON_KART_TYPE = TypeVar('TASK_ON_KART_TYPE', bound='gokart.TaskOnKart[Any]')
[docs] class TaskInstanceParameter(luigi.Parameter[TASK_ON_KART_TYPE], Generic[TASK_ON_KART_TYPE]): def __init__( self, expected_type: type[TASK_ON_KART_TYPE] | None = None, default: TASK_ON_KART_TYPE | _NoValueType = _no_value, **kwargs: Unpack[ParameterKwargs], ): if expected_type is None: self.expected_type: type = gokart.TaskOnKart elif isinstance(expected_type, type): self.expected_type = expected_type else: raise TypeError(f'expected_type must be a type, not {type(expected_type)}') super().__init__(default=default, **kwargs) @staticmethod def _recursive(param_dict): params = param_dict['params'] task_cls = task_register.Register.get_task_cls(param_dict['type']) for key, value in task_cls.get_params(): if key in params: params[key] = value.parse(params[key]) return task_cls(**params) @staticmethod def _recursive_decompress(s): s = dict(luigi.DictParameter().parse(s)) if 'params' in s: s['params'] = TaskInstanceParameter._recursive_decompress(bz2.decompress(bytes.fromhex(s['params'])).decode()) return s
[docs] def parse(self, x): if isinstance(x, str): x = self._recursive_decompress(x) return self._recursive(x)
[docs] def serialize(self, x): params = bz2.compress(json.dumps(x.to_str_params(only_significant=True)).encode()).hex() values = dict(type=x.get_task_family(), params=params) return luigi.DictParameter().serialize(values)
def _warn_on_wrong_param_type(self, param_name, param_value): if not isinstance(param_value, self.expected_type): raise TypeError(f'{param_value} is not an instance of {self.expected_type}')
class _TaskInstanceEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, luigi.Task): return TaskInstanceParameter().serialize(o) # Let the base class default method raise the TypeError return json.JSONEncoder.default(self, o)
[docs] class ListTaskInstanceParameter(luigi.Parameter[list[TASK_ON_KART_TYPE]], Generic[TASK_ON_KART_TYPE]): def __init__( self, expected_elements_type: type[TASK_ON_KART_TYPE] | None = None, default: list[TASK_ON_KART_TYPE] | _NoValueType = _no_value, **kwargs: Unpack[ParameterKwargs], ): if expected_elements_type is None: self.expected_elements_type: type = gokart.TaskOnKart elif isinstance(expected_elements_type, type): self.expected_elements_type = expected_elements_type else: raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}') super().__init__(default=default, **kwargs)
[docs] def parse(self, x): return [TaskInstanceParameter().parse(item) for item in list(json.loads(x))]
[docs] def serialize(self, x): return json.dumps(x, cls=_TaskInstanceEncoder)
def _warn_on_wrong_param_type(self, param_name, param_value): for v in param_value: if not isinstance(v, self.expected_elements_type): raise TypeError(f'{v} is not an instance of {self.expected_elements_type}')
[docs] class ExplicitBoolParameter(luigi.BoolParameter): def __init__(self, *args, **kwargs): super(luigi.BoolParameter, self).__init__(*args, **kwargs) def _parser_kwargs(self, *args, **kwargs): # type: ignore return luigi.Parameter._parser_kwargs(*args, *kwargs)
T = TypeVar('T')
[docs] class Serializable(Protocol):
[docs] def gokart_serialize(self) -> str: """Implement this method to serialize the object as an parameter You can omit some fields from results of serialization if you want to ignore changes of them """ ...
[docs] @classmethod def gokart_deserialize(cls: type[T], s: str) -> T: """Implement this method to deserialize the object from a string""" ...
S = TypeVar('S', bound=Serializable)
[docs] class SerializableParameter(luigi.Parameter[S], Generic[S]): def __init__(self, object_type: type[S], *args: Any, **kwargs: Any) -> None: self._object_type = object_type super().__init__(*args, **kwargs)
[docs] def parse(self, x: str) -> S: return self._object_type.gokart_deserialize(x)
[docs] def serialize(self, x: S) -> str: return x.gokart_serialize()
[docs] class ZonedDateSecondParameter(luigi.Parameter[datetime.datetime]): """ ZonedDateSecondParameter supports a datetime.datetime object with timezone information. A ZonedDateSecondParameter is a `ISO 8601 <http://en.wikipedia.org/wiki/ISO_8601>`_ formatted date, time specified to the second and timezone. For example, ``2013-07-10T19:07:38+09:00`` specifies July 10, 2013 at 19:07:38 +09:00. The separator `:` can be omitted for Python3.11 and later. """ def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def parse(self, x): # special character 'Z' is replaced with '+00:00' # because Python 3.11 and later support fromisoformat with Z at the end of the string. if x.endswith('Z'): x = x[:-1] + '+00:00' dt = datetime.datetime.fromisoformat(x) if dt.tzinfo is None: warn('The input does not have timezone information. Please consider using luigi.DateSecondParameter instead.', stacklevel=1) return dt
[docs] def serialize(self, x): return x.isoformat()
[docs] def normalize(self, x): # override _DatetimeParameterBase.normalize to avoid do nothing to normalize except removing microsecond. # microsecond is removed because the number of digits of microsecond is not fixed. # See also luigi's implementation https://github.com/spotify/luigi/blob/v3.6.0/luigi/parameter.py#L612 return x.replace(microsecond=0)