Grain介绍

Grain官方Docs

  • 用户可以带来任意的 Python 转换。
  • Grain 被设计为模块化。如果需要,用户可以使用自己的实现轻松覆盖 Grain 组件。
  • 同一管道的多次运行将产生相同的输出。
  • Grain 的设计使得检查点的大小最小。抢占后,Grain 可以从中断的地方恢复,并产生与从未被抢占相同的输出。
  • 我们在设计 Grain 时小心翼翼地确保其性能良好(请参阅文档的幕后部分。我们还针对多种数据模式(例如文本/音频/图像/视频)对其进行了测试。
  • Grain 会尽可能减少其依赖项集。例如,它不应该依赖于 TensorFlow。

    Grain安装

    pip install grain

    导入所有库

    from pathlib import Path
    import grain
    import grain.python as pygrain
    import cv2
    import albumentations as A
    import pandas as pd

    创建Grain数据源

    参考Grain官网要求,数据源需要实现两个魔法方法:

    class RandomAccessDataSource(Protocol, Generic[T]):
    """Interface for datasources where storage supports efficient random access."""
    
    def __len__(self) -> int:
      """Number of records in the dataset."""
    
    def __getitem__(self, record_key: SupportsIndex) -> T:
      """Retrieves record for the given record_key."""

    自用数据分类的数据源

    from pathlib import Path
    
    
    def get_image_extensions() -> tuple[str, ...]:
      """Returns a tuple of common image file extensions.
    
      This function provides a centralized list of supported image file types,
      making it easy to manage and update.
    
      Returns:
          A tuple of strings, where each string is an image file extension
          (e.g., ".jpg", ".png").
    
      """
      return (".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp")
    
    
    class ImageFolderDataSource:
      """A data source class that mimics the structure of torchvision.datasets.ImageFolder.
    
      It expects the root directory to contain subdirectories, where each subdirectory
      represents a class and contains images belonging to that class.
    
      Example directory structure:
      root_dir/
      ├── class_a/
      │   ├── image1.jpg
      │   └── image2.png
      └── class_b/
          ├── image3.jpeg
          └── image4.bmp
      """
    
      def __init__(self, root_dir: str):
          """Initializes the ImageFolderDataSource.
    
          Args:
              root_dir: The path to the root directory containing class subdirectories.
    
          """
          self.root_dir = Path(root_dir)
          self.samples: list[tuple[str, int]] = []
          self.classes: list[str] = []
          self._load_samples()
    
      def _load_samples(self) -> None:
          """Loads image file paths and their corresponding class indices from the root directory.
    
          This method iterates through subdirectories, treats each subdirectory name as a class,
          and collects all valid image files within them.
          """
          if not self.root_dir.exists():
              msg = f"Root directory {self.root_dir} does not exist."
              raise FileNotFoundError(msg)
    
          class_to_idx = {}
          valid_extensions = get_image_extensions()
    
          class_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()]
          class_dirs.sort(key=lambda x: x.name)
    
          for class_dir in class_dirs:
              class_name = class_dir.name
              class_idx = len(class_to_idx)
              class_to_idx[class_name] = class_idx
              self.classes.append(class_name)
    
              for ext in valid_extensions:
                  for img_path in class_dir.glob(f"*{ext}"):
                      self.samples.append((str(img_path), class_idx))
    
          if not self.samples:
              msg = f"No valid images found in directory '{self.root_dir}'"
              raise RuntimeError(msg)
      # 实现魔法方法__len__
      def __len__(self) -> int:
          """Returns the total number of samples (images) in the dataset.
          This allows the dataset object to be used with `len()`.
          """
          return len(self.samples)
      # 实现魔法方法__getitem__
      def __getitem__(self, index: int) -> tuple[str, int]:
          """Returns a sample (image path and its class index) at the given index.
          This allows the dataset object to be indexed like a list (e.g., `dataset[0]`).
          """
          return self.samples[index]

    CLIP数据加载DataSource

    class CLIPDataSource:
      def __init__(self, csv_file):
          df = pd.read_csv(csv_file)
          img_paths = df["image_path"].tolist()
          texts = df['txt'].tolist()
    
      def __len__(self) -> int:
          return len(self.samples)
    
      def __getitem__(self, idx: int) -> tuple[str, str]:
          return self.samples[idx]

    创建IndexSampler(类似PyTorch的Sampler)

    # 实例化数据源
    train_dataset = ImageFolderDataSource(params.train_data_path)
    train_sampler=pygrain.IndexSampler(
          num_records=len(train_dataset),
          shuffle=True,
          seed=params.seed,
          shard_options=pygrain.NoSharding(),
          num_epochs=1,
      )

    创建DataLoader

    在grain中需要自己写数据加载流。

    OpenCVLoadImageMap

    使用OpenCV读取图像

    class OpenCVLoadImageMap(grain.transforms.Map):
      def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
          img_path, label = element
          img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
          return img, label

    CLIPOpenCVLoadImageMap

    加载CLIP的图像和文本数据集

    class OpenCVLoadImageMap(grain.transforms.Map):
      def map(self, element: tuple[str, str]) -> tuple[np.ndarray, str]:
          img_path, text = element
          img = cv2.imread(img_path, cv2.IMREAD_COLOR_RGB)
          return img, text

    PILoadImageMap

    使用PIL读取图像

    class PILoadImageMap(grain.transforms.Map):
      def map(self, element: tuple[str, int]) -> tuple[np.ndarray, int]:
          img_path, label = element
          img = np.asarray(Image.open(img_path).convert(mode="RGB"))
          return img, label

    AlbumentationsTransform

    使用Albumentations进行图像数据增强

    class AlbumentationsTransform(grain.transforms.Map):
      def __init__(self, transforms):
          self.transforms = transforms
    
      def map(self, element: tuple[np.ndarray, int]) -> tuple[np.ndarray, int]:
          image, label = element
          transformed_image = self.transforms(image=image)["image"]
          return transformed_image, label

    CLIPAlbumentationsTransform

    class CLIPAlbumentationsTransform(grain.transforms.Map):
      def __init__(self, transforms):
          self.transforms = transforms
    
      def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, str]:
          image, text = element
          transformed_image = self.transforms(image=image)["image"]
          return transformed_image, text

    TokenizerMap

    对数据集中的Text进行分词。

    class TokenizerMap(grain.transforms.Map):
      def __init__(self, tokenizer, context_length=77):
          self.tokenizer = partial(tokenizer, context_length=context_length)
    
      def map(self, element: tuple[np.ndarray, str]) -> tuple[np.ndarray, np.ndarray]:
          img, txt = element
          text = self.tokenizer(txt)[0]
          return img, text

    create_transforms

    创建图像增强

    def create_transforms(target_size, *, is_training=True) -> A.Compose:
      """Create image augmentation and normalization transformations.
    
      Args:
          target_size: The desired height and width for the images.
          is_training: A boolean indicating whether to apply training-specific augmentations.
    
      Returns:
          An `A.Compose` object containing the sequence of transformations.
    
      """
      transforms_list = [
          A.Resize(height=target_size, width=target_size, p=1.0),
      ]
    
      if is_training:
          transforms_list.extend(
              [
                  A.HorizontalFlip(p=0.5),
                  A.VerticalFlip(p=0.5),
                  A.Rotate(limit=30, p=0.5),
                  A.ColorJitter(
                      brightness=0.2,
                      contrast=0.2,
                      saturation=0.2,
                      hue=0.05,
                      p=0.5,
                  ),
                  A.RandomResizedCrop(
                      size=(target_size, target_size),
                      scale=(0.8, 1.0),
                      ratio=(0.75, 1.33),
                      p=0.5,
                  ),
              ],
          )
    
      transforms_list.append(
          A.Normalize(
              mean=[0.485, 0.456, 0.406],
              std=[0.229, 0.224, 0.225],
              max_pixel_value=255.0,
          ),
      )
    
      return A.Compose(transforms_list)
