1949啦网--小小 痛苦,是因为能力和欲望不匹配造成的

TensorFlow创建自定义类继承tf.layers.Layer创建新的layer层,自定义类继承keras.Model创建自定义model

tf.layers.Layer类是tf.layers里所有层都继承的基类,实现了通用的基础功能。用户只需要实例化它,就可以直接调用得到的实例。

Layer的子类一般这样子实现:

__init__():先初始化父类。然后在成员变量中保存配置。

build():一般用于初始化层内的参数和变量。在调用call()方法前,类会自动调用该方法。在该方法末尾需要设置self.built = True,保证build()方法只被调用一次。

call():用于定义层对输入张量的实际操作。

下面是我们自定义一个全连接层的例子。(self.add_weight的参数name一定要定义,否则model.save_weights("./weight/07/07.weight")会报错,我错误找了好久)

class MyDense(keras.layers.Layer):      def __init__(self, outdim):          super().__init__()          self.outdim = outdim                def build(self, input_shape):          self.indim = int(input_shape[-1])                        self.kernel = self.add_weight(              name="w",               shape=[self.indim, self.outdim],               dtype=tf.float32,               initializer=tf.random_normal_initializer()          )          self.built = True                 def call(self, inputs):          inputs = tf.cast(inputs, dtype=tf.float32)          return inputs@self.kernel
class MyModel(keras.Model):      def __init__(self):          super().__init__()          self.f1 = MyDense(256)          self.f2 = MyDense(256)          self.f3 = MyDense(128)          self.f4 = MyDense(32)          self.f5 = MyDense(10)      def call(self, inputs):          inputs = tf.reshape(inputs, [-1, 32*32*3])          out = self.f1(inputs)          out = tf.nn.relu(out)          out = self.f2(out)          out = tf.nn.relu(out)          out = self.f3(out)          out = tf.nn.relu(out)          out = self.f4(out)          out = tf.nn.relu(out)          out = self.f5(out)                    self.out = out          return out
model = MyModel()  model.build([None, 32*32*3])  model.summary()
model.compile(      optimizer=keras.optimizers.Adam(1e-3),      loss=keras.losses.CategoricalCrossentropy(from_logits=True),      metrics=["accuracy"]  )
model.fit(db, epochs=5, validation_data=db_test)

版权声明:本文为期权记的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

原文链接:https://www.qiquanji.com/post/9673.html

微信扫码关注

更新实时通知

作者:xialibing 分类:编程小记 浏览: