Как стать автором
Обновить
859.75
OTUS
Цифровые навыки от ведущих экспертов

Создаем GAN с помощью PyTorch

Время на прочтение8 мин
Количество просмотров16K
Автор оригинала: Ta-Ying Cheng

Реалистичные изображения из ничего?

Генеративно-состязательные сети (Generative Adversarial Networks — GAN), предложенные Goodfellow и др. в 2014 году, произвели революцию в области создания изображений в компьютерном зрении — никто не мог поверить, что эти потрясающие живые изображения на самом деле создаются машинами с нуля. И даже больше — люди раньше думали, что задача генерации невозможна, и были поражены мощью GAN, потому что традиционно в этой области просто не существует каких-либо эталонных данных, с которыми мы могли бы сравнить наши сгенерированные изображения.

В этой статье представлена ​​простая идея, лежащая в основе создания GAN, за которой следует реализация сверточной GAN с помощью PyTorch и процедура ее обучения.

Идея, лежащая в основе GAN

В отличие от традиционной задачи классификации, где прогнозы нашей сети можно напрямую сравнить с правильным ответом из эталонных данных, «правильность» сгенерированного изображения трудно определить и измерить. Goodfellow и др. в своей оригинальной статье Generative Adversarial Networks, предложили интересную идею: использовать хорошо обученный классификатор, чтобы различать сгенерированное изображение и реальное изображение. Если у нас есть такой классификатор, мы можем создать и обучить сеть-генератор, пока она не сможет производить изображения, которые могут полностью обмануть классификатор.

Рисунок 1. Конвейер GAN. Изображение создано автором.
Рисунок 1. Конвейер GAN. Изображение создано автором.

GAN является продуктом этой процедуры: она содержит генератор, который генерирует изображение на основе заданного набора данных, и дискриминатор (классификатор), чтобы различать, является ли изображение реальным или сгенерированным. Разбор конвейера GAN можно увидеть на рисунке 1.

Функция потерь

Оптимизировать одновременно и генератор, и дискриминатор сложно, потому что, как вы могли догадаться, две сети преследуют совершенно противоположные цели: генератор хочет создать что-то как можно более реалистичное, а дискриминатор хочет различать сгенерированные материалы.

Чтобы проиллюстрировать это, пусть D(x) будет выходом дискриминатора, который представляет собой вероятность того, что x является реальным изображением, а G(z) будет выходом нашего генератора. Дискриминатор аналогичен бинарному классификатору, поэтому цель дискриминатора — максимизировать функцию:

По сути, это бинарная перекрестная потеря энтропии без отрицательного знака в начале. С другой стороны, целью генератора было бы минимизировать шансы дискриминатора сделать правильное определение, поэтому его целью было бы минимизировать функцию. Следовательно, окончательная функция потерь будет минимаксной игрой между двумя классификаторами, которую можно проиллюстрировать следующим образом:

которая теоретически сходится к дискриминатору, предсказывающему все с вероятностью 0,5.

Однако на практике минимаксная игра часто приводит к тому, что сеть не сходится, поэтому важно тщательно настроить процесс обучения. Гиперпараметры, такие как скорость обучения, значительно более важны при обучении GAN — небольшие изменения могут привести к тому, что GAN будет генерировать один и тот же выходной сигнал независимо от входных шумов.

Вычислительное окружение

Библиотеки

Вся программа построена с помощью библиотеки PyTorch (включая torchvision). Визуализация результатов, сгенерированных GAN, строится с использованием библиотеки Matplotlib. Следующий код импортирует все библиотеки:

"""
Импортируем необходимые библиотеки для создания генеративно-состязательной сети
Код разработан в основном с использованием библиотеки PyTorch
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt

Наборы данных

Наборы данных являются важным аспектом при обучении GAN. Неструктурированная природа изображений подразумевает, что любой данный класс (например, собаки, кошки или рукописные цифры) может иметь распределение возможных данных, и такое распределение в конечном итоге является основой контента, генерируемого GAN.

Для демонстрации GAN в этой статье будет использоваться простейший набор данных MNIST, который содержит 60000 изображений рукописных цифр от 0 до 9. Неструктурированные наборы данных, такие как MNIST, можно найти на Graviti. Это достаточно молодой стартап, который стремится помочь сообществу с неструктурированными наборами данных, и на их платформе уже можно найти одни из лучших общедоступных неструктурированных наборов данных, включая MNIST.

Требования к оборудованию

Обучать нейронную сеть предпочтительнее на графических процессорах, так как они значительно увеличивают скорость обучения. Однако, даже если доступны только обычные процессоры, вы все равно можете протестировать программу. Чтобы позволить вашей программе самой определить оборудование, просто используйте следующее:

"""
Определяем, доступны ли какие-либо графические процессоры
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Реализация

