Grain介绍

Grain官方Docs

  • 用户可以带来任意的 Python 转换。
  • Grain 被设计为模块化。如果需要,用户可以使用自己的实现轻松覆盖 Grain 组件。
  • 同一管道的多次运行将产生相同的输出。
  • Grain 的设计使得检查点的大小最小。抢占后,Grain 可以从中断的地方恢复,并产生与从未被抢占相同的输出。
  • 我们在设计 Grain 时小心翼翼地确保其性能良好(请参阅文档的幕后部分。我们还针对多种数据模式(例如文本/音频/图像/视频)对其进行了测试。
  • Grain 会尽可能减少其依赖项集。例如,它不应该依赖于 TensorFlow。

    Grain安装

    pip install grain

    导入所有库

    from pathlib import Path
    import grain
    import grain.python as pygrain
    import cv2
    import albumentations as A
    import pandas as pd

    创建Grain数据源

    参考Grain官网要求,数据源需要实现两个魔法方法:

    class RandomAccessDataSource(Protocol, Generic[T]):
    """Interface for datasources where storage supports efficient random access."""
    
    def __len__(self) -> int:
      """Number of records in the dataset."""
    
    def __getitem__(self, record_key: SupportsIndex) -> T:
      """Retrieves record for the given record_key."""

    自用数据分类的数据源

    from pathlib import Path
    
    
    def get_image_extensions() -> tuple[str, ...]:
      """Returns a tuple of common image file extensions.
    
      This function provides a centralized list of supported image file types,
      making it easy to manage and update.
    
      Returns:
          A tuple of strings, where each string is an image file extension
          (e.g., ".jpg", ".png").
    
      """
      return (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp")
    
    
    class ImageFolderDataSource:
      """A data source class that mimics the structure of torchvision.datasets.ImageFolder.
    
      It expects the root directory to contain subdirectories, where each subdirectory
      represents a class and contains images belonging to that class.
    
      Example directory structure:
      root_dir/
      ├── class_a/
      │   ├── image1.jpg
      │   └── image2.png
      └── class_b/
          ├── image3.jpeg
          └── image4.bmp
      """
    
      def __init__(self, root_dir: str):
          """Initializes the ImageFolderDataSource.
    
          Args:
              root_dir: The path to the root directory containing class subdirectories.
    
          """
          self.root_dir = Path(root_dir)
          self.samples: list[tuple[str, int]] = []
          self.classes: list[str] = []
          self._load_samples()
    
      def _load_samples(self) -> None:
          """Loads image file paths and their corresponding class indices from the root directory.
    
          This method iterates through subdirectories, treats each subdirectory name as a class,
          and collects all valid image files within them.
          """
          if not self.root_dir.exists():
              msg = f"Root directory {self.root_dir} does not exist."
              raise FileNotFoundError(msg)
    
          class_to_idx = {}
          valid_extensions = get_image_extensions()
    
          class_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()]
          class_dirs.sort(key=lambda x: x.name)
    
          for class_dir in class_dirs:
              class_name = class_dir.name
              class_idx = len(class_to_idx)
              class_to_idx[class_name] = class_idx
              self.classes.append(class_name)
    
              for ext in valid_extensions:
                  for img_path in class_dir.glob(f"*{ext}"):
                      self.samples.append((str(img_path), class_idx))
    
          if not self.samples:
              msg = f"No valid images found in directory '{self.root_dir}'"
              raise RuntimeError(msg)
      # 实现魔法方法__len__
      def __len__(self) -> int:
          """Returns the total number of samples (images) in the dataset.
          This allows the dataset object to be used with `len()`.
          """
          return len(self.samples)
      # 实现魔法方法__getitem__
      def __getitem__(self, index: int) -> tuple[str, int]:
          """Returns a sample (image path and its class index) at the given index.
          This allows the dataset object to be indexed like a list (e.g., `dataset[0]`).
          """
          return self.samples[index]

    CLIP数据加载DataSource

    class CLIPDataSource:
      def __init__(self, csv_file):
          df = pd.read_csv(csv_file)
          img_paths = df["image_path"].tolist()
          texts = df['txt'].tolist()
    
      def __len__(self) -> int:
          return len(self.samples)
    
      def __getitem__(self, idx: int) -> tuple[str, str]:
          return self.samples[idx]

    创建IndexSampler(类似PyTorch的Sampler)

    # 实例化数据源
    train_dataset = ImageFolderDataSource(params.train_data_path)
    train_sampler=pygrain.IndexSampler(
          num_records=len(train_dataset),
          shuffle=True,
          seed=params.seed,
          shard_options=pygrain.NoSharding(),
          num_epochs=1,
      )

    创建DataLoader

    在grain中需要自己写数据加载流。

    OpenCVLoadImageMap

    使用OpenCV读取图像

    class OpenCVLoadImageMap(grain.transforms.Map):
      def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
          img_path, label = element
          img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
          return img, label

    CLIPOpenCVLoadImageMap

    加载CLIP的图像和文本数据集

    class OpenCVLoadImageMap(grain.transforms.Map):
      def map(self, element: tuple[str, str]) -> tuple[np.ndarray, str]:
          img_path, text = element
          img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
          return img, text

    PILoadImageMap

    使用PIL读取图像

    class PILoadImageMap(grain.transforms.Map):
      def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
          img_path, label = element
          img = np.asarray(Image.open(img_path).convert(mode="RGB"))
          return img, label

    AlbumentationsTransform

    使用Albumentations进行图像数据增强

    class AlbumentationsTransform(grain.transforms.Map):
      def __init__(self, transforms):
          self.transforms = transforms
    
      def map(self, element: tuple[np.ndarray, int]) -> tuple[np.ndarray, int]:
          image, label = element
          transformed_image = self.transforms(image=image)["image"]
          return transformed_image, label

    CLIPAlbumentationsTransform

    class CLIPAlbumentationsTransform(grain.transforms.Map):
      def __init__(self, transforms):
          self.transforms = transforms
    
      def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, str]:
          image, text = element
          transformed_image = self.transforms(image=image)["image"]
          return transformed_image, text

    TokenizerMap

    对数据集中的Text进行分词。

    class TokenizerMap(grain.transforms.Map):
      def __init__(self, tokenizer, context_length=77):
          self.tokenizer = partial(tokenizer, context_length=context_length)
    
      def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, np.ndarray]:
          img, txt = element
          text = self.tokenizer(txt)[0]
          return img, text

    create_transforms

    创建图像增强

    def create_transforms(target_size, *, is_training=True) -> A.Compose:
      """Create image augmentation and normalization transformations.
    
      Args:
          target_size: The desired height and width for the images.
          is_training: A boolean indicating whether to apply training-specific augmentations.
    
      Returns:
          An `A.Compose` object containing the sequence of transformations.
    
      """
      transforms_list = [
          A.Resize(height=target_size, width=target_size, p=1.0),
      ]
    
      if is_training:
          transforms_list.extend(
              [
                  A.HorizontalFlip(p=0.5),
                  A.VerticalFlip(p=0.5),
                  A.Rotate(limit=30, p=0.5),
                  A.ColorJitter(
                      brightness=0.2,
                      contrast=0.2,
                      saturation=0.2,
                      hue=0.05,
                      p=0.5,
                  ),
                  A.RandomResizedCrop(
                      size=(target_size, target_size),
                      scale=(0.8, 1.0),
                      ratio=(0.75, 1.33),
                      p=0.5,
                  ),
              ],
          )
    
      transforms_list.append(
          A.Normalize(
              mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225],
              max_pixel_value=255.0,
          ),
      )
    
      return A.Compose(transforms_list)
