我已经创建了以下类并保存了scripted_module装进C++空气污染指数:
torch::jit::script::Module module;
module = torch::jit::load("scriptmodule.pt");
现在,问题是我怎么打电话func来自C++?
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
@torch.jit.ignore
def get_rand(self):
return torch.randint(0, 2, size=[1])
def forward(self):
pass
@torch.jit.export
def func(self):
done = self.get_rand()
print (done)
scripted_module = torch.jit.script(MyModule())
将模型存下来
scripted_module.save("scripted_module.pt")
然后下载或者自己编译libpytroch,参考如下代码加载模型
// One-stop header.
#include <torch/script.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "ok\n";
}
具体可参考:https://pytorch.org/tutorials/advanced/cpp_export.html
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。