Source code for gokart.file_processor.base

from __future__ import annotations

import xml.etree.ElementTree as ET
from abc import abstractmethod
from io import BytesIO
from logging import getLogger
from typing import Any, Literal, cast

import dill
import luigi
import luigi.format
import numpy as np

from gokart.utils import load_dill_with_pandas_backward_compatibility

logger = getLogger(__name__)

# Type alias for DataFrame library return type
DataFrameType = Literal['pandas', 'polars', 'polars-lazy']


[docs] class FileProcessor:
[docs] @abstractmethod def format(self) -> Any: ...
[docs] @abstractmethod def load(self, file: Any) -> Any: ...
[docs] @abstractmethod def dump(self, obj: Any, file: Any) -> None: ...
[docs] class BinaryFileProcessor(FileProcessor): """ Pass bytes to this processor ``` figure_binary = io.BytesIO() plt.savefig(figure_binary) figure_binary.seek(0) BinaryFileProcessor().dump(figure_binary.read()) ``` """
[docs] def format(self): return luigi.format.Nop
[docs] def load(self, file): return file.read()
[docs] def dump(self, obj, file): file.write(obj)
class _ChunkedLargeFileReader: def __init__(self, file: Any) -> None: self._file = file def __getattr__(self, item): return getattr(self._file, item) def read(self, n: int) -> bytes: if n >= (1 << 31): logger.info(f'reading a large file with total_bytes={n}.') buffer = bytearray(n) idx = 0 while idx < n: batch_size = min(n - idx, (1 << 31) - 1) logger.info(f'reading bytes [{idx}, {idx + batch_size})...') buffer[idx : idx + batch_size] = self._file.read(batch_size) idx += batch_size logger.info('done.') return bytes(buffer) return cast(bytes, self._file.read(n)) def readline(self) -> bytes: return cast(bytes, self._file.readline()) def seek(self, offset: int) -> None: self._file.seek(offset) def seekable(self) -> bool: return cast(bool, self._file.seekable())
[docs] class PickleFileProcessor(FileProcessor):
[docs] def format(self): return luigi.format.Nop
[docs] def load(self, file): if not file.seekable(): # load_dill_with_pandas_backward_compatibility() requires file with seek() and readlines() implemented. # Therefore, we need to wrap with BytesIO which makes file seekable and readlinesable. # For example, ReadableS3File is not a seekable file. return load_dill_with_pandas_backward_compatibility(BytesIO(file.read())) return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file))
[docs] def dump(self, obj, file): self._write(dill.dumps(obj, protocol=4), file)
@staticmethod def _write(buffer, file): n = len(buffer) idx = 0 while idx < n: logger.info(f'writing a file with total_bytes={n}...') batch_size = min(n - idx, (1 << 31) - 1) logger.info(f'writing bytes [{idx}, {idx + batch_size})') file.write(buffer[idx : idx + batch_size]) idx += batch_size logger.info('done')
[docs] class TextFileProcessor(FileProcessor):
[docs] def format(self): return None
[docs] def load(self, file): return [s.rstrip() for s in file.readlines()]
[docs] def dump(self, obj, file): if isinstance(obj, list): for x in obj: file.write(str(x) + '\n') else: file.write(str(obj))
[docs] class GzipFileProcessor(FileProcessor):
[docs] def format(self): return luigi.format.Gzip
[docs] def load(self, file): return [s.rstrip().decode() for s in file.readlines()]
[docs] def dump(self, obj, file): if isinstance(obj, list): for x in obj: file.write((str(x) + '\n').encode()) else: file.write(str(obj).encode())
[docs] class XmlFileProcessor(FileProcessor):
[docs] def format(self): return None
[docs] def load(self, file): try: return ET.parse(file) except ET.ParseError: return ET.ElementTree()
[docs] def dump(self, obj, file): assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.' obj.write(file)
[docs] class NpzFileProcessor(FileProcessor):
[docs] def format(self): return luigi.format.Nop
[docs] def load(self, file): return np.load(file)['data']
[docs] def dump(self, obj, file): assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.' np.savez_compressed(file, data=obj)