Архитектура сети

Из-за простоты чисел две архитектуры — дискриминатор и генератор — построены из полностью связанных слоев. Обратите внимание, что полностью связанная GAN также иногда легче сходится, чем DCGAN.

Ниже приведены реализации PyTorch обеих архитектур:

"""
Сетевые архитектуры
Ниже приведены архитектуры дискриминатора и генератора
"""
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()
    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, 1, 28, 28)
        return nn.Tanh()(x)

Обучение

При обучении GAN мы оптимизируем результаты дискриминатора и, в то же время, улучшаем наш генератор. Следовательно, будут две потери, которые противоречат друг другу во время каждой итерации их одновременной оптимизации. То, что мы вводим в генератор, — это случайные шумы, а генератор предположительно должен создавать изображения на основе незначительных различий данных шумов:

"""
Процедура обучения сети.
Каждый шаг потери обновляется как для дискиминатора, так и для генератора.
Дискриминатор стремится классифицировать реальные и fakes
Генератор стремится генерировать как можно более реалистичные изображения
"""
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1
        # Обучаем дискриминатор
        # real_inputs - изображения из набора данных MNIST 
        # fake_inputs - изображения от генератора
        # real_inputs должны быть классифицированы как 1, а fake_inputs - как 0
        real_inputs = imgs.to(device)
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)
        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_label, fake_label), 0)
        D_loss = loss(outputs, targets)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()
        # Обучаем генератор
        # Цель генератора получить от дискриминатора 1 по всем изображениям
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
    if (epoch+1) % 10 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')

Результаты

После 100 эпох мы можем построить наборы данных и увидеть результаты — цифры сгенерированные из случайных шумов:

Рисунок 2. Результаты, генерируемые GAN. Изображение создано автором.
Рисунок 2. Результаты, генерируемые GAN. Изображение создано автором.

Как вы можете увидеть выше, полученные результаты действительно похожи на настоящие. Учитывая, что сети довольно просты, результаты действительно кажутся многообещающими!

Помимо создания контента 

GAN сильно отличалось от предыдущих работ в области компьютерного зрения. Последовавшие за этим многочисленные применения удивили академическое сообщество тем, на что способны глубокие сети. Некоторые удивительные работы описаны ниже.

CycleGAN

CycleGAN, автор Zhu и др. вводит концепцию, которая переводит изображение из домена X в домен Y без необходимости парных выборок. Результаты CycleGAN были удивительными и точными: лошади превратились в зебр, а летнее солнце превратилось в снежную бурю.

Рисунок 3. Результаты CycleGAN, представленные Zhu и др. Изображение получено с их страницы на Github.
Рисунок 3. Результаты CycleGAN, представленные Zhu и др. Изображение получено с их страницы на Github.

GauGAN

Nvidia применила GAN для преобразования примитивных рисунков в элегантные и реалистичные фотографии на основе семантики кистей. Хотя обучающие ресурсы были дорогостоящими в вычислительном отношении, они создают совершенно новую область исследований и применений.

Рисунок 3. Результаты сгенерированные GaoGAN. Слева - исходный рисунок, справа - сгенерированный результат. Изображение создано автором.
Рисунок 3. Результаты сгенерированные GaoGAN. Слева - исходный рисунок, справа - сгенерированный результат. Изображение создано автором.

AdvGAN

Также существует расширение генеративно-состязательных сетей, которое очищает состязательные изображения и преобразовывает их в чистые примеры, которые не обманывают классификации. Более подробную информацию о состязательных атаках и защите можно найти здесь.

Заключение

На этом все! Надеюсь, в этой статье представлен вполне достаточный обзор для понимания того, как самостоятельно построить GAN. Полную реализацию можно найти в этом Github репозитории: https://github.com/ttchengab/MnistGAN


Материал подготовлен в рамках курса "Компьютерное зрение". Если вам интересно узнать больше о формате обучения и программе, познакомиться с преподавателем — приглашаем на день открытых дверей онлайн. Регистрация здесь.

Теги:
Хабы:
+10
Комментарии1

Публикации

Информация

Сайт
otus.ru
Дата регистрации
Дата основания
Численность
101–200 человек
Местоположение
Россия
Представитель
OTUS