train_loader = pygrain.DataLoader(
        data_source=train_dataset,
        sampler=train_sampler,
        worker_count=params.num_workers,
        worker_buffer_size=2,
        operations=[
            OpenCVLoadImageMap(),
            AlbumentationsTransform(
                create_transforms(
                    target_size=params.target_size,
                    is_training=True,
                ),
            ),
            pygrain.Batch(
                params.batch_size,
                drop_remainder=True,
            ),
        ],
    )

数据集地址:https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset

from shutil import copyfile
import random
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import os
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import datetime
import torchvision
import torch
import torchvision.transforms as T

# import cv2

torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.cuda.set_per_process_memory_fraction(0.95, 0)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('PyTorch版本:' + torch.__version__)
print('CUDA版本:' + torch.version.cuda)
print('CUDNN版本:' + str(torch.backends.cudnn.version()))
print('设备名称:' + torch.cuda.get_device_name(0))
PyTorch版本:1.11.0
CUDA版本:11.3
CUDNN版本:8200
设备名称:NVIDIA GeForce RTX 3060 Laptop GPU


def walk_through_dir(directory_name):  # 输出路径下的目录和目录中的文件
    for dirpaths, dirnames, filenames in os.walk(directory_name):
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpaths}'")
input_data_dir = './PetImages'  # 当前数据目录
walk_through_dir(input_data_dir)  # 查看当前数据路径下的目录和目录中的文件
There are 2 directories and 0 images in './PetImages'
There are 0 directories and 12501 images in './PetImages\Cat'
There are 0 directories and 12501 images in './PetImages\Dog'


print("文件目录结构:")
!tree./ PetImages
文件目录结构:
卷 新加卷 的文件夹 PATH 列表
卷序列号为 44A6-CC9A
D:\DOCUMENTS\PYTORCH深度学习实战\PETIMAGES
├─Cat
└─Dog
os.mkdir('./data')  # 创建目录来存放训练数据和测试数据
os.mkdir('./data/train')
os.mkdir('./data/test')
for folder in os.listdir(input_data_dir):  # 遍历cat和dog目录
    files = os.listdir(os.path.join(input_data_dir, folder))  # 拼接得到图像文件的父目录
    images = []  # 创建列表来存储图像路径
    for f in files:  # 遍历Dog/Cat目录下的图像
        try:
            img = Image.open(os.path.join(input_data_dir, folder, f)).convert("RGB")  # 如果图像能够打开则添加到路径列表
            images.append(f)
        except IOError:  # 如果发生输入输出错误则输出发生错误的文件,并跳过该文件
            print(f'fail on {f}')
            pass

    random.shuffle(images)  # 将列表中的路径打乱
    count = len(images)  # Cat/Dog目录下的总图片数
    split = int(0.8 * count)  # 选取其中80%作为训练集
    os.mkdir(os.path.join('./data/train', folder))  # 创建训练集下的Dog/Cat目录
    os.mkdir(os.path.join('./data/test', folder))  # 创建测试集下的Dog/Cat目录

    for c in range(split):
        source_file = os.path.join(input_data_dir, folder, images[c])  # 得到训练集源文件的路径
        distination = os.path.join('./data/train', folder, images[c])  # 创建目标路径
        copyfile(source_file, distination)  # 将训练集路径下的文件放到训练集中
    for c in range(split, count):
        source_file = os.path.join(input_data_dir, folder, images[c])  # 得到测试集源文件的路径
        distination = os.path.join('./data/test', folder, images[c])  # 创建目标存放路径
        copyfile(source_file, distination)  # 将测试集路径下的文件放到测试集中
train_dir = './data/train'  # 生成的训练集文件目录
walk_through_dir(train_dir)  # 查看训练集目录
test_dir = './data/test'  # 生成的测试集文件目录
walk_through_dir(test_dir)  # 查看测试集目录
There are 2 directories and 0 images in './data/train'
There are 0 directories and 9999 images in './data/train\Cat'
There are 0 directories and 9999 images in './data/train\Dog'
There are 2 directories and 0 images in './data/test'
There are 0 directories and 2500 images in './data/test\Cat'
There are 0 directories and 2500 images in './data/test\Dog'
# 设置测试集的图像处理
train_transforms = T.Compose([
    T.RandomResizedCrop(224),  # 随机裁剪一个区域,然后重塑形状到(224,224)
    T.RandomHorizontalFlip(0.5),  # 50%的概率随机水平翻转
    T.ToTensor()  # 将像素值归一化到 [0.0,1.0]
])
valid_transform = T.Compose([
    T.CenterCrop(224),  # 从图像中心裁剪(224,224)的图像
    T.ToTensor()
])
train_dataset = torchvision.datasets.ImageFolder(root=train_dir, transform=train_transforms)  # 创造训练集
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)  # 创造训练数据加载器
valid_dataset = torchvision.datasets.ImageFolder(root=test_dir, transform=valid_transform)  # 创造测试集
valid_loader = torch.utils.data.DataLoader(valid_dataset)  # 创造测试集数据加载器
train_dataset.class_to_idx  # 查看训练集类映射
{'Cat': 0, 'Dog': 1}
# 使用BatchNormalization层的VGG16模型
model = torchvision.models.vgg16_bn(pretrained=True).to(device)
model.classifier[6] = nn.Linear(4096, 2, device=device)  # 修改classifier层最后一个输出2个特征
for name, module in model.named_modules():
    if name != 'classifier.6':  # 冻结除了最后一个层外所有层的参数
        # print(module)
        for parameter in module.parameters():
            parameter.required_grad = False
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)  # 使用AdamW优化器,设置学习率和权值衰减
criterion = nn.CrossEntropyLoss().to(device)  # 使用交叉熵损失函数(相当于softmax+NLLLoss)
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.
    for i, (inputs, labels) in enumerate(train_loader):  # 进行一个Epoch
        inputs, labels = inputs.to(device), labels.to(device)  # 将数据放到device上
        optimizer.zero_grad()  # 初始梯度置为0
        outputs = model(inputs)  # 得到输出
        loss = criterion(outputs, labels)  # 计算损失
        loss.backward()  # 损失反向传播,计算梯度
        optimizer.step()  # 使用优化器对权重进行更新
        running_loss += loss.item()  # 添加到运行损失中
        if i % 1000 == 999:  # 每1000批
            last_loss = running_loss / 1000  # 平均每批的损失
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1  # 作为全局步长来写入TensorBoard
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)  # 写入训练损失到TensorBoard
            running_loss = 0.  # 重置训练损失

    return last_loss
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')  # 获取当前时间来作为写入的路径
writer = SummaryWriter('runs/dogvscat_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5  # 训练轮次

best_vloss = 1000000.  # 最好的测试集损失

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))  # 输出轮次

    avg_loss = train_one_epoch(epoch_number, writer)  # 获得一轮结束后平均损失

    with torch.no_grad():  # 不求梯度
        running_vloss = 0.0
        for i, (vinputs, vlabels) in enumerate(valid_loader):
            vinputs, vlabels = vinputs.to(device), vlabels.to(device)
            voutputs = model(vinputs)
            vloss = criterion(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    writer.add_scalars('Training vs. Validation Loss',
                       {'Training': avg_loss, 'Validation': avg_vloss},
                       epoch_number + 1)  # 写入训练集和测试集损失
    writer.flush()  # 清空缓冲区数据

    if avg_vloss < best_vloss:  # 如果平均损失小于最小的损失
        best_vloss = avg_vloss  # 更新最小损失为平均损失
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)  # 将模型保存到该路径

    epoch_number += 1
