Tensorflow的妙用​

简介: Tensorflow的妙用​

向大家推荐一个 TensorFlow 工具———TensorFlow Hub,它包含各种预训练模型的综合代码库,这些模型稍作调整便可部署到任何设备上。只需几行代码即可重复使用经过训练的模型,例如 BERT 和 Faster R-CNN,实现这些些牛X的应用,简直和把大象装进冰箱一样简单。


640.png


第一步:安装 TensorFlow Hub


Tensorflow_hub 库可与  TensorFlow 一起安装(建议直接上TF2)


pip install "tensorflow>=2.0.0"
pip install --upgrade tensorflow-hub


使用时


import tensorflow as tf
import tensorflow_hub as hub


第二步:从 TF Hub 下载模型


TensorFlow Hub 在 hub.tensorflow.google.cn 中提供了一个开放的训练模型存储库。tensorflow_hub 库可以从这个存储库和其他基于 HTTP 的机器学习模型存储库中加载模型。


640.png


从 下载并解压缩模型后,tensorflow_hub 库会将这些模型缓存到文件系统上。下载位置默认为本地临时目录,但可以通过设置环境变量 TFHUB_CACHE_DIR(推荐)或传递命令行标记 --tfhub_cache_dir 进行自定义。


os.environ['TFHUB_CACHE_DIR'] = '/home/user/workspace/tf_cache'


值得注意的是,TensorFlow Hub Module仅为我们提供了包含模型体系结构的图形以及在某些数据集上训练的权重。大多数模块允许访问模型的内部层,可以根据不同的用例使用。但是,有些模块不能精细调整。在开始开发之前,建议在TensorFlow Hub网站中查看有关该模块的说明。

以目标检测为例:打开网站,动几下鼠标即可
https://hub.tensorflow.google.cn/


640.png


640.png


640.png

拿来直接用


640.png


module_handle = "https://hub.tensorflow.google.cn/google/faster_rcnn/openimages_v4/inception_resnet_v2/1" 
detector = hub.load(module_handle).signatures['default']
def load_img(path):
  img = tf.io.read_file(path)
  img = tf.image.decode_jpeg(img, channels=3)
  return img
def run_detector(detector, path):
  img = load_img(path)
  converted_img  = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
  start_time = time.time()
  result = detector(converted_img)
  end_time = time.time()
  result = {key:value.numpy() for key,value in result.items()}
  print("Found %d objects." % len(result["detection_scores"]))
  print("Inference time: ", end_time-start_time)
  image_with_boxes = draw_boxes(
      img.numpy(), result["detection_boxes"],
      result["detection_class_entities"], result["detection_scores"])
  display_image(image_with_boxes)
run_detector(detector, downloaded_image_path)


无需重复训练,拿来即用!6不6?

640.png


目录
打赏
0
0
1
0
13
分享
相关文章
【Pytorch写代码技巧--Einsum】Einsum详解+常用写法
不知大家在看论文代码的时候是否会常常看见 torch.einsum(),这玩意儿看起来是真的抽象,但是深入了解后发现它原来这么好用。
704 0
Tensorflow源码解析3 -- TensorFlow核心对象 - Graph
# 1 Graph概述 计算图Graph是TensorFlow的核心对象,TensorFlow的运行流程基本都是围绕它进行的。包括图的构建、传递、剪枝、按worker分裂、按设备二次分裂、执行、注销等。因此理解计算图Graph对掌握TensorFlow运行尤为关键。 # 2 默认Graph ### 默认图替换 之前讲解Session的时候就说过,一个Session只能r
3864 0