开发者社区> 问答> 正文

在C++中用jit调用自定义函数

我已经创建了以下类并保存了scripted_module装进C++空气污染指数:

torch::jit::script::Module module;
 module = torch::jit::load("scriptmodule.pt");

现在,问题是我怎么打电话func来自C++?

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    @torch.jit.ignore
    def get_rand(self):
        return torch.randint(0, 2, size=[1])

    def forward(self):
        pass 

    @torch.jit.export
    def func(self):
        done = self.get_rand()
        print (done)

scripted_module = torch.jit.script(MyModule())

展开
收起
aqal5zs3gkqgc 2019-12-04 22:40:40 1563 0
1 条回答
写回答
取消 提交回答
  • 将模型存下来

    scripted_module.save("scripted_module.pt")
    

    然后下载或者自己编译libpytroch,参考如下代码加载模型

    // One-stop header.
    #include <torch/script.h> 
    
    #include <iostream>
    #include <memory>
    
    int main(int argc, const char* argv[]) {
      if (argc != 2) {
        std::cerr << "usage: example-app <path-to-exported-script-module>\n";
        return -1;
      }
    
    
      torch::jit::script::Module module;
      try {
        // Deserialize the ScriptModule from a file using torch::jit::load().
        module = torch::jit::load(argv[1]);
      }
      catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
      }
    
      std::cout << "ok\n";
    }
    

    具体可参考:https://pytorch.org/tutorials/advanced/cpp_export.html

    2020-03-05 19:11:58
    赞同 展开评论 打赏
问答标签:
问答地址:
问答排行榜
最热
最新

相关电子书

更多
使用C++11开发PHP7扩展 立即下载
GPON Class C++ SFP O;T Transce 立即下载
GPON Class C++ SFP OLT Transce 立即下载

相关实验场景

更多