EPOCH 1:
  batch 1000 loss: 0.6323979058712721


C:\Users\reion\miniconda3\envs\torch\lib\site-packages\PIL\TiffImagePlugin.py:845: UserWarning: Truncated File Read
  warnings.warn(str(msg))


LOSS train 0.6323979058712721 valid 0.6807079315185547
EPOCH 2:
  batch 1000 loss: 0.4388493032604456
LOSS train 0.4388493032604456 valid 0.687056303024292
EPOCH 3:
  batch 1000 loss: 0.3774294305369258
LOSS train 0.3774294305369258 valid 0.740951418876648
EPOCH 4:
  batch 1000 loss: 0.355394969265908
LOSS train 0.355394969265908 valid 0.8710261583328247
EPOCH 5:
  batch 1000 loss: 0.5819370641112328
LOSS train 0.5819370641112328 valid 0.7211285829544067


model.load_state_dict(torch.load('./model_20220518_195539_0'))  # 加载保存下来的模型
model.eval()
with torch.no_grad():  # 计算准确率
    correct = 0
    total = 0

    for i, (vinputs, vlabels) in enumerate(valid_loader):
        total += len(vinputs)
        vinputs, vlabels = vinputs.to(device), vlabels.to(device)
        voutputs = model(vinputs)
        correct += (voutputs.argmax(1) == vlabels).type(torch.float).sum().item()
print('准确率:' + str(correct / total))
准确率:0.7476

import random
import os
import tensorflow as tf
from tensorflow import keras
from shutil import copyfile
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
import datetime

# from tensorflow_addons.optimizers import AdamW
print('TensorFlow版本:' + tf.__version__)
print('CUDA设备:' + tf.test.gpu_device_name())
TensorFlow版本:2.6.0
CUDA设备:/device:GPU:0
def walk_through_dir(directory_name):
    for dirpaths, dirnames, filenames in os.walk(directory_name):
        print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpaths}'")
input_data_dir = '../PetImages'
print('目录结构:')
!tree ../PetImages
walk_through_dir(input_data_dir)
目录结构:
卷 新加卷 的文件夹 PATH 列表
卷序列号为 44A6-CC9A
D:\DOCUMENTS\PYTORCH深度学习实战\PETIMAGES
├─Cat
└─Dog
There are 2 directories and 0 images in '../PetImages'
There are 0 directories and 12501 images in '../PetImages\Cat'
There are 0 directories and 12501 images in '../PetImages\Dog'
os.mkdir('./data')  # 创建目录来存放训练数据和测试数据
os.mkdir('./data/train')
os.mkdir('./data/test')
for folder in os.listdir(input_data_dir):  # 遍历cat和dog目录
    files = os.listdir(os.path.join(input_data_dir, folder))  # 拼接得到图像文件的父目录
    images = []  # 创建列表来存储图像路径
    for f in files:  # 遍历Dog/Cat目录下的图像
        try:
            img = Image.open(os.path.join(input_data_dir, folder, f)).convert("RGB")  # 如果图像能够打开则添加到路径列表
            images.append(f)
        except IOError:  # 如果发生输入输出错误则输出发生错误的文件,并跳过该文件
            print(f'fail on {f}')
            pass

    random.shuffle(images)  # 将列表中的路径打乱
    count = len(images)  # Cat/Dog目录下的总图片数
    split = int(0.8 * count)  # 选取其中80%作为训练集
    os.mkdir(os.path.join('./data/train', folder))  # 创建训练集下的Dog/Cat目录
    os.mkdir(os.path.join('./data/test', folder))  # 创建测试集下的Dog/Cat目录

    for c in range(split):
        source_file = os.path.join(input_data_dir, folder, images[c])  # 得到训练集源文件的路径
        distination = os.path.join('./data/train', folder, images[c])  # 创建目标路径
        copyfile(source_file, distination)  # 将训练集路径下的文件放到训练集中
    for c in range(split, count):
        source_file = os.path.join(input_data_dir, folder, images[c])  # 得到测试集源文件的路径
        distination = os.path.join('./data/test', folder, images[c])  # 创建目标存放路径
        copyfile(source_file, distination)  # 将测试集路径下的文件放到测试集中
train_dir = './data/train'  # 生成的训练集文件目录
walk_through_dir(train_dir)  # 查看训练集目录
test_dir = './data/test'  # 生成的测试集文件目录
walk_through_dir(test_dir)  # 查看测试集目录
There are 2 directories and 0 images in './data/train'
There are 0 directories and 9999 images in './data/train\Cat'
There are 0 directories and 9999 images in './data/train\Dog'
There are 2 directories and 0 images in './data/test'
There are 0 directories and 2500 images in './data/test\Cat'
There are 0 directories and 2500 images in './data/test\Dog'
train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=0.5, vertical_flip=0.5)  # 构建训练集图像迭代器
test_datagen = ImageDataGenerator(rescale=1. / 255)  # 构建测试集图像迭代器
# 从目录获取图像,并设置类模式
train_dataset = train_datagen.flow_from_directory(train_dir, target_size=(224, 224), batch_size=128, class_mode='binary',
                                                  shuffle=True)
test_dataset = test_datagen.flow_from_directory(test_dir, target_size=(224, 224), batch_size=128, class_mode='binary')
Found 19998 images belonging to 2 classes.
Found 5000 images belonging to 2 classes.


