Source code for gokart.target

from __future__ import annotations

import hashlib
import os
import shutil
from abc import abstractmethod
from datetime import datetime
from glob import glob
from logging import getLogger
from typing import Any, cast

import luigi
import numpy as np
import pandas as pd

from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
from gokart.object_storage import ObjectStorage
from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import FlattenableItems
from gokart.zip_client_util import make_zip_client

logger = getLogger(__name__)


[docs] class TargetOnKart(luigi.Target):
[docs] def exists(self) -> bool: return self._exists()
[docs] def load(self) -> Any: return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()
[docs] def dump( self, obj: Any, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: if lock_at_dump: wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)( obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs, ) else: self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs)
[docs] def remove(self) -> None: if self.exists(): wrap_remove_with_lock(self._remove, task_lock_params=self._get_task_lock_params())()
[docs] def last_modification_time(self) -> datetime: return self._last_modification_time()
[docs] def path(self) -> str: return self._path()
@abstractmethod def _exists(self) -> bool: pass @abstractmethod def _get_task_lock_params(self) -> TaskLockParams: pass @abstractmethod def _load(self) -> Any: pass @abstractmethod def _dump( self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: pass @abstractmethod def _remove(self) -> None: pass @abstractmethod def _last_modification_time(self) -> datetime: pass @abstractmethod def _path(self) -> str: pass
[docs] class SingleFileTarget(TargetOnKart): def __init__( self, target: luigi.target.FileSystemTarget, processor: FileProcessor, task_lock_params: TaskLockParams, ) -> None: self._target = target self._processor = processor self._task_lock_params = task_lock_params def _exists(self) -> bool: return cast(bool, self._target.exists()) def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params def _load(self) -> Any: with self._target.open('r') as f: return self._processor.load(f) def _dump( self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: with self._target.open('w') as f: self._processor.dump(obj, f) if self.path().startswith('gs://'): GCSObjectMetadataClient.add_task_state_labels( path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs ) def _remove(self) -> None: self._target.remove() def _last_modification_time(self) -> datetime: return _get_last_modification_time(self._target.path) def _path(self) -> str: return self._target.path
[docs] class ModelTarget(TargetOnKart): def __init__( self, file_path: str, temporary_directory: str, load_function: Any, save_function: Any, task_lock_params: TaskLockParams, ) -> None: self._zip_client = make_zip_client(file_path, temporary_directory) self._temporary_directory = temporary_directory self._save_function = save_function self._load_function = load_function self._task_lock_params = task_lock_params def _exists(self) -> bool: return self._zip_client.exists() def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params def _load(self) -> Any: self._zip_client.unpack_archive() self._load_function = self._load_function or make_target(self._load_function_path()).load() model = self._load_function(self._model_path()) self._remove_temporary_directory() return model def _dump( self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: self._make_temporary_directory() self._save_function(obj, self._model_path()) make_target(self._load_function_path()).dump( self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs ) self._zip_client.make_archive() self._remove_temporary_directory() def _remove(self) -> None: self._zip_client.remove() def _last_modification_time(self) -> datetime: return _get_last_modification_time(self._zip_client.path) def _path(self) -> str: return self._zip_client.path def _model_path(self): return os.path.join(self._temporary_directory, 'model.pkl') def _load_function_path(self): return os.path.join(self._temporary_directory, 'load_function.pkl') def _remove_temporary_directory(self): shutil.rmtree(self._temporary_directory) def _make_temporary_directory(self): os.makedirs(self._temporary_directory, exist_ok=True)
[docs] class LargeDataFrameProcessor: def __init__(self, max_byte: int): self.max_byte = int(max_byte)
[docs] def save(self, df: pd.DataFrame, file_path: str) -> None: dir_path = os.path.dirname(file_path) os.makedirs(dir_path, exist_ok=True) if df.empty: df.to_pickle(os.path.join(dir_path, 'data_0.pkl')) return split_size = df.values.nbytes // self.max_byte + 1 logger.info(f'saving a large pdDataFrame with split_size={split_size}') for i, idx in list(enumerate(np.array_split(range(df.shape[0]), split_size))): df.iloc[idx[0] : idx[-1] + 1].to_pickle(os.path.join(dir_path, f'data_{i}.pkl'))
[docs] @staticmethod def load(file_path: str) -> pd.DataFrame: dir_path = os.path.dirname(file_path) return cast(pd.DataFrame, pd.concat([pd.read_pickle(fp) for fp in glob(os.path.join(dir_path, 'data_*.pkl'))]))
def _make_file_system_target(file_path: str, processor: FileProcessor | None = None, store_index_in_feather: bool = True) -> luigi.target.FileSystemTarget: processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) if ObjectStorage.if_object_storage_path(file_path): return ObjectStorage.get_object_storage_target(file_path, processor.format()) return luigi.LocalTarget(file_path, format=processor.format()) def _make_file_path(original_path: str, unique_id: str | None = None) -> str: if unique_id is not None: [base, extension] = os.path.splitext(original_path) return base + '_' + unique_id + extension return original_path def _get_last_modification_time(path: str) -> datetime: if ObjectStorage.if_object_storage_path(path): if ObjectStorage.exists(path): return ObjectStorage.get_timestamp(path) raise FileNotFoundError(f'No such file or directory: {path}') return datetime.fromtimestamp(os.path.getmtime(path))
[docs] def make_target( file_path: str, unique_id: str | None = None, processor: FileProcessor | None = None, task_lock_params: TaskLockParams | None = None, store_index_in_feather: bool = True, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather) return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params)
[docs] def make_model_target( file_path: str, temporary_directory: str, save_function: Any, load_function: Any, unique_id: str | None = None, task_lock_params: TaskLockParams | None = None, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest()) return ModelTarget( file_path=file_path, temporary_directory=temporary_directory, save_function=save_function, load_function=load_function, task_lock_params=_task_lock_params, )