{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "lEMI1oxe8KNy"
},
"source": [
"Updated 21/Nov/2021 by Yoshihisa Nitta "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X4_Pjeum8NUa"
},
"source": [
"\n",
"# Further Training of Variational Auto Encoder for CelebA dataset with Tensorflow 2 on Google Colab\n",
"\n",
"Train Variational Auto Encoder further on CelebA dataset.\n",
"It is assumed that it is in the state after executing VAE_CelebA_Train.ipynb.\n",
"\n",
"## CelebA データセットに対して Variational Auto Encoder をGoogle Colab 上の Tensorflow 2 で追加学習する\n",
"\n",
"CelebA データセットに対して変分オートエンコーダをさらに学習させる。\n",
"VAE_CelebA_Train.ipynb を実行した後の状態であることを前提としている。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"executionInfo": {
"elapsed": 270,
"status": "ok",
"timestamp": 1637505886000,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "CnbfjOX_7wEa"
},
"outputs": [],
"source": [
"#! pip install tensorflow==2.7.0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2260,
"status": "ok",
"timestamp": 1637505895914,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "woOXJdh57sIx",
"outputId": "7d22b67d-2269-4366-b102-b2e16ed4a396"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7.0\n"
]
}
],
"source": [
"%tensorflow_version 2.x\n",
"\n",
"import tensorflow as tf\n",
"print(tf.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bXj23n8r9Tac"
},
"source": [
"# Check the Google Colab runtime environment\n",
"\n",
"## Google Colab 実行環境を調べる"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 659,
"status": "ok",
"timestamp": 1637505915895,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "4xRE6QCs9QO1",
"outputId": "5ece69df-eb07-4802-eca0-0de3b00b986a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sun Nov 21 14:45:15 2021 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 495.44 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 33C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n",
"processor\t: 0\n",
"vendor_id\t: GenuineIntel\n",
"cpu family\t: 6\n",
"model\t\t: 79\n",
"model name\t: Intel(R) Xeon(R) CPU @ 2.20GHz\n",
"stepping\t: 0\n",
"microcode\t: 0x1\n",
"cpu MHz\t\t: 2199.998\n",
"cache size\t: 56320 KB\n",
"physical id\t: 0\n",
"siblings\t: 2\n",
"core id\t\t: 0\n",
"cpu cores\t: 1\n",
"apicid\t\t: 0\n",
"initial apicid\t: 0\n",
"fpu\t\t: yes\n",
"fpu_exception\t: yes\n",
"cpuid level\t: 13\n",
"wp\t\t: yes\n",
"flags\t\t: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities\n",
"bugs\t\t: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa\n",
"bogomips\t: 4399.99\n",
"clflush size\t: 64\n",
"cache_alignment\t: 64\n",
"address sizes\t: 46 bits physical, 48 bits virtual\n",
"power management:\n",
"\n",
"processor\t: 1\n",
"vendor_id\t: GenuineIntel\n",
"cpu family\t: 6\n",
"model\t\t: 79\n",
"model name\t: Intel(R) Xeon(R) CPU @ 2.20GHz\n",
"stepping\t: 0\n",
"microcode\t: 0x1\n",
"cpu MHz\t\t: 2199.998\n",
"cache size\t: 56320 KB\n",
"physical id\t: 0\n",
"siblings\t: 2\n",
"core id\t\t: 0\n",
"cpu cores\t: 1\n",
"apicid\t\t: 1\n",
"initial apicid\t: 1\n",
"fpu\t\t: yes\n",
"fpu_exception\t: yes\n",
"cpuid level\t: 13\n",
"wp\t\t: yes\n",
"flags\t\t: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities\n",
"bugs\t\t: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa\n",
"bogomips\t: 4399.99\n",
"clflush size\t: 64\n",
"cache_alignment\t: 64\n",
"address sizes\t: 46 bits physical, 48 bits virtual\n",
"power management:\n",
"\n",
"Ubuntu 18.04.5 LTS \\n \\l\n",
"\n",
" total used free shared buff/cache available\n",
"Mem: 12G 733M 9G 1.2M 2.0G 11G\n",
"Swap: 0B 0B 0B\n"
]
}
],
"source": [
"! nvidia-smi\n",
"! cat /proc/cpuinfo\n",
"! cat /etc/issue\n",
"! free -h"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zeGrymCg9ZtL"
},
"source": [
"# Mount Google Drive from Google Colab\n",
"\n",
"## Google Colab から GoogleDrive をマウントする"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 24864,
"status": "ok",
"timestamp": 1637505975236,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "9B4MX6GC9Vf9",
"outputId": "bcf572f3-bb59-4a1f-f51c-bb315fd77731"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1637505975237,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "P4voAIIh9aiO",
"outputId": "a7621009-b46a-4e31-990b-daf067d60055"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MyDrive Shareddrives\n"
]
}
],
"source": [
"! ls /content/drive"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cED1p2U998IE"
},
"source": [
"# Download source file from Google Drive or nw.tsuda.ac.jp\n",
"\n",
"Basically, gdown
from Google Drive.\n",
"Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.\n",
"\n",
"# Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする\n",
"\n",
"基本的に Google Drive から gdown
してください。\n",
"Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2707,
"status": "ok",
"timestamp": 1637506093686,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "K13qk7Td9mH_",
"outputId": "c24b8044-e953-48a7-8133-630c28eb58d7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3\n",
"To: /content/nw/VariationalAutoEncoder.py\n",
"\r",
" 0% 0.00/18.7k [00:00, ?B/s]\r",
"100% 18.7k/18.7k [00:00<00:00, 16.3MB/s]\n"
]
}
],
"source": [
"# Download source file\n",
"nw_path = './nw'\n",
"! rm -rf {nw_path}\n",
"! mkdir -p {nw_path}\n",
"\n",
"if True: # from Google Drive\n",
" url_model = 'https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3'\n",
" ! (cd {nw_path}; gdown {url_model})\n",
"else: # from nw.tsuda.ac.jp\n",
" URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'\n",
" url_model = f'{URL_NW}/models/VariationalAutoEncoder.py'\n",
" ! wget -nd {url_model} -P {nw_path}"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 244,
"status": "ok",
"timestamp": 1637506102894,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "WmOyk35j-AZ7",
"outputId": "c3dd811f-3888-4b94-da39-98454e895051"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import os\n",
"import pickle\n",
"import datetime\n",
"\n",
"class Sampling(tf.keras.layers.Layer):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
"\n",
" def call(self, inputs):\n",
" mu, log_var = inputs\n",
" epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mu), mean=0., stddev=1.)\n",
" return mu + tf.keras.backend.exp(log_var / 2) * epsilon\n",
"\n",
"\n",
"class VAEModel(tf.keras.models.Model):\n",
" def __init__(self, encoder, decoder, r_loss_factor, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.r_loss_factor = r_loss_factor\n",
"\n",
"\n",
" @tf.function\n",
" def loss_fn(self, x):\n",
" z_mean, z_log_var, z = self.encoder(x)\n",
" reconstruction = self.decoder(z)\n",
" reconstruction_loss = tf.reduce_mean(\n",
" tf.square(x - reconstruction), axis=[1,2,3]\n",
" ) * self.r_loss_factor\n",
" kl_loss = tf.reduce_sum(\n",
" 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),\n",
" axis = 1\n",
" ) * (-0.5)\n",
" total_loss = reconstruction_loss + kl_loss\n",
" return total_loss, reconstruction_loss, kl_loss\n",
"\n",
"\n",
" @tf.function\n",
" def compute_loss_and_grads(self, x):\n",
" with tf.GradientTape() as tape:\n",
" total_loss, reconstruction_loss, kl_loss = self.loss_fn(x)\n",
" grads = tape.gradient(total_loss, self.trainable_weights)\n",
" return total_loss, reconstruction_loss, kl_loss, grads\n",
"\n",
"\n",
" def train_step(self, data):\n",
" if isinstance(data, tuple):\n",
" data = data[0]\n",
" total_loss, reconstruction_loss, kl_loss, grads = self.compute_loss_and_grads(data)\n",
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
" return {\n",
" \"loss\": tf.math.reduce_mean(total_loss),\n",
" \"reconstruction_loss\": tf.math.reduce_mean(reconstruction_loss),\n",
" \"kl_loss\": tf.math.reduce_mean(kl_loss),\n",
" }\n",
"\n",
" def call(self,inputs):\n",
" _, _, z = self.encoder(inputs)\n",
" return self.decoder(z)\n",
"\n",
"\n",
"class VariationalAutoEncoder():\n",
" def __init__(self, \n",
" input_dim,\n",
" encoder_conv_filters,\n",
" encoder_conv_kernel_size,\n",
" encoder_conv_strides,\n",
" decoder_conv_t_filters,\n",
" decoder_conv_t_kernel_size,\n",
" decoder_conv_t_strides,\n",
" z_dim,\n",
" r_loss_factor, ### added\n",
" use_batch_norm = False,\n",
" use_dropout = False,\n",
" epoch = 0\n",
" ):\n",
" self.name = 'variational_autoencoder'\n",
" self.input_dim = input_dim\n",
" self.encoder_conv_filters = encoder_conv_filters\n",
" self.encoder_conv_kernel_size = encoder_conv_kernel_size\n",
" self.encoder_conv_strides = encoder_conv_strides\n",
" self.decoder_conv_t_filters = decoder_conv_t_filters\n",
" self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size\n",
" self.decoder_conv_t_strides = decoder_conv_t_strides\n",
" self.z_dim = z_dim\n",
" self.r_loss_factor = r_loss_factor ### added\n",
" \n",
" self.use_batch_norm = use_batch_norm\n",
" self.use_dropout = use_dropout\n",
"\n",
" self.epoch = epoch\n",
" \n",
" self.n_layers_encoder = len(encoder_conv_filters)\n",
" self.n_layers_decoder = len(decoder_conv_t_filters)\n",
" \n",
" self._build()\n",
" \n",
"\n",
" def _build(self):\n",
" ### THE ENCODER\n",
" encoder_input = tf.keras.layers.Input(shape=self.input_dim, name='encoder_input')\n",
" x = encoder_input\n",
" \n",
" for i in range(self.n_layers_encoder):\n",
" x = conv_layer = tf.keras.layers.Conv2D(\n",
" filters = self.encoder_conv_filters[i],\n",
" kernel_size = self.encoder_conv_kernel_size[i],\n",
" strides = self.encoder_conv_strides[i],\n",
" padding = 'same',\n",
" name = 'encoder_conv_' + str(i)\n",
" )(x)\n",
"\n",
" if self.use_batch_norm: ### The order of layers is opposite to AutoEncoder\n",
" x = tf.keras.layers.BatchNormalization()(x) ### AE: LeakyReLU -> BatchNorm\n",
" x = tf.keras.layers.LeakyReLU()(x) ### VAE: BatchNorm -> LeakyReLU\n",
" \n",
" if self.use_dropout:\n",
" x = tf.keras.layers.Dropout(rate = 0.25)(x)\n",
" \n",
" shape_before_flattening = tf.keras.backend.int_shape(x)[1:]\n",
" \n",
" x = tf.keras.layers.Flatten()(x)\n",
" \n",
" self.mu = tf.keras.layers.Dense(self.z_dim, name='mu')(x)\n",
" self.log_var = tf.keras.layers.Dense(self.z_dim, name='log_var')(x) \n",
" self.z = Sampling(name='encoder_output')([self.mu, self.log_var])\n",
" \n",
" self.encoder = tf.keras.models.Model(encoder_input, [self.mu, self.log_var, self.z], name='encoder')\n",
" \n",
" \n",
" ### THE DECODER\n",
" decoder_input = tf.keras.layers.Input(shape=(self.z_dim,), name='decoder_input')\n",
" x = decoder_input\n",
" x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(x)\n",
" x = tf.keras.layers.Reshape(shape_before_flattening)(x)\n",
" \n",
" for i in range(self.n_layers_decoder):\n",
" x = conv_t_layer = tf.keras.layers.Conv2DTranspose(\n",
" filters = self.decoder_conv_t_filters[i],\n",
" kernel_size = self.decoder_conv_t_kernel_size[i],\n",
" strides = self.decoder_conv_t_strides[i],\n",
" padding = 'same',\n",
" name = 'decoder_conv_t_' + str(i)\n",
" )(x)\n",
" \n",
" if i < self.n_layers_decoder - 1:\n",
" if self.use_batch_norm: ### The order of layers is opposite to AutoEncoder\n",
" x = tf.keras.layers.BatchNormalization()(x) ### AE: LeakyReLU -> BatchNorm\n",
" x = tf.keras.layers.LeakyReLU()(x) ### VAE: BatchNorm -> LeakyReLU \n",
" if self.use_dropout:\n",
" x = tf.keras.layers.Dropout(rate=0.25)(x)\n",
" else:\n",
" x = tf.keras.layers.Activation('sigmoid')(x)\n",
" \n",
" decoder_output = x\n",
" self.decoder = tf.keras.models.Model(decoder_input, decoder_output, name='decoder') ### added (name)\n",
" \n",
" ### THE FULL AUTOENCODER\n",
" self.model = VAEModel(self.encoder, self.decoder, self.r_loss_factor)\n",
" \n",
" \n",
" def save(self, folder):\n",
" self.save_params(os.path.join(folder, 'params.pkl'))\n",
" self.save_weights(folder)\n",
"\n",
"\n",
" @staticmethod\n",
" def load(folder, epoch=None): # VariationalAutoEncoder.load(folder)\n",
" params = VariationalAutoEncoder.load_params(os.path.join(folder, 'params.pkl'))\n",
" VAE = VariationalAutoEncoder(*params)\n",
" if epoch is None:\n",
" VAE.load_weights(folder)\n",
" else:\n",
" VAE.load_weights(folder, epoch-1)\n",
" VAE.epoch = epoch\n",
" return VAE\n",
"\n",
" \n",
" def save_params(self, filepath):\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" with open(filepath, 'wb') as f:\n",
" pickle.dump([\n",
" self.input_dim,\n",
" self.encoder_conv_filters,\n",
" self.encoder_conv_kernel_size,\n",
" self.encoder_conv_strides,\n",
" self.decoder_conv_t_filters,\n",
" self.decoder_conv_t_kernel_size,\n",
" self.decoder_conv_t_strides,\n",
" self.z_dim,\n",
" self.r_loss_factor,\n",
" self.use_batch_norm,\n",
" self.use_dropout,\n",
" self.epoch\n",
" ], f)\n",
"\n",
"\n",
" @staticmethod\n",
" def load_params(filepath):\n",
" with open(filepath, 'rb') as f:\n",
" params = pickle.load(f)\n",
" return params\n",
"\n",
"\n",
" def save_weights(self, folder, epoch=None):\n",
" if epoch is None:\n",
" self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights.h5'))\n",
" self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights.h5'))\n",
" else:\n",
" self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))\n",
" self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))\n",
"\n",
"\n",
" def save_model_weights(self, model, filepath):\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" model.save_weights(filepath)\n",
"\n",
"\n",
" def load_weights(self, folder, epoch=None):\n",
" if epoch is None:\n",
" self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights.h5'))\n",
" self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights.h5'))\n",
" else:\n",
" self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))\n",
" self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))\n",
"\n",
"\n",
" def save_images(self, imgs, filepath):\n",
" z_mean, z_log_var, z = self.encoder.predict(imgs)\n",
" reconst_imgs = self.decoder.predict(z)\n",
" txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z ]\n",
" AutoEncoder.showImages(imgs, reconst_imgs, txts, 1.4, 1.4, 0.5, filepath)\n",
" \n",
"\n",
" def compile(self, learning_rate):\n",
" self.learning_rate = learning_rate\n",
" optimizer = tf.keras.optimizers.Adam(lr=learning_rate)\n",
" self.model.compile(optimizer=optimizer) # CAUTION!!!: loss(y_true, y_pred) function is not specified.\n",
" \n",
" \n",
" def train_with_fit(\n",
" self,\n",
" x_train,\n",
" batch_size,\n",
" epochs,\n",
" run_folder='run/'\n",
" ):\n",
" history = self.model.fit(\n",
" x_train,\n",
" x_train,\n",
" batch_size = batch_size,\n",
" shuffle=True,\n",
" initial_epoch = self.epoch,\n",
" epochs = epochs\n",
" )\n",
" if (self.epoch < epochs):\n",
" self.epoch = epochs\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
" \n",
" return history\n",
"\n",
"\n",
" def train_generator_with_fit(\n",
" self,\n",
" data_flow,\n",
" epochs,\n",
" run_folder='run/'\n",
" ):\n",
" history = self.model.fit(\n",
" data_flow,\n",
" initial_epoch = self.epoch,\n",
" epochs = epochs\n",
" )\n",
" if (self.epoch < epochs):\n",
" self.epoch = epochs\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
" \n",
" return history\n",
"\n",
"\n",
" def train_tf(\n",
" self,\n",
" x_train,\n",
" batch_size = 32,\n",
" epochs = 10,\n",
" shuffle = False,\n",
" run_folder = 'run/',\n",
" optimizer = None,\n",
" save_epoch_interval = 100,\n",
" validation_data = None\n",
" ):\n",
" start_time = datetime.datetime.now()\n",
" steps = x_train.shape[0] // batch_size\n",
"\n",
" total_losses = []\n",
" reconstruction_losses = []\n",
" kl_losses = []\n",
"\n",
" val_total_losses = []\n",
" val_reconstruction_losses = []\n",
" val_kl_losses = []\n",
"\n",
" for epoch in range(self.epoch, epochs):\n",
" epoch_loss = 0\n",
" indices = tf.range(x_train.shape[0], dtype=tf.int32)\n",
" if shuffle:\n",
" indices = tf.random.shuffle(indices)\n",
" x_ = x_train[indices]\n",
"\n",
" step_total_losses = []\n",
" step_reconstruction_losses = []\n",
" step_kl_losses = []\n",
" for step in range(steps):\n",
" start = batch_size * step\n",
" end = start + batch_size\n",
"\n",
" total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x_[start:end])\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_weights))\n",
" \n",
" step_total_losses.append(np.mean(total_loss))\n",
" step_reconstruction_losses.append(np.mean(reconstruction_loss))\n",
" step_kl_losses.append(np.mean(kl_loss))\n",
" \n",
" epoch_total_loss = np.mean(step_total_losses)\n",
" epoch_reconstruction_loss = np.mean(step_reconstruction_losses)\n",
" epoch_kl_loss = np.mean(step_kl_losses)\n",
"\n",
" total_losses.append(epoch_total_loss)\n",
" reconstruction_losses.append(epoch_reconstruction_loss)\n",
" kl_losses.append(epoch_kl_loss)\n",
"\n",
" val_str = ''\n",
" if not validation_data is None:\n",
" x_val = validation_data\n",
" tl, rl, kl = self.model.loss_fn(x_val)\n",
" val_tl = np.mean(tl)\n",
" val_rl = np.mean(rl)\n",
" val_kl = np.mean(kl)\n",
" val_total_losses.append(val_tl)\n",
" val_reconstruction_losses.append(val_rl)\n",
" val_kl_losses.append(val_kl)\n",
" val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '\n",
"\n",
" if (epoch+1) % save_epoch_interval == 0 and run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch)\n",
"\n",
" elapsed_time = datetime.datetime.now() - start_time\n",
" print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')\n",
"\n",
" self.epoch += 1\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
"\n",
" dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }\n",
" if not validation_data is None:\n",
" dic['val_loss'] = val_total_losses\n",
" dic['val_reconstruction_loss'] = val_reconstruction_losses\n",
" dic['val_kl_loss'] = val_kl_losses\n",
"\n",
" return dic\n",
" \n",
"\n",
" def train_tf_generator(\n",
" self,\n",
" data_flow,\n",
" epochs = 10,\n",
" run_folder = 'run/',\n",
" optimizer = None,\n",
" save_epoch_interval = 100,\n",
" validation_data_flow = None\n",
" ):\n",
" start_time = datetime.datetime.now()\n",
" steps = len(data_flow)\n",
"\n",
" total_losses = []\n",
" reconstruction_losses = []\n",
" kl_losses = []\n",
"\n",
" val_total_losses = []\n",
" val_reconstruction_losses = []\n",
" val_kl_losses = []\n",
"\n",
" for epoch in range(self.epoch, epochs):\n",
" epoch_loss = 0\n",
"\n",
" step_total_losses = []\n",
" step_reconstruction_losses = []\n",
" step_kl_losses = []\n",
"\n",
" for step in range(steps):\n",
" x, _ = next(data_flow)\n",
"\n",
" total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x)\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_weights))\n",
" \n",
" step_total_losses.append(np.mean(total_loss))\n",
" step_reconstruction_losses.append(np.mean(reconstruction_loss))\n",
" step_kl_losses.append(np.mean(kl_loss))\n",
" \n",
" epoch_total_loss = np.mean(step_total_losses)\n",
" epoch_reconstruction_loss = np.mean(step_reconstruction_losses)\n",
" epoch_kl_loss = np.mean(step_kl_losses)\n",
"\n",
" total_losses.append(epoch_total_loss)\n",
" reconstruction_losses.append(epoch_reconstruction_loss)\n",
" kl_losses.append(epoch_kl_loss)\n",
"\n",
" val_str = ''\n",
" if not validation_data_flow is None:\n",
" step_val_tl = []\n",
" step_val_rl = []\n",
" step_val_kl = []\n",
" for i in range(len(validation_data_flow)):\n",
" x, _ = next(validation_data_flow)\n",
" tl, rl, kl = self.model.loss_fn(x)\n",
" step_val_tl.append(np.mean(tl))\n",
" step_val_rl.append(np.mean(rl))\n",
" step_val_kl.append(np.mean(kl))\n",
" val_tl = np.mean(step_val_tl)\n",
" val_rl = np.mean(step_val_rl)\n",
" val_kl = np.mean(step_val_kl)\n",
" val_total_losses.append(val_tl)\n",
" val_reconstruction_losses.append(val_rl)\n",
" val_kl_losses.append(val_kl)\n",
" val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '\n",
"\n",
" if (epoch+1) % save_epoch_interval == 0 and run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch)\n",
"\n",
" elapsed_time = datetime.datetime.now() - start_time\n",
" print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')\n",
"\n",
" self.epoch += 1\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
"\n",
" dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }\n",
" if not validation_data_flow is None:\n",
" dic['val_loss'] = val_total_losses\n",
" dic['val_reconstruction_loss'] = val_reconstruction_losses\n",
" dic['val_kl_loss'] = val_kl_losses\n",
"\n",
" return dic\n",
"\n",
"\n",
" @staticmethod\n",
" def showImages(imgs1, imgs2, txts, w, h, vskip=0.5, filepath=None):\n",
" n = len(imgs1)\n",
" fig, ax = plt.subplots(2, n, figsize=(w * n, (2+vskip) * h))\n",
" for i in range(n):\n",
" if n == 1:\n",
" axis = ax[0]\n",
" else:\n",
" axis = ax[0][i]\n",
" img = imgs1[i].squeeze()\n",
" axis.imshow(img, cmap='gray_r')\n",
" axis.axis('off')\n",
"\n",
" axis.text(0.5, -0.35, txts[i], fontsize=10, ha='center', transform=axis.transAxes)\n",
"\n",
" if n == 1:\n",
" axis = ax[1]\n",
" else:\n",
" axis = ax[1][i]\n",
" img2 = imgs2[i].squeeze()\n",
" axis.imshow(img2, cmap='gray_r')\n",
" axis.axis('off')\n",
"\n",
" if not filepath is None:\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" fig.savefig(filepath, dpi=600)\n",
" plt.close()\n",
" else:\n",
" plt.show()\n",
"\n",
" @staticmethod\n",
" def plot_history(vals, labels):\n",
" colors = ['red', 'blue', 'green', 'orange', 'black', 'pink']\n",
" n = len(vals)\n",
" fig, ax = plt.subplots(1, 1, figsize=(9,4))\n",
" for i in range(n):\n",
" ax.plot(vals[i], c=colors[i], label=labels[i])\n",
" ax.legend(loc='upper right')\n",
" ax.set_xlabel('epochs')\n",
" # ax[0].set_ylabel('loss')\n",
" \n",
" plt.show()\n"
]
}
],
"source": [
"! cat {nw_path}/VariationalAutoEncoder.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K29zyLNo-JG-"
},
"source": [
"# Preparing CelebA dataset\n",
"\n",
"Official WWW of CelebA dataset:\n",
"\n",
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html\n",
"\n",
"\n",
"Google Drive of CelebA dataset:\n",
"\n",
"https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg\n",
"\n",
"\n",
"img_align_celeba.zip mirrored on my Google Drive: \n",
"\n",
"https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx\n",
"\n",
"\n",
"## CelebA データセットを用意する\n",
"\n",
"CelebA データセットの公式ページ:\n",
"\n",
"https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html\n",
"\n",
"\n",
"CelebA データセットのGoogle Drive:\n",
"\n",
"https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg\n",
"\n",
"\n",
"自分の Google Drive 上にミラーした img_align_celeba.zip: \n",
"\n",
"https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx\n",
""
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 20130,
"status": "ok",
"timestamp": 1637506176166,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "g99ZWERz-DP8",
"outputId": "62a1e72f-154f-4c4b-f00b-275f79503df9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx\n",
"To: /content/img_align_celeba.zip\n",
"100% 1.44G/1.44G [00:06<00:00, 238MB/s]\n"
]
}
],
"source": [
"# Download img_align_celeba.zip from GoogleDrive\n",
"\n",
"MIRRORED_URL = 'https://drive.google.com/uc?id=1LFKeoI-hb96jlV0K10dO1o04iQPBoFdx'\n",
"\n",
"! gdown {MIRRORED_URL}"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 13,
"status": "ok",
"timestamp": 1637506176167,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "xXFiRu9y-QSj",
"outputId": "65bdaf54-4f9e-40a0-9124-56c6bf8c25f9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total 1409676\n",
"drwx------ 6 root root 4096 Nov 21 14:46 drive\n",
"-rw-r--r-- 1 root root 1443490838 Nov 21 14:49 img_align_celeba.zip\n",
"drwxr-xr-x 2 root root 4096 Nov 21 14:48 nw\n",
"drwxr-xr-x 1 root root 4096 Nov 18 14:36 sample_data\n"
]
}
],
"source": [
"! ls -l"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"executionInfo": {
"elapsed": 260,
"status": "ok",
"timestamp": 1637506709488,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "0SnX7ijr-UBg"
},
"outputs": [],
"source": [
"DATA_DIR = 'data'\n",
"DATA_SUBDIR = 'img_align_celeba'"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"executionInfo": {
"elapsed": 18168,
"status": "ok",
"timestamp": 1637506735151,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "CeMWTJWeAXVq"
},
"outputs": [],
"source": [
"! rm -rf {DATA_DIR}\n",
"! unzip -d {DATA_DIR} -q {DATA_SUBDIR}.zip"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 1999,
"status": "ok",
"timestamp": 1637506737138,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "fDN_8kaFAZPV",
"outputId": "09c85014-f7cf-4a4b-80a9-5da2d03f1e31"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"total 1737936\n",
"-rw-r--r-- 1 root root 11440 Sep 28 2015 000001.jpg\n",
"-rw-r--r-- 1 root root 7448 Sep 28 2015 000002.jpg\n",
"-rw-r--r-- 1 root root 4253 Sep 28 2015 000003.jpg\n",
"-rw-r--r-- 1 root root 10747 Sep 28 2015 000004.jpg\n",
"-rw-r--r-- 1 root root 6351 Sep 28 2015 000005.jpg\n",
"-rw-r--r-- 1 root root 8073 Sep 28 2015 000006.jpg\n",
"-rw-r--r-- 1 root root 8203 Sep 28 2015 000007.jpg\n",
"-rw-r--r-- 1 root root 7725 Sep 28 2015 000008.jpg\n",
"-rw-r--r-- 1 root root 8641 Sep 28 2015 000009.jpg\n",
" 202599 202599 2228589\n"
]
}
],
"source": [
"! ls -l {DATA_DIR}/{DATA_SUBDIR} | head\n",
"! ls {DATA_DIR}/{DATA_SUBDIR} | wc"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JTblqnRqAvLW"
},
"source": [
"# Check the CelebA dataset\n",
"\n",
"## CelebA データセットを確認する"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 809,
"status": "ok",
"timestamp": 1637506845329,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "DR3yVPDZAuRw",
"outputId": "4680c5f5-0eb3-4e08-b512-490a9bcf2663"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"202599\n"
]
}
],
"source": [
"# paths to all the image files.\n",
"\n",
"import os\n",
"import glob\n",
"import numpy as np\n",
"\n",
"all_file_paths = np.array(glob.glob(os.path.join(DATA_DIR, DATA_SUBDIR, '*.jpg')))\n",
"n_all_images = len(all_file_paths)\n",
"\n",
"print(n_all_images)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"executionInfo": {
"elapsed": 2,
"status": "ok",
"timestamp": 1637506845329,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "OmuBX5z_A1qG"
},
"outputs": [],
"source": [
"# slect some image files.\n",
"\n",
"n_to_show = 10\n",
"selected_indices = np.random.choice(range(n_all_images), n_to_show)\n",
"selected_paths = all_file_paths[selected_indices]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 107
},
"executionInfo": {
"elapsed": 721,
"status": "ok",
"timestamp": 1637506846299,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "RABvNE7zA3nl",
"outputId": "63e14d31-9936-4e6c-827a-1297abfc97a7"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"