# 查看类的映射
train_dataset.class_indices
{'Cat': 0, 'Dog': 1}
# 设置二分类的评价指标
METRICS = [
    keras.metrics.TruePositives(name='tp'),
    keras.metrics.FalsePositives(name='fp'),
    keras.metrics.TrueNegatives(name='tn'),
    keras.metrics.FalseNegatives(name='fn'),
    keras.metrics.BinaryAccuracy(name='accuracy'),
    keras.metrics.Precision(name='precision'),
    keras.metrics.Recall(name='recall'),
    keras.metrics.AUC(name='auc'),
    keras.metrics.AUC(name='prc', curve='PR'),
]
def build_model():  # 设置函数来创建模型
    model = keras.Sequential()
    # 使用VGG16作为模型的基本框架,加载在ImageNet上训练的权重,并去除输出层
    base_model = tf.keras.applications.VGG16(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
    base_model.trainable = False # 冻结VGG16的权重,只训练输出层
    model.add(base_model)
    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))  # 二分类,使用sigmoid映射到[0.0,1.0]
    return model
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')  # 获取当前时间来作为写入的路径
# 设置回调函数来写入训练情况
tf_callback = tf.keras.callbacks.TensorBoard(log_dir='./runs/dogvscat_trainer_{}'.format(timestamp))
model = build_model()
model.compile(
    loss=keras.losses.binary_crossentropy,  # 二分类交叉熵损失
    optimizer=keras.optimizers.Adam(lr=1e-3),  # 使用Adam优化器
    metrics=METRICS  # 设置评价函数
)
print(model.summary())  # 输出模型的形状
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
vgg16 (Functional)           (None, 7, 7, 512)         14714688  
_________________________________________________________________
flatten (Flatten)            (None, 25088)             0         
_________________________________________________________________
dense (Dense)                (None, 1)                 25089     
=================================================================
Total params: 14,739,777
Trainable params: 25,089
Non-trainable params: 14,714,688
_________________________________________________________________
None
EPOCHS = 5 # 训练5轮
history = model.fit(
    train_dataset, epochs=5, validation_data=test_dataset, callbacks=[tf_callback]
)
Epoch 1/5
157/157 [==============================] - ETA: 0s - loss: 0.3822 - tp: 8125.0000 - fp: 1558.0000 - tn: 8441.0000 - fn: 1874.0000 - accuracy: 0.8284 - precision: 0.8391 - recall: 0.8126 - auc: 0.9095 - prc: 0.9141

C:\Users\reion\miniconda3\envs\tf2\lib\site-packages\PIL\TiffImagePlugin.py:845: UserWarning: Truncated File Read
  warnings.warn(str(msg))


157/157 [==============================] - 85s 459ms/step - loss: 0.3822 - tp: 8125.0000 - fp: 1558.0000 - tn: 8441.0000 - fn: 1874.0000 - accuracy: 0.8284 - precision: 0.8391 - recall: 0.8126 - auc: 0.9095 - prc: 0.9141 - val_loss: 0.2288 - val_tp: 2142.0000 - val_fp: 131.0000 - val_tn: 2369.0000 - val_fn: 358.0000 - val_accuracy: 0.9022 - val_precision: 0.9424 - val_recall: 0.8568 - val_auc: 0.9735 - val_prc: 0.9729
Epoch 2/5
157/157 [==============================] - 73s 466ms/step - loss: 0.2533 - tp: 8888.0000 - fp: 908.0000 - tn: 9091.0000 - fn: 1111.0000 - accuracy: 0.8990 - precision: 0.9073 - recall: 0.8889 - auc: 0.9628 - prc: 0.9643 - val_loss: 0.1962 - val_tp: 2300.0000 - val_fp: 191.0000 - val_tn: 2309.0000 - val_fn: 200.0000 - val_accuracy: 0.9218 - val_precision: 0.9233 - val_recall: 0.9200 - val_auc: 0.9782 - val_prc: 0.9779
Epoch 3/5
157/157 [==============================] - 75s 478ms/step - loss: 0.2302 - tp: 8969.0000 - fp: 826.0000 - tn: 9173.0000 - fn: 1030.0000 - accuracy: 0.9072 - precision: 0.9157 - recall: 0.8970 - auc: 0.9686 - prc: 0.9694 - val_loss: 0.2223 - val_tp: 2078.0000 - val_fp: 70.0000 - val_tn: 2430.0000 - val_fn: 422.0000 - val_accuracy: 0.9016 - val_precision: 0.9674 - val_recall: 0.8312 - val_auc: 0.9791 - val_prc: 0.9789
Epoch 4/5
157/157 [==============================] - 77s 487ms/step - loss: 0.2148 - tp: 9040.0000 - fp: 786.0000 - tn: 9213.0000 - fn: 959.0000 - accuracy: 0.9127 - precision: 0.9200 - recall: 0.9041 - auc: 0.9723 - prc: 0.9733 - val_loss: 0.1851 - val_tp: 2329.0000 - val_fp: 220.0000 - val_tn: 2280.0000 - val_fn: 171.0000 - val_accuracy: 0.9218 - val_precision: 0.9137 - val_recall: 0.9316 - val_auc: 0.9802 - val_prc: 0.9802
Epoch 5/5
157/157 [==============================] - 78s 493ms/step - loss: 0.1997 - tp: 9144.0000 - fp: 710.0000 - tn: 9289.0000 - fn: 855.0000 - accuracy: 0.9217 - precision: 0.9279 - recall: 0.9145 - auc: 0.9763 - prc: 0.9770 - val_loss: 0.1807 - val_tp: 2271.0000 - val_fp: 147.0000 - val_tn: 2353.0000 - val_fn: 229.0000 - val_accuracy: 0.9248 - val_precision: 0.9392 - val_recall: 0.9084 - val_auc: 0.9806 - val_prc: 0.9806

生成式对抗网络

Ian Goodfellow等人在2014年的论文中提出了生成式对抗网络,尽管这个想法立刻使研究人员们兴奋不已,但还是花了几年时间才克服了训练GAN的一些困难。就像许多伟大的想法一样,事后看起来似乎很简单:让神经网络竞争,希望这种竞争能够促使它们变得更好。GAN由两个神经网络组成

  • 生成器
    以随机分布作为输入(通常是高斯分布),并输出一些数据(通常是图像)。可以将随机输入视为要生成的图像的潜在表征(即编码)。因此可以看到,生成器提供的功能与变分自动编码器中的解码器相同,并且可以使用相同的方式来生成新图片(是需要馈入一些高斯噪声,就会输出一个新图片)
  • 判别器
    输入从生成器得到的伪图像或从训练集中得到的真实图像,并且必须猜测输入图像是伪图像还是真实图像

在训练过程中,生成器和判别器有相反的目标:判别器试图从真实图像中分别出虚假图像,而生成器则试图产生看起来足够真实的图像来欺骗判别器。由于GAN由不同目标的两个网络组成,因此无法像常规神经网络一样对其进行训练。每个训练迭代都分为两个阶段

  • 在第一阶段,训练判别器。从训练集中采样一批真实图像,再加上用生成器生成的相等数量的伪图像组成训练批次。对于伪图像,将标签设置为0;对于真实图像将标签设置为1,并使用二元交叉熵损失在该被标签的批次上对判别器进行训练。重要的是,在这个阶段反向传播只能优化判别器的权重
  • 在第二阶段,训练生成器。首先使用它来生成另一批伪图像,然后再次使用判别器来判断图像是伪图像还是真实图像。在这个批次中不添加真实图像,并且将所有标签都设置为1(真实):换句话说,希望生成器能产生判别器(骗过判别器,让判别器认为是真的)会认为是真实的图像。至关重要的是,在此步骤中,判别器的权重会被固定,因此反向传播只会影响生成器的权重

