DL之MaskR-CNN:基于类MaskR-CNN算法(RetinaNet+mask head)训练自己的数据集(.h5文件)从而实现图像分割daiding

本文涉及的产品
视觉智能开放平台,视频资源包5000点
视觉智能开放平台,图像资源包5000点
视觉智能开放平台,分割抠图1万点
简介: DL之MaskR-CNN:基于类MaskR-CNN算法(RetinaNet+mask head)训练自己的数据集(.h5文件)从而实现图像分割daiding

输出结果

image.png

image.png

设计思路

https://yunyaniu.blog.csdn.net/article/details/80330637


核心代码

1、train.py

#!/usr/bin/env python

"""

Copyright 2017-2018 Fizyr (https://fizyr.com)

Licensed under the Apache License, Version 2.0 (the "License");

you may not use this file except in compliance with the License.

You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software

distributed under the License is distributed on an "AS IS" BASIS,

WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

See the License for the specific language governing permissions and

limitations under the License.

"""

import argparse

import os

import sys

import keras

import keras.preprocessing.image

import tensorflow as tf

import keras_retinanet.losses

from keras_retinanet.callbacks import RedirectModel

from keras_retinanet.utils.config import read_config_file, parse_anchor_parameters

from keras_retinanet.utils.transform import random_transform_generator

from keras_retinanet.utils.keras_version import check_keras_version

from keras_retinanet.utils.model import freeze as freeze_model

# Allow relative imports when being executed as script.

if __name__ == "__main__" and __package__ is None:

   sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))

   import keras_maskrcnn.bin

   __package__ = "keras_maskrcnn.bin"

# Change these to absolute imports if you copy this script outside the keras_retinanet package.

from .. import losses

from .. import models

from ..callbacks.eval import Evaluate

def get_session():

   config = tf.ConfigProto()

   config.gpu_options.allow_growth = True

   return tf.Session(config=config)

def model_with_weights(model, weights, skip_mismatch):

   if weights is not None:

       model.load_weights(weights, by_name=True, skip_mismatch=skip_mismatch)

   return model

def create_models(backbone_retinanet, num_classes, weights, freeze_backbone=False, class_specific_filter=True, anchor_params=None):

   modifier = freeze_model if freeze_backbone else None

   model            = model_with_weights(

       backbone_retinanet(

           num_classes,

           nms=True,

           class_specific_filter=class_specific_filter,

           modifier=modifier,

           anchor_params=anchor_params

       ), weights=weights, skip_mismatch=True)

   training_model   = model

   prediction_model = model

   # compile model

   training_model.compile(

       loss={

           'regression'    : keras_retinanet.losses.smooth_l1(),

           'classification': keras_retinanet.losses.focal(),

           'masks'         : losses.mask(),

       },

       optimizer=keras.optimizers.adam(lr=1e-5, clipnorm=0.001)

   )

   return model, training_model, prediction_model

def create_callbacks(model, training_model, prediction_model, validation_generator, args):

   callbacks = []

   # save the prediction model

   if args.snapshots:

       # ensure directory created first; otherwise h5py will error after epoch.

       os.makedirs(args.snapshot_path, exist_ok=True)

       checkpoint = keras.callbacks.ModelCheckpoint(

           os.path.join(

               args.snapshot_path,

               '{backbone}_{dataset_type}_{{epoch:02d}}.h5'.format(backbone=args.backbone, dataset_type=args.dataset_type)

           ),

           verbose=1

       )

       checkpoint = RedirectModel(checkpoint, prediction_model)

       callbacks.append(checkpoint)

   tensorboard_callback = None

   if args.tensorboard_dir:

       tensorboard_callback = keras.callbacks.TensorBoard(

           log_dir                = args.tensorboard_dir,

           histogram_freq         = 0,

           batch_size             = args.batch_size,

           write_graph            = True,

           write_grads            = False,

           write_images           = False,

           embeddings_freq        = 0,

           embeddings_layer_names = None,

           embeddings_metadata    = None

       )

       callbacks.append(tensorboard_callback)

   if args.evaluation and validation_generator:

       if args.dataset_type == 'coco':

           from ..callbacks.coco import CocoEval

           # use prediction model for evaluation

           evaluation = CocoEval(validation_generator)

       else:

           evaluation = Evaluate(validation_generator, tensorboard=tensorboard_callback, weighted_average=args.weighted_average)

       evaluation = RedirectModel(evaluation, prediction_model)

       callbacks.append(evaluation)

   callbacks.append(keras.callbacks.ReduceLROnPlateau(

       monitor  = 'loss',

       factor   = 0.1,

       patience = 2,

       verbose  = 1,

       mode     = 'auto',

       epsilon  = 0.0001,

       cooldown = 0,

       min_lr   = 0

   ))

   return callbacks

