第三十四节,目标检测之谷歌Object Detection API源码解析

时间:2018-07-05 12:15:49   收藏:0   阅读:2745

我们在第三十二节,使用谷歌Object Detection API进行目标检测、训练新的模型(使用VOC 2012数据集)那一节我们介绍了如何使用谷歌Object Detection API进行目标检测,以及如何使用谷歌提供的目标检测模型训练自己的数据。在训练自己的数据集时,主要包括以下几步:

技术分享图片

在这里我主要解析一下train.py文件的工作流程。

一 train.py文件解析

先附上源码:

技术分享图片
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

r"""Training executable for detection models.

This executable is used to train DetectionModels. There are two ways of
configuring the training job:

1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
can be specified by --pipeline_config_path.

Example usage:
    ./train         --logtostderr         --train_dir=path/to/train_dir         --pipeline_config_path=pipeline_config.pbtxt

2) Three configuration files can be provided: a model_pb2.DetectionModel
configuration file to define what type of DetectionModel is being trained, an
input_reader_pb2.InputReader file to specify what training data will be used and
a train_pb2.TrainConfig file to configure training parameters.

Example usage:
    ./train         --logtostderr         --train_dir=path/to/train_dir         --model_config_path=model_config.pbtxt         --train_config_path=train_config.pbtxt         --input_config_path=train_input_config.pbtxt
"""

import functools
import json
import os
import tensorflow as tf

from object_detection import trainer
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util
from object_detection.utils import dataset_util

tf.logging.set_verbosity(tf.logging.INFO)

flags = tf.app.flags
flags.DEFINE_string(master, ‘‘, Name of the TensorFlow master to use.)
flags.DEFINE_integer(task, 0, task id)
flags.DEFINE_integer(num_clones, 1, Number of clones to deploy per worker.)
flags.DEFINE_boolean(clone_on_cpu, False,
                     Force clones to be deployed on CPU.  Note that even if 
                     set to False (allowing ops to run on gpu), some ops may 
                     still be run on the CPU if they have no GPU kernel.)
flags.DEFINE_integer(worker_replicas, 1, Number of worker+trainer 
                     replicas.)
flags.DEFINE_integer(ps_tasks, 0,
                     Number of parameter server tasks. If None, does not use 
                     a parameter server.)
flags.DEFINE_string(train_dir, ‘‘,
                    Directory to save the checkpoints and training summaries.)

flags.DEFINE_string(pipeline_config_path, ‘‘,
                    Path to a pipeline_pb2.TrainEvalPipelineConfig config 
                    file. If provided, other configs are ignored)

flags.DEFINE_string(train_config_path, ‘‘,
                    Path to a train_pb2.TrainConfig config file.)
flags.DEFINE_string(input_config_path, ‘‘,
                    Path to an input_reader_pb2.InputReader config file.)
flags.DEFINE_string(model_config_path, ‘‘,
                    Path to a model_pb2.DetectionModel config file.)

FLAGS = flags.FLAGS


def main(_):
  assert FLAGS.train_dir, `train_dir` is missing.
  if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
  if FLAGS.pipeline_config_path:
    configs = config_util.get_configs_from_pipeline_file(
        FLAGS.pipeline_config_path)
    if FLAGS.task == 0:
      tf.gfile.Copy(FLAGS.pipeline_config_path,
                    os.path.join(FLAGS.train_dir, pipeline.config),
                    overwrite=True)
  else:
    configs = config_util.get_configs_from_multiple_files(
        model_config_path=FLAGS.model_config_path,
        train_config_path=FLAGS.train_config_path,
        train_input_config_path=FLAGS.input_config_path)
    if FLAGS.task == 0:
      for name, config in [(model.config, FLAGS.model_config_path),
                           (train.config, FLAGS.train_config_path),
                           (input.config, FLAGS.input_config_path)]:
        tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
                      overwrite=True)

  model_config = configs[model]
  train_config = configs[train_config]
  input_config = configs[train_input_config]

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=True)

  def get_next(config):
    return dataset_util.make_initializable_iterator(
        dataset_builder.build(config)).get_next()

  create_input_dict_fn = functools.partial(get_next, input_config)

  env = json.loads(os.environ.get(TF_CONFIG, {}))
  cluster_data = env.get(cluster, None)
  cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
  task_data = env.get(task, None) or {type: master, index: 0}
  task_info = type(TaskSpec, (object,), task_data)

  # Parameters for a single worker.
  ps_tasks = 0
  worker_replicas = 1
  worker_job_name = lonely_worker
  task = 0
  is_chief = True
  master = ‘‘

  if cluster_data and worker in cluster_data:
    # Number of total worker replicas include "worker"s and the "master".
    worker_replicas = len(cluster_data[worker]) + 1
  if cluster_data and ps in cluster_data:
    ps_tasks = len(cluster_data[ps])

  if worker_replicas > 1 and ps_tasks < 1:
    raise ValueError(At least 1 ps task is needed for distributed training.)

  if worker_replicas >= 1 and ps_tasks > 0:
    # Set up distributed training.
    server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol=grpc,
                             job_name=task_info.type,
                             task_index=task_info.index)
    if task_info.type == ps:
      server.join()
      return

    worker_job_name = %s/task:%d % (task_info.type, task_info.index)
    task = task_info.index
    is_chief = (task_info.type == master)
    master = server.target

  graph_rewriter_fn = None
  if graph_rewriter_config in configs:
    graph_rewriter_fn = graph_rewriter_builder.build(
        configs[graph_rewriter_config], is_training=True)

  trainer.train(
      create_input_dict_fn,
      model_fn,
      train_config,
      master,
      task,
      FLAGS.num_clones,
      worker_replicas,
      FLAGS.clone_on_cpu,
      ps_tasks,
      worker_job_name,
      is_chief,
      FLAGS.train_dir,
      graph_hook_fn=graph_rewriter_fn)