生成器实际上从未看到过任何真实的图像,但是它会逐渐学会令人信服的伪图像。它所得到的是流经判别器的回流梯度。幸运的是,判别器越好,这些二手梯度中包含的真实图像信息越多,因此生成器可以取得很大的进步

现在给Fashion MNIST构建一个简单的GAN

首先,需要构建生成器和判别器。生成器类似于自动编码器的解码器,判别器是常规的二元分类器(它以图像作为输入,包含单个神经元和使用sigmoid激活函数的Dense层)。对于每个训练迭代的第二阶段,还需要一个完整的GAN模型,其中包含生成器,后面跟随一个判别器:

import tensorflow as tf
from tensorflow import keras

codings_size = 30
generator = keras.models.Sequential([
    keras.layers.Dense(100, activation='gelu', input_shape=[codings_size]),
    keras.layers.Dense(150, activation='gelu'),
    keras.layers.Dense(28 * 28, activation='sigmoid'),
    keras.layers.Reshape([28, 28])
])
discriminator = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28]),
    keras.layers.Dense(150, activation='gelu'),
    keras.layers.Dense(100, activation='gelu'),
    keras.layers.Dense(1, activation='sigmoid')
])
gan = keras.models.Sequential([generator, discriminator])

接下来,需要编译模型。由于判别器是二元分类器,可以使用二元交叉熵损失。生成器仅通过gan模型进行训练,因此不需要对其进行编译。gan模型也是二元分类器,因此可以使用二元交叉熵损失。重要的是,判别器不应该在第二阶段进行训练,因此在训练gan模型之前,将其设为不可训练:

discriminator.compile(loss='binary_crossentropy', optimizer='rmsprop')
discriminator.trainable = False
gan.compile(loss='binary_crossentropy', optimizer='rmsprop')

Keras仅在编译模型时才考虑可训练属性,因此在运行此代码后,如果调用其fit()方法或其train_on_batch()方法,则判别器是可以训练的;当在gan模型上调用这些方法时,则判别器是不可训练的

由于训练循环不寻常,因此不能使用常规的fit()方法,相反,需要写一个自定义训练循环。为此,首先需要创建一个数据集来遍历数据集

fashion_mnist = keras.datasets.fashion_mnist
(X_train_all, y_train_all), (X_test, y_test) = fashion_mnist.load_data()
X_train_all, y_train_all, X_test, y_test = tf.cast(X_train_all, tf.float32), tf.cast(y_train_all, tf.float32), tf.cast(
    X_test, tf.float32), tf.cast(y_test, tf.float32)
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train_all / 255.).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

现在编写训练循环,将其包装在train_gan()函数中

def train_gan(gan, dataset, batch_size, codings_size, n_epochs=50):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        for X_batch in dataset:
            noise = tf.random.normal(shape=[batch_size, codings_size])
            generated_images = generator(noise)
            X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
            y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
            discriminator.trainable = True
            discriminator.train_on_batch(X_fake_and_real, y1)
            noise = tf.random.normal(shape=[batch_size, codings_size])
            y2 = tf.constant([[1.]] * batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y2)


train_gan(gan, dataset, batch_size, codings_size)
coding = tf.random.normal(shape=[12, codings_size])
images = generator(coding).numpy()
import matplotlib.pyplot as plt


def plot_image(image):
    plt.imshow(image, cmap='binary')
    plt.axis('off')


fig = plt.figure(figsize=(12 * 1.5, 3))
for image_index in range(12):
    plt.subplot(3, 4, image_index + 1)
    plot_image(images[image_index])
  • 在第一阶段,将高斯噪声馈送到生成器以生成伪图像,然后通过合并相等数量的真实图像来组成这一批次,对于伪图像,目标y1被设置为0;对于真实图像,目标y1被设置为1。然后在这个批次上进行判别其训练。这里将判别器的1可训练属性设置为True:这只是为了消除Keras注意到在编译模型时可训练属性为True,但是现在为False时显示的警告
  • 在第二阶段,向GAN馈入一些高斯噪声。它的生成器首先会生成伪图像,然后判别器会试图猜测这些图像是伪图像还是真实图像。我们希望判别器相信伪图像是真实的,因此将目标y2设置为1。再次将可训练属性设置为False,以避免再次发出警告

GAN的训练难点

在训练过程中,生成器和判别器在零和游戏中不断地试图超越彼此。随着训练的进行,游戏可能会在一种博弈论者称为纳什均衡的状态中结束,这种均衡以数学家约翰·纳什的名字命名:在这种情况下,假设其他玩家都没有改变他们的策略,那么没有人会改变自己的策略使自己变得更好。例如,当每个人都在道路左侧行驶时,达到了纳什均衡:当么给人都在道路的右侧行驶时。不同的初始状态和动态发展可能会导致一个平衡或另一个平衡。在此示例中,一旦达到平衡(即与其他所有人在同一侧行驶),就存在一个最佳策略,但是一个纳什均衡可能涉及多种竞争策略(例如,捕食者追捕猎物,猎物试图逃避,改变它们的策略也不会变得更好)

那么如果适用于GAN,论文的作者证明了GAN只能达到单个纳什均衡:当生成器产生完美逼真的图像时,判别器只能被迫猜测(50%真实,50%伪造)。这个事实非常令人鼓舞:似乎只需要训练GAN足够长的时间,它最终就会达到均衡,从而给出一个完美的生成器。不幸的是,并不是那么简单:没有任何东西可以保证达到均衡

最大的困难被称为模式崩溃:就是当生成器的输入逐渐变得不太多样化时。这是怎么发生的?假设生成器在产生逼真的鞋子方面比其他任何类都更好。他就会用鞋子来更多地欺骗判别器,这会鼓励它生成更多地鞋子图像。逐渐地它会忘记如何产生其他任何东西。同时判别器看到的唯一伪造图像将是鞋子,因此它也会忘记如何辨别其他类别的伪造图像。最终当判别器在进行假鞋和真鞋区分时,生成器将被迫转移到另一类。这样以来,它可能会变得擅长于生成衬衫,而忘记了鞋子,然后判别器也会跟着生成器。GAN可能会逐渐在几个类别中循环,而不会擅长于生成任何一个类别

此外由于生成器和判别器不断地相互竞争,因此它们的参数可能最终会振荡并开始变得不稳定。训练开始时可能会很好,然后由于这些不稳定而突然发散,没有明显的原因。而且有许多因素会影响这些复杂的动态过程,因此GAN对超参数非常敏感:可能不得不花费大量的精力来微调它们

自2014年以来,研究人员忙于解决这些问题:针对该问题发表了许多论文,其中一些提出了新的成本函数(尽管Google研究人员在2018年发表了一篇论文质疑其效率)或技术来解决训练稳定性或避免模式崩溃的问题。例如,一种称为重播体验的流行技术包括将生成器在每次迭代中生成的图像存储在重播缓存区中(逐渐删除较早生成的图像),使用真实图像以及从该缓冲区中取出的伪图像来训练判别器(而不是由当前生成器生成的伪图像)。这减少了判别器过拟合最新生成输出的图像的机会。另一种常见的技术称为小批量判别:它可测量跨批次中相似图像的程度,并将此统计信息提供给判别器,因此判别器可以轻松拒绝缺乏多样性的一整个批次的伪图像。这会鼓励生成器生成更多样的图像,从而减少模式崩溃。其他论文只是提出了一些表现良好的特定网络架构

