当前位置: 代码迷 >> 综合 >> TensorFlow 14——ch11-CycleGAN 与 非配对图像转换
  详细解决方案

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

热度:100   发布时间:2023-09-26 21:25:18.0

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换
代码:https://github.com/MONI-JUAN/Tensorflow_Study/tree/master/TensorFlowch11-CycleGAN与非配对图像转换

数据集 百度云链接:https://pan.baidu.com/s/1U78OTOAi0bUuJ-bGhtWZ5A
提取码:1r1f

目录

    • 一、CycleGAN 原理
      • 1.损失函数
      • 2.定义
        • 模型
        • 循环一致性损失
        • 生成器损失
        • 判别器损失
        • 调用损失
    • 二、苹果到橘子
      • 1.下载数据集
      • 2.转换成 tfrecords 格式
      • 3.训练模型
      • 4.查看训练情况
      • 5.导出模型
      • 6.测试模型
    • 三、男性到女性
      • 1.下载数据集
      • 2.转换成 tfrecords 格式
      • 3.训练模型
      • 4.查看训练情况
    • 四、书上的一些效果展示

一、CycleGAN 原理

1.损失函数

CycleGAN 与 pix2pix 的不同点在于,可以利用不成对数据训练出从 X 空间到 Y 空间的映射。例如用大量照片和油画图片可以学习到如何把照片转成油画。

算法的目标是学习从空间 X 到空间 Y 的映射,设这个映射为 F。 对应着 GAN 中的生成器, F可以将 X 中的图片 x 转为 Y 中的图片 F(x)。对于生成的图片,还需要 GAN 中的判别器来判别它是否为真实图片,由此构成对抗生成网络。设这个判别器为Dy。根据生成器和判别器可以构造 GAN 的损失了 ,该损失和原始 GAN 中的损失的形式是相同的:
LGAN(F,Dy,X,Y)=Ey?Pdata(y)[ln?DY(y)]+Ex?pdata(x)[ln?(1?DY(F(x)))]L_{GAN}(F,D_{y},X,Y)=E_{y \sim P_{\text {data}}(y)}[\ln D_{Y}(y)]+E_{x \sim p_{\text {data}}(x)}[\ln (1-D_{Y}(F(x)))] LGAN?(F,Dy?,X,Y)=Ey?Pdata?(y)?[lnDY?(y)]+Ex?pdata?(x)?[ln(1?DY?(F(x)))]
但只使用这一个损失是无法进行训练的。原因在于没再成对数据,映射 F可以将所有 x 都映射为 Y 空间中的同一张图片,使损失无效化。对此,作者又提出了所谓的“循环一致性损失”( cycle consistency loss )。让再假设一个映射 G,它可以将 Y 空间中的图片y 转换为 X 中的图片 G(y)。CycleGAN 同时学习 F 和 G 两个映射,并要求 F(G(y))≈yF(G(y))\approx yF(G(y))y,以及G(F(x))≈xG(F(x))\approx xG(F(x))x。 也是说,将 x 的图片转换到 Y 空间后,应该还可以转换回来。这样可以杜绝模型把所高 X 的图片都转换为 Y 空间中的同一张图片。

根据 F(G(y))≈yF(G(y))\approx yF(G(y))yG(F(x))≈xG(F(x))\approx xG(F(x))x ,循环一致性损失定义为:
Lcyc(F,G,X,Y)=Ex?pdata(x)[∥G(F(x))?x∥1]+Ey?pdata(y)[∥F(G(y))?y∥2]L_{cyc}(F,G,X,Y)=E_{x \sim p_{\text {data}}(x)}[{\left\| G(F(x))- x \right\|_1}]+E_{y \sim p_{\text {data}}(y)}[{\left\| F(G(y))- y \right\|_2}] Lcyc?(F,G,X,Y)=Ex?pdata?(x)?[G(F(x))?x1?]+Ey?pdata?(y)?[F(G(y))?y2?]
同时,为 G 也引入一个判别器Dx,由此可以同样定义一个 GAN 损失LGAN(G,Dx,X,Y)L_{GAN}(G,D_{x},X,Y)LGAN?(G,Dx?,X,Y), ,最终的损失由三部分组成:
L=LGAN(F,Dy,X,Y)+LGAN(F,Dx,X,Y)+λLcyc(F,G,X,Y)L=L_{GAN}(F,D_{y},X,Y)+L_{GAN}(F,D_{x},X,Y) +\lambda L_{cyc}(F,G,X,Y) L=LGAN?(F,Dy?,X,Y)+LGAN?(F,Dx?,X,Y)+λLcyc?(F,G,X,Y)
CycleGAN 的主要想法是上述的“循环一致性损失”,利用这个损失,可以巧妙地处理 X 空间和 Y 空间训练样本不一一配对的问题。

