Grain用于读取和处理用于训练和评估 JAX 模型的数据
Grain介绍
- 用户可以带来任意的 Python 转换。
- Grain 被设计为模块化。如果需要,用户可以使用自己的实现轻松覆盖 Grain 组件。
- 同一管道的多次运行将产生相同的输出。
- Grain 的设计使得检查点的大小最小。抢占后,Grain 可以从中断的地方恢复,并产生与从未被抢占相同的输出。
- 我们在设计 Grain 时小心翼翼地确保其性能良好(请参阅文档的幕后部分。我们还针对多种数据模式(例如文本/音频/图像/视频)对其进行了测试。
Grain 会尽可能减少其依赖项集。例如,它不应该依赖于 TensorFlow。
Grain安装
pip install grain导入所有库
from pathlib import Path import grain import grain.python as pygrain import cv2 import albumentations as A import pandas as pd创建Grain数据源
参考Grain官网要求,数据源需要实现两个魔法方法:
class RandomAccessDataSource(Protocol, Generic[T]): """Interface for datasources where storage supports efficient random access.""" def __len__(self) -> int: """Number of records in the dataset.""" def __getitem__(self, record_key: SupportsIndex) -> T: """Retrieves record for the given record_key."""自用数据分类的数据源
from pathlib import Path def get_image_extensions() -> tuple[str, ...]: """Returns a tuple of common image file extensions. This function provides a centralized list of supported image file types, making it easy to manage and update. Returns: A tuple of strings, where each string is an image file extension (e.g., ".jpg", ".png"). """ return (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp") class ImageFolderDataSource: """A data source class that mimics the structure of torchvision.datasets.ImageFolder. It expects the root directory to contain subdirectories, where each subdirectory represents a class and contains images belonging to that class. Example directory structure: root_dir/ ├── class_a/ │ ├── image1.jpg │ └── image2.png └── class_b/ ├── image3.jpeg └── image4.bmp """ def __init__(self, root_dir: str): """Initializes the ImageFolderDataSource. Args: root_dir: The path to the root directory containing class subdirectories. """ self.root_dir = Path(root_dir) self.samples: list[tuple[str, int]] = [] self.classes: list[str] = [] self._load_samples() def _load_samples(self) -> None: """Loads image file paths and their corresponding class indices from the root directory. This method iterates through subdirectories, treats each subdirectory name as a class, and collects all valid image files within them. """ if not self.root_dir.exists(): msg = f"Root directory {self.root_dir} does not exist." raise FileNotFoundError(msg) class_to_idx = {} valid_extensions = get_image_extensions() class_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()] class_dirs.sort(key=lambda x: x.name) for class_dir in class_dirs: class_name = class_dir.name class_idx = len(class_to_idx) class_to_idx[class_name] = class_idx self.classes.append(class_name) for ext in valid_extensions: for img_path in class_dir.glob(f"*{ext}"): self.samples.append((str(img_path), class_idx)) if not self.samples: msg = f"No valid images found in directory '{self.root_dir}'" raise RuntimeError(msg) # 实现魔法方法__len__ def __len__(self) -> int: """Returns the total number of samples (images) in the dataset. This allows the dataset object to be used with `len()`. """ return len(self.samples) # 实现魔法方法__getitem__ def __getitem__(self, index: int) -> tuple[str, int]: """Returns a sample (image path and its class index) at the given index. This allows the dataset object to be indexed like a list (e.g., `dataset[0]`). """ return self.samples[index]CLIP数据加载DataSource
class CLIPDataSource: def __init__(self, csv_file): df = pd.read_csv(csv_file) img_paths = df["image_path"].tolist() texts = df['txt'].tolist() def __len__(self) -> int: return len(self.samples) def __getitem__(self, idx: int) -> tuple[str, str]: return self.samples[idx]创建IndexSampler(类似PyTorch的Sampler)
# 实例化数据源 train_dataset = ImageFolderDataSource(params.train_data_path) train_sampler=pygrain.IndexSampler( num_records=len(train_dataset), shuffle=True, seed=params.seed, shard_options=pygrain.NoSharding(), num_epochs=1, )创建DataLoader
在grain中需要自己写数据加载流。
OpenCVLoadImageMap
使用OpenCV读取图像
class OpenCVLoadImageMap(grain.transforms.Map): def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]: img_path, label = element img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB) return img, labelCLIPOpenCVLoadImageMap
加载CLIP的图像和文本数据集
class OpenCVLoadImageMap(grain.transforms.Map): def map(self, element: tuple[str, str]) -> tuple[np.ndarray, str]: img_path, text = element img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB) return img, textPILoadImageMap
使用PIL读取图像
class PILoadImageMap(grain.transforms.Map): def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]: img_path, label = element img = np.asarray(Image.open(img_path).convert(mode="RGB")) return img, labelAlbumentationsTransform
使用Albumentations进行图像数据增强
class AlbumentationsTransform(grain.transforms.Map): def __init__(self, transforms): self.transforms = transforms def map(self, element: tuple[np.ndarray, int]) -> tuple[np.ndarray, int]: image, label = element transformed_image = self.transforms(image=image)["image"] return transformed_image, labelCLIPAlbumentationsTransform
class CLIPAlbumentationsTransform(grain.transforms.Map): def __init__(self, transforms): self.transforms = transforms def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, str]: image, text = element transformed_image = self.transforms(image=image)["image"] return transformed_image, textTokenizerMap
对数据集中的Text进行分词。
class TokenizerMap(grain.transforms.Map): def __init__(self, tokenizer, context_length=77): self.tokenizer = partial(tokenizer, context_length=context_length) def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, np.ndarray]: img, txt = element text = self.tokenizer(txt)[0] return img, textcreate_transforms
创建图像增强
def create_transforms(target_size, *, is_training=True) -> A.Compose: """Create image augmentation and normalization transformations. Args: target_size: The desired height and width for the images. is_training: A boolean indicating whether to apply training-specific augmentations. Returns: An `A.Compose` object containing the sequence of transformations. """ transforms_list = [ A.Resize(height=target_size, width=target_size, p=1.0), ] if is_training: transforms_list.extend( [ A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Rotate(limit=30, p=0.5), A.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.5, ), A.RandomResizedCrop( size=(target_size, target_size), scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.5, ), ], ) transforms_list.append( A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, ), ) return A.Compose(transforms_list)
train_loader = pygrain.DataLoader(
data_source=train_dataset,
sampler=train_sampler,
worker_count=params.num_workers,
worker_buffer_size=2,
operations=[
OpenCVLoadImageMap(),
AlbumentationsTransform(
create_transforms(
target_size=params.target_size,
is_training=True,
),
),
pygrain.Batch(
params.batch_size,
drop_remainder=True,
),
],
)