def create_generators(args):

   # create random transform generator for augmenting training data

   transform_generator = random_transform_generator(flip_x_chance=0.5)

   if args.dataset_type == 'coco':

       # import here to prevent unnecessary dependency on cocoapi

       from ..preprocessing.coco import CocoGenerator

       train_generator = CocoGenerator(

           args.coco_path,

           'train2017',

           transform_generator=transform_generator,

           batch_size=args.batch_size,

           config=args.config

       )

       validation_generator = CocoGenerator(

           args.coco_path,

           'val2017',

           batch_size=args.batch_size,

           config=args.config

       )

   elif args.dataset_type == 'csv':

       from ..preprocessing.csv_generator import CSVGenerator

       train_generator = CSVGenerator(

           args.annotations,

           args.classes,

           transform_generator=transform_generator,

           batch_size=args.batch_size,

           config=args.config

       )

       if args.val_annotations:

           validation_generator = CSVGenerator(

               args.val_annotations,

               args.classes,

               batch_size=args.batch_size,

               config=args.config

           )

       else:

           validation_generator = None

   else:

       raise ValueError('Invalid data type received: {}'.format(args.dataset_type))

   return train_generator, validation_generator

def check_args(parsed_args):

   """

   Function to check for inherent contradictions within parsed arguments.

   For example, batch_size < num_gpus

   Intended to raise errors prior to backend initialisation.

   :param parsed_args: parser.parse_args()

   :return: parsed_args

   """

   return parsed_args

def parse_args(args):

   parser     = argparse.ArgumentParser(description='Simple training script for training a RetinaNet mask network.')

   subparsers = parser.add_subparsers(help='Arguments for specific dataset types.', dest='dataset_type')

   subparsers.required = True

   coco_parser = subparsers.add_parser('coco')

   coco_parser.add_argument('coco_path', help='Path to dataset directory (ie. /tmp/COCO).')

   csv_parser = subparsers.add_parser('csv')

   csv_parser.add_argument('annotations', help='Path to CSV file containing annotations for training.')

   csv_parser.add_argument('classes', help='Path to a CSV file containing class label mapping.')

   csv_parser.add_argument('--val-annotations', help='Path to CSV file containing annotations for validation (optional).')

   group = parser.add_mutually_exclusive_group()

   group.add_argument('--snapshot',          help='Resume training from a snapshot.')

   group.add_argument('--imagenet-weights',  help='Initialize the model with pretrained imagenet weights. This is the default behaviour.', action='store_const', const=True, default=True)

   group.add_argument('--weights',           help='Initialize the model with weights from a file.')

   group.add_argument('--no-weights',        help='Don\'t initialize the model with any weights.', dest='imagenet_weights', action='store_const', const=False)

   parser.add_argument('--backbone',         help='Backbone model used by retinanet.', default='resnet50', type=str)

   parser.add_argument('--batch-size',       help='Size of the batches.', default=1, type=int)

   parser.add_argument('--gpu',              help='Id of the GPU to use (as reported by nvidia-smi).')

   parser.add_argument('--epochs',           help='Number of epochs to train.', type=int, default=50)

   parser.add_argument('--steps',            help='Number of steps per epoch.', type=int, default=10000)

   parser.add_argument('--snapshot-path',    help='Path to store snapshots of models during training (defaults to \'./snapshots\')', default='./snapshots')

   parser.add_argument('--tensorboard-dir',  help='Log directory for Tensorboard output', default='./logs')

   parser.add_argument('--no-snapshots',     help='Disable saving snapshots.', dest='snapshots', action='store_false')

   parser.add_argument('--no-evaluation',    help='Disable per epoch evaluation.', dest='evaluation', action='store_false')

   parser.add_argument('--freeze-backbone',  help='Freeze training of backbone layers.', action='store_true')

   parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false')

   parser.add_argument('--config',           help='Path to a configuration parameters .ini file.')

   parser.add_argument('--weighted-average', help='Compute the mAP using the weighted average of precisions among classes.', action='store_true')

   return check_args(parser.parse_args(args))