深度卷积GAN

2014年的GAN原始论文试验了卷积层,但只是生成了小图像。不久之后,许多研究人员试图基于更深的卷积网络为更大的图像构建GAN。由于训练非常不稳定,所以被证明是棘手的,但是Alec Radford等人在实验了许多不同的架构和超参数后,终于在2015年末取得了成功。他们称其架构为深度卷积GAN(DCGAN)。以下是他们为构建稳定的卷积GAN提出的主要指导

  • 用跨步卷积(在判别器中)和转置卷积(在生成器中)替换所有池化层
  • 除生成器的输出层和判别器的输入层,在生成器和判别器中都使用批量归一化
  • 删除全连接的隐藏层以获得更深的架构
  • 对生成器中所有层使用ReLU激活函数,除了输出层应该使用tanh
  • 对判别器中所有层使用leaky ReLU激活函数

这些准则在许多情况下都会起作用,但并非总是如此,因此可能需要试验不同的超参数(实际上,仅仅更改随机种子并再次训练相同的模型有时会起作用),下面是一个小型DCGAN,在Fashion MNIST数据集上可以很好地工作:

codings_size = 100
generator = keras.models.Sequential([
    keras.layers.Dense(7 * 7 * 128, input_shape=[codings_size]),
    keras.layers.Reshape([7, 7, 128]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding='same', activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding='same', activation='tanh')
])
discriminator = keras.models.Sequential([
    keras.layers.Conv2D(64, kernel_size=5, strides=2, padding='same', activation=keras.layers.LeakyReLU(.2),
                        input_shape=[28, 28, 1]),
    keras.layers.Dropout(.4),
    keras.layers.Conv2D(128, kernel_size=5, strides=2, padding='same', activation=keras.layers.LeakyReLU(.2)),
    keras.layers.Dropout(.4),
    keras.layers.Flatten(),
    keras.layers.Dense(1, activation='sigmoid')
])
gan = keras.models.Sequential([generator, discriminator])

生成器使用大小为100的编码,将其投影到6272维度(77128),对结果进行重构以获得77128张量。该张量被批量归一化后,馈入步幅为2的转置卷积层层,将其从77上采样至1414,将深度从128缩小至64.其结果在此被批量归一化,并馈入另一个步幅为2的转置卷积层,将其从1414上采样到2828,将深度从64减少到1.该层使用tanh激活函数,因此输入范围为-1。因此在训练GAN之前,需要将训练集重新按比例调整为相同的范围。还需要重构形状来添加通道维度:

X_train_all = tf.reshape(X_train_all, [-1, 28, 28, 1]) * 2. - 1.
discriminator.compile(loss='binary_crossentropy', optimizer='rmsprop')
discriminator.trainable = False
gan.compile(loss='binary_crossentropy', optimizer='rmsprop')
dataset = tf.data.Dataset.from_tensor_slices(X_train_all).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
train_gan(gan, dataset, batch_size, codings_size)
coding=tf.random.normal(shape=[batch_size, codings_size])
images = generator(coding).numpy()
fig = plt.figure(figsize=(12, 8))
for image_index in range(batch_size):
    plt.subplot(4, 8, image_index + 1)
    plot_image(images[image_index])

GAN的逐步增长

Nvidia的研究人员Tero Karras等人在2018年的一篇论文中提出了一项重要技术:他们建议在训练开始时生成小图像,然后向生成器和判别器中逐渐添加卷积层以生成越来越大的图像。这种方法类似于堆叠式自动编码器的贪婪分层训练。额外的层添加在生成的末尾和判别器的开始处,并且先前训练过的层仍是可训练的

例如,当生成器的输出从44增长到88时,一个上采样层(使用最近邻滤波)被添加到现有的卷积层中,因此它输出88特征图,然后将其馈送到新的卷积层(使用‘same’填充和1的步幅,因此其输出也为88)。这个新层之后是新的输出卷积层:这是内核为1的常规卷积层,它将输出向下投影到所需数量的颜色通道。为了避免在添加新卷积层时破坏第一个卷积层的训练权重,最终输出是原始输出层和新输出层的加权和。新输出的权重为a,而原始输出的权重是1-a,并且a从0缓慢增加到1.换句话说,新的卷积层逐渐增强,而原始输出层逐渐减弱。在新的卷积层添加到判别器时,使用类似的增强/减弱技术

这篇论文还介绍了其他几种旨在增加输出多样性(避免模式崩溃)和使训练稳定的技术:

  • 小批次标准化层
    在判别器末尾附近添加。对于输入中的每个位置,它计算批次中所有通道和所有实例的标准差(S=tf.math.reduce_std(inputs,axis=[0,-1]))。然后,将这些标准差在所有点上取平均值,得到一个单一值(v=tf.reduce_mean(S))。最后,向批量处理中的每个实例中添加一个额外的特征图,并用计算值填充(tf.concat([inputs,tf.fill([batch_size,height,width,1],v)],axis=-1))。如果生成器生成的图像变化不大,那么判别器中特张图上的标准差就会小很多。多亏了这一层,判别器就可以轻松访问此统计信息,从而减少被生成器(生成很少的多样性)欺骗的可能性。这会鼓励生成器产生不同的输出,从而降低模式崩溃的风险。
  • 均衡的学习率
    使用均值为0和标准差为1的简单高斯分布而不是使用He初始化来初始化所有权重。但是,权重会在运行时(即每次执行层时)按与He初始化相同的系数来缩小:权重除以$\sqrt{2/n_{inputs}}$,其中$n_{inputs}$是层输入的数量。论文表明,当使用RMSProp、Adam或其他自适应梯度优化器时,该技术显著提高了GAN的性能。实际上,这些优化器通过估计的标准差对梯度更新进行归一化,因此动态范围较大的参数需要花费较长的训练时间,而动态范围较小的参数可能更新得太快,从而导致不稳定。通过将权重调整作为模型本身的一部分,而不仅仅是在初始化时进行权重调整,这种方法可确保在整个训练过程中,所有参数的动态范围都相同,因此它们都以相同的速度学习。这加快和稳定了训练过程
  • 像素归一化层
    在生成器中的每个卷积层之后添加。它基于同一图像中和同一位置但所有通道之间的所有激活对每个激活进行归一化(除以均方激活的平方根),在TensorFlow代码中,这是inputs/tf.sqrt(tf.reduce_mean(tf.square(X),axis=-1,keepdims=True)+1e-8)(需要平滑项来避免被零除)。该技术避免了由于生成器和判别器之间过度竞争而导致的激活爆炸

StyleGAN

