class VAE(keras.Model):
def __init__(self):
super(VAE, self).__init__()
# Encoder
self.fc1 = layers.Dense(128)
self.fc2 = layers.Dense(64)
self.fc3 = layers.Dense(z_dim) # get mean prediction
self.fc4 = layers.Dense(z_dim)
# Decoder
self.fc5 = layers.Dense(64)
self.fc6 = layers.Dense(128)
self.fc7 = layers.Dense(96)
def encoder(self, x):
h = tf.nn.relu(self.fc1(x))
h = tf.nn.relu(self.fc2(h))
# get mean
mu = self.fc3(h)
# get variance
log_var = self.fc4(h)
return mu, log_var
def decoder(self, z):
out = tf.nn.relu(self.fc5(z))
out = tf.nn.relu(self.fc6(out))
out = self.fc7(out)
return out
def reparameterize(self, mu, log_var):
eps = tf.random.normal(log_var.shape)
std = tf.exp(log_var*0.5)
z = mu + std * eps
return z
def call(self, inputs, training=None):
# [b, 96] => [b, z_dim], [b, z_dim]
mu, log_var = self.encoder(inputs)
# reparameterization trick
z = self.reparameterize(mu, log_var)
x_hat = self.decoder(z)
return x_hat, mu, log_var
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
模型调用对应input
vae=VAE()
xhat,mu1,var=vae(data.values)
- 1
- 2