2.定义

模型

def model(self):# 读入X空间和Y空间的数据,保存到 x 和 y 中X_reader = Reader(self.X_train_file, name='X',image_size=self.image_size, batch_size=self.batch_size)Y_reader = Reader(self.Y_train_file, name='Y',image_size=self.image_size, batch_size=self.batch_size)x = X_reader.feed()y = Y_reader.feed()# 定义循环一致性损失:G:X->Y,F:Y->Xcycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y)# G: X -> Yfake_y = self.G(x)G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan) # G生成图片的lossG_loss =  G_gan_loss + cycle_lossD_Y_loss = self.discriminator_loss(self.D_Y, y, self.fake_y, use_lsgan=self.use_lsgan) # Y空间判别器的损失# F: Y -> Xfake_x = self.F(y)F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan) # F生成图片的lossF_loss = F_gan_loss + cycle_lossD_X_loss = self.discriminator_loss(self.D_X, x, self.fake_x, use_lsgan=self.use_lsgan) # Y空间的损失

循环一致性损失

def cycle_consistency_loss(self, G, F, x, y):"""cycle consistency loss (L1 norm)循环一致性损失"""forward_loss = tf.reduce_mean(tf.abs(F(G(x))-x))backward_loss = tf.reduce_mean(tf.abs(G(F(y))-y))loss = self.lambda1*forward_loss + self.lambda2*backward_lossreturn loss

生成器损失

def generator_loss(self, D, fake_y, use_lsgan=True):""" fool discriminator into believing that G(x) is real生成器损失"""if use_lsgan:# use mean squared errorloss = tf.reduce_mean(tf.squared_difference(D(fake_y), REAL_LABEL))else:# heuristic, non-saturating lossloss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2return loss

判别器损失

def discriminator_loss(self, D, y, fake_y, use_lsgan=True):""" Note: default: D(y).shape == (batch_size,5,5,1),fake_buffer_size=50, batch_size=1Args:G: generator objectD: discriminator objecty: 4D tensor (batch_size, image_size, image_size, 3)Returns:loss: scalar判别器损失"""if use_lsgan:# use mean squared errorerror_real = tf.reduce_mean(tf.squared_difference(D(y), REAL_LABEL))error_fake = tf.reduce_mean(tf.square(D(fake_y)))else:# use cross entropyerror_real = -tf.reduce_mean(ops.safe_log(D(y)))error_fake = -tf.reduce_mean(ops.safe_log(1-D(fake_y)))loss = (error_real + error_fake) / 2return loss

调用损失

optimize():

# G_loss、F_loss、D_Y_loss、D_X_loss
G_optimizer = make_optimizer(G_loss, self.G.variables, name='Adam_G')
D_Y_optimizer = make_optimizer(D_Y_loss, self.D_Y.variables, name='Adam_D_Y')
F_optimizer =  make_optimizer(F_loss, self.F.variables, name='Adam_F')
D_X_optimizer = make_optimizer(D_X_loss, self.D_X.variables, name='Adam_D_X')with tf.control_dependencies([G_optimizer, D_Y_optimizer, F_optimizer, D_X_optimizer]):
return tf.no_op(name='optimizers')

二、苹果到橘子

1.下载数据集

下载数据集 apple2orange.zip

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

网页:https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/