def main(args=None):

   # parse arguments

   if args is None:

       args = sys.argv[1:]

   args = parse_args(args)

   # make sure keras is the minimum required version

   check_keras_version()

   # create object that stores backbone information

   backbone = models.backbone(args.backbone)

   # optionally choose specific GPU

   if args.gpu:

       os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

   keras.backend.tensorflow_backend.set_session(get_session())

   # optionally load config parameters

   if args.config:

       args.config = read_config_file(args.config)

   # create the generators

   train_generator, validation_generator = create_generators(args)

   # create the model

   if args.snapshot is not None:

       print('Loading model, this may take a second...')

       model            = models.load_model(args.snapshot, backbone_name=args.backbone)

       training_model   = model

       prediction_model = model

   else:

       weights = args.weights

       # default to imagenet if nothing else is specified

       if weights is None and args.imagenet_weights:

           weights = backbone.download_imagenet()

       anchor_params = None

       if args.config and 'anchor_parameters' in args.config:

           anchor_params = parse_anchor_parameters(args.config)

       print('Creating model, this may take a second...')

       model, training_model, prediction_model = create_models(

           backbone_retinanet=backbone.maskrcnn,

           num_classes=train_generator.num_classes(),

           weights=weights,

           freeze_backbone=args.freeze_backbone,

           class_specific_filter=args.class_specific_filter,

           anchor_params=anchor_params

       )

   # print model summary

   print(model.summary())

   # create the callbacks

   callbacks = create_callbacks(

       model,

       training_model,

       prediction_model,

       validation_generator,

       args,

   )

   # start training

   training_model.fit_generator(

       generator=train_generator,

       steps_per_epoch=args.steps,

       epochs=args.epochs,

       verbose=1,

       callbacks=callbacks,

       max_queue_size=1,

   )

if __name__ == '__main__':

   main()


相关文章
|
19天前
|
算法 数据挖掘 数据安全/隐私保护
基于FCM模糊聚类算法的图像分割matlab仿真
本项目展示了基于模糊C均值(FCM)算法的图像分割技术。算法运行效果良好,无水印。使用MATLAB 2022a开发,提供完整代码及中文注释,附带操作步骤视频。FCM算法通过隶属度矩阵和聚类中心矩阵实现图像分割,适用于灰度和彩色图像,广泛应用于医学影像、遥感图像等领域。
|
2月前
|
存储 机器学习/深度学习 算法
蓝桥杯练习题(三):Python组之算法训练提高综合五十题
蓝桥杯Python编程练习题的集合,涵盖了从基础到提高的多个算法题目及其解答。
66 3
蓝桥杯练习题(三):Python组之算法训练提高综合五十题
|
27天前
|
分布式计算 Java 开发工具
阿里云MaxCompute-XGBoost on Spark 极限梯度提升算法的分布式训练与模型持久化oss的实现与代码浅析
本文介绍了XGBoost在MaxCompute+OSS架构下模型持久化遇到的问题及其解决方案。首先简要介绍了XGBoost的特点和应用场景,随后详细描述了客户在将XGBoost on Spark任务从HDFS迁移到OSS时遇到的异常情况。通过分析异常堆栈和源代码,发现使用的`nativeBooster.saveModel`方法不支持OSS路径,而使用`write.overwrite().save`方法则能成功保存模型。最后提供了完整的Scala代码示例、Maven配置和提交命令,帮助用户顺利迁移模型存储路径。
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化CNN-LSTM网络的数据分类识别算法matlab仿真
本项目展示了基于贝叶斯优化(BO)的CNN-LSTM网络在数据分类中的应用。通过MATLAB 2022a实现,优化前后效果对比明显。核心代码附带中文注释和操作视频,涵盖BO、CNN、LSTM理论,特别是BO优化CNN-LSTM网络的batchsize和学习率,显著提升模型性能。
|
2月前
|
机器学习/深度学习 算法 决策智能
【机器学习】揭秘深度学习优化算法:加速训练与提升性能
【机器学习】揭秘深度学习优化算法:加速训练与提升性能
|
2月前
|
算法 Java C++
【贪心算法】算法训练 ALGO-1003 礼物(C/C++)
【贪心算法】算法训练 ALGO-1003 礼物(C/C++)
【贪心算法】算法训练 ALGO-1003 礼物(C/C++)
|
2月前
|
机器学习/深度学习 算法 数据安全/隐私保护
基于贝叶斯优化卷积神经网络(Bayes-CNN)的多因子数据分类识别算法matlab仿真
本项目展示了贝叶斯优化在CNN中的应用,包括优化过程、训练与识别效果对比,以及标准CNN的识别结果。使用Matlab2022a开发,提供完整代码及视频教程。贝叶斯优化通过构建代理模型指导超参数优化,显著提升模型性能,适用于复杂数据分类任务。
|
2月前
|
算法 C++
蓝桥 算法训练 共线(C++)
蓝桥 算法训练 共线(C++)
|
3月前
|
机器学习/深度学习 数据采集 数据可视化
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
本文详细介绍如何使用PyTorch构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行图像分类。从数据预处理、模型定义到训练过程及结果可视化,文章全面展示了深度学习项目的全流程。通过实际操作,读者可以深入了解CNN在图像分类任务中的应用,并掌握PyTorch的基本使用方法。希望本文为您的深度学习项目提供有价值的参考与启示。
|
4月前
|
机器学习/深度学习
CNN网络编译和训练
【8月更文挑战第10天】CNN网络编译和训练。
93 20

热门文章

最新文章