相同的Nvidia团队在2018年发表的一篇论文中提出了高分辨率图像生成的最新技术,该论文介绍了流行的StyleGAN架构。作者在生成器中使用了风格转换技术,以确保生成的图像在各个尺度上都具有与训练图像相同的局部结构,从而极大地提高了所生成的图像质量。判别器和损失函数没有被修改,仅仅修改了生成器

  • 映射网络
    一个8层的MLP把潜在表示$z$(即编码)映射到向量$w$。然后,通过多个仿射变换(没有激活的Dense层)发送此向量,从而生成多个向量。这些向量从细粒度的纹理(例如头发的颜色)到高级特征(例如成人或儿童)在不同级别上控制生成图像的风格。简而言之,映射网络映射到多个风格向量
  • 合成网络
    负责生成图像。它具有恒定的学习输入(需要明确的是,此输入在训练后将保持不变,但是在训练过程中,它会经过反向传播不断调整)。如前所述,它通过多个卷积和上采样层处理此输入,但是有两处调整:首先,在卷积层的输入和所有输出中添加了一些噪声(在激活函数之前)。其次,每个噪声层后面是一个自适应实例归一化层(AdaIN):它独立地归一化每个特征图(通过减去特征图的均值并处以其标准差),然后使用风格向量确定每个向量图的比例和偏移量(风格向量为每个特征图包含一个比例和一个偏置项)

独立于编码来增加噪声的想法非常重要。图像的某些部分非常随机,例如每个雀斑或头发的确切位置。在较早的GAN中,这种随机性要么来自编码,要么是生成器自身产生的一些伪随机噪声。如果它来自编码,则意味着生成器使用了编码的表征力的很大一部分来存储噪声:这非常浪费。而且,噪声必须能够流经网络并达到生成器的最后一层:这似乎是不必要的约束,可能会减慢训练速度。最后,可能会出现一些人工视觉,因为在不同层次使用了相同的噪声。相反,如果生成器试图自己产生自己的伪随机噪声,该噪声可能看起来不那么令人信服,从而导致更多人工视觉。另外,生成器的权重的一部分用于产生伪随机噪声,这似乎又是浪费的。通过增加额外的噪声输入,可以避免所有这些问题。GAN能够使用所提供的噪声为图像的每个部分添加适当数量的随机性

每个级别增加的噪声都不相同。每个噪声输入由一个充满了高斯噪声的单个特征图组成,该噪声会广播到所有(给定级别的)的特征图,并在添加之前使用学习到的每个特征图的比例因子进行缩放

最后,StyleGAN使用一种称为混和正则化(或风格混合)的技术,使用两种不同的编码生成一定百分比的生成图像。具体来说,编码$c_1$和$c_2$通过映射网络发送,给出两个风格向量$w_1$和$w_2$。然后,合成网络基于第一个级别的风格$w_1$和其余级别的样式$w_2$生成图像。阶段级别是随机选择的。这可以防止网络假设相邻级别的风格是相关联的,这反过来又鼓励了GAN中的局部性,意味着每个风格向量仅影响所生成图像中有限数量的特征

变分自动编码器

Diederik Kingma和Max Welling于2013年推出了自动编码器的另一个重要类别,并迅速成为最受欢迎的自动编码器类型之一:变分自动编码器

它们与目前为止的自动编码器有很大的不同,它们具有以下特殊的地方:

  • 它们是概率自动编码器,这意味着即使在训练后,它们的输出会部分由概率决定(与仅在训练期间使用随机性的去噪自动编码器相反)
  • 它们是生成式自动编码器,这意味着它们可以生成看起来像是从训练集中采样的新实例

这两个属性使得它们与RBM相当类似,但是它们更容易训练,并且采样过程要快得多(使用RMB,需要等到网络稳定到“热平衡”,然后才能采样新实例)。变分自动编码器执行变分贝叶斯推理,这是执行近似贝叶斯推理的有效方法

变分自动编码器不是直接为给定输入生成编码,而是编码器产生平均编码$\mu$和标准差$\sigma$。然后实际编码是从均值$\mu$和标准差$\sigma$的高斯分布中随机采样的。之后解码器正常解码采样到的编码

在训练过程中,成本函数会迫使编码逐渐地在编码空间中内移动,最终看起来像高斯点云。一个很好的结果是,在训练了变分自动编码器之后,可以轻松地生成一个新实例:只需从高斯分布中采样一个随机编码,对其进行解码,然后就伪造出来了

成本函数由两部分组成:第一部分是通常的重构损失,它会迫使自动编码器重现其输入。第二个是潜在损失,它使自动编码器的编码看起像是从简单地高斯分布中采样得到的:它是目标分布(高斯分布)和编码的实际分布之间的KL散度。在数学上比稀疏自动编码器要复杂一些,特别是由于高斯噪声,它限制了可以传输到编码层的信息量(因此迫使自动编码器学习有用的特征)。

变分自动编码器的潜在损失$\mathcal{L}=-\frac12\sum_{i=1}^K1+\log{(\sigma_i^2)}-\sigma_i^2-\mu_i^2$
在这个等式中,$\mathcal{L}$是潜在损失,$n$是编码的维度,$\mu_i$和$\sigma_i$是编码中第$i$个分量的均值和标准差。向量$\mu$和$\sigma$(包含所有$\mu_i$和$\sigma_i$)由编码器输出

变分自动编码器架构的通常调整是使编码器输出$\gamma=\log(\sigma^2)$而不是$\sigma$。然后如下列公式计算潜在损失,这种方法在数值上更稳定,而且可以加快训练速度$$\mathcal{L}=-\frac12\sum_{i=1}^K1+\gamma_i-\exp(\gamma_i)-\mu_i^2$$

下面为Fashion MNIST构建一个变分自动编码器,使用$gamma$调整,首先,给定$\mu$和$\gamma$,需要定义一个自定义层来采样编码

from tensorflow import keras
import tensorflow as tf

K = keras.backend


class Sampling(keras.layers.Layer):
    def call(self, inputs):
        mean, log_var = inputs
        return K.random_normal(tf.shape(log_var)) * K.exp(log_var / 2)

Sampling层接受两个输入:mean$(\mu)$和log_var$(\gamma)$,它使用函数K.random_normal()从正态分布中采样一个均值为0和标准差为1的随机向量(与$\gamma$形状相同),然后将其乘以$\exp(\gamma/2)$(等于$\sigma$),最后将$\mu$加起来并返回结果。该方法从均值$\mu$和标准差$\sigma$的正态分布中采样一个编码向量

接下来使用函数API来创建编码器,因为模型不是完全顺序的:

codings_size = 10

inputs = keras.layers.Input(shape=[28, 28])
z = keras.layers.Flatten()(inputs)
z = keras.layers.Dense(150, activation='gelu')(z)
z = keras.layers.Dense(100, activation='gelu')(z)
codings_mean = keras.layers.Dense(codings_size)(z)
codings_log_var = keras.layers.Dense(codings_size)(z)
codings = Sampling()([codings_mean, codings_log_var])
variational_encoder = keras.Model(inputs=[inputs], outputs=[codings_mean, codings_log_var, codings])

输出codings_mean$(\mu)和codings_log_var$(\gamma)$的Dense层具有相同的形状(第二个Dense输出)。将codings_mean和codings_log_var都传递给Sampling层。最后,如果要检查codings_mean和codings_log_var的值,variational_encoder模型具有三个输出,需要使用的是最后一个codings,现在开始构建解码器:

