Finetuning with Tensorflow

很多网络的特征提取部分都会用到fine-tunning,比如resnet-50,inception等,该文章以AlexNet为例,分析tensorflow如何进行微调

finetuning的三要素:

  • 预训练模型,如resnet_v2_50.npy、resnet_v2_50.ckpt等。
  • 模型的网络结构定义。
  • 所需从预训练模型中恢复的变量,通常以排除的方式给出。

Tips:

  • 所定义网络结构中的变量名需要和预训练模型中的变量名保持相同。预训练模型中的变量名可有以下方式查看:

    1
    2
    3
    4
    5
    6
    7
    # 使用tf.train.NewCheckpointReader直接读取ckpt文件里的变量
    from tensorflow.python import pywrap_tensorflow

    reader = pywarp_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
    print("tensor_name: ", key)
  • 恢复部分变量时可借助slim。如:

    1
    2
    3
    variables_to_restore = slim.get_varibales_to_restore(exclude=[args.resnet_model] + "logits", "optimizer_vars",  "DeepLab_v3/ASPP_layer", "DeepLab_v3/logits"])
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, "./resnet/checkpoints/" + args.resnet_model + "./ckpt")