2022年 11月 5日

python类的调用

class VAE(keras.Model):

    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.fc1 = layers.Dense(128)
        self.fc2 = layers.Dense(64)
        self.fc3 = layers.Dense(z_dim) # get mean prediction
        self.fc4 = layers.Dense(z_dim)

        # Decoder
        self.fc5 = layers.Dense(64)
        self.fc6 = layers.Dense(128)
        self.fc7 = layers.Dense(96)

    def encoder(self, x):
        h = tf.nn.relu(self.fc1(x))
        h = tf.nn.relu(self.fc2(h))
        # get mean
        mu = self.fc3(h)
        # get variance
        log_var = self.fc4(h)

        return mu, log_var

    def decoder(self, z):
        out = tf.nn.relu(self.fc5(z))
        out = tf.nn.relu(self.fc6(out))
        out = self.fc7(out)

        return out

    def reparameterize(self, mu, log_var):
        eps = tf.random.normal(log_var.shape)
        std = tf.exp(log_var*0.5)
        z = mu + std * eps

        return z

    def call(self, inputs, training=None):
        # [b, 96] => [b, z_dim], [b, z_dim]
        mu, log_var = self.encoder(inputs)
        # reparameterization trick
        z = self.reparameterize(mu, log_var)
        x_hat = self.decoder(z)

        return x_hat, mu, log_var
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

模型调用对应input

vae=VAE()
xhat,mu1,var=vae(data.values)
  • 1
  • 2