很多网络的特征提取部分都会用到fine-tunning,比如resnet-50,inception等,该文章以AlexNet为例,分析tensorflow如何进行微调
finetuning的三要素:
- 预训练模型,如resnet_v2_50.npy、resnet_v2_50.ckpt等。
- 模型的网络结构定义。
- 所需从预训练模型中恢复的变量,通常以排除的方式给出。
Tips:
所定义网络结构中的变量名需要和预训练模型中的变量名保持相同。预训练模型中的变量名可有以下方式查看:
1
2
3
4
5
6
7# 使用tf.train.NewCheckpointReader直接读取ckpt文件里的变量
from tensorflow.python import pywrap_tensorflow
reader = pywarp_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)恢复部分变量时可借助slim。如:
1
2
3variables_to_restore = slim.get_varibales_to_restore(exclude=[args.resnet_model] + "logits", "optimizer_vars", "DeepLab_v3/ASPP_layer", "DeepLab_v3/logits"])
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, "./resnet/checkpoints/" + args.resnet_model + "./ckpt")