ONNX Runtime介绍_onnxruntime-程序员宅基地

技术标签: PyTorch  ONNX Runtime  Deep Learning  

      ONNX Runtime:由微软推出,用于优化和加速机器学习推理和训练,适用于ONNX模型,是一个跨平台推理和训练机器学习加速器(ONNX Runtime is a cross-platform inference and training machine-learning accelerator),源码地址:https://github.com/microsoft/onnxruntime,最新发布版本为v1.11.1,License为MIT:

      1.ONNX Runtime Inferencing:高性能推理引擎

      (1).可在不同的操作系统上运行,包括Windows、Linux、Mac、Android、iOS等;

      (2).可利用硬件增加性能,包括CUDA、TensorRT、DirectML、OpenVINO等;

      (3).支持PyTorch、TensorFlow等深度学习框架的模型,需先调用相应接口转换为ONNX模型;

      (4).在Python中训练,确可部署到C++/Java等应用程序中。

      2.ONNX Runtime Training:于2021年4月发布,可加快PyTorch对模型训练,可通过CUDA加速,目前多用于Linux平台。

      通过conda命令安装执行:

conda install -c conda-forge onnxruntime

      以下为测试代码:通过ResNet-50对图像进行分类

import numpy as np
import onnxruntime
import onnx
from onnx import numpy_helper
import urllib.request
import os
import tarfile
import json
import cv2

# reference: https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/inference_demos/resnet50_modelzoo_onnxruntime_inference.ipynb
def download_onnx_model():
    labels_file_name = "imagenet-simple-labels.json"
    model_tar_name = "resnet50v2.tar.gz"
    model_directory_name = "resnet50v2"

    if os.path.exists(model_tar_name) and os.path.exists(labels_file_name):
        print("files exist, don't need to download")
    else:
        print("files don't exist, need to download ...")

        onnx_model_url = "https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v2/resnet50v2.tar.gz"
        imagenet_labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"

        # retrieve our model from the ONNX model zoo
        urllib.request.urlretrieve(onnx_model_url, filename=model_tar_name)
        urllib.request.urlretrieve(imagenet_labels_url, filename=labels_file_name)

        print("download completed, start decompress ...")
        file = tarfile.open(model_tar_name)
        file.extractall("./")
        file.close()

    return model_directory_name, labels_file_name

def load_labels(path):
    with open(path) as f:
        data = json.load(f)
    return np.asarray(data)

def images_preprocess(images_path, images_name):
    input_data = []

    for name in images_name:
        img = cv2.imread(images_path + name)
        img = cv2.resize(img, (224, 224))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        data = np.array(img).transpose(2, 0, 1)
        #print(f"name: {name}, opencv image shape(h,w,c): {img.shape}, transpose shape(c,h,w): {data.shape}")
        # convert the input data into the float32 input
        data = data.astype('float32')

        # normalize
        mean_vec = np.array([0.485, 0.456, 0.406])
        stddev_vec = np.array([0.229, 0.224, 0.225])
        norm_data = np.zeros(data.shape).astype('float32')
        for i in range(data.shape[0]):
            norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]

        # add batch channel
        norm_data = norm_data.reshape(1, 3, 224, 224).astype('float32')
        input_data.append(norm_data)

    return input_data

def softmax(x):
    x = x.reshape(-1)
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def postprocess(result):
    return softmax(np.array(result)).tolist()

def inference(onnx_model, labels, input_data, images_name, images_label):
    session = onnxruntime.InferenceSession(onnx_model, None)
    # get the name of the first input of the model
    input_name = session.get_inputs()[0].name
    count = 0
    for data in input_data:
        print(f"{count+1}. image name: {images_name[count]}, actual value: {images_label[count]}")
        count += 1

        raw_result = session.run([], {input_name: data})

        res = postprocess(raw_result)

        idx = np.argmax(res)
        print(f"  result: idx: {idx}, label: {labels[idx]}, percentage: {round(res[idx]*100, 4)}%")

        sort_idx = np.flip(np.squeeze(np.argsort(res)))
        print("  top 5 labels are:", labels[sort_idx[:5]])

def main():
    model_directory_name, labels_file_name = download_onnx_model()

    labels = load_labels(labels_file_name)
    print("the number of categories is:", len(labels)) # 1000

    images_path = "../../data/image/"
    images_name = ["5.jpg", "6.jpg", "7.jpg", "8.jpg", "9.jpg", "10.jpg"]
    images_label = ["goldfish", "hen", "ostrich", "crocodile", "goose", "sheep"]
    if len(images_name) != len(images_label):
        print("Error: images count and labes'length don't match")
        return

    input_data = images_preprocess(images_path, images_name)

    onnx_model = model_directory_name + "/resnet50v2.onnx"
    inference(onnx_model, labels, input_data, images_name, images_label)

    print("test finish")

if __name__ == "__main__":
    main()

      测试图像如下所示:

      执行结果如下所示:

 

      GitHub: https://github.com/fengbingchun/PyTorch_Test

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/fengbingchun/article/details/125951896

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签