{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "pDyCVsqlFA_5"
},
"source": [
"Updated 19/Nov/2021 by Yoshihisa Nitta "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I4thkRCveS2o"
},
"source": [
"\n",
"\n",
"# Variational Auto Encoder Analysis for MNIST dataset with Tensorflow 2 on Google Colab\n",
"\n",
"To run this notebook, we assume it is in the state after training with VAE_MNIST_Train.ipynb.\n",
"\n",
"\n",
"## MNIST データセットを用いて Variational Auto Encoder をGooble Colab 上の Tensorflow 2 で解析する\n",
"\n",
"このノートブックを実行するには VAE_MNIST_Train.ipynb で訓練した後の状態であることを仮定している。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"executionInfo": {
"elapsed": 3,
"status": "ok",
"timestamp": 1637572808348,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "gLzKsQqnToi6"
},
"outputs": [],
"source": [
"#! pip install tensorflow==2.7.0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 2015,
"status": "ok",
"timestamp": 1637572810361,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "j3o8E-o4TpyQ",
"outputId": "bfcc50a5-ce5d-4195-face-07cc15598efc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7.0\n"
]
}
],
"source": [
"%tensorflow_version 2.x\n",
"\n",
"import tensorflow as tf\n",
"print(tf.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RV4ugCIeerBy"
},
"source": [
"# Check the Google Colab runtime environment\n",
"\n",
"## Google Colab 実行環境を確認する"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 1159,
"status": "ok",
"timestamp": 1637572811513,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "2BkP2SYVeOtQ",
"outputId": "36188a02-1704-4f09-b2d8-044d8a4eeb53"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mon Nov 22 09:20:10 2021 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 495.44 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla P100-PCIE... Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 34C P0 26W / 250W | 0MiB / 16280MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n",
"processor\t: 0\n",
"vendor_id\t: GenuineIntel\n",
"cpu family\t: 6\n",
"model\t\t: 79\n",
"model name\t: Intel(R) Xeon(R) CPU @ 2.20GHz\n",
"stepping\t: 0\n",
"microcode\t: 0x1\n",
"cpu MHz\t\t: 2199.998\n",
"cache size\t: 56320 KB\n",
"physical id\t: 0\n",
"siblings\t: 2\n",
"core id\t\t: 0\n",
"cpu cores\t: 1\n",
"apicid\t\t: 0\n",
"initial apicid\t: 0\n",
"fpu\t\t: yes\n",
"fpu_exception\t: yes\n",
"cpuid level\t: 13\n",
"wp\t\t: yes\n",
"flags\t\t: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities\n",
"bugs\t\t: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa\n",
"bogomips\t: 4399.99\n",
"clflush size\t: 64\n",
"cache_alignment\t: 64\n",
"address sizes\t: 46 bits physical, 48 bits virtual\n",
"power management:\n",
"\n",
"processor\t: 1\n",
"vendor_id\t: GenuineIntel\n",
"cpu family\t: 6\n",
"model\t\t: 79\n",
"model name\t: Intel(R) Xeon(R) CPU @ 2.20GHz\n",
"stepping\t: 0\n",
"microcode\t: 0x1\n",
"cpu MHz\t\t: 2199.998\n",
"cache size\t: 56320 KB\n",
"physical id\t: 0\n",
"siblings\t: 2\n",
"core id\t\t: 0\n",
"cpu cores\t: 1\n",
"apicid\t\t: 1\n",
"initial apicid\t: 1\n",
"fpu\t\t: yes\n",
"fpu_exception\t: yes\n",
"cpuid level\t: 13\n",
"wp\t\t: yes\n",
"flags\t\t: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities\n",
"bugs\t\t: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa\n",
"bogomips\t: 4399.99\n",
"clflush size\t: 64\n",
"cache_alignment\t: 64\n",
"address sizes\t: 46 bits physical, 48 bits virtual\n",
"power management:\n",
"\n",
"Ubuntu 18.04.5 LTS \\n \\l\n",
"\n",
" total used free shared buff/cache available\n",
"Mem: 12G 738M 9G 1.2M 2.0G 11G\n",
"Swap: 0B 0B 0B\n"
]
}
],
"source": [
"! nvidia-smi\n",
"! cat /proc/cpuinfo\n",
"! cat /etc/issue\n",
"! free -h"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PLB9-03nfCxx"
},
"source": [
"# Mount Google Drive from Google Colab\n",
"\n",
"## Google Colab から GoogleDrive をマウントする"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 38581,
"status": "ok",
"timestamp": 1637572850092,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "JS_DxEGYe3bK",
"outputId": "23b987ed-ba88-4dd5-bb1d-5d03577fb250"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1637572850093,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "2B6gVAM6fJEw",
"outputId": "1af2f27c-e878-4476-81b3-871b8058c369"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MyDrive Shareddrives\n"
]
}
],
"source": [
"! ls /content/drive"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0EN3D7nffROf"
},
"source": [
"# Download source file from Google Drive or nw.tsuda.ac.jp\n",
"Basically, gdown
from Google Drive. Download from nw.tsuda.ac.jp above only if the specifications of Google Drive change and you cannot download from Google Drive.\n",
"\n",
"## Google Drive または nw.tsuda.ac.jp からファイルをダウンロードする\n",
"基本的に Google Drive から gdown
してください。 Google Drive の仕様が変わってダウンロードができない場合にのみ、nw.tsuda.ac.jp からダウンロードしてください。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 3977,
"status": "ok",
"timestamp": 1637572854067,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "F3wJxpfjfQxJ",
"outputId": "9393e00c-c465-4f28-ad44-205f8a2063c2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"From: https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3\n",
"To: /content/nw/VariationalAutoEncoder.py\n",
"\r",
" 0% 0.00/18.7k [00:00, ?B/s]\r",
"100% 18.7k/18.7k [00:00<00:00, 14.9MB/s]\n"
]
}
],
"source": [
"# Download source file\n",
"nw_path = './nw'\n",
"! rm -rf {nw_path}\n",
"! mkdir -p {nw_path}\n",
"\n",
"if True: # from Google Drive\n",
" url_model = 'https://drive.google.com/uc?id=1ZCihR7JkMOity4wCr66ZCp-3ZOlfwwo3'\n",
" ! (cd {nw_path}; gdown {url_model})\n",
"else: # from nw.tsuda.ac.jp\n",
" URL_NW = 'https://nw.tsuda.ac.jp/lec/GoogleColab/pub'\n",
" url_model = f'{URL_NW}/models/VariationalAutoEncoder.py'\n",
" ! wget -nd {url_model} -P {nw_path}"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 23,
"status": "ok",
"timestamp": 1637572854068,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "LdJkkfFtfW72",
"outputId": "6e7b56dc-3708-4be6-e8c5-f64d326fb1f8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import os\n",
"import pickle\n",
"import datetime\n",
"\n",
"class Sampling(tf.keras.layers.Layer):\n",
" def __init__(self, **kwargs):\n",
" super().__init__(**kwargs)\n",
"\n",
" def call(self, inputs):\n",
" mu, log_var = inputs\n",
" epsilon = tf.keras.backend.random_normal(shape=tf.keras.backend.shape(mu), mean=0., stddev=1.)\n",
" return mu + tf.keras.backend.exp(log_var / 2) * epsilon\n",
"\n",
"\n",
"class VAEModel(tf.keras.models.Model):\n",
" def __init__(self, encoder, decoder, r_loss_factor, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.r_loss_factor = r_loss_factor\n",
"\n",
"\n",
" @tf.function\n",
" def loss_fn(self, x):\n",
" z_mean, z_log_var, z = self.encoder(x)\n",
" reconstruction = self.decoder(z)\n",
" reconstruction_loss = tf.reduce_mean(\n",
" tf.square(x - reconstruction), axis=[1,2,3]\n",
" ) * self.r_loss_factor\n",
" kl_loss = tf.reduce_sum(\n",
" 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),\n",
" axis = 1\n",
" ) * (-0.5)\n",
" total_loss = reconstruction_loss + kl_loss\n",
" return total_loss, reconstruction_loss, kl_loss\n",
"\n",
"\n",
" @tf.function\n",
" def compute_loss_and_grads(self, x):\n",
" with tf.GradientTape() as tape:\n",
" total_loss, reconstruction_loss, kl_loss = self.loss_fn(x)\n",
" grads = tape.gradient(total_loss, self.trainable_weights)\n",
" return total_loss, reconstruction_loss, kl_loss, grads\n",
"\n",
"\n",
" def train_step(self, data):\n",
" if isinstance(data, tuple):\n",
" data = data[0]\n",
" total_loss, reconstruction_loss, kl_loss, grads = self.compute_loss_and_grads(data)\n",
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
" return {\n",
" \"loss\": tf.math.reduce_mean(total_loss),\n",
" \"reconstruction_loss\": tf.math.reduce_mean(reconstruction_loss),\n",
" \"kl_loss\": tf.math.reduce_mean(kl_loss),\n",
" }\n",
"\n",
" def call(self,inputs):\n",
" _, _, z = self.encoder(inputs)\n",
" return self.decoder(z)\n",
"\n",
"\n",
"class VariationalAutoEncoder():\n",
" def __init__(self, \n",
" input_dim,\n",
" encoder_conv_filters,\n",
" encoder_conv_kernel_size,\n",
" encoder_conv_strides,\n",
" decoder_conv_t_filters,\n",
" decoder_conv_t_kernel_size,\n",
" decoder_conv_t_strides,\n",
" z_dim,\n",
" r_loss_factor, ### added\n",
" use_batch_norm = False,\n",
" use_dropout = False,\n",
" epoch = 0\n",
" ):\n",
" self.name = 'variational_autoencoder'\n",
" self.input_dim = input_dim\n",
" self.encoder_conv_filters = encoder_conv_filters\n",
" self.encoder_conv_kernel_size = encoder_conv_kernel_size\n",
" self.encoder_conv_strides = encoder_conv_strides\n",
" self.decoder_conv_t_filters = decoder_conv_t_filters\n",
" self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size\n",
" self.decoder_conv_t_strides = decoder_conv_t_strides\n",
" self.z_dim = z_dim\n",
" self.r_loss_factor = r_loss_factor ### added\n",
" \n",
" self.use_batch_norm = use_batch_norm\n",
" self.use_dropout = use_dropout\n",
"\n",
" self.epoch = epoch\n",
" \n",
" self.n_layers_encoder = len(encoder_conv_filters)\n",
" self.n_layers_decoder = len(decoder_conv_t_filters)\n",
" \n",
" self._build()\n",
" \n",
"\n",
" def _build(self):\n",
" ### THE ENCODER\n",
" encoder_input = tf.keras.layers.Input(shape=self.input_dim, name='encoder_input')\n",
" x = encoder_input\n",
" \n",
" for i in range(self.n_layers_encoder):\n",
" x = conv_layer = tf.keras.layers.Conv2D(\n",
" filters = self.encoder_conv_filters[i],\n",
" kernel_size = self.encoder_conv_kernel_size[i],\n",
" strides = self.encoder_conv_strides[i],\n",
" padding = 'same',\n",
" name = 'encoder_conv_' + str(i)\n",
" )(x)\n",
"\n",
" if self.use_batch_norm: ### The order of layers is opposite to AutoEncoder\n",
" x = tf.keras.layers.BatchNormalization()(x) ### AE: LeakyReLU -> BatchNorm\n",
" x = tf.keras.layers.LeakyReLU()(x) ### VAE: BatchNorm -> LeakyReLU\n",
" \n",
" if self.use_dropout:\n",
" x = tf.keras.layers.Dropout(rate = 0.25)(x)\n",
" \n",
" shape_before_flattening = tf.keras.backend.int_shape(x)[1:]\n",
" \n",
" x = tf.keras.layers.Flatten()(x)\n",
" \n",
" self.mu = tf.keras.layers.Dense(self.z_dim, name='mu')(x)\n",
" self.log_var = tf.keras.layers.Dense(self.z_dim, name='log_var')(x) \n",
" self.z = Sampling(name='encoder_output')([self.mu, self.log_var])\n",
" \n",
" self.encoder = tf.keras.models.Model(encoder_input, [self.mu, self.log_var, self.z], name='encoder')\n",
" \n",
" \n",
" ### THE DECODER\n",
" decoder_input = tf.keras.layers.Input(shape=(self.z_dim,), name='decoder_input')\n",
" x = decoder_input\n",
" x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(x)\n",
" x = tf.keras.layers.Reshape(shape_before_flattening)(x)\n",
" \n",
" for i in range(self.n_layers_decoder):\n",
" x = conv_t_layer = tf.keras.layers.Conv2DTranspose(\n",
" filters = self.decoder_conv_t_filters[i],\n",
" kernel_size = self.decoder_conv_t_kernel_size[i],\n",
" strides = self.decoder_conv_t_strides[i],\n",
" padding = 'same',\n",
" name = 'decoder_conv_t_' + str(i)\n",
" )(x)\n",
" \n",
" if i < self.n_layers_decoder - 1:\n",
" if self.use_batch_norm: ### The order of layers is opposite to AutoEncoder\n",
" x = tf.keras.layers.BatchNormalization()(x) ### AE: LeakyReLU -> BatchNorm\n",
" x = tf.keras.layers.LeakyReLU()(x) ### VAE: BatchNorm -> LeakyReLU \n",
" if self.use_dropout:\n",
" x = tf.keras.layers.Dropout(rate=0.25)(x)\n",
" else:\n",
" x = tf.keras.layers.Activation('sigmoid')(x)\n",
" \n",
" decoder_output = x\n",
" self.decoder = tf.keras.models.Model(decoder_input, decoder_output, name='decoder') ### added (name)\n",
" \n",
" ### THE FULL AUTOENCODER\n",
" self.model = VAEModel(self.encoder, self.decoder, self.r_loss_factor)\n",
" \n",
" \n",
" def save(self, folder):\n",
" self.save_params(os.path.join(folder, 'params.pkl'))\n",
" self.save_weights(folder)\n",
"\n",
"\n",
" @staticmethod\n",
" def load(folder, epoch=None): # VariationalAutoEncoder.load(folder)\n",
" params = VariationalAutoEncoder.load_params(os.path.join(folder, 'params.pkl'))\n",
" VAE = VariationalAutoEncoder(*params)\n",
" if epoch is None:\n",
" VAE.load_weights(folder)\n",
" else:\n",
" VAE.load_weights(folder, epoch-1)\n",
" VAE.epoch = epoch\n",
" return VAE\n",
"\n",
" \n",
" def save_params(self, filepath):\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" with open(filepath, 'wb') as f:\n",
" pickle.dump([\n",
" self.input_dim,\n",
" self.encoder_conv_filters,\n",
" self.encoder_conv_kernel_size,\n",
" self.encoder_conv_strides,\n",
" self.decoder_conv_t_filters,\n",
" self.decoder_conv_t_kernel_size,\n",
" self.decoder_conv_t_strides,\n",
" self.z_dim,\n",
" self.r_loss_factor,\n",
" self.use_batch_norm,\n",
" self.use_dropout,\n",
" self.epoch\n",
" ], f)\n",
"\n",
"\n",
" @staticmethod\n",
" def load_params(filepath):\n",
" with open(filepath, 'rb') as f:\n",
" params = pickle.load(f)\n",
" return params\n",
"\n",
"\n",
" def save_weights(self, folder, epoch=None):\n",
" if epoch is None:\n",
" self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights.h5'))\n",
" self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights.h5'))\n",
" else:\n",
" self.save_model_weights(self.encoder, os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))\n",
" self.save_model_weights(self.decoder, os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))\n",
"\n",
"\n",
" def save_model_weights(self, model, filepath):\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" model.save_weights(filepath)\n",
"\n",
"\n",
" def load_weights(self, folder, epoch=None):\n",
" if epoch is None:\n",
" self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights.h5'))\n",
" self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights.h5'))\n",
" else:\n",
" self.encoder.load_weights(os.path.join(folder, f'weights/encoder-weights_{epoch}.h5'))\n",
" self.decoder.load_weights(os.path.join(folder, f'weights/decoder-weights_{epoch}.h5'))\n",
"\n",
"\n",
" def save_images(self, imgs, filepath):\n",
" z_mean, z_log_var, z = self.encoder.predict(imgs)\n",
" reconst_imgs = self.decoder.predict(z)\n",
" txts = [ f'{p[0]:.3f}, {p[1]:.3f}' for p in z ]\n",
" AutoEncoder.showImages(imgs, reconst_imgs, txts, 1.4, 1.4, 0.5, filepath)\n",
" \n",
"\n",
" def compile(self, learning_rate):\n",
" self.learning_rate = learning_rate\n",
" optimizer = tf.keras.optimizers.Adam(lr=learning_rate)\n",
" self.model.compile(optimizer=optimizer) # CAUTION!!!: loss(y_true, y_pred) function is not specified.\n",
" \n",
" \n",
" def train_with_fit(\n",
" self,\n",
" x_train,\n",
" batch_size,\n",
" epochs,\n",
" run_folder='run/'\n",
" ):\n",
" history = self.model.fit(\n",
" x_train,\n",
" x_train,\n",
" batch_size = batch_size,\n",
" shuffle=True,\n",
" initial_epoch = self.epoch,\n",
" epochs = epochs\n",
" )\n",
" if (self.epoch < epochs):\n",
" self.epoch = epochs\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
" \n",
" return history\n",
"\n",
"\n",
" def train_generator_with_fit(\n",
" self,\n",
" data_flow,\n",
" epochs,\n",
" run_folder='run/'\n",
" ):\n",
" history = self.model.fit(\n",
" data_flow,\n",
" initial_epoch = self.epoch,\n",
" epochs = epochs\n",
" )\n",
" if (self.epoch < epochs):\n",
" self.epoch = epochs\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
" \n",
" return history\n",
"\n",
"\n",
" def train_tf(\n",
" self,\n",
" x_train,\n",
" batch_size = 32,\n",
" epochs = 10,\n",
" shuffle = False,\n",
" run_folder = 'run/',\n",
" optimizer = None,\n",
" save_epoch_interval = 100,\n",
" validation_data = None\n",
" ):\n",
" start_time = datetime.datetime.now()\n",
" steps = x_train.shape[0] // batch_size\n",
"\n",
" total_losses = []\n",
" reconstruction_losses = []\n",
" kl_losses = []\n",
"\n",
" val_total_losses = []\n",
" val_reconstruction_losses = []\n",
" val_kl_losses = []\n",
"\n",
" for epoch in range(self.epoch, epochs):\n",
" epoch_loss = 0\n",
" indices = tf.range(x_train.shape[0], dtype=tf.int32)\n",
" if shuffle:\n",
" indices = tf.random.shuffle(indices)\n",
" x_ = x_train[indices]\n",
"\n",
" step_total_losses = []\n",
" step_reconstruction_losses = []\n",
" step_kl_losses = []\n",
" for step in range(steps):\n",
" start = batch_size * step\n",
" end = start + batch_size\n",
"\n",
" total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x_[start:end])\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_weights))\n",
" \n",
" step_total_losses.append(np.mean(total_loss))\n",
" step_reconstruction_losses.append(np.mean(reconstruction_loss))\n",
" step_kl_losses.append(np.mean(kl_loss))\n",
" \n",
" epoch_total_loss = np.mean(step_total_losses)\n",
" epoch_reconstruction_loss = np.mean(step_reconstruction_losses)\n",
" epoch_kl_loss = np.mean(step_kl_losses)\n",
"\n",
" total_losses.append(epoch_total_loss)\n",
" reconstruction_losses.append(epoch_reconstruction_loss)\n",
" kl_losses.append(epoch_kl_loss)\n",
"\n",
" val_str = ''\n",
" if not validation_data is None:\n",
" x_val = validation_data\n",
" tl, rl, kl = self.model.loss_fn(x_val)\n",
" val_tl = np.mean(tl)\n",
" val_rl = np.mean(rl)\n",
" val_kl = np.mean(kl)\n",
" val_total_losses.append(val_tl)\n",
" val_reconstruction_losses.append(val_rl)\n",
" val_kl_losses.append(val_kl)\n",
" val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '\n",
"\n",
" if (epoch+1) % save_epoch_interval == 0 and run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch)\n",
"\n",
" elapsed_time = datetime.datetime.now() - start_time\n",
" print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')\n",
"\n",
" self.epoch += 1\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
"\n",
" dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }\n",
" if not validation_data is None:\n",
" dic['val_loss'] = val_total_losses\n",
" dic['val_reconstruction_loss'] = val_reconstruction_losses\n",
" dic['val_kl_loss'] = val_kl_losses\n",
"\n",
" return dic\n",
" \n",
"\n",
" def train_tf_generator(\n",
" self,\n",
" data_flow,\n",
" epochs = 10,\n",
" run_folder = 'run/',\n",
" optimizer = None,\n",
" save_epoch_interval = 100,\n",
" validation_data_flow = None\n",
" ):\n",
" start_time = datetime.datetime.now()\n",
" steps = len(data_flow)\n",
"\n",
" total_losses = []\n",
" reconstruction_losses = []\n",
" kl_losses = []\n",
"\n",
" val_total_losses = []\n",
" val_reconstruction_losses = []\n",
" val_kl_losses = []\n",
"\n",
" for epoch in range(self.epoch, epochs):\n",
" epoch_loss = 0\n",
"\n",
" step_total_losses = []\n",
" step_reconstruction_losses = []\n",
" step_kl_losses = []\n",
"\n",
" for step in range(steps):\n",
" x, _ = next(data_flow)\n",
"\n",
" total_loss, reconstruction_loss, kl_loss, grads = self.model.compute_loss_and_grads(x)\n",
" optimizer.apply_gradients(zip(grads, self.model.trainable_weights))\n",
" \n",
" step_total_losses.append(np.mean(total_loss))\n",
" step_reconstruction_losses.append(np.mean(reconstruction_loss))\n",
" step_kl_losses.append(np.mean(kl_loss))\n",
" \n",
" epoch_total_loss = np.mean(step_total_losses)\n",
" epoch_reconstruction_loss = np.mean(step_reconstruction_losses)\n",
" epoch_kl_loss = np.mean(step_kl_losses)\n",
"\n",
" total_losses.append(epoch_total_loss)\n",
" reconstruction_losses.append(epoch_reconstruction_loss)\n",
" kl_losses.append(epoch_kl_loss)\n",
"\n",
" val_str = ''\n",
" if not validation_data_flow is None:\n",
" step_val_tl = []\n",
" step_val_rl = []\n",
" step_val_kl = []\n",
" for i in range(len(validation_data_flow)):\n",
" x, _ = next(validation_data_flow)\n",
" tl, rl, kl = self.model.loss_fn(x)\n",
" step_val_tl.append(np.mean(tl))\n",
" step_val_rl.append(np.mean(rl))\n",
" step_val_kl.append(np.mean(kl))\n",
" val_tl = np.mean(step_val_tl)\n",
" val_rl = np.mean(step_val_rl)\n",
" val_kl = np.mean(step_val_kl)\n",
" val_total_losses.append(val_tl)\n",
" val_reconstruction_losses.append(val_rl)\n",
" val_kl_losses.append(val_kl)\n",
" val_str = f'val loss total {val_tl:.3f} reconstruction {val_rl:.3f} kl {val_kl:.3f} '\n",
"\n",
" if (epoch+1) % save_epoch_interval == 0 and run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch)\n",
"\n",
" elapsed_time = datetime.datetime.now() - start_time\n",
" print(f'{epoch+1}/{epochs} {steps} loss: total {epoch_total_loss:.3f} reconstruction {epoch_reconstruction_loss:.3f} kl {epoch_kl_loss:.3f} {val_str}{elapsed_time}')\n",
"\n",
" self.epoch += 1\n",
"\n",
" if run_folder != None:\n",
" self.save(run_folder)\n",
" self.save_weights(run_folder, self.epoch-1)\n",
"\n",
" dic = { 'loss' : total_losses, 'reconstruction_loss' : reconstruction_losses, 'kl_loss' : kl_losses }\n",
" if not validation_data_flow is None:\n",
" dic['val_loss'] = val_total_losses\n",
" dic['val_reconstruction_loss'] = val_reconstruction_losses\n",
" dic['val_kl_loss'] = val_kl_losses\n",
"\n",
" return dic\n",
"\n",
"\n",
" @staticmethod\n",
" def showImages(imgs1, imgs2, txts, w, h, vskip=0.5, filepath=None):\n",
" n = len(imgs1)\n",
" fig, ax = plt.subplots(2, n, figsize=(w * n, (2+vskip) * h))\n",
" for i in range(n):\n",
" if n == 1:\n",
" axis = ax[0]\n",
" else:\n",
" axis = ax[0][i]\n",
" img = imgs1[i].squeeze()\n",
" axis.imshow(img, cmap='gray_r')\n",
" axis.axis('off')\n",
"\n",
" axis.text(0.5, -0.35, txts[i], fontsize=10, ha='center', transform=axis.transAxes)\n",
"\n",
" if n == 1:\n",
" axis = ax[1]\n",
" else:\n",
" axis = ax[1][i]\n",
" img2 = imgs2[i].squeeze()\n",
" axis.imshow(img2, cmap='gray_r')\n",
" axis.axis('off')\n",
"\n",
" if not filepath is None:\n",
" dpath, fname = os.path.split(filepath)\n",
" if dpath != '' and not os.path.exists(dpath):\n",
" os.makedirs(dpath)\n",
" fig.savefig(filepath, dpi=600)\n",
" plt.close()\n",
" else:\n",
" plt.show()\n",
"\n",
" @staticmethod\n",
" def plot_history(vals, labels):\n",
" colors = ['red', 'blue', 'green', 'orange', 'black', 'pink']\n",
" n = len(vals)\n",
" fig, ax = plt.subplots(1, 1, figsize=(9,4))\n",
" for i in range(n):\n",
" ax.plot(vals[i], c=colors[i], label=labels[i])\n",
" ax.legend(loc='upper right')\n",
" ax.set_xlabel('epochs')\n",
" # ax[0].set_ylabel('loss')\n",
" \n",
" plt.show()\n"
]
}
],
"source": [
"! cat {nw_path}/VariationalAutoEncoder.py"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A1sT3O6Ofdd2"
},
"source": [
"# Preparing MNIST dataset\n",
"\n",
"## MNIST データセットを用意する"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 18,
"status": "ok",
"timestamp": 1637572854068,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "zXeh_DeCfZjD",
"outputId": "8be082bd-3934-4c9c-f2ad-828986c64b35"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7.0\n"
]
}
],
"source": [
"%tensorflow_version 2.x\n",
"\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"\n",
"print(tf.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 1100,
"status": "ok",
"timestamp": 1637572855155,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "yHVkd5c0fkzB",
"outputId": "54d98c13-6bdf-4c2c-9833-18ddcab165fa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
"11493376/11490434 [==============================] - 0s 0us/step\n",
"11501568/11490434 [==============================] - 0s 0us/step\n",
"(60000, 28, 28)\n",
"(60000,)\n",
"(10000, 28, 28)\n",
"(10000,)\n"
]
}
],
"source": [
"# prepare data\n",
"(x_train_raw, y_train_raw), (x_test_raw, y_test_raw) = tf.keras.datasets.mnist.load_data()\n",
"print(x_train_raw.shape)\n",
"print(y_train_raw.shape)\n",
"print(x_test_raw.shape)\n",
"print(y_test_raw.shape)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"executionInfo": {
"elapsed": 29,
"status": "ok",
"timestamp": 1637572855158,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "pakzbL16iVP3",
"outputId": "b9236c19-c21d-406d-d303-5d49101b2fd8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(60000, 28, 28, 1)\n",
"(10000, 28, 28, 1)\n"
]
}
],
"source": [
"x_train = x_train_raw.reshape(x_train_raw.shape+(1,)).astype('float32') / 255.0\n",
"x_test = x_test_raw.reshape(x_test_raw.shape+(1,)).astype('float32') / 255.0\n",
"print(x_train.shape)\n",
"print(x_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
},
"executionInfo": {
"elapsed": 23,
"status": "ok",
"timestamp": 1637572855160,
"user": {
"displayName": "Yoshihisa Nitta",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgJLeg9AmjfexROvC3P0wzJdd5AOGY_VOu-nxnh=s64",
"userId": "15888006800030996813"
},
"user_tz": -540
},
"id": "45ij1IwMfnFH",
"outputId": "ee169e72-f658-40d2-f58e-efd0e46f95d4"
},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"