decoder_inputs = keras.layers.Input(shape=[codings_size])
x = keras.layers.Dense(100, activation='gelu')(decoder_inputs)
x = keras.layers.Dense(150, activation='gelu')(x)
x = keras.layers.Dense(28 * 28, activation='sigmoid')(x)
outputs = keras.layers.Reshape([28, 28])(x)
variational_decoder = keras.Model(inputs=[decoder_inputs], outputs=[outputs])

对于此编码器,可以使用顺序API而不是函数式API,因为它实际上只是一个简单的层堆栈。最后,建立变分自动编码器模型:

_, _, codings = variational_encoder(inputs)
reconstructions = variational_decoder(codings)
variational_ae = keras.Model(inputs=[inputs], outputs=[reconstructions])

最后,必须加上潜在损失和重构损失

latent_loss = -0.5 * K.sum(
    1 + codings_log_var - K.exp(codings_log_var) - K.square(codings_mean), axis=-1
)
variational_ae.add_loss(K.mean(latent_loss) / 784.)
variational_ae.compile(loss='binary_crossentropy', optimizer='rmsprop')

首先应用公式来计算该批次中每个实例的潜在损失(在最后一个轴求和)。然后,计算该批次中所有实例的平均损失,将结果除以784,以确保它和重构损失相比具有合适的比例标度。变分自动编码器的重建损失应该是像素重建误差的总和,但是当Keras计算“binary_crossentropy”损失时,它计算所有784个像素的均值,而不是总和。因此重构损失比需要的少784倍。可以定义一个损失来计算总和而不是平均值,但是把潜在损失除以784更为简单(最终损失要比其应该的小784倍,但这只是意味着需要使用更大一点的学习率)

在这里使用RMSProp优化器,该优化器在这个示例下效果很好,下面训练自动编码器

fashion_mnist = keras.datasets.fashion_mnist
(X_train_all, y_train_all), (X_test, y_test) = fashion_mnist.load_data()
X_valid, X_train = X_train_all[:5000] / 255., X_train_all[5000:] / 255.
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
history = variational_ae.fit(X_train, X_train, epochs=50, batch_size=32, validation_data=(X_valid, X_valid))
Epoch 1/50
1719/1719 [==============================] - 13s 7ms/step - loss: 0.4348 - val_loss: 0.3956
Epoch 2/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3915 - val_loss: 0.3833
Epoch 3/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3822 - val_loss: 0.3752
Epoch 4/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3746 - val_loss: 0.3687
Epoch 5/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3699 - val_loss: 0.3657
Epoch 6/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3672 - val_loss: 0.3646
Epoch 7/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3652 - val_loss: 0.3626
Epoch 8/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3638 - val_loss: 0.3607
Epoch 9/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3627 - val_loss: 0.3598
Epoch 10/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3619 - val_loss: 0.3584
Epoch 11/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3610 - val_loss: 0.3578
Epoch 12/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3603 - val_loss: 0.3577
Epoch 13/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3597 - val_loss: 0.3578
Epoch 14/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3592 - val_loss: 0.3555
Epoch 15/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3587 - val_loss: 0.3564
Epoch 16/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3584 - val_loss: 0.3558
Epoch 17/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3581 - val_loss: 0.3564
Epoch 18/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3581 - val_loss: 0.3561
Epoch 19/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3578 - val_loss: 0.3560
Epoch 20/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3576 - val_loss: 0.3542
Epoch 21/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3574 - val_loss: 0.3548
Epoch 22/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3571 - val_loss: 0.3552
Epoch 23/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3569 - val_loss: 0.3545
Epoch 24/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3567 - val_loss: 0.3555
Epoch 25/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3566 - val_loss: 0.3543
Epoch 26/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3562 - val_loss: 0.3536
Epoch 27/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3561 - val_loss: 0.3563
Epoch 28/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3561 - val_loss: 0.3550
Epoch 29/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3560 - val_loss: 0.3539
Epoch 30/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3562 - val_loss: 0.3539
Epoch 31/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3559 - val_loss: 0.3537
Epoch 32/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3555 - val_loss: 0.3533
Epoch 33/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3555 - val_loss: 0.3524
Epoch 34/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3552 - val_loss: 0.3541
Epoch 35/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3552 - val_loss: 0.3524
Epoch 36/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3552 - val_loss: 0.3531
Epoch 37/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3550 - val_loss: 0.3525
Epoch 38/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3550 - val_loss: 0.3531
Epoch 39/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3550 - val_loss: 0.3532
Epoch 40/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3550 - val_loss: 0.3517
Epoch 41/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3546 - val_loss: 0.3525
Epoch 42/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3546 - val_loss: 0.3523
Epoch 43/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3547 - val_loss: 0.3510
Epoch 44/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3545 - val_loss: 0.3543
Epoch 45/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3546 - val_loss: 0.3532
Epoch 46/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3547 - val_loss: 0.3513
Epoch 47/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3549 - val_loss: 0.3523
Epoch 48/50
1719/1719 [==============================] - 11s 6ms/step - loss: 0.3546 - val_loss: 0.3519
Epoch 49/50
1719/1719 [==============================] - 12s 7ms/step - loss: 0.3549 - val_loss: 0.3510
Epoch 50/50
1719/1719 [==============================] - 11s 7ms/step - loss: 0.3547 - val_loss: 0.3517

生成Fashion MNIST图像

使用变分自动编码器来生成看起来时尚的物品。需要做的就是从高斯分布中采样随机编码并对它们进行解码

coding = tf.random.normal(shape=[12, codings_size])
images = variational_decoder(coding).numpy()
import matplotlib.pyplot as plt


def plot_image(image):
    plt.imshow(image, cmap='binary')
    plt.axis('off')


fig = plt.figure(figsize=(12 * 1.5, 3))
for image_index in range(12):
    plt.subplot(3, 4, image_index + 1)
    plot_image(images[image_index])


可变自动编码器使得语义插值成为可能:可以在编码级别进行插值,而不是在像素级别插值两个图像(看起来好像两个图像被叠加了一样)。首先让两个图像通过编码器,然后对获得的两个编码进行插值,最后对插值的编码进行解码来获得最终图像。它看起来像是常规的Fashion MNIST图像,但它是原始图像之间的中间图像,在下面的代码示例,使用刚刚生成的12个编码器,把它们组织在$3\times4$网格中,然后使用TensorFlow的tf.image.resize()函数将该网格的大小调整为$5\times7$。默认情况下,resize()函数会执行双线性插值,因此每隔一行和一列会包含插值编码。然后,使用解码器生成所有图像:

codings_grid = tf.reshape(coding, [1, 3, 4, codings_size])
larger_grid = tf.image.resize(codings_grid, size=[5, 7])
interpolated_codings = tf.reshape(larger_grid, [-1, codings_size])
images = variational_decoder(interpolated_codings).numpy()
fig = plt.figure(figsize=(6 * 1.5, 6))
for image_index in range(35):
    plt.subplot(5, 7, image_index + 1)
    plot_image(images[image_index])