import os
import pickle
import xml.etree.ElementTree as ET
from abc import abstractmethod
from logging import getLogger
import luigi
import luigi.contrib.s3
import luigi.format
import numpy as np
import pandas as pd
import pandas.errors
from gokart.object_storage import ObjectStorage
logger = getLogger(__name__)
[docs]class FileProcessor(object):
[docs] @abstractmethod
def load(self, file):
pass
[docs] @abstractmethod
def dump(self, obj, file):
pass
[docs]class BinaryFileProcessor(FileProcessor):
"""
Pass bytes to this processor
```
figure_binary = io.BytesIO()
plt.savefig(figure_binary)
figure_binary.seek(0)
BinaryFileProcessor().dump(figure_binary.read())
```
"""
[docs] def load(self, file):
return file.read()
[docs] def dump(self, obj, file):
file.write(obj)
class _ChunkedLargeFileReader(object):
def __init__(self, file) -> None:
self._file = file
def __getattr__(self, item):
return getattr(self._file, item)
def read(self, n):
if n >= (1 << 31):
logger.info(f'reading a large file with total_bytes={n}.')
buffer = bytearray(n)
idx = 0
while idx < n:
batch_size = min(n - idx, 1 << 31 - 1)
logger.info(f'reading bytes [{idx}, {idx + batch_size})...')
buffer[idx:idx + batch_size] = self._file.read(batch_size)
idx += batch_size
logger.info('done.')
return buffer
return self._file.read(n)
[docs]class PickleFileProcessor(FileProcessor):
[docs] def load(self, file):
if not ObjectStorage.is_buffered_reader(file):
return pickle.loads(file.read())
return pickle.load(_ChunkedLargeFileReader(file))
[docs] def dump(self, obj, file):
self._write(pickle.dumps(obj, protocol=4), file)
@staticmethod
def _write(buffer, file):
n = len(buffer)
idx = 0
while idx < n:
logger.info(f'writing a file with total_bytes={n}...')
batch_size = min(n - idx, 1 << 31 - 1)
logger.info(f'writing bytes [{idx}, {idx + batch_size})')
file.write(buffer[idx:idx + batch_size])
idx += batch_size
logger.info('done')
[docs]class TextFileProcessor(FileProcessor):
[docs] def format(self):
return None
[docs] def load(self, file):
return [s.rstrip() for s in file.readlines()]
[docs] def dump(self, obj, file):
if isinstance(obj, list):
for x in obj:
file.write(str(x) + '\n')
else:
file.write(str(obj))
[docs]class CsvFileProcessor(FileProcessor):
def __init__(self, sep=','):
self._sep = sep
super(CsvFileProcessor, self).__init__()
[docs] def load(self, file):
try:
return pd.read_csv(file, sep=self._sep)
except pd.errors.EmptyDataError:
return pd.DataFrame()
[docs] def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame, pd.Series)), \
f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.'
obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True)
[docs]class GzipFileProcessor(FileProcessor):
[docs] def load(self, file):
return [s.rstrip().decode() for s in file.readlines()]
[docs] def dump(self, obj, file):
if isinstance(obj, list):
for x in obj:
file.write((str(x) + '\n').encode())
else:
file.write(str(obj).encode())
[docs]class JsonFileProcessor(FileProcessor):
[docs] def load(self, file):
try:
return pd.read_json(file)
except pd.errors.EmptyDataError:
return pd.DataFrame()
[docs] def dump(self, obj, file):
assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), \
f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.'
if isinstance(obj, dict):
obj = pd.DataFrame.from_dict(obj)
obj.to_json(file)
[docs]class XmlFileProcessor(FileProcessor):
[docs] def load(self, file):
try:
return ET.parse(file)
except ET.ParseError:
return ET.ElementTree()
[docs] def dump(self, obj, file):
assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.'
obj.write(file)
[docs]class NpzFileProcessor(FileProcessor):
[docs] def load(self, file):
return np.load(file)['data']
[docs] def dump(self, obj, file):
assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.'
np.savez_compressed(file, data=obj)
[docs]class ParquetFileProcessor(FileProcessor):
def __init__(self, engine='pyarrow', compression=None):
self._engine = engine
self._compression = compression
super(ParquetFileProcessor, self).__init__()
[docs] def load(self, file):
# MEMO: read_parquet only supports a filepath as string (not a file handle)
return pd.read_parquet(file.name)
[docs] def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame)), \
f'requires pd.DataFrame, but {type(obj)} is passed.'
# MEMO: to_parquet only supports a filepath as string (not a file handle)
obj.to_parquet(file.name, index=False, compression=self._compression)
[docs]class FeatherFileProcessor(FileProcessor):
def __init__(self, store_index_in_feather: bool):
super(FeatherFileProcessor, self).__init__()
self._store_index_in_feather = store_index_in_feather
self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__'
[docs] def load(self, file):
loaded_df = pd.read_feather(file.name)
if self._store_index_in_feather:
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[:len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
index_column = index_columns[0]
index_name = index_column[len(self.INDEX_COLUMN_PREFIX):]
loaded_df.index = pd.Index(loaded_df[index_column], name=index_name)
loaded_df = loaded_df.drop(columns={index_column})
return loaded_df
[docs] def dump(self, obj, file):
assert isinstance(obj, (pd.DataFrame)), \
f'requires pd.DataFrame, but {type(obj)} is passed.'
dump_obj = obj.copy()
if self._store_index_in_feather:
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
assert index_column_name not in dump_obj.columns, f'column name {index_column_name} already exists in dump_obj. \
Consider not saving index by setting store_index_in_feather=False.'
dump_obj[index_column_name] = dump_obj.index
dump_obj = dump_obj.reset_index(drop=True)
# to_feather supports "binary" file-like object, but file variable is text
dump_obj.to_feather(file.name)
[docs]def make_file_processor(file_path: str, store_index_in_feather: bool) -> FileProcessor:
extension2processor = {
'.txt': TextFileProcessor(),
'.csv': CsvFileProcessor(sep=','),
'.tsv': CsvFileProcessor(sep='\t'),
'.pkl': PickleFileProcessor(),
'.gz': GzipFileProcessor(),
'.json': JsonFileProcessor(),
'.xml': XmlFileProcessor(),
'.npz': NpzFileProcessor(),
'.parquet': ParquetFileProcessor(compression='gzip'),
'.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather),
'.png': BinaryFileProcessor(),
'.jpg': BinaryFileProcessor(),
}
extension = os.path.splitext(file_path)[1]
assert extension in extension2processor, f'{extension} is not supported. The supported extensions are {list(extension2processor.keys())}.'
return extension2processor[extension]