TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了。tf.keras 从 1.x 版本迁移到 2.0 版本,需要修改几个地方。

1. 设置随机种子

import tensorflow as tf

# TF 1.x
tf.set_random_seed(args.seed)
# TF 2.0
tf.random.set_seed(args.seed)

2. 设置并行线程数和动态分配显存

import tensorflow as tf
from tensorflow.python.keras import backend as K

# TF 1.x
config = tf.ConfigProto(intra_op_parallelism_threads=1,
                         inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True  # 不全部占满显存, 按需分配
K.set_session(tf.Session(config=config))

# TF 2.0
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1,
                                  inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True  # 不全部占满显存, 按需分配
K.set_session(tf.compat.v1.Session(config=config))

3. model.fit() 生成的 log 中,acc 改名 accuracy,val_acc 改名 val_accuracy。故在 callbacks.ModelCheckpoint 中需要做修改:

from tensorflow.python.keras import callbacks

# TF 1.x
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_acc', mode='max',
                                            verbose=1, save_best_only=True, save_weights_only=True)

# TF 2.0
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_accuracy', mode='max',
                                            verbose=1, save_best_only=True, save_weights_only=True)
内容来源于网络如有侵权请私信删除

文章来源: 博客园

原文链接: https://www.cnblogs.com/wuliytTaotao/p/12016656.html

你还没有登录,请先登录注册
  • 还没有人评论,欢迎说说您的想法!