或者用脚本:

bash download_dataset.sh apple2orange

或者百度云:

链接:https://pan.baidu.com/s/1U78OTOAi0bUuJ-bGhtWZ5A
提取码:1r1f

在目录创建 data 文件夹,把 apple2orange 放进去

2.转换成 tfrecords 格式

python build_data.py \--X_input_dir data/apple2orange/trainA \--Y_input_dir data/apple2orange/trainB \--X_output_file data/tfrecords/apple.tfrecords \--Y_output_file data/tfrecords/orange.tfrecords
python build_data.py --X_input_dir data/apple2orange/trainA --Y_input_dir data/apple2orange/trainB --X_output_file data/tfrecords/apple.tfrecords --Y_output_file data/tfrecords/orange.tfrecords

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

3.训练模型

python train.py \--X data/tfrecords/apple.tfrecords \--Y data/tfrecords/orange.tfrecords \--image_size 256
python train.py --X data/tfrecords/apple.tfrecords --Y data/tfrecords/orange.tfrecords --image_size 256

4.查看训练情况

后面的路径改成自己的

tensorboard --logdir checkpoints/20200912-1241

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换
花了两个小时才step100,运行之前没发现step每隔100存一下,导致模型一直看不到情况,其实可以设置小一点。

if step % 100 == 0:train_writer.add_summary(summary, step)train_writer.flush()if step % 100 == 0:logging.info('-----------Step %d:-------------' % step)logging.info(' G_loss : {}'.format(G_loss_val))logging.info(' D_Y_loss : {}'.format(D_Y_loss_val))logging.info(' F_loss : {}'.format(F_loss_val))logging.info(' D_X_loss : {}'.format(D_X_loss_val))if step % 10000 == 0:save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)logging.info("Model saved in file: %s" % save_path)step += 1

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

5.导出模型

后面的路径改成自己的

python export_graph.py \--checkpoint_dir checkpoints/20200912-1241 \--XtoY_model apple2orange.pb \--YtoX_model orange2apple.pb \--image_size 256
python export_graph.py --checkpoint_dir checkpoints/20200912-1241 --XtoY_model apple2orange.pb --YtoX_model orange2apple.pb --image_size 256

会在 生成两个 文件夹生成 apple2orange.pborange2apple.pb两个模型

6.测试模型

python inference.py \--model pretrained/apple2orange.pb \--input data/apple2orange/testA/n07740461_1661.jpg \--output data/apple2orange/output_sample.jpg \--image_size 256
python inference.py --model pretrained/apple2orange.pb --input data/apple2orange/testA/n07740461_1661.jpg --output data/apple2orange/output_sample.jpg --image_size 256

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

刚刚100步的结果。。。[老爷爷看手机]

三、男性到女性

步骤比较类似,其他的数据集也是只要分到两个文件夹,一个是X一个是Y即可。

1.下载数据集

链接:https://pan.baidu.com/s/1U78OTOAi0bUuJ-bGhtWZ5A

2.转换成 tfrecords 格式

python build_data.py \--X_input_dir data/man2woman/a_resized/ \--Y_input_dir data/man2woman/b_resized/ \--X_output_file data/man2woman/man.tfrecords \--Y_output_file data/man2woman/woman.tfrecords
python build_data.py --X_input_dir data/man2woman/a_resized/ --Y_input_dir data/man2woman/b_resized/ --X_output_file data/man2woman/man.tfrecords --Y_output_file data/man2woman/woman.tfrecords

3.训练模型

python train.py \--X data/man2woman/man.tfrecords \--Y data/man2woman/woman.tfrecords \--image_size 256
python train.py --X data/man2woman/man.tfrecords --Y data/man2woman/woman.tfrecords --image_size 256

4.查看训练情况

后面的路径改成自己的

tensorboard --logdir checkpoints/xxxxxxxxxxx

四、书上的一些效果展示

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

TensorFlow 14——ch11-CycleGAN 与 非配对图像转换
TensorFlow 14——ch11-CycleGAN 与 非配对图像转换

  相关解决方案