Tensorflow源码解析3 -- TensorFlow核心对象 - Graph

本文涉及的产品
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
云解析 DNS,旗舰版 1个月
简介: # 1 Graph概述 计算图Graph是TensorFlow的核心对象,TensorFlow的运行流程基本都是围绕它进行的。包括图的构建、传递、剪枝、按worker分裂、按设备二次分裂、执行、注销等。因此理解计算图Graph对掌握TensorFlow运行尤为关键。 # 2 默认Graph ### 默认图替换 之前讲解Session的时候就说过,一个Session只能r

1 Graph概述

计算图Graph是TensorFlow的核心对象,TensorFlow的运行流程基本都是围绕它进行的。包括图的构建、传递、剪枝、按worker分裂、按设备二次分裂、执行、注销等。因此理解计算图Graph对掌握TensorFlow运行尤为关键。

2 默认Graph

默认图替换

之前讲解Session的时候就说过,一个Session只能run一个Graph,但一个Graph可以运行在多个Session中。常见情况是,session会运行全局唯一的隐式的默认的Graph,operation也是注册到这个Graph中。

也可以显示创建Graph,并调用as_default()使他替换默认Graph。在该上下文管理器中创建的op都会注册到这个graph中。退出上下文管理器后,则恢复原来的默认graph。一般情况下,我们不用显式创建Graph,使用系统创建的那个默认Graph即可。

print tf.get_default_graph()

with tf.Graph().as_default() as g:
    print tf.get_default_graph() is g
    print tf.get_default_graph()

print tf.get_default_graph()

输出如下

<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
True
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>

由此可见,在上下文管理器中,当前线程的默认图被替换了,而退出上下文管理后,则恢复为了原来的默认图。

默认图管理

默认graph和默认session一样,也是线程作用域的。当前线程中,永远都有且仅有一个graph为默认图。TensorFlow同样通过栈来管理线程的默认graph。

@tf_export("Graph")
class Graph(object):
    # 替换线程默认图
    def as_default(self):
        return _default_graph_stack.get_controller(self)
    
    # 栈式管理,push pop
    @tf_contextlib.contextmanager
    def get_controller(self, default):
        try:
          context.context_stack.push(default.building_function, default.as_default)
        finally:
          context.context_stack.pop()

替换默认图采用了堆栈的管理方式,通过push pop操作进行管理。获取默认图的操作如下,通过默认graph栈_default_graph_stack来获取。

@tf_export("get_default_graph")
def get_default_graph():
  return _default_graph_stack.get_default()

下面来看_default_graph_stack的创建

_default_graph_stack = _DefaultGraphStack()
class _DefaultGraphStack(_DefaultStack):  
  def __init__(self):
    # 调用父类来创建
    super(_DefaultGraphStack, self).__init__()
    self._global_default_graph = None
    
class _DefaultStack(threading.local):
  def __init__(self):
    super(_DefaultStack, self).__init__()
    self._enforce_nesting = True
    # 和默认session栈一样,本质上也是一个list
    self.stack = []

_default_graph_stack的创建如上所示,最终和默认session栈一样,本质上也是一个list。

3 前端Graph数据结构

Graph数据结构

理解一个对象,先从它的数据结构开始。我们先来看Python前端中,Graph的数据结构。Graph主要的成员变量是Operation和Tensor。Operation是Graph的节点,它代表了运算算子。Tensor是Graph的边,它代表了运算数据。

@tf_export("Graph")
class Graph(object):
    def __init__(self):
           # 加线程锁,使得注册op时,不会有其他线程注册op到graph中,从而保证共享graph是线程安全的
        self._lock = threading.Lock()
        
        # op相关数据。
        # 为graph的每个op分配一个id,通过id可以快速索引到相关op。故创建了_nodes_by_id字典
        self._nodes_by_id = dict()  # GUARDED_BY(self._lock)
        self._next_id_counter = 0  # GUARDED_BY(self._lock)
        # 同时也可以通过name来快速索引op,故创建了_nodes_by_name字典
        self._nodes_by_name = dict()  # GUARDED_BY(self._lock)
        self._version = 0  # GUARDED_BY(self._lock)
        
        # tensor相关数据。
        # 处理tensor的placeholder
        self._handle_feeders = {}
        # 处理tensor的read操作
        self._handle_readers = {}
        # 处理tensor的move操作
        self._handle_movers = {}
        # 处理tensor的delete操作
        self._handle_deleters = {}

下面看graph如何添加op的,以及保证线程安全的。

  def _add_op(self, op):
    # graph被设置为final后,就是只读的了,不能添加op了。
    self._check_not_finalized()
    
    # 保证共享graph的线程安全
    with self._lock:
      # 将op以id和name分别构建字典,添加到_nodes_by_id和_nodes_by_name字典中,方便后续快速索引
      self._nodes_by_id[op._id] = op
      self._nodes_by_name[op.name] = op
      self._version = max(self._version, op._id)

GraphKeys 图分组