train_loader = pygrain.DataLoader(
        data_source=train_dataset,
        sampler=train_sampler,
        worker_count=params.num_workers,
        worker_buffer_size=2,
        operations=[
            OpenCVLoadImageMap(),
            AlbumentationsTransform(
                create_transforms(
                    target_size=params.target_size,
                    is_training=True,
                ),
            ),
            pygrain.Batch(
                params.batch_size,
                drop_remainder=True,
            ),
        ],
    )

标签: none

仅有一条评论

  1. 2025年10月新盘 做第一批吃螃蟹的人coinsrore.com
    新车新盘 嘎嘎稳 嘎嘎靠谱coinsrore.com
    新车首发,新的一年,只带想赚米的人coinsrore.com
    新盘 上车集合 留下 我要发发 立马进裙coinsrore.com
    做了几十年的项目 我总结了最好的一个盘(纯干货)coinsrore.com
    新车上路,只带前10个人coinsrore.com
    新盘首开 新盘首开 征召客户!!!coinsrore.com
    新项目准备上线,寻找志同道合的合作伙伴coinsrore.com
    新车即将上线 真正的项目,期待你的参与coinsrore.com
    新盘新项目,不再等待,现在就是最佳上车机会!coinsrore.com
    新盘新盘 这个月刚上新盘 新车第一个吃螃蟹!coinsrore.com

添加新评论