if __name__ == __main__:
  tf.app.run()
View Code

1、先定义了tf.app.flags,用于支持接受命令行传递参数,相当于接受argv。

flags = tf.app.flags
flags.DEFINE_string(master, ‘‘, Name of the TensorFlow master to use.)
flags.DEFINE_integer(task, 0, task id)
flags.DEFINE_integer(num_clones, 1, Number of clones to deploy per worker.)
flags.DEFINE_boolean(clone_on_cpu, False,
                     Force clones to be deployed on CPU.  Note that even if 
                     set to False (allowing ops to run on gpu), some ops may 
                     still be run on the CPU if they have no GPU kernel.)
flags.DEFINE_integer(worker_replicas, 1, Number of worker+trainer 
                     replicas.)
flags.DEFINE_integer(ps_tasks, 0,
                     Number of parameter server tasks. If None, does not use 
                     a parameter server.)
flags.DEFINE_string(train_dir, ‘‘,
                    Directory to save the checkpoints and training summaries.)

flags.DEFINE_string(pipeline_config_path, ‘‘,
                    Path to a pipeline_pb2.TrainEvalPipelineConfig config 
                    file. If provided, other configs are ignored)

flags.DEFINE_string(train_config_path, ‘‘,
                    Path to a train_pb2.TrainConfig config file.)
flags.DEFINE_string(input_config_path, ‘‘,
                    Path to an input_reader_pb2.InputReader config file.)
flags.DEFINE_string(model_config_path, ‘‘,
                    Path to a model_pb2.DetectionModel config file.)

FLAGS = flags.FLAGS

这里面有几个比较重要的参数,train_dir目录用于保存训练的模型和日志文件,pipeline_config_path用于指定pipeline_pb2.TrainEvalPipelineConfig配置文件的全路径(如果不指定指定这个参数,需要指定train_config_pathinput_config_path,model_config_path配置文件,其实这三个文件就是把pipeline_pb2.TrainEvalPipelineConfig配置文件分成了三部分)。

2、再来看一下main函数,我们把它分成几部分来解读。

假设我们在控制台下的命令如下:

python train.py --train_dir voc/train_dir/ --pipeline_config_path voc/faster_rcnn_inception_resnet_v2_atrous_voc.config

 

  assert FLAGS.train_dir, `train_dir` is missing.
  if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
  if FLAGS.pipeline_config_path:
    configs = config_util.get_configs_from_pipeline_file(
        FLAGS.pipeline_config_path)
    if FLAGS.task == 0:
      tf.gfile.Copy(FLAGS.pipeline_config_path,
                    os.path.join(FLAGS.train_dir, pipeline.config),
                    overwrite=True)
  else:
    configs = config_util.get_configs_from_multiple_files(
        model_config_path=FLAGS.model_config_path,
        train_config_path=FLAGS.train_config_path,
        train_input_config_path=FLAGS.input_config_path)
    if FLAGS.task == 0:
      for name, config in [(model.config, FLAGS.model_config_path),
                           (train.config, FLAGS.train_config_path),
                           (input.config, FLAGS.input_config_path)]:
        tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
                      overwrite=True)

因为我们传入了train_dir,pipeline_config_path参数,程序执行时会:

  model_config = configs[model]
  train_config = configs[train_config]
  input_config = configs[train_input_config]

  model_fn = functools.partial(
      model_builder.build,
      model_config=model_config,
      is_training=True)

  def get_next(config):
    return dataset_util.make_initializable_iterator(
        dataset_builder.build(config)).get_next()

  create_input_dict_fn = functools.partial(get_next, input_config)

 

评论(0
© 2014 mamicode.com 版权所有 京ICP备13008772号-2  联系我们:gaon5@hotmail.com
迷上了代码!