每个Operation节点都有一个特定的标签,从而实现节点的分类。相同标签的节点归为一类,放到同一个Collection中。标签是一个唯一的GraphKey,GraphKey被定义在类GraphKeys中,如下

@tf_export("GraphKeys")
class GraphKeys(object):
    GLOBAL_VARIABLES = "variables"
    QUEUE_RUNNERS = "queue_runners"
    SAVERS = "savers"
    WEIGHTS = "weights"
    BIASES = "biases"
    ACTIVATIONS = "activations"
    UPDATE_OPS = "update_ops"
    LOSSES = "losses"
    TRAIN_OP = "train_op"
    # 省略其他

name_scope 节点命名空间

使用name_scope对graph中的节点进行层次化管理,上下层之间通过斜杠分隔。

# graph节点命名空间
g = tf.get_default_graph()
with g.name_scope("scope1"):
    c = tf.constant("hello, world", name="c")
    print c.op.name

    with g.name_scope("scope2"):
        c = tf.constant("hello, world", name="c")
        print c.op.name

输出如下

scope1/c
scope1/scope2/c  # 内层的scope会继承外层的,类似于栈,形成层次化管理


4 后端Graph数据结构

Graph

先来看graph.h文件中的Graph类的定义,只看关键代码

 class Graph {
     private:
      // 所有已知的op计算函数的注册表
      FunctionLibraryDefinition ops_;

      // GraphDef版本号
      const std::unique_ptr<VersionDef> versions_;

      // 节点node列表,通过id来访问
      std::vector<Node*> nodes_;

      // node个数
      int64 num_nodes_ = 0;

      // 边edge列表,通过id来访问
      std::vector<Edge*> edges_;

      // graph中非空edge的数目
      int num_edges_ = 0;

      // 已分配了内存,但还没使用的node和edge
      std::vector<Node*> free_nodes_;
      std::vector<Edge*> free_edges_;
 }

后端中的Graph主要成员也是节点node和边edge。节点node为计算算子Operation,边为算子所需要的数据,或者代表节点间的依赖关系。这一点和Python中的定义相似。边Edge的持有它的源节点和目标节点的指针,从而将两个节点连接起来。下面看Edge类的定义。

Edge

class Edge {
     private:
      Edge() {}

      friend class EdgeSetTest;
      friend class Graph;
      // 源节点, 边的数据就来源于源节点的计算。源节点是边的生产者
      Node* src_;

      // 目标节点,边的数据提供给目标节点进行计算。目标节点是边的消费者
      Node* dst_;

      // 边id,也就是边的标识符
      int id_;

      // 表示当前边为源节点的第src_output_条边。源节点可能会有多条输出边
      int src_output_;

      // 表示当前边为目标节点的第dst_input_条边。目标节点可能会有多条输入边。
      int dst_input_;
};

Edge既可以承载tensor数据,提供给节点Operation进行运算,也可以用来表示节点之间有依赖关系。对于表示节点依赖的边,其src_output_, dst_input_均为-1,此时边不承载任何数据。

下面来看Node类的定义。

Node

class Node {
 public:
    // NodeDef,节点算子Operation的信息,比如op分配到哪个设备上了,op的名字等,运行时有可能变化。
      const NodeDef& def() const;
    
    // OpDef, 节点算子Operation的元数据,不会变的。比如Operation的入参列表,出参列表等
      const OpDef& op_def() const;
 private:
      // 输入边,传递数据给节点。可能有多条
      EdgeSet in_edges_;

      // 输出边,节点计算后得到的数据。可能有多条
      EdgeSet out_edges_;
}

节点Node中包含的主要数据有输入边和输出边的集合,从而能够由Node找到跟他关联的所有边。Node中还包含NodeDef和OpDef两个成员。NodeDef表示节点算子的信息,运行时可能会变,创建Node时会new一个NodeDef对象。OpDef表示节点算子的元信息,运行时不会变,创建Node时不需要new OpDef,只需要从OpDef仓库中取出即可。因为元信息是确定的,比如Operation的入参个数等。

由Node和Edge,即可以组成图Graph,通过任何节点和任何边,都可以遍历完整图。Graph执行计算时,按照拓扑结构,依次执行每个Node的op计算,最终即可得到输出结果。入度为0的节点,也就是依赖数据已经准备好的节点,可以并发执行,从而提高运行效率。

系统中存在默认的Graph,初始化Graph时,会添加一个Source节点和Sink节点。Source表示Graph的起始节点,Sink为终止节点。Source的id为0,Sink的id为1,其他节点id均大于1.

5 Graph运行时生命周期

