TensorFlow Lite源码解析--模型加载和执行

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: TensorFlow Lite是专门针对移动和嵌入式设备的特性重新实现的TensorFlow版本。相比普通的TensorFlow,它的功能更加精简,不支持模型的训练,不支持分布式运行,也没有太多跨平台逻辑,支持的op也比较有限。但正因其精简性,因此比较适合用来探究一个机器学习框架的实现原理。不过准确讲,从TensorFlow Lite只能看到预测(inference)部分,无法看到训练(t

TensorFlow Lite是专门针对移动和嵌入式设备的特性重新实现的TensorFlow版本。相比普通的TensorFlow,它的功能更加精简,不支持模型的训练,不支持分布式运行,也没有太多跨平台逻辑,支持的op也比较有限。但正因其精简性,因此比较适合用来探究一个机器学习框架的实现原理。不过准确讲,从TensorFlow Lite只能看到预测(inference)部分,无法看到训练(training)部分。

TensorFlow Lite的架构如下(图片来自官网

image.png

从架构图中,可以看到TensorFlow Lite主要功能可以分成几块:

  • TensorFlow模型文件的格式转换。将普通的TensorFlow的pb文件转换成TensorFlow Lite需要的FlatBuffers文件。
  • 模型文件的加载和解析。将文本的FlatBuffers文件解析成对应的op,用以执行。
  • op的执行。具体op的实现本文暂不涉及。

TensorFlow Lite的源码结构

TensorFlow Lite源码位于tensorflow/contrib/lite,目录结构如下:

- lite
    - [d] c             一些基础数据结构定义,例如TfLiteContext, TfLiteStatus
    - [d] core          基础API定义
    - [d] delegates     对eager和nnapi的代理封装
    - [d] examples      使用例子,包括android和iOS的demo
    - [d] experimental  
    - [d] g3doc
    - [d] java          Java API
    - [d] kernels       内部实现的op,包括op的注册逻辑
    - [d] lib_package
    - [d] models        内置的模型,例如智能回复功能
    - [d] nnapi         对android平台硬件加速接口的调用
    - [d] profiling     内部分析库,例如模型执行次数之类。可关掉
    - [d] python        Python API
    - [d] schema        对TensorFlow Lite的FlatBuffers文件格式的解析
    - [d] testdata
    - [d] testing
    - [d] toco          TensorFlow Lite Converter,将普通的TensorFlow的pb格式转换成TensorFlow Lite的FlatBuffers格式
    - [d] tools
    - [d] tutorials
    - BUILD
    - README.md
    - allocation.cc
    - allocation.h
    - arena_planner.cc
    - arena_planner.h
    - arena_planner_test.cc
    - build_def.bzl
    - builtin_op_data.h
    - builtin_ops.h
    - context.h
    - context_util.h
    - error_reporter.h
    - graph_info.cc
    - graph_info.h
    - graph_info_test.cc
    - interpreter.cc
    - interpreter.h     TensorFlow Lite解释器,C++的入口API
    - interpreter_test.cc
    - memory_planner.h
    - mmap_allocation.cc
    - mmap_allocation_disabled.cc
    - model.cc
    - model.h           TensorFlow Lite的加载和解析
    - model_test.cc
    - mutable_op_resolver.cc
    - mutable_op_resolver.h
    - mutable_op_resolver_test.cc
    - nnapi_delegate.cc
    - nnapi_delegate.h
    - nnapi_delegate_disabled.cc
    - op_resolver.h     op的查找
    - optional_debug_tools.cc
    - optional_debug_tools.h
    - simple_memory_arena.cc
    - simple_memory_arena.h
    - simple_memory_arena_test.cc
    - special_rules.bzl
    - stderr_reporter.cc
    - stderr_reporter.h
    - string.h
    - string_util.cc
    - string_util.h
    - string_util_test.cc
    - util.cc
    - util.h
    - util_test.cc
    - version.h

我们主要看op_resolver.h,model.h,interpreter.h,schema/,java/,kernels/register.h等文件。

TensorFlow Lite执行过程分析

因为比较熟悉,我们从Java的API入手分析。Java API的核心类是Interpreter.java,其具体实现是在NativeInterpreterWrappter.java,而最终是调用到Native的nativeinterpreterwraptter_jni.h,自此就进入C++实现的逻辑。
image.png

模型文件的加载和解析

image.png

加载模型文件
//nativeinterpreter_jni.cc的createModel()函数是加载和解析文件的入口
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel(
    JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) {
  BufferErrorReporter* error_reporter =
      convertLongToErrorReporter(env, error_handle);
  if (error_reporter == nullptr) return 0;
  const char* path = env->GetStringUTFChars(model_file, nullptr);

  std::unique_ptr<tflite::TfLiteVerifier> verifier;
  verifier.reset(new JNIFlatBufferVerifier());

  //读取并解析文件关键代码
  auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(
      path, verifier.get(), error_reporter);
  if (!model) {
    throwException(env, kIllegalArgumentException,
                   "Contents of %s does not encode a valid "
                   "TensorFlowLite model: %s",
                   path, error_reporter->CachedErrorMessage());
    env->ReleaseStringUTFChars(model_file, path);
    return 0;
  }
  env->ReleaseStringUTFChars(model_file, path);
  return reinterpret_cast<jlong>(model.release());
}

上述逻辑中,最关键之处在于

auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(path, verifier.get(), error_reporter);

这行代码的作用是读取并解析文件看起具体实现。

//代码再model.cc文件中
std::unique_ptr<FlatBufferModel> FlatBufferModel::VerifyAndBuildFromFile(
    const char* filename, TfLiteVerifier* verifier,
    ErrorReporter* error_reporter) {
  error_reporter = ValidateErrorReporter(error_reporter);

  std::unique_ptr<FlatBufferModel> model;
  //读取文件
  auto allocation = GetAllocationFromFile(filename, /*mmap_file=*/true,
                                          error_reporter, /*use_nnapi=*/true);
  if (verifier &&
      !verifier->Verify(static_cast<const char*>(allocation->base()),
                        allocation->bytes(), error_reporter)) {
    return model;
  }
  //用FlatBuffers库解析文件
  model.reset(new FlatBufferModel(allocation.release(), error_reporter));
  if (!model->initialized()) model.reset();
  return model;
}

这段逻辑中GetAllocationFromFile()函数获取了文件内容的地址,FlatBufferModel()构造函数中则利用FlatBuffers库读取文件内容。

解析模型文件

上面的流程将模型文件读到FlatBuffers的Model数据结构中,具体数据结构定义可以见schema.fbs。接下去,需要文件中的数据映射成对应可以执行的op数据结构。这个工作主要由InterpreterBuilder完成。

//nativeinterpreter_jni.cc的createInterpreter()函数是将模型文件映射成可以执行的op的入口函数。
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter(
    JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle,
    jint num_threads) {
  tflite::FlatBufferModel* model = convertLongToModel(env, model_handle);
  if (model == nullptr) return 0;
  BufferErrorReporter* error_reporter =
      convertLongToErrorReporter(env, error_handle);
  if (error_reporter == nullptr) return 0;
  //先注册op,将op的实现和FlatBuffers中的index关联起来。
  auto resolver = ::tflite::CreateOpResolver();
  std::unique_ptr<tflite::Interpreter> interpreter;
  //解析FlatBuffers,将配置文件中的内容映射成可以执行的op实例。
  TfLiteStatus status = tflite::InterpreterBuilder(*model, *(resolver.get()))(
      &interpreter, static_cast<int>(num_threads));
  if (status != kTfLiteOk) {
    throwException(env, kIllegalArgumentException,
                   "Internal error: Cannot create interpreter: %s",
                   error_reporter->CachedErrorMessage());
    return 0;
  }
  // allocates memory
  status = interpreter->AllocateTensors();
  if (status != kTfLiteOk) {
    throwException(
        env, kIllegalStateException,
        "Internal error: Unexpected failure when preparing tensor allocations:"
        " %s",
        error_reporter->CachedErrorMessage());
    return 0;
  }
  return reinterpret_cast<jlong>(interpreter.release());
}

这里首先调用::tflite::CreateOpResolver()完成op的注册,将op实例和FlatBuffers中的索引对应起来,具体索引见schema_generated.h里面的BuiltinOperator枚举。

//builtin_ops_jni.cc
std::unique_ptr<OpResolver> CreateOpResolver() {  // NOLINT
  return std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver>(
      new tflite::ops::builtin::BuiltinOpResolver());
}


//register.cc
BuiltinOpResolver::BuiltinOpResolver() {
  AddBuiltin(BuiltinOperator_RELU, Register_RELU());
  AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
  AddBuiltin(BuiltinOperator_RELU6, Register_RELU6());
  AddBuiltin(BuiltinOperator_TANH, Register_TANH());
  AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC());
  AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
  AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
  AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
  AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
  AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()
  ...
}

InterpreterBuilder则负责根据FlatBuffers内容,构建Interpreter对象。


TfLiteStatus InterpreterBuilder::operator()(
    std::unique_ptr<Interpreter>* interpreter, int num_threads) {
  
  .....

  // Flatbuffer model schemas define a list of opcodes independent of the graph.
  // We first map those to registrations. This reduces string lookups for custom
  // ops since we only do it once per custom op rather than once per custom op
  // invocation in the model graph.
  // Construct interpreter with correct number of tensors and operators.
  auto* subgraphs = model_->subgraphs();
  auto* buffers = model_->buffers();
  if (subgraphs->size() != 1) {
    error_reporter_->Report("Only 1 subgraph is currently supported.\n");
    return cleanup_and_error();
  }
  const tflite::SubGraph* subgraph = (*subgraphs)[0];
  auto operators = subgraph->operators();
  auto tensors = subgraph->tensors();
  if (!operators || !tensors || !buffers) {
    error_reporter_->Report(
        "Did not get operators, tensors, or buffers in input flat buffer.\n");
    return cleanup_and_error();
  }
  interpreter->reset(new Interpreter(error_reporter_));
  if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
    return cleanup_and_error();
  }
  // Set num threads
  (**interpreter).SetNumThreads(num_threads);
  // Parse inputs/outputs
  (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
  (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));

  // Finally setup nodes and tensors
  if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();
  if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
    return cleanup_and_error();

  std::vector<int> variables;
  for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
    auto* tensor = (*interpreter)->tensor(i);
    if (tensor->is_variable) {
      variables.push_back(i);
    }
  }
  (**interpreter).SetVariables(std::move(variables));

  return kTfLiteOk;
}

模型的执行

image.png
经过InterpreterBuilder的工作,模型文件的内容已经解析成可执行的op,存储在interpreter.cc的nodes_and_registration_列表中。剩下的工作就是循环遍历调用每个op的invoke接口。

//interpreter.cc
TfLiteStatus Interpreter::Invoke() {
.....

  // Invocations are always done in node order.
  // Note that calling Invoke repeatedly will cause the original memory plan to
  // be reused, unless either ResizeInputTensor() or AllocateTensors() has been
  // called.
  // TODO(b/71913981): we should force recalculation in the presence of dynamic
  // tensors, because they may have new value which in turn may affect shapes
  // and allocations.
  for (int execution_plan_index = 0;
       execution_plan_index < execution_plan_.size(); execution_plan_index++) {
    if (execution_plan_index == next_execution_plan_index_to_prepare_) {
      TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
      TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >=
                                    execution_plan_index);
    }
    int node_index = execution_plan_[execution_plan_index];
    TfLiteNode& node = nodes_and_registration_[node_index].first;
    const TfLiteRegistration& registration =
        nodes_and_registration_[node_index].second;
    SCOPED_OPERATOR_PROFILE(profiler_, node_index);

    // TODO(ycling): This is an extra loop through inputs to check if the data
    // need to be copied from Delegate buffer to raw memory, which is often not
    // needed. We may want to cache this in prepare to know if this needs to be
    // done for a node or not.
    for (int i = 0; i < node.inputs->size; ++i) {
      int tensor_index = node.inputs->data[i];
      if (tensor_index == kOptionalTensor) {
        continue;
      }
      TfLiteTensor* tensor = &tensors_[tensor_index];
      if (tensor->delegate && tensor->delegate != node.delegate &&
          tensor->data_is_stale) {
        EnsureTensorDataIsReadable(tensor_index);
      }
    }

    EnsureTensorsVectorCapacity();
    tensor_resized_since_op_invoke_ = false;
    //逐个op调用
    if (OpInvoke(registration, &node) == kTfLiteError) {
      status = ReportOpError(&context_, node, registration, node_index,
                             "failed to invoke");
    }

    // Force execution prep for downstream ops if the latest op triggered the
    // resize of a dynamic tensor.
    if (tensor_resized_since_op_invoke_ &&
        HasDynamicTensor(context_, node.outputs)) {
      next_execution_plan_index_to_prepare_ = execution_plan_index + 1;
    }
  }

  if (!allow_buffer_handle_output_) {
    for (int tensor_index : outputs_) {
      EnsureTensorDataIsReadable(tensor_index);
    }
  }

  return status;
}

参考文章

目录
相关文章
|
2月前
|
机器学习/深度学习 人工智能 算法
模型无关的局部解释(LIME)技术原理解析及多领域应用实践
在当前数据驱动的商业环境中,人工智能(AI)和机器学习(ML)已成为各行业决策的关键工具,但随之而来的是“黑盒”问题:模型内部机制难以理解,引发信任缺失、监管合规难题及伦理考量。LIME(局部可解释模型无关解释)应运而生,通过解析复杂模型的个别预测,提供清晰、可解释的结果。LIME由华盛顿大学的研究者于2016年提出,旨在解决AI模型的透明度问题。它具有模型无关性、直观解释和局部保真度等优点,在金融、医疗等领域广泛应用。LIME不仅帮助企业提升决策透明度,还促进了模型优化和监管合规,是实现可解释AI的重要工具。
92 9
|
2月前
|
开发框架 供应链 监控
并行开发模型详解:类型、步骤及其应用解析
在现代研发环境中,企业需要在有限时间内推出高质量的产品,以满足客户不断变化的需求。传统的线性开发模式往往拖慢进度,导致资源浪费和延迟交付。并行开发模型通过允许多个开发阶段同时进行,极大提高了产品开发的效率和响应能力。本文将深入解析并行开发模型,涵盖其类型、步骤及如何通过辅助工具优化团队协作和管理工作流。
62 3
|
14天前
|
存储 网络协议 安全
30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场
本文精选了 30 道初级网络工程师面试题,涵盖 OSI 模型、TCP/IP 协议栈、IP 地址、子网掩码、VLAN、STP、DHCP、DNS、防火墙、NAT、VPN 等基础知识和技术,帮助小白们充分准备面试,顺利踏入职场。
47 2
|
14天前
|
存储 安全 Linux
Golang的GMP调度模型与源码解析
【11月更文挑战第11天】GMP 调度模型是 Go 语言运行时系统的核心部分,用于高效管理和调度大量协程(goroutine)。它通过少量的操作系统线程(M)和逻辑处理器(P)来调度大量的轻量级协程(G),从而实现高性能的并发处理。GMP 模型通过本地队列和全局队列来减少锁竞争,提高调度效率。在 Go 源码中,`runtime.h` 文件定义了关键数据结构,`schedule()` 和 `findrunnable()` 函数实现了核心调度逻辑。通过深入研究 GMP 模型,可以更好地理解 Go 语言的并发机制。
|
2月前
|
机器学习/深度学习 搜索推荐 大数据
深度解析:如何通过精妙的特征工程与创新模型结构大幅提升推荐系统中的召回率,带你一步步攻克大数据检索难题
【10月更文挑战第2天】在处理大规模数据集的推荐系统项目时,提高检索模型的召回率成为关键挑战。本文分享了通过改进特征工程(如加入用户活跃时段和物品相似度)和优化模型结构(引入注意力机制)来提升召回率的具体策略与实现代码。严格的A/B测试验证了新模型的有效性,为改善用户体验奠定了基础。这次实践加深了对特征工程与模型优化的理解,并为未来的技术探索提供了方向。
100 2
深度解析:如何通过精妙的特征工程与创新模型结构大幅提升推荐系统中的召回率,带你一步步攻克大数据检索难题
|
2月前
|
安全 Java
Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧
【10月更文挑战第20天】Java多线程通信新解:本文通过生产者-消费者模型案例,深入解析wait()、notify()、notifyAll()方法的实用技巧,包括避免在循环外调用wait()、优先使用notifyAll()、确保线程安全及处理InterruptedException等,帮助读者更好地掌握这些方法的应用。
19 1
|
2月前
|
机器学习/深度学习 算法 Python
深度解析机器学习中过拟合与欠拟合现象:理解模型偏差背后的原因及其解决方案,附带Python示例代码助你轻松掌握平衡技巧
【10月更文挑战第10天】机器学习模型旨在从数据中学习规律并预测新数据。训练过程中常遇过拟合和欠拟合问题。过拟合指模型在训练集上表现优异但泛化能力差,欠拟合则指模型未能充分学习数据规律,两者均影响模型效果。解决方法包括正则化、增加训练数据和特征选择等。示例代码展示了如何使用Python和Scikit-learn进行线性回归建模,并观察不同情况下的表现。
331 3
|
3月前
|
机器学习/深度学习 存储 人工智能
让模型评估模型:构建双代理RAG评估系统的步骤解析
在当前大语言模型(LLM)应用开发中,评估模型输出的准确性成为关键问题。本文介绍了一个基于双代理的RAG(检索增强生成)评估系统,使用生成代理和反馈代理对输出进行评估。文中详细描述了系统的构建过程,并展示了基于四种提示工程技术(ReAct、思维链、自一致性和角色提示)的不同结果。实验结果显示,ReAct和思维链技术表现相似,自一致性技术则呈现相反结果,角色提示技术最为不稳定。研究强调了多角度评估的重要性,并提供了系统实现的详细代码。
61 10
让模型评估模型:构建双代理RAG评估系统的步骤解析
|
21天前
|
安全 测试技术 Go
Go语言中的并发编程模型解析####
在当今的软件开发领域,高效的并发处理能力是提升系统性能的关键。本文深入探讨了Go语言独特的并发编程模型——goroutines和channels,通过实例解析其工作原理、优势及最佳实践,旨在为开发者提供实用的Go语言并发编程指南。 ####
|
3月前
|
机器学习/深度学习 数据采集 存储
一文读懂蒙特卡洛算法:从概率模拟到机器学习模型优化的全方位解析
蒙特卡洛方法起源于1945年科学家斯坦尼斯劳·乌拉姆对纸牌游戏中概率问题的思考,与约翰·冯·诺依曼共同奠定了该方法的理论基础。该方法通过模拟大量随机场景来近似复杂问题的解,因命名灵感源自蒙特卡洛赌场。如今,蒙特卡洛方法广泛应用于机器学习领域,尤其在超参数调优、贝叶斯滤波等方面表现出色。通过随机采样超参数空间,蒙特卡洛方法能够高效地找到优质组合,适用于处理高维度、非线性问题。本文通过实例展示了蒙特卡洛方法在估算圆周率π和优化机器学习模型中的应用,并对比了其与网格搜索方法的性能。
315 1

推荐镜像

更多