Graph是TensorFlow的核心对象,TensorFlow的运行均是围绕Graph进行的。运行时Graph大致经过了以下阶段

  1. 图构建:client端用户将创建的节点注册到Graph中,一般不需要显示创建Graph,使用系统创建的默认的即可。
  2. 图发送:client通过session.run()执行运行时,将构建好的整图序列化为GraphDef后,传递给master
  3. 图剪枝:master先反序列化拿到Graph,然后根据session.run()传递的fetches和feeds列表,反向遍历全图full graph,实施剪枝,得到最小依赖子图。
  4. 图分裂:master将最小子图分裂为多个Graph Partition,并注册到多个worker上。一个worker对应一个Graph Partition。
  5. 图二次分裂:worker根据当前可用硬件资源,如CPU GPU,将Graph Partition按照op算子设备约束规范(例如tf.device(’/cpu:0’),二次分裂到不同设备上。每个计算设备对应一个Graph Partition。
  6. 图运行:对于每一个计算设备,worker依照op在kernel中的实现,完成op的运算。设备间数据通信可以使用send/recv节点,而worker间通信,则使用GRPC或RDMA协议。

这些阶段根据TensorFlow运行时的不同,会进行不同的处理。运行时有两种,本地运行时和分布式运行时。故Graph生命周期到后面分析本地运行时和分布式运行时的时候,再详细讲解。

目录
相关文章
|
3月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
124 2
|
2月前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
创建型模式的主要关注点是“怎样创建对象?”,它的主要特点是"将对象的创建与使用分离”。这样可以降低系统的耦合度,使用者不需要关注对象的创建细节。创建型模式分为5种:单例模式、工厂方法模式抽象工厂式、原型模式、建造者模式。
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
2月前
|
存储 设计模式 算法
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为,后者采用组合或聚合在对象间分配行为。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象行为模式比类行为模式具有更大的灵活性。 行为型模式分为: • 模板方法模式 • 策略模式 • 命令模式 • 职责链模式 • 状态模式 • 观察者模式 • 中介者模式 • 迭代器模式 • 访问者模式 • 备忘录模式 • 解释器模式
【23种设计模式·全精解析 | 行为型模式篇】11种行为型模式的结构概述、案例实现、优缺点、扩展对比、使用场景、源码解析
|
2月前
|
设计模式 存储 安全
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
结构型模式描述如何将类或对象按某种布局组成更大的结构。它分为类结构型模式和对象结构型模式,前者采用继承机制来组织接口和类,后者釆用组合或聚合来组合对象。由于组合关系或聚合关系比继承关系耦合度低,满足“合成复用原则”,所以对象结构型模式比类结构型模式具有更大的灵活性。 结构型模式分为以下 7 种: • 代理模式 • 适配器模式 • 装饰者模式 • 桥接模式 • 外观模式 • 组合模式 • 享元模式
【23种设计模式·全精解析 | 创建型模式篇】5种创建型模式的结构概述、实现、优缺点、扩展、使用场景、源码解析
|
2月前
|
机器学习/深度学习 人工智能 算法
深入解析图神经网络:Graph Transformer的算法基础与工程实践
Graph Transformer是一种结合了Transformer自注意力机制与图神经网络(GNNs)特点的神经网络模型,专为处理图结构数据而设计。它通过改进的数据表示方法、自注意力机制、拉普拉斯位置编码、消息传递与聚合机制等核心技术,实现了对图中节点间关系信息的高效处理及长程依赖关系的捕捉,显著提升了图相关任务的性能。本文详细解析了Graph Transformer的技术原理、实现细节及应用场景,并通过图书推荐系统的实例,展示了其在实际问题解决中的强大能力。
257 30
|
21天前
|
自然语言处理 数据处理 索引
mindspeed-llm源码解析(一)preprocess_data
mindspeed-llm是昇腾模型套件代码仓,原来叫"modelLink"。这篇文章带大家阅读一下数据处理脚本preprocess_data.py(基于1.0.0分支),数据处理是模型训练的第一步,经常会用到。
40 0
|
3月前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
72 12
|
2月前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
2月前
|
安全 搜索推荐 数据挖掘
陪玩系统源码开发流程解析,成品陪玩系统源码的优点
我们自主开发的多客陪玩系统源码,整合了市面上主流陪玩APP功能,支持二次开发。该系统适用于线上游戏陪玩、语音视频聊天、心理咨询等场景,提供用户注册管理、陪玩者资料库、预约匹配、实时通讯、支付结算、安全隐私保护、客户服务及数据分析等功能,打造综合性社交平台。随着互联网技术发展,陪玩系统正成为游戏爱好者的新宠,改变游戏体验并带来新的商业模式。
|
3月前
|
存储 安全 Linux
Golang的GMP调度模型与源码解析
【11月更文挑战第11天】GMP 调度模型是 Go 语言运行时系统的核心部分,用于高效管理和调度大量协程(goroutine)。它通过少量的操作系统线程(M)和逻辑处理器(P)来调度大量的轻量级协程(G),从而实现高性能的并发处理。GMP 模型通过本地队列和全局队列来减少锁竞争,提高调度效率。在 Go 源码中,`runtime.h` 文件定义了关键数据结构,`schedule()` 和 `findrunnable()` 函数实现了核心调度逻辑。通过深入研究 GMP 模型,可以更好地理解 Go 语言的并发机制。

推荐镜像

更多