Commit b93bde0d by Ting PAN

Refactor ONNX frontends and backends

Summary:
This commit redesigns the ``vm.onnx`` by referring the official repository.
Frontends and backends are aligned with identical API for dragon, torch and tensorrt.
1 parent e82d2ba4
Showing with 1023 additions and 861 deletions
vm.onnx dragon.onnx
======= ===========
.. toctree:: .. only:: html
:hidden:
Classes
-------
`class BackendRep <onnx/BackendRep.html>`_
: ONNX-Dragon backend to execute repeatedly.
Functions
---------
`export(...) <onnx/export.html>`_
: Export the recorded graph to an onnx model.
onnx/Shell `prepare_backend(...) <onnx/prepare_backend.html>`_
: Create a backend to execute repeatedly.
`record(...) <onnx/record.html>`_
: Context-manger to record the graph.
`run_model(...) <onnx/run_model.html>`_
: Execute an onnx model once.
`supports_device(...) <onnx/supports_device.html>`_
: Query if the given device is supported to execute.
Operators Operators
######### ---------
======================== ========= ======================================== ======================== ========= ========================================
Name Supported Reference Name Supported Reference
...@@ -168,13 +189,15 @@ Name Supported Reference ...@@ -168,13 +189,15 @@ Name Supported Reference
`Xor`_ |v| :func:`dragon.bitwise.bitwise_xor` `Xor`_ |v| :func:`dragon.bitwise.bitwise_xor`
======================== ========= ======================================== ======================== ========= ========================================
.. only:: html .. toctree::
:hidden:
Classes
#######
`class Shell <onnx/Shell.html>`_ onnx/BackendRep
: Context-manger to export or load onnx models. onnx/prepare_backend
onnx/export
onnx/record
onnx/run_model
onnx/supports_device
.. _Abs: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Abs .. _Abs: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Abs
.. _Acos: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Acos .. _Acos: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Acos
...@@ -331,14 +354,14 @@ Name Supported Reference ...@@ -331,14 +354,14 @@ Name Supported Reference
.. _Where: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Where .. _Where: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Where
.. _Xor: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Xor .. _Xor: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Xor
.. |v| image:: ../_static/images/tick.png .. |v| image:: ../../_static/images/tick.png
:height: 18 :height: 18
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "Module: dragon."; content: "Module: ";
color: #103d3e; color: #103d3e;
} }
</style> </style>
Shell BackendRep
===== ==========
.. autoclass:: dragon.vm.onnx.Shell .. autoclass:: dragon.onnx.BackendRep
__init__ __init__
-------- --------
.. automethod:: dragon.vm.onnx.Shell.__init__ .. automethod:: dragon.onnx.BackendRep.__init__
Methods Methods
------- -------
as_default run
########## ###
.. automethod:: dragon.vm.onnx.Shell.as_default .. automethod:: dragon.onnx.BackendRep.run
export
######
.. automethod:: dragon.vm.onnx.Shell.export
load_model
##########
.. automethod:: dragon.vm.onnx.Shell.load_model
.. raw:: html .. raw:: html
......
prepare export
======= ======
.. autofunction:: dragon.vm.tensorrt.backend.prepare .. autofunction:: dragon.onnx.export
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "tensorrt.backend."; content: "dragon.onnx.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
prepare_backend
===============
.. autofunction:: dragon.onnx.prepare_backend
.. raw:: html
<style>
h1:before {
content: "dragon.onnx.";
color: #103d3e;
}
</style>
record
======
.. autofunction:: dragon.onnx.record
.. _dragon.onnx.export(...): export.html
.. raw:: html
<style>
h1:before {
content: "dragon.onnx.";
color: #103d3e;
}
</style>
run_model run_model
========= =========
.. autofunction:: dragon.vm.tensorrt.backend.run_model .. autofunction:: dragon.onnx.run_model
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "tensorrt.backend."; content: "dragon.onnx.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
supports_device supports_device
=============== ===============
.. autofunction:: dragon.vm.tensorrt.backend.supports_device .. autofunction:: dragon.onnx.supports_device
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "tensorrt.backend."; content: "dragon.onnx.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
...@@ -38,6 +38,7 @@ Dragon ...@@ -38,6 +38,7 @@ Dragon
* `dragon.math <dragon/math.html>`_ * `dragon.math <dragon/math.html>`_
* `dragon.metrics <dragon/metrics.html>`_ * `dragon.metrics <dragon/metrics.html>`_
* `dragon.nn <dragon/nn.html>`_ * `dragon.nn <dragon/nn.html>`_
* `dragon.onnx <dragon/onnx.html>`_
* `dragon.optimizers <dragon/optimizers.html>`_ * `dragon.optimizers <dragon/optimizers.html>`_
* `dragon.random <dragon/random.html>`_ * `dragon.random <dragon/random.html>`_
* `dragon.vision <dragon/vision.html>`_ * `dragon.vision <dragon/vision.html>`_
...@@ -146,7 +147,9 @@ ONNX ...@@ -146,7 +147,9 @@ ONNX
This integration involves the following components: This integration involves the following components:
* `onnx <onnx.html>`_ * `dragon.onnx <dragon/onnx.html>`_
* `tensorrt.onnx <tensorrt/onnx.html>`_
* `torch.onnx <torch/onnx.html>`_
TensorRT TensorRT
######## ########
...@@ -158,7 +161,7 @@ TensorRT ...@@ -158,7 +161,7 @@ TensorRT
This integration involves the following components: This integration involves the following components:
* `tensorrt <tensorrt.html>`_ * `tensorrt <tensorrt.html>`_
* `tensorrt.backend <tensorrt/backend.html>`_ * `tensorrt.onnx <tensorrt/onnx.html>`_
Modules Modules
------- -------
...@@ -198,6 +201,9 @@ Modules ...@@ -198,6 +201,9 @@ Modules
`Module nn <dragon/nn.html>`_ `Module nn <dragon/nn.html>`_
: Native API for ``dragon.nn`` namespace. : Native API for ``dragon.nn`` namespace.
`Module onnx <dragon/onnx.html>`_
: Native API for ``dragon.onnx`` namespace.
`Module optimizers <dragon/optimizers.html>`_ `Module optimizers <dragon/optimizers.html>`_
: Native API for ``dragon.optimizers`` namespace. : Native API for ``dragon.optimizers`` namespace.
...@@ -219,9 +225,6 @@ Modules ...@@ -219,9 +225,6 @@ Modules
`Module vm.dali.ops <dali/ops.html>`_ `Module vm.dali.ops <dali/ops.html>`_
: Virtual API for ``dali.ops`` namespace. : Virtual API for ``dali.ops`` namespace.
`Module vm.onnx <onnx.html>`_
: Virtual API for ``onnx`` namespace.
`Module vm.tensorflow <tensorflow.html>`_ `Module vm.tensorflow <tensorflow.html>`_
: Virtual API for ``tensorflow`` namespace. : Virtual API for ``tensorflow`` namespace.
...@@ -255,8 +258,11 @@ Modules ...@@ -255,8 +258,11 @@ Modules
`Module vm.tensorlayer.models <tensorlayer/models.html>`_ `Module vm.tensorlayer.models <tensorlayer/models.html>`_
: Virtual API for ``tensorlayer.models`` namespace. : Virtual API for ``tensorlayer.models`` namespace.
`Module vm.tensorrt.backend <tensorrt/backend.html>`_ `Module vm.tensorrt <tensorrt.html>`_
: Virtual API for ``tensorrt.backend`` namespace. : Virtual API for ``tensorrt`` namespace.
`Module vm.tensorrt.onnx <tensorrt/onnx.html>`_
: Virtual API for ``tensorrt.onnx`` namespace.
`Module vm.torch <torch.html>`_ `Module vm.torch <torch.html>`_
: Virtual API for ``torch`` namespace. : Virtual API for ``torch`` namespace.
...@@ -306,6 +312,7 @@ Modules ...@@ -306,6 +312,7 @@ Modules
dragon/math dragon/math
dragon/metrics dragon/metrics
dragon/nn dragon/nn
dragon/onnx
dragon/optimizers dragon/optimizers
dragon/random dragon/random
dragon/vision dragon/vision
...@@ -313,7 +320,6 @@ Modules ...@@ -313,7 +320,6 @@ Modules
caffe/layers caffe/layers
dali dali
dali/ops dali/ops
onnx
tensorflow tensorflow
tensorflow/bitwise tensorflow/bitwise
tensorflow/dtypes tensorflow/dtypes
...@@ -326,7 +332,7 @@ Modules ...@@ -326,7 +332,7 @@ Modules
tensorlayer/layers tensorlayer/layers
tensorlayer/models tensorlayer/models
tensorrt tensorrt
tensorrt/backend tensorrt/onnx
torch torch
torch/autograd torch/autograd
torch/distributed torch/distributed
......
...@@ -12,15 +12,11 @@ vm.tensorrt ...@@ -12,15 +12,11 @@ vm.tensorrt
`class Engine <tensorrt/Engine.html>`_ `class Engine <tensorrt/Engine.html>`_
: The executing engine with bindings. : The executing engine with bindings.
`class ONNXBackendRep <tensorrt/ONNXBackendRep.html>`_
: Load and run onnx models.
.. toctree:: .. toctree::
:hidden: :hidden:
tensorrt/Binding tensorrt/Binding
tensorrt/Engine tensorrt/Engine
tensorrt/ONNXBackendRep
.. raw:: html .. raw:: html
......
vm.tensorrt.backend vm.tensorrt.onnx
=================== ================
.. only:: html .. only:: html
Classes
-------
`class BackendRep <onnx/BackendRep.html>`_
: ONNX-TensorRT backend to execute repeatedly.
Functions Functions
--------- ---------
`prepare(...) <backend/prepare.html>`_ `prepare_backend(...) <onnx/prepare_backend.html>`_
: Build a TensorRT engine from the onnx model. : Create a backend to execute repeatedly.
`run_model(...) <backend/run_model.html>`_ `run_model(...) <backend/run_model.html>`_
: Build and run a TensorRT engine from the onnx model. : Execute an onnx model once.
`run_node(...) <backend/run_node.html>`_ `run_node(...) <backend/run_node.html>`_
: Build and run a TensorRT engine from the onnx node. : Execute an onnx node once.
`supports_device(...) <backend/supports_device.html>`_ `supports_device(...) <backend/supports_device.html>`_
: Query if given device is supported. : Query if the given device is supported to execute.
.. toctree:: .. toctree::
:hidden: :hidden:
backend/prepare onnx/BackendRep
backend/run_model onnx/prepare
backend/run_node onnx/run_model
backend/supports_device onnx/run_node
onnx/supports_device
.. raw:: html .. raw:: html
......
ONNXBackendRep BackendRep
============== ==========
.. autoclass:: dragon.vm.tensorrt.ONNXBackendRep .. autoclass:: dragon.vm.tensorrt.onnx.BackendRep
__init__ __init__
-------- --------
.. automethod:: dragon.vm.tensorrt.ONNXBackendRep.__init__ .. automethod:: dragon.vm.tensorrt.onnx.BackendRep.__init__
Properties Properties
---------- ----------
engine engine
###### ######
.. autoattribute:: dragon.vm.tensorrt.ONNXBackendRep.engine .. autoattribute:: dragon.vm.tensorrt.onnx.BackendRep.engine
Methods Methods
...@@ -20,7 +20,7 @@ Methods ...@@ -20,7 +20,7 @@ Methods
run run
### ###
.. automethod:: dragon.vm.tensorrt.ONNXBackendRep.run .. automethod:: dragon.vm.tensorrt.onnx.BackendRep.run
.. raw:: html .. raw:: html
......
prepare_backend
===============
.. autofunction:: dragon.vm.tensorrt.onnx.prepare_backend
.. raw:: html
<style>
h1:before {
content: "tensorrt.onnx.";
color: #103d3e;
}
</style>
run_model
=========
.. autofunction:: dragon.vm.tensorrt.onnx.run_model
.. raw:: html
<style>
h1:before {
content: "tensorrt.onnx.";
color: #103d3e;
}
</style>
run_node run_node
======== ========
.. autofunction:: dragon.vm.tensorrt.backend.run_node .. autofunction:: dragon.vm.tensorrt.onnx.run_node
.. raw:: html .. raw:: html
<style> <style>
h1:before { h1:before {
content: "tensorrt.backend."; content: "tensorrt.onnx.";
color: #103d3e; color: #103d3e;
} }
</style> </style>
supports_device
===============
.. autofunction:: dragon.vm.tensorrt.onnx.supports_device
.. raw:: html
<style>
h1:before {
content: "tensorrt.onnx.";
color: #103d3e;
}
</style>
...@@ -6,7 +6,7 @@ vm.torch.onnx ...@@ -6,7 +6,7 @@ vm.torch.onnx
Functions Functions
--------- ---------
`export(...) <onnx/export.html>`_ `export(...) <onnx/export.html>`_
: Export a model into ONNX format. : Export the recorded graph to an onnx model.
.. toctree:: .. toctree::
:hidden: :hidden:
......
...@@ -274,7 +274,7 @@ PYBIND11_MODULE(libdragon_python, m) { ...@@ -274,7 +274,7 @@ PYBIND11_MODULE(libdragon_python, m) {
}) })
/*! \brief Load tensors and graph from a ONNX model */ /*! \brief Load tensors and graph from a ONNX model */
.def("ImportONNXModel", [](Workspace* self, const string& model_path) { .def("PrepareONNXModel", [](Workspace* self, const string& model_path) {
GraphDef init_graph, pred_graph; GraphDef init_graph, pred_graph;
onnx::ONNXBackend onnx_backend; onnx::ONNXBackend onnx_backend;
onnx_backend.Prepare(model_path, &init_graph, &pred_graph); onnx_backend.Prepare(model_path, &init_graph, &pred_graph);
......
...@@ -29,6 +29,7 @@ from dragon._api import losses ...@@ -29,6 +29,7 @@ from dragon._api import losses
from dragon._api import math from dragon._api import math
from dragon._api import metrics from dragon._api import metrics
from dragon._api import nn from dragon._api import nn
from dragon._api import onnx
from dragon._api import optimizers from dragon._api import optimizers
from dragon._api import random from dragon._api import random
from dragon._api import vision from dragon._api import vision
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
# Classes
from dragon.vm.onnx.core.backend.native import BackendRep
# Functions
from dragon.vm.onnx.core.backend.native import prepare as prepare_backend
from dragon.vm.onnx.core.backend.native import run_model
from dragon.vm.onnx.core.backend.native import supports_device
from dragon.vm.onnx.core.frontend.native import export
from dragon.vm.onnx.core.frontend.native import record
# Attributes
__all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -20,8 +20,8 @@ import os ...@@ -20,8 +20,8 @@ import os
from dragon.core.autograph import grad_maker from dragon.core.autograph import grad_maker
from dragon.core.autograph.op_def import OpDef from dragon.core.autograph.op_def import OpDef
from dragon.core.autograph.op_def import OpInfo from dragon.core.autograph.op_def import OpInfo
from dragon.core.autograph.tensor import TensorRef
from dragon.core.framework import config from dragon.core.framework import config
from dragon.core.framework import context
from dragon.core.framework import proto_util from dragon.core.framework import proto_util
from dragon.core.framework import types from dragon.core.framework import types
from dragon.core.framework import workspace from dragon.core.framework import workspace
...@@ -33,12 +33,14 @@ from dragon.core.util import nest ...@@ -33,12 +33,14 @@ from dragon.core.util import nest
def add_device_option(graph_def): def add_device_option(graph_def):
"""Add the device option.""" """Add the device option."""
cfg = config.config() cfg = config.config()
str2idx = {'cpu': 0, 'cuda': 1, 'cnml': 2} spec = context.get_device_spec()
dev_opt = dragon_pb2.DeviceOption() graph_def.device_option.CopyFrom(
dev_opt.device_type = str2idx[cfg.device_type] dragon_pb2.DeviceOption(
dev_opt.device_id = cfg.device_index device_type={'cpu': 0,
dev_opt.random_seed = cfg.random_seed 'cuda': 1,
graph_def.device_option.CopyFrom(dev_opt) 'cnml': 2}[spec.type],
device_id=spec.index,
random_seed=cfg.random_seed))
def add_grad_info(graph_def, targets): def add_grad_info(graph_def, targets):
...@@ -226,17 +228,13 @@ class Function(object): ...@@ -226,17 +228,13 @@ class Function(object):
f.write(str(graph_def)) f.write(str(graph_def))
logging.info('Export meta graph into: {}'.format(path)) logging.info('Export meta graph into: {}'.format(path))
def import_from(self, graph_def, explicit_inputs=False): def import_from(self, graph_def):
"""Import a defined function from a graph def. """Import a defined function from a graph def.
Set ``explicit_inputs`` to **True** to enforce feeding.
Parameters Parameters
---------- ----------
graph_def : GraphDef graph_def : GraphDef
The definition of graph. The definition of graph.
explicit_inputs : bool
Whether to enforce feeding on executing.
Returns Returns
------- -------
...@@ -244,8 +242,9 @@ class Function(object): ...@@ -244,8 +242,9 @@ class Function(object):
The self. The self.
""" """
self.outputs = [TensorRef(name) for name in graph_def.output] current_ws = workspace.get_workspace()
self.inputs = [TensorRef(name).constant() for name in graph_def.input] self.outputs = [current_ws.create_tensor(name) for name in graph_def.output]
self.inputs = [current_ws.create_tensor(name) for name in graph_def.input]
# Fill with all known graph elements. # Fill with all known graph elements.
add_device_option(graph_def) add_device_option(graph_def)
...@@ -253,18 +252,14 @@ class Function(object): ...@@ -253,18 +252,14 @@ class Function(object):
add_phase(graph_def, self.outputs) add_phase(graph_def, self.outputs)
# Notify the backend to create and optimize. # Notify the backend to create and optimize.
current_ws = workspace.get_workspace() graph_def.name = self.graph_def.name
self.graph_def = graph_def self.graph_def = graph_def
self.graph_name = current_ws.create_graph(graph_def) self.graph_name = current_ws.create_graph(graph_def)
# Bind a callback to run this graph. # Bind a callback to run this graph.
self.callback = lambda *args, **kwargs: \ self.callback = lambda *args, **kwargs: \
current_ws.run_graph( current_ws.run_graph(
name=self.graph_name, name=self.graph_name, outputs=self.outputs, **kwargs)
inputs_and_values=(self.inputs if explicit_inputs else [], args),
outputs=self.outputs,
**kwargs
)
return self return self
......
...@@ -26,7 +26,6 @@ register = _GLOBAL_REGISTERED_SPECS.register ...@@ -26,7 +26,6 @@ register = _GLOBAL_REGISTERED_SPECS.register
@register('Accuracy') @register('Accuracy')
def accuracy_spec(args, inputs, outputs): def accuracy_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype, outputs[0].shape = 'float32', [] outputs[0].dtype, outputs[0].shape = 'float32', []
return outputs return outputs
...@@ -97,7 +96,6 @@ def binary_shape_spec(inputs, outputs): ...@@ -97,7 +96,6 @@ def binary_shape_spec(inputs, outputs):
'Where', 'Where',
]) ])
def binary_math_spec(args, inputs, outputs): def binary_math_spec(args, inputs, outputs):
_ = locals()
outputs = binary_shape_spec(inputs, outputs) outputs = binary_shape_spec(inputs, outputs)
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
if inputs[0].dtype is None: if inputs[0].dtype is None:
...@@ -114,7 +112,6 @@ def binary_math_spec(args, inputs, outputs): ...@@ -114,7 +112,6 @@ def binary_math_spec(args, inputs, outputs):
'NotEqual', 'NotEqual',
]) ])
def binary_compare_spec(args, inputs, outputs): def binary_compare_spec(args, inputs, outputs):
_ = locals()
outputs = binary_shape_spec(inputs, outputs) outputs = binary_shape_spec(inputs, outputs)
outputs[0].dtype = 'bool' outputs[0].dtype = 'bool'
return outputs return outputs
...@@ -262,7 +259,6 @@ def depth_to_space_spec(args, inputs, outputs): ...@@ -262,7 +259,6 @@ def depth_to_space_spec(args, inputs, outputs):
@register('Dot') @register('Dot')
def dot_spec(args, inputs, outputs): def dot_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
try: try:
a_shape, b_shape = inputs[0].shape[:], inputs[1].shape[:] a_shape, b_shape = inputs[0].shape[:], inputs[1].shape[:]
...@@ -358,7 +354,6 @@ def expand_dims_spec(args, inputs, outputs): ...@@ -358,7 +354,6 @@ def expand_dims_spec(args, inputs, outputs):
'TruncatedNormal', 'TruncatedNormal',
]) ])
def fill_spec(args, inputs, outputs): def fill_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = args['dtype'] outputs[0].dtype = args['dtype']
try: try:
if 'dims' in args: if 'dims' in args:
...@@ -476,7 +471,6 @@ def index_select_spec(args, inputs, outputs): ...@@ -476,7 +471,6 @@ def index_select_spec(args, inputs, outputs):
@register(['IsInf', 'IsNaN']) @register(['IsInf', 'IsNaN'])
def is_spec(args, inputs, outputs): def is_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = 'bool' outputs[0].dtype = 'bool'
try: try:
outputs[0].shape = inputs[0].shape[:] outputs[0].shape = inputs[0].shape[:]
...@@ -487,7 +481,6 @@ def is_spec(args, inputs, outputs): ...@@ -487,7 +481,6 @@ def is_spec(args, inputs, outputs):
@register('LinSpace') @register('LinSpace')
def linspace_spec(args, inputs, outputs): def linspace_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = args['dtype'] outputs[0].dtype = args['dtype']
outputs[0].shape = args['dims'] outputs[0].shape = args['dims']
return outputs return outputs
...@@ -495,7 +488,6 @@ def linspace_spec(args, inputs, outputs): ...@@ -495,7 +488,6 @@ def linspace_spec(args, inputs, outputs):
@register('MaskedSelect') @register('MaskedSelect')
def masked_select_spec(args, inputs, outputs): def masked_select_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
outputs[0].shape = (None,) outputs[0].shape = (None,)
return outputs return outputs
...@@ -552,7 +544,6 @@ def multinomial_spec(args, inputs, outputs): ...@@ -552,7 +544,6 @@ def multinomial_spec(args, inputs, outputs):
@register('NonZero') @register('NonZero')
def non_zero_spec(args, inputs, outputs): def non_zero_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = 'int64' outputs[0].dtype = 'int64'
try: try:
outputs[0].shape = (None, len(inputs[0].shape)) outputs[0].shape = (None, len(inputs[0].shape))
...@@ -592,7 +583,6 @@ def pad_spec(args, inputs, outputs): ...@@ -592,7 +583,6 @@ def pad_spec(args, inputs, outputs):
@register('Permutation') @register('Permutation')
def permutation_spec(args, inputs, outputs): def permutation_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = args['dtype'] outputs[0].dtype = args['dtype']
if len(inputs) == 1: if len(inputs) == 1:
try: try:
...@@ -638,13 +628,11 @@ def pool_spec(args, inputs, outputs): ...@@ -638,13 +628,11 @@ def pool_spec(args, inputs, outputs):
@register(['PythonPlugin', 'PythonPluginInfer']) @register(['PythonPlugin', 'PythonPluginInfer'])
def python_spec(args, inputs, outputs): def python_spec(args, inputs, outputs):
_ = locals()
return outputs return outputs
@register('Range') @register('Range')
def range_spec(args, inputs, outputs): def range_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = args['dtype'] outputs[0].dtype = args['dtype']
slice_args = args['slice'] slice_args = args['slice']
if len(slice_args) == 2: if len(slice_args) == 2:
...@@ -802,7 +790,6 @@ def roi_pool_spec(args, inputs, outputs): ...@@ -802,7 +790,6 @@ def roi_pool_spec(args, inputs, outputs):
@register('Shape') @register('Shape')
def shape_spec(args, inputs, outputs): def shape_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = 'int64' outputs[0].dtype = 'int64'
try: try:
outputs[0].shape = [len(inputs[0].shape)] outputs[0].shape = [len(inputs[0].shape)]
...@@ -861,7 +848,6 @@ def softmax_loss_spec(args, inputs, outputs): ...@@ -861,7 +848,6 @@ def softmax_loss_spec(args, inputs, outputs):
@register('Sort') @register('Sort')
def sort_spec(args, inputs, outputs): def sort_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
outputs[1].dtype = 'int64' outputs[1].dtype = 'int64'
try: try:
...@@ -1054,7 +1040,6 @@ def top_k_spec(args, inputs, outputs): ...@@ -1054,7 +1040,6 @@ def top_k_spec(args, inputs, outputs):
@register('Unchanged') @register('Unchanged')
def unchanged_spec(args, inputs, outputs): def unchanged_spec(args, inputs, outputs):
_ = locals()
outputs[0].dtype = inputs[0].dtype outputs[0].dtype = inputs[0].dtype
try: try:
outputs[0].shape = inputs[0].shape[:] outputs[0].shape = inputs[0].shape[:]
......
...@@ -80,7 +80,6 @@ def deprecated(date, instructions, warn_once=True): ...@@ -80,7 +80,6 @@ def deprecated(date, instructions, warn_once=True):
def not_installed(package=''): def not_installed(package=''):
"""Return a dummy function for the package that is not installed.""" """Return a dummy function for the package that is not installed."""
def dummy_fn(*args, **kwargs): def dummy_fn(*args, **kwargs):
_ = locals()
raise ImportError('Package <%s> is required but not installed.' % package) raise ImportError('Package <%s> is required but not installed.' % package)
return dummy_fn return dummy_fn
......
...@@ -14,9 +14,6 @@ from __future__ import absolute_import as _absolute_import ...@@ -14,9 +14,6 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
# Classes
from dragon.vm.onnx.core.shell import Shell
# Attributes # Attributes
from dragon.vm.onnx.core import nodes as _nodes from dragon.vm.onnx.core import exporters as _import_exporters
__all__ = [_s for _s in dir() if not _s.startswith('_')] __all__ = [_s for _s in dir() if not _s.startswith('_')]
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Native ONNX backend."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import numpy
try:
import onnx
from onnx.backend.base import Backend
from onnx.backend.base import BackendRep as ONNXBackendRep
from onnx.backend.base import Device
from onnx.backend.base import DeviceType
from onnx.backend.base import namedtupledict
except ImportError:
from dragon.core.util import deprecation
onnx = deprecation.NotInstalled('onnx')
Backend = object
ONNXBackendRep = object
Device = deprecation.NotInstalled('onnx')
DeviceType = deprecation.NotInstalled('onnx')
from dragon.core.autograph import function_lib
from dragon.core.eager.tensor import EagerTensor
from dragon.core.framework import context
from dragon.core.framework import device_spec
from dragon.core.framework import workspace
from dragon.core.proto import dragon_pb2
from dragon.core.util import nest
class BackendRep(ONNXBackendRep):
"""ONNX-Dragon backend to execute repeatedly."""
def __init__(self, model, device, **kwargs):
"""Create a ``BackendRep``.
Parameters
----------
model : str
The path of onnx model file.
device : onnx.Device
The executing device.
"""
if not isinstance(device, Device):
device = Device(device)
graph_str = workspace.get_workspace().PrepareONNXModel(model)
graph_def = dragon_pb2.GraphDef()
graph_def.ParseFromString(graph_str)
if device.type == DeviceType.CPU:
device_type, device_index = 'cpu', 0
elif device.type == DeviceType.CUDA:
device_type, device_index = 'cuda', device.device_id
else:
raise ValueError('Unsupported device type: ' + device.type)
with context.device(device_type, device_index):
self._function = function_lib.Function(name='ONNXGraph') \
.import_from(graph_def)
self._input_dict = collections.OrderedDict(
[(impl.name, EagerTensor(impl=impl, device=device_spec.DeviceSpec(
device_type, device_index))) for impl in self._function.inputs])
self._output_dict = collections.OrderedDict(
[(impl.name, EagerTensor(impl=impl, device=device_spec.DeviceSpec(
device_type, device_index))) for impl in self._function.outputs])
def run(self, inputs, **kwargs):
"""Run the model.
Parameters
----------
inputs : Union[Sequence, Dict]
The input arrays.
Returns
-------
namedtuple
The model outputs.
"""
if isinstance(inputs, numpy.ndarray):
inputs = [inputs]
if isinstance(inputs, dict):
for name, value in inputs.items():
self._input_dict[name]._impl.FromNumpy(value)
elif nest.is_sequence(inputs):
for ref, value in zip(self._input_dict.values(), inputs):
ref._impl.FromNumpy(value)
else:
raise ValueError('Excepted sequence or dict inputs.')
self._function.callback(return_outputs=False)
named_outputs = namedtupledict('Outputs', list(self._output_dict.keys()))
return named_outputs(*(self._output_dict.values()))
class DragonBackend(Backend):
"""ONNX-Dragon backend."""
@classmethod
def prepare(cls, model, device='CPU:0', **kwargs):
"""Create a backend to execute repeatedly.
Parameters
----------
model : str
The path of onnx model file.
device : str, optional, default='CPU:0'
The executing device.
Returns
-------
dragon.onnx.BackendRep
The backend.
"""
if not os.path.exists(model):
raise ValueError('Model({}) is not existed.'.format(model))
return BackendRep(model, device, **kwargs)
@classmethod
def run_model(cls, model, inputs, device='CUDA:0', **kwargs):
"""Execute an onnx model once.
Parameters
----------
model : str
The path of onnx model file.
inputs : Union[Sequence, Dict]
The input arrays.
device : str, optional
The executing device.
Returns
-------
namedtuple
The model outputs.
"""
return cls.prepare(model, device, **kwargs).run(inputs)
@classmethod
def supports_device(cls, device_str):
"""Query if the given device is supported.
Parameters
----------
device_str : str
The device descriptor.
Returns
-------
bool
**True** if device is supported otherwise **False**.
"""
device = Device(device_str)
if device.type in (DeviceType.CPU, DeviceType.CUDA):
return True
return False
prepare = DragonBackend.prepare
run_model = DragonBackend.run_model
supports_device = DragonBackend.supports_device
...@@ -7,11 +7,8 @@ ...@@ -7,11 +7,8 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/onnx/onnx-tensorrt/blob/master/onnx_tensorrt/backend.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
"""TensorRT ONNX backend."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -22,14 +19,14 @@ import numpy ...@@ -22,14 +19,14 @@ import numpy
try: try:
import onnx import onnx
from onnx.backend.base import Backend from onnx.backend.base import Backend
from onnx.backend.base import BackendRep from onnx.backend.base import BackendRep as ONNXBackendRep
from onnx.backend.base import Device from onnx.backend.base import Device
from onnx.backend.base import DeviceType from onnx.backend.base import DeviceType
except ImportError: except ImportError:
from dragon.core.util import deprecation from dragon.core.util import deprecation
onnx = deprecation.NotInstalled('onnx') onnx = deprecation.NotInstalled('onnx')
Backend = object Backend = object
BackendRep = object ONNXBackendRep = object
Device = deprecation.NotInstalled('onnx') Device = deprecation.NotInstalled('onnx')
DeviceType = deprecation.NotInstalled('onnx') DeviceType = deprecation.NotInstalled('onnx')
...@@ -41,8 +38,8 @@ from dragon.vm.tensorrt.core.engine import trt ...@@ -41,8 +38,8 @@ from dragon.vm.tensorrt.core.engine import trt
from dragon.vm.tensorrt.core.engine import TRT_LOGGER from dragon.vm.tensorrt.core.engine import TRT_LOGGER
class ONNXBackendRep(BackendRep): class BackendRep(ONNXBackendRep):
"""Load and run onnx models.""" """ONNX-TensorRT backend to execute repeatedly."""
def __init__( def __init__(
self, self,
...@@ -53,7 +50,7 @@ class ONNXBackendRep(BackendRep): ...@@ -53,7 +50,7 @@ class ONNXBackendRep(BackendRep):
optimization_profiles=None, optimization_profiles=None,
serialize_engine=False, serialize_engine=False,
): ):
"""Create a ``ONNXBackendRep``. """Create a ``BackendRep``.
Parameters Parameters
---------- ----------
...@@ -71,6 +68,8 @@ class ONNXBackendRep(BackendRep): ...@@ -71,6 +68,8 @@ class ONNXBackendRep(BackendRep):
Whether to serialize engine into a file. Whether to serialize engine into a file.
""" """
if not isinstance(device, Device):
device = Device(device)
self._set_device(device) self._set_device(device)
self._logger = TRT_LOGGER self._logger = TRT_LOGGER
self._builder = trt.Builder(self._logger) self._builder = trt.Builder(self._logger)
...@@ -193,12 +192,12 @@ class ONNXBackendRep(BackendRep): ...@@ -193,12 +192,12 @@ class ONNXBackendRep(BackendRep):
cuda.set_device(device.device_id) cuda.set_device(device.device_id)
class ONNXBackend(Backend): class TensorRTBackend(Backend):
"""ONNX-TensorRT backend.""" """ONNX-TensorRT backend."""
@classmethod @classmethod
def prepare(cls, model, device='CUDA:0', **kwargs): def prepare(cls, model, device='CUDA:0', **kwargs):
"""Build a TensorRT engine from the onnx model. """Create a backend to execute repeatedly.
Parameters Parameters
---------- ----------
...@@ -209,17 +208,15 @@ class ONNXBackend(Backend): ...@@ -209,17 +208,15 @@ class ONNXBackend(Backend):
Returns Returns
------- -------
dragon.vm.tensorrt.ONNXBackendRep tensorrt.onnx.BackendRep
The backend rep. The backend.
""" """
if not isinstance(device, Device): return BackendRep(model, device, **kwargs)
device = Device(device)
return ONNXBackendRep(model, device, **kwargs)
@classmethod @classmethod
def run_model(cls, model, inputs, device='CUDA:0', **kwargs): def run_model(cls, model, inputs, device='CUDA:0', **kwargs):
"""Build and run a TensorRT engine from the onnx model. """Execute an onnx model once.
Parameters Parameters
---------- ----------
...@@ -240,7 +237,7 @@ class ONNXBackend(Backend): ...@@ -240,7 +237,7 @@ class ONNXBackend(Backend):
@classmethod @classmethod
def run_node(cls, node, inputs, device='CUDA:0', **kwargs): def run_node(cls, node, inputs, device='CUDA:0', **kwargs):
"""Build and run a TensorRT engine from the onnx node. """Execute an onnx node once.
Parameters Parameters
---------- ----------
...@@ -248,7 +245,7 @@ class ONNXBackend(Backend): ...@@ -248,7 +245,7 @@ class ONNXBackend(Backend):
The onnx node. The onnx node.
inputs : Union[Sequence, Dict] inputs : Union[Sequence, Dict]
The input arrays. The input arrays.
device : str, optional device : str, optional, default='CUDA:0'
The executing device. The executing device.
Returns Returns
...@@ -257,7 +254,7 @@ class ONNXBackend(Backend): ...@@ -257,7 +254,7 @@ class ONNXBackend(Backend):
The model outputs. The model outputs.
""" """
super(ONNXBackend, cls).run_node(node, inputs, device) super(TensorRTBackend, cls).run_node(node, inputs, device)
model = onnx_helper.make_model_from_node(node, inputs, use_weights=True) model = onnx_helper.make_model_from_node(node, inputs, use_weights=True)
try: try:
results = cls.prepare(model, device).run(inputs[:1]) results = cls.prepare(model, device).run(inputs[:1])
...@@ -268,7 +265,7 @@ class ONNXBackend(Backend): ...@@ -268,7 +265,7 @@ class ONNXBackend(Backend):
@classmethod @classmethod
def supports_device(cls, device_str): def supports_device(cls, device_str):
"""Query if given device is supported. """Query if the given device is supported.
Parameters Parameters
---------- ----------
...@@ -285,7 +282,7 @@ class ONNXBackend(Backend): ...@@ -285,7 +282,7 @@ class ONNXBackend(Backend):
return device.type == DeviceType.CUDA return device.type == DeviceType.CUDA
prepare = ONNXBackend.prepare prepare = TensorRTBackend.prepare
run_node = ONNXBackend.run_node run_node = TensorRTBackend.run_node
run_model = ONNXBackend.run_model run_model = TensorRTBackend.run_model
supports_device = ONNXBackend.supports_device supports_device = TensorRTBackend.supports_device
...@@ -13,8 +13,8 @@ from __future__ import absolute_import ...@@ -13,8 +13,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.vm.onnx.core.nodes import activation from dragon.vm.onnx.core.exporters import activation
from dragon.vm.onnx.core.nodes import array from dragon.vm.onnx.core.exporters import array
from dragon.vm.onnx.core.nodes import normalization from dragon.vm.onnx.core.exporters import normalization
from dragon.vm.onnx.core.nodes import math from dragon.vm.onnx.core.exporters import math
from dragon.vm.onnx.core.nodes import vision from dragon.vm.onnx.core.exporters import vision
...@@ -13,26 +13,26 @@ from __future__ import absolute_import ...@@ -13,26 +13,26 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.vm.onnx.core import exporter
from dragon.vm.onnx.core import helper from dragon.vm.onnx.core import helper
from dragon.vm.onnx.core.exporters import utils as export_util
@exporter.register('Dropout') @export_util.register('Dropout')
def dropout_exporter(op_def, shape_dict, ws): def dropout_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
drop_ratio = 0.5 # The prob to set zeros randomly. drop_ratio = 0.5 # The prob to set zeros randomly.
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'prob': if arg.name == 'prob':
drop_ratio = arg.f drop_ratio = arg.f
elif arg.name == 'prob_desc': elif arg.name == 'prob_desc':
drop_ratio = helper.fetch_argument(op_def, arg, ws) drop_ratio = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'ratio', drop_ratio) helper.add_attribute(node, 'ratio', drop_ratio)
return node, const_tensors return node, const_tensors
@exporter.register('HardSigmoid') @export_util.register('HardSigmoid')
def hardsigmoid_exporter(op_def, shape_dict, ws): def hardsigmoid_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
alpha, beta = 0.2, 0.5 alpha, beta = 0.2, 0.5
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'alpha': if arg.name == 'alpha':
...@@ -44,16 +44,16 @@ def hardsigmoid_exporter(op_def, shape_dict, ws): ...@@ -44,16 +44,16 @@ def hardsigmoid_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('PRelu') @export_util.register('PRelu')
def prelu_exporter(op_def, shape_dict, ws): def prelu_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
const_tensors = [helper.from_tensor(op_def.input[1], ws)] const_tensors = [helper.from_tensor(op_def.input[1], context.ws)]
return node, const_tensors return node, const_tensors
@exporter.register('Relu') @export_util.register('Relu')
def relu_exporter(op_def, shape_dict, ws): def relu_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'alpha': if arg.name == 'alpha':
if arg.f > 0: if arg.f > 0:
...@@ -62,9 +62,9 @@ def relu_exporter(op_def, shape_dict, ws): ...@@ -62,9 +62,9 @@ def relu_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Selu') @export_util.register('Selu')
def selu_exporter(op_def, shape_dict, ws): def selu_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
alpha, gamma = 1.67326, 1.0507 alpha, gamma = 1.67326, 1.0507
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'alpha': if arg.name == 'alpha':
...@@ -76,10 +76,10 @@ def selu_exporter(op_def, shape_dict, ws): ...@@ -76,10 +76,10 @@ def selu_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Softmax') @export_util.register('Softmax')
def softmax_exporter(op_def, shape_dict, ws): def softmax_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
ndim = len(shape_dict[op_def.input[0]]) ndim = len(context.blob_shapes[op_def.input[0]])
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axis': if arg.name == 'axis':
axis = arg.i + (ndim if arg.i < 0 else 0) axis = arg.i + (ndim if arg.i < 0 else 0)
......
...@@ -20,13 +20,13 @@ except ImportError: ...@@ -20,13 +20,13 @@ except ImportError:
from dragon.core.util import deprecation from dragon.core.util import deprecation
TensorProto = deprecation.not_installed('ONNX') TensorProto = deprecation.not_installed('ONNX')
from dragon.vm.onnx.core import exporter
from dragon.vm.onnx.core import helper from dragon.vm.onnx.core import helper
from dragon.vm.onnx.core.exporters import utils as export_util
@exporter.register(['ArgMax', 'ArgMin']) @export_util.register(['ArgMax', 'ArgMin'])
def arg_reduce_exporter(op_def, shape_dict, ws): def arg_reduce_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
# ONNX requires indices only, remove the values. # ONNX requires indices only, remove the values.
indices = node.output[0] indices = node.output[0]
node.ClearField('output') node.ClearField('output')
...@@ -39,9 +39,9 @@ def arg_reduce_exporter(op_def, shape_dict, ws): ...@@ -39,9 +39,9 @@ def arg_reduce_exporter(op_def, shape_dict, ws):
return node, None return node, None
@exporter.register('Cast') @export_util.register('Cast')
def cast_exporter(op_def, shape_dict, ws): def cast_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Cast' node.op_type = 'Cast'
if len(node.input) == 0: if len(node.input) == 0:
raise ValueError('ONNX does not support in-place cast.') raise ValueError('ONNX does not support in-place cast.')
...@@ -51,9 +51,9 @@ def cast_exporter(op_def, shape_dict, ws): ...@@ -51,9 +51,9 @@ def cast_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('ChannelAffine') @export_util.register('ChannelAffine')
def channel_affine_exporter(op_def, shape_dict, ws): def channel_affine_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelAffine') helper.add_attribute(node, 'op_type', 'ChannelAffine')
for arg in op_def.arg: for arg in op_def.arg:
...@@ -62,13 +62,13 @@ def channel_affine_exporter(op_def, shape_dict, ws): ...@@ -62,13 +62,13 @@ def channel_affine_exporter(op_def, shape_dict, ws):
elif arg.name == 'num_axes': elif arg.name == 'num_axes':
helper.add_attribute(node, 'num_axes', arg.i) helper.add_attribute(node, 'num_axes', arg.i)
# Weights and biases # Weights and biases
const_tensors = [helper.from_tensor(e, ws) for e in op_def.input[1:]] const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
@exporter.register('ChannelNormalize') @export_util.register('ChannelNormalize')
def channel_normalize_exporter(op_def, shape_dict, ws): def channel_normalize_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx node.op_type = 'ATen' # Currently not supported in ai.onnx
helper.add_attribute(node, 'op_type', 'ChannelNormalize') helper.add_attribute(node, 'op_type', 'ChannelNormalize')
for arg in op_def.arg: for arg in op_def.arg:
...@@ -83,27 +83,27 @@ def channel_normalize_exporter(op_def, shape_dict, ws): ...@@ -83,27 +83,27 @@ def channel_normalize_exporter(op_def, shape_dict, ws):
elif arg.name == 'perm': elif arg.name == 'perm':
helper.add_attribute(node, 'perm', arg.ints) helper.add_attribute(node, 'perm', arg.ints)
elif arg.name == 'perm_desc': elif arg.name == 'perm_desc':
values = helper.fetch_argument(op_def, arg, ws) values = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values) helper.add_attribute(node, 'perm', values)
elif arg.name == 'perm_descs': elif arg.name == 'perm_descs':
if len(arg.strings) > 0: if len(arg.strings) > 0:
values = helper.fetch_arguments(op_def, arg, ws) values = helper.fetch_arguments(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values) helper.add_attribute(node, 'perm', values)
return node, const_tensors return node, const_tensors
@exporter.register('Concat') @export_util.register('Concat')
def concat_exporter(op_def, shape_dict, ws): def concat_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axis': if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i) helper.add_attribute(node, 'axis', arg.i)
return node, const_tensors return node, const_tensors
@exporter.register('CumSum') @export_util.register('CumSum')
def cumulative_exporter(op_def, shape_dict, ws): def cumulative_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
axis = 0 axis = 0
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axis': if arg.name == 'axis':
...@@ -114,27 +114,27 @@ def cumulative_exporter(op_def, shape_dict, ws): ...@@ -114,27 +114,27 @@ def cumulative_exporter(op_def, shape_dict, ws):
helper.add_attribute(node, 'reverse', arg.i) helper.add_attribute(node, 'reverse', arg.i)
axis = helper.from_array( axis = helper.from_array(
numpy.array(axis, 'int64'), numpy.array(axis, 'int64'),
op_def.input[0] + '/cumulative/axis', context.unique_name(op_def.input[0] + '/cumulative/axis'),
) )
node.input.extend([axis.name]) node.input.extend([axis.name])
return node, [axis] return node, [axis]
@exporter.register('Expand') @export_util.register('Expand')
def expand_exporter(op_def, shape_dict, ws): def expand_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
shape = list(shape_dict[op_def.output[0]]) shape = list(context.blob_shapes[op_def.output[0]])
shape = helper.from_array( shape = helper.from_array(
numpy.array(shape, 'int64'), numpy.array(shape, 'int64'),
op_def.input[0] + '/expand/shape', context.unique_name(op_def.input[0] + '/expand/shape'),
) )
node.input.extend([shape.name]) node.input.extend([shape.name])
return node, [shape] return node, [shape]
@exporter.register('ExpandDims') @export_util.register('ExpandDims')
def expand_dims_exporter(op_def, shape_dict, ws): def expand_dims_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Unsqueeze' node.op_type = 'Unsqueeze'
axes = None axes = None
for arg in op_def.arg: for arg in op_def.arg:
...@@ -145,13 +145,13 @@ def expand_dims_exporter(op_def, shape_dict, ws): ...@@ -145,13 +145,13 @@ def expand_dims_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Eye') @export_util.register('Eye')
def eye_exporter(op_def, shape_dict, ws): def eye_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
if len(op_def.input) > 0: if len(op_def.input) > 0:
node.op_type += 'Like' node.op_type += 'Like'
else: else:
output_shape = list(shape_dict[op_def.output[0]]) output_shape = list(context.blob_shapes[op_def.output[0]])
helper.add_attribute(node, 'shape', output_shape) helper.add_attribute(node, 'shape', output_shape)
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'k': if arg.name == 'k':
...@@ -161,9 +161,9 @@ def eye_exporter(op_def, shape_dict, ws): ...@@ -161,9 +161,9 @@ def eye_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Flatten') @export_util.register('Flatten')
def flatten_exporter(op_def, shape_dict, ws): def flatten_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axis': if arg.name == 'axis':
helper.add_attribute(node, 'axis', arg.i) helper.add_attribute(node, 'axis', arg.i)
...@@ -177,9 +177,9 @@ def flatten_exporter(op_def, shape_dict, ws): ...@@ -177,9 +177,9 @@ def flatten_exporter(op_def, shape_dict, ws):
return node, None return node, None
@exporter.register('IndexSelect') @export_util.register('IndexSelect')
def index_select_exporter(op_def, shape_dict, ws): def index_select_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Gather' node.op_type = 'Gather'
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axis': if arg.name == 'axis':
...@@ -190,9 +190,9 @@ def index_select_exporter(op_def, shape_dict, ws): ...@@ -190,9 +190,9 @@ def index_select_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Multinomial') @export_util.register('Multinomial')
def multinomial_exporter(op_def, shape_dict, ws): def multinomial_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
helper.add_attribute(node, 'dtype', helper.tensor_type('int64')) helper.add_attribute(node, 'dtype', helper.tensor_type('int64'))
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'num_samples': if arg.name == 'num_samples':
...@@ -200,12 +200,12 @@ def multinomial_exporter(op_def, shape_dict, ws): ...@@ -200,12 +200,12 @@ def multinomial_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('OneHot') @export_util.register('OneHot')
def one_hot_exporter(op_def, shape_dict, ws): def one_hot_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
helper.add_attribute(node, 'axis', -1) helper.add_attribute(node, 'axis', -1)
depth, on_value, off_value = 1, 1, 0 depth, on_value, off_value = 1, 1, 0
dtype = ws.FetchTensor(node.output[0]).dtype dtype = context.ws.FetchTensor(node.output[0]).dtype
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'depth': if arg.name == 'depth':
depth = arg.i depth = arg.i
...@@ -215,27 +215,27 @@ def one_hot_exporter(op_def, shape_dict, ws): ...@@ -215,27 +215,27 @@ def one_hot_exporter(op_def, shape_dict, ws):
off_value = arg.i off_value = arg.i
depth = helper.from_array( depth = helper.from_array(
numpy.array(depth, 'int64'), numpy.array(depth, 'int64'),
op_def.input[0] + '/one_hot/depth', context.unique_name(op_def.input[0] + '/one_hot/depth'),
) )
values = helper.from_array( values = helper.from_array(
numpy.array([off_value, on_value], dtype), numpy.array([off_value, on_value], dtype),
op_def.input[0] + '/one_hot/values', context.unique_name(op_def.input[0] + '/one_hot/values'),
) )
const_tensors = [depth, values] const_tensors = [depth, values]
node.input.extend([depth.name, values.name]) node.input.extend([depth.name, values.name])
return node, const_tensors return node, const_tensors
def pad_exporter(op_def, shape_dict, ws): def pad_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
pads, value = [], 0 pads, value = [], 0
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'pads': if arg.name == 'pads':
pads = [int(e) for e in arg.ints] pads = [int(e) for e in arg.ints]
elif arg.name == 'pads_desc': elif arg.name == 'pads_desc':
pads = helper.fetch_argument(op_def, arg, ws) pads = helper.fetch_argument(op_def, arg, context.ws)
elif arg.name == 'pads_descs': elif arg.name == 'pads_descs':
pads = helper.fetch_arguments(op_def, arg, ws) pads = helper.fetch_arguments(op_def, arg, context.ws)
elif arg.name == 'mode': elif arg.name == 'mode':
helper.add_attribute(node, 'mode', arg.s.lower()) helper.add_attribute(node, 'mode', arg.s.lower())
elif arg.name == 'value': elif arg.name == 'value':
...@@ -243,44 +243,44 @@ def pad_exporter(op_def, shape_dict, ws): ...@@ -243,44 +243,44 @@ def pad_exporter(op_def, shape_dict, ws):
return node, pads, value return node, pads, value
@exporter.register('Pad-1') @export_util.register('Pad-1')
def pad_exporter_v1(op_def, shape_dict, ws): def pad_exporter_v1(op_def, context):
node, pads, value = pad_exporter(**locals()) node, pads, value = pad_exporter(**locals())
helper.add_attribute(node, 'paddings', pads) helper.add_attribute(node, 'paddings', pads)
helper.add_attribute(node, 'value', value) helper.add_attribute(node, 'value', value)
return node, [] return node, []
@exporter.register('Pad-2') @export_util.register('Pad-2')
def pad_exporter_v2(op_def, shape_dict, ws): def pad_exporter_v2(op_def, context):
node, pads, value = pad_exporter(**locals()) node, pads, value = pad_exporter(**locals())
helper.add_attribute(node, 'pads', pads) helper.add_attribute(node, 'pads', pads)
helper.add_attribute(node, 'value', value) helper.add_attribute(node, 'value', value)
return node, [] return node, []
@exporter.register('Pad-11') @export_util.register('Pad-11')
def pad_exporter_v11(op_def, shape_dict, ws): def pad_exporter_v11(op_def, context):
node, pads, value = pad_exporter(**locals()) node, pads, value = pad_exporter(**locals())
pads = helper.from_array( pads = helper.from_array(
numpy.array(pads, 'int64'), numpy.array(pads, 'int64'),
op_def.input[0] + '/pad/pads', context.unique_name(op_def.input[0] + '/pad/pads'),
) )
value = helper.from_array( value = helper.from_array(
numpy.array(value, 'float64'), numpy.array(value, 'float64'),
op_def.input[0] + '/pad/value', context.unique_name(op_def.input[0] + '/pad/value'),
) )
node.input.extend([pads.name, value.name]) node.input.extend([pads.name, value.name])
return node, [pads, value] return node, [pads, value]
@exporter.register('RandomNormal') @export_util.register('RandomNormal')
def random_normal_exporter(op_def, shape_dict, ws): def random_normal_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
if len(op_def.input) > 0: if len(op_def.input) > 0:
node.op_type += 'Like' node.op_type += 'Like'
else: else:
output_shape = list(shape_dict[op_def.output[0]]) output_shape = list(context.blob_shapes[op_def.output[0]])
helper.add_attribute(node, 'shape', output_shape) helper.add_attribute(node, 'shape', output_shape)
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'mean': if arg.name == 'mean':
...@@ -292,13 +292,13 @@ def random_normal_exporter(op_def, shape_dict, ws): ...@@ -292,13 +292,13 @@ def random_normal_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('RandomUniform') @export_util.register('RandomUniform')
def random_uniform_exporter(op_def, shape_dict, ws): def random_uniform_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
if len(op_def.input) > 0: if len(op_def.input) > 0:
node.op_type += 'Like' node.op_type += 'Like'
else: else:
output_shape = list(shape_dict[op_def.output[0]]) output_shape = list(context.blob_shapes[op_def.output[0]])
helper.add_attribute(node, 'shape', output_shape) helper.add_attribute(node, 'shape', output_shape)
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'low': if arg.name == 'low':
...@@ -310,10 +310,10 @@ def random_uniform_exporter(op_def, shape_dict, ws): ...@@ -310,10 +310,10 @@ def random_uniform_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register(['ReduceMax', 'ReduceMean', 'ReduceMin', 'ReduceSum']) @export_util.register(['ReduceMax', 'ReduceMean', 'ReduceMin', 'ReduceSum'])
def reduce_exporter(op_def, shape_dict, ws): def reduce_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
axes = list(range(len(shape_dict[op_def.input[0]]))) axes = list(range(len(context.blob_shapes[op_def.input[0]])))
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axes': if arg.name == 'axes':
axes = arg.ints axes = arg.ints
...@@ -323,44 +323,44 @@ def reduce_exporter(op_def, shape_dict, ws): ...@@ -323,44 +323,44 @@ def reduce_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Reshape') @export_util.register('Reshape')
def reshape_exporter(op_def, shape_dict, ws): def reshape_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
shape = dims = list(shape_dict[op_def.output[0]]) shape = dims = list(context.blob_shapes[op_def.output[0]])
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'dims': if arg.name == 'dims':
dims = [int(e) for e in arg.ints] dims = [int(e) for e in arg.ints]
elif arg.name == 'dims_desc': elif arg.name == 'dims_desc':
dims = helper.fetch_argument(op_def, arg, ws) dims = helper.fetch_argument(op_def, arg, context.ws)
elif arg.name == 'dims_descs': elif arg.name == 'dims_descs':
dims = helper.fetch_arguments(op_def, arg, ws) dims = helper.fetch_arguments(op_def, arg, context.ws)
for axis, dim in enumerate(dims): for axis, dim in enumerate(dims):
shape[axis] = dim if dim <= 0 else shape[axis] shape[axis] = dim if dim <= 0 else shape[axis]
shape = helper.from_array( shape = helper.from_array(
numpy.array(shape, 'int64'), numpy.array(shape, 'int64'),
op_def.input[0] + '/reshape/shape', context.unique_name(op_def.input[0] + '/reshape/shape'),
) )
node.input.extend([shape.name]) node.input.extend([shape.name])
return node, [shape] return node, [shape]
def slice_exporter(op_def, shape_dict, ws): def slice_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
in_shape = shape_dict[op_def.input[0]] in_shape = context.blob_shapes[op_def.input[0]]
starts, sizes, ends = [], [], [] starts, sizes, ends = [], [], []
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'starts': if arg.name == 'starts':
starts = [int(e) for e in arg.ints] starts = [int(e) for e in arg.ints]
elif arg.name == 'starts_desc': elif arg.name == 'starts_desc':
starts = helper.fetch_argument(op_def, arg, ws) starts = helper.fetch_argument(op_def, arg, context.ws)
elif arg.name == 'starts_descs': elif arg.name == 'starts_descs':
starts = helper.fetch_arguments(op_def, arg, ws) starts = helper.fetch_arguments(op_def, arg, context.ws)
elif arg.name == 'sizes': elif arg.name == 'sizes':
sizes = [int(e) for e in arg.ints] sizes = [int(e) for e in arg.ints]
elif arg.name == 'sizes_desc': elif arg.name == 'sizes_desc':
sizes = helper.fetch_argument(op_def, arg, ws) sizes = helper.fetch_argument(op_def, arg, context.ws)
elif arg.name == 'sizes_descs': elif arg.name == 'sizes_descs':
sizes = helper.fetch_arguments(op_def, arg, ws) sizes = helper.fetch_arguments(op_def, arg, context.ws)
for i, size in enumerate(sizes): for i, size in enumerate(sizes):
if size == -1: if size == -1:
ends.append(in_shape[i]) ends.append(in_shape[i])
...@@ -371,8 +371,8 @@ def slice_exporter(op_def, shape_dict, ws): ...@@ -371,8 +371,8 @@ def slice_exporter(op_def, shape_dict, ws):
return node, starts, ends return node, starts, ends
@exporter.register('Slice-1') @export_util.register('Slice-1')
def slice_exporter_v1(op_def, shape_dict, ws): def slice_exporter_v1(op_def, context):
node, starts, ends = slice_exporter(**locals()) node, starts, ends = slice_exporter(**locals())
helper.add_attribute(node, 'axes', numpy.arange(len(starts))) helper.add_attribute(node, 'axes', numpy.arange(len(starts)))
helper.add_attribute(node, 'ends', ends) helper.add_attribute(node, 'ends', ends)
...@@ -380,40 +380,40 @@ def slice_exporter_v1(op_def, shape_dict, ws): ...@@ -380,40 +380,40 @@ def slice_exporter_v1(op_def, shape_dict, ws):
return node, [] return node, []
@exporter.register('Slice-10') @export_util.register('Slice-10')
def slice_exporter_v10(op_def, shape_dict, ws): def slice_exporter_v10(op_def, context):
node, starts, ends = slice_exporter(**locals()) node, starts, ends = slice_exporter(**locals())
axes = helper.from_array( axes = helper.from_array(
numpy.arange(len(starts), dtype='int64'), numpy.arange(len(starts), dtype='int64'),
op_def.input[0] + '/slice/axes', context.unique_name(op_def.input[0] + '/slice/axes'),
) )
starts = helper.from_array( starts = helper.from_array(
numpy.array(starts, 'int64'), numpy.array(starts, 'int64'),
op_def.input[0] + '/slice/starts', context.unique_name(op_def.input[0] + '/slice/starts'),
) )
ends = helper.from_array( ends = helper.from_array(
numpy.array(ends, 'int64'), numpy.array(ends, 'int64'),
op_def.input[0] + '/slice/ends', context.unique_name(op_def.input[0] + '/slice/ends'),
) )
node.input.extend([starts.name, ends.name, axes.name]) node.input.extend([starts.name, ends.name, axes.name])
return node, [starts, ends, axes] return node, [starts, ends, axes]
@exporter.register('Split') @export_util.register('Split')
def split_exporter(op_def, shape_dict, ws): def split_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
axis = 0 axis = 0
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axis': if arg.name == 'axis':
axis = arg.i axis = arg.i
size_splits = [shape_dict[e][axis] for e in op_def.output] size_splits = [context.blob_shapes[e][axis] for e in op_def.output]
helper.add_attribute(node, 'split', size_splits) helper.add_attribute(node, 'split', size_splits)
return node, const_tensors return node, const_tensors
@exporter.register('Squeeze') @export_util.register('Squeeze')
def squeeze_exporter(op_def, shape_dict, ws): def squeeze_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
axes = None axes = None
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axes': if arg.name == 'axes':
...@@ -423,43 +423,43 @@ def squeeze_exporter(op_def, shape_dict, ws): ...@@ -423,43 +423,43 @@ def squeeze_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Tile') @export_util.register('Tile')
def tile_exporter(op_def, shape_dict, ws): def tile_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
repeats = [] repeats = []
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'repeats': if arg.name == 'repeats':
repeats = [e for e in arg.ints] repeats = [e for e in arg.ints]
elif arg.name == 'repeats_desc': elif arg.name == 'repeats_desc':
repeats = helper.fetch_argument(op_def, arg, ws) repeats = helper.fetch_argument(op_def, arg, context.ws)
elif arg.name == 'repeats_descs': elif arg.name == 'repeats_descs':
repeats = helper.fetch_arguments(op_def, arg, ws) repeats = helper.fetch_arguments(op_def, arg, context.ws)
repeats = helper.from_array( repeats = helper.from_array(
numpy.array(repeats, 'int64'), numpy.array(repeats, 'int64'),
op_def.input[0] + '/tile/repeats', context.unique_name(op_def.input[0] + '/tile/repeats'),
) )
node.input.extend([repeats.name]) node.input.extend([repeats.name])
return node, [repeats] return node, [repeats]
@exporter.register('Transpose') @export_util.register('Transpose')
def transpose_exporter(op_def, shape_dict, ws): def transpose_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'perm': if arg.name == 'perm':
helper.add_attribute(node, 'perm', arg.ints) helper.add_attribute(node, 'perm', arg.ints)
elif arg.name == 'perm_desc': elif arg.name == 'perm_desc':
values = helper.fetch_argument(op_def, arg, ws) values = helper.fetch_argument(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values) helper.add_attribute(node, 'perm', values)
elif arg.name == 'perm_descs': elif arg.name == 'perm_descs':
if len(arg.strings) > 0: if len(arg.strings) > 0:
values = helper.fetch_arguments(op_def, arg, ws) values = helper.fetch_arguments(op_def, arg, context.ws)
helper.add_attribute(node, 'perm', values) helper.add_attribute(node, 'perm', values)
return node, None return node, None
def top_k_exporter(op_def, shape_dict, ws): def top_k_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
k, axis, largest, sorted = 1, -1, True, True k, axis, largest, sorted = 1, -1, True, True
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'k': if arg.name == 'k':
...@@ -473,9 +473,9 @@ def top_k_exporter(op_def, shape_dict, ws): ...@@ -473,9 +473,9 @@ def top_k_exporter(op_def, shape_dict, ws):
return node, (k, axis, largest, sorted) return node, (k, axis, largest, sorted)
@exporter.register('TopK-1') @export_util.register('TopK-1')
def top_k_exporter_v1(op_def, shape_dict, ws): def top_k_exporter_v1(op_def, context):
node, (k, axis, largest, sorted) = top_k_exporter(op_def, shape_dict, ws) node, (k, axis, largest, sorted) = top_k_exporter(**locals())
if largest == 0: if largest == 0:
raise ValueError('TopK-1 does not support smallest mode.') raise ValueError('TopK-1 does not support smallest mode.')
helper.add_attribute(node, 'axis', axis) helper.add_attribute(node, 'axis', axis)
...@@ -483,37 +483,37 @@ def top_k_exporter_v1(op_def, shape_dict, ws): ...@@ -483,37 +483,37 @@ def top_k_exporter_v1(op_def, shape_dict, ws):
return node, None return node, None
@exporter.register('TopK-10') @export_util.register('TopK-10')
def top_k_exporter_v10(op_def, shape_dict, ws): def top_k_exporter_v10(op_def, context):
node, (k, axis, largest, sorted) = top_k_exporter(op_def, shape_dict, ws) node, (k, axis, largest, sorted) = top_k_exporter(**locals())
if largest == 0: if largest == 0:
raise ValueError('TopK-10 does not support smallest mode.') raise ValueError('TopK-10 does not support smallest mode.')
helper.add_attribute(node, 'axis', axis) helper.add_attribute(node, 'axis', axis)
k = helper.from_array( k = helper.from_array(
numpy.array([k], 'int64'), numpy.array([k], 'int64'),
op_def.input[0] + '/top_k/k', context.unique_name(op_def.input[0] + '/top_k/k'),
) )
node.input.extend([k.name]) node.input.extend([k.name])
return node, [k] return node, [k]
@exporter.register('TopK-11') @export_util.register('TopK-11')
def top_k_exporter_v11(op_def, shape_dict, ws): def top_k_exporter_v11(op_def, context):
node, (k, axis, largest, sorted) = top_k_exporter(op_def, shape_dict, ws) node, (k, axis, largest, sorted) = top_k_exporter(**locals())
helper.add_attribute(node, 'axis', axis) helper.add_attribute(node, 'axis', axis)
helper.add_attribute(node, 'largest', largest) helper.add_attribute(node, 'largest', largest)
helper.add_attribute(node, 'sorted', sorted) helper.add_attribute(node, 'sorted', sorted)
k = helper.from_array( k = helper.from_array(
numpy.array([k], 'int64'), numpy.array([k], 'int64'),
op_def.input[0] + '/top_k/k', context.unique_name(op_def.input[0] + '/top_k/k'),
) )
node.input.extend([k.name]) node.input.extend([k.name])
return node, [k] return node, [k]
@exporter.register('Unique') @export_util.register('Unique')
def unique_exporter(op_def, shape_dict, ws): def unique_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
helper.add_attribute(node, 'sorted', 1) helper.add_attribute(node, 'sorted', 1)
return_inverse = return_counts = 0 return_inverse = return_counts = 0
for arg in op_def.arg: for arg in op_def.arg:
......
...@@ -15,35 +15,35 @@ from __future__ import print_function ...@@ -15,35 +15,35 @@ from __future__ import print_function
import numpy import numpy
from dragon.vm.onnx.core import exporter
from dragon.vm.onnx.core import helper from dragon.vm.onnx.core import helper
from dragon.vm.onnx.core.exporters import utils as export_util
@exporter.register('Add') @export_util.register('Add')
def add_exporter(op_def, shape_dict, ws): def add_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
dtype = str(helper.fetch_tensor(op_def.output[0], ws).dtype) dtype = str(helper.fetch_tensor(op_def.output[0], context.ws).dtype)
node.op_type = 'Or' if dtype == 'bool' else 'Add' node.op_type = 'Or' if dtype == 'bool' else 'Add'
const_tensors = [] # Global scalars const_tensors = [] # Global scalars
for e in op_def.input: for name in op_def.input:
if e.startswith('/share/scalar/'): if name.startswith('/share/scalar/'):
const_tensors.append(helper.from_tensor(e, ws)) const_tensors.append(helper.from_tensor(name, context.ws))
return node, const_tensors return node, const_tensors
@exporter.register('Div') @export_util.register('Div')
def div_exporter(op_def, shape_dict, ws): def div_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
const_tensors = [] # Global scalars const_tensors = [] # Global scalars
for e in op_def.input: for name in op_def.input:
if e.startswith('/share/scalar/'): if name.startswith('/share/scalar/'):
const_tensors.append(helper.from_tensor(e, ws)) const_tensors.append(helper.from_tensor(name, context.ws))
return node, const_tensors return node, const_tensors
@exporter.register('Clip') @export_util.register('Clip')
def clip_exporter(op_def, shape_dict, ws): def clip_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'low': if arg.name == 'low':
helper.add_attribute(node, 'min', arg.f) helper.add_attribute(node, 'min', arg.f)
...@@ -52,11 +52,11 @@ def clip_exporter(op_def, shape_dict, ws): ...@@ -52,11 +52,11 @@ def clip_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Clip-11') @export_util.register('Clip-11')
def clip_exporter_v11(op_def, shape_dict, ws): def clip_exporter_v11(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
min_value, max_value, const_tensors = None, None, [] min_value, max_value, const_tensors = None, None, []
dtype = ws.FetchTensor(op_def.output[0]).dtype dtype = context.ws.FetchTensor(op_def.output[0]).dtype
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'low': if arg.name == 'low':
min_value = arg.f min_value = arg.f
...@@ -65,7 +65,7 @@ def clip_exporter_v11(op_def, shape_dict, ws): ...@@ -65,7 +65,7 @@ def clip_exporter_v11(op_def, shape_dict, ws):
if min_value is not None: if min_value is not None:
const_tensors.append(helper.from_array( const_tensors.append(helper.from_array(
numpy.array(min_value, dtype), numpy.array(min_value, dtype),
op_def.input[0] + '/clip/min_value', context.unique_name(op_def.input[0] + '/clip/min_value'),
)) ))
node.input.extend([const_tensors[-1].name]) node.input.extend([const_tensors[-1].name])
else: else:
...@@ -73,7 +73,7 @@ def clip_exporter_v11(op_def, shape_dict, ws): ...@@ -73,7 +73,7 @@ def clip_exporter_v11(op_def, shape_dict, ws):
if max_value is not None: if max_value is not None:
const_tensors.append(helper.from_array( const_tensors.append(helper.from_array(
numpy.array(max_value, dtype), numpy.array(max_value, dtype),
op_def.input[0] + '/clip/max_value', context.unique_name(op_def.input[0] + '/clip/max_value'),
)) ))
node.input.extend([const_tensors[-1].name]) node.input.extend([const_tensors[-1].name])
else: else:
...@@ -81,9 +81,9 @@ def clip_exporter_v11(op_def, shape_dict, ws): ...@@ -81,9 +81,9 @@ def clip_exporter_v11(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('FullyConnected-7') @export_util.register('FullyConnected-7')
def fully_connected_exporter_v7(op_def, shape_dict, ws): def fully_connected_exporter_v7(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Gemm' node.op_type = 'Gemm'
helper.add_attribute(node, 'alpha', 1.) helper.add_attribute(node, 'alpha', 1.)
helper.add_attribute(node, 'beta', 1.) helper.add_attribute(node, 'beta', 1.)
...@@ -91,27 +91,28 @@ def fully_connected_exporter_v7(op_def, shape_dict, ws): ...@@ -91,27 +91,28 @@ def fully_connected_exporter_v7(op_def, shape_dict, ws):
if arg.name == 'transW': if arg.name == 'transW':
helper.add_attribute(node, 'transB', arg.i) helper.add_attribute(node, 'transB', arg.i)
# Weights and biases # Weights and biases
const_tensors = [helper.from_tensor(e, ws) for e in op_def.input[1:]] const_tensors = [helper.from_tensor(name, context.ws)
for name in op_def.input[1:]]
return node, const_tensors return node, const_tensors
@exporter.register('FullyConnected') @export_util.register('FullyConnected')
def fully_connected_exporter(op_def, shape_dict, ws): def fully_connected_exporter(op_def, context):
node, const_tensors = fully_connected_exporter_v7(op_def, shape_dict, ws) node, const_tensors = fully_connected_exporter_v7(op_def, context)
helper.add_attribute(node, 'broadcast', 1) # Removed since opset 7 helper.add_attribute(node, 'broadcast', 1) # Removed since opset 7
return node, const_tensors return node, const_tensors
@exporter.register('Invert') @export_util.register('Invert')
def invert_exporter(op_def, shape_dict, ws): def invert_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Not' node.op_type = 'Not'
return node, const_tensors return node, const_tensors
@exporter.register('Matmul') @export_util.register('Matmul')
def matmul_exporter(op_def, shape_dict, ws): def matmul_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'MatMul' node.op_type = 'MatMul'
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'transA': if arg.name == 'transA':
...@@ -123,57 +124,57 @@ def matmul_exporter(op_def, shape_dict, ws): ...@@ -123,57 +124,57 @@ def matmul_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Maximum') @export_util.register('Maximum')
def maximum_exporter(op_def, shape_dict, ws): def maximum_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Max' # Eltwise, Broadcast node.op_type = 'Max' # Eltwise, Broadcast
const_tensors = [] # Global scalars const_tensors = [] # Global scalars
for e in op_def.input: for name in op_def.input:
if e.startswith('/share/scalar/'): if name.startswith('/share/scalar/'):
const_tensors.append(helper.from_tensor(e, ws)) const_tensors.append(helper.from_tensor(name, context.ws))
return node, const_tensors return node, const_tensors
@exporter.register('Minimum') @export_util.register('Minimum')
def minimum_exporter(op_def, shape_dict, ws): def minimum_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Min' # Eltwise, Broadcast node.op_type = 'Min' # Eltwise, Broadcast
const_tensors = [] # Global scalars const_tensors = [] # Global scalars
for e in op_def.input: for name in op_def.input:
if e.startswith('/share/scalar/'): if name.startswith('/share/scalar/'):
const_tensors.append(helper.from_tensor(e, ws)) const_tensors.append(helper.from_tensor(name, context.ws))
return node, const_tensors return node, const_tensors
@exporter.register('Mul') @export_util.register('Mul')
def mul_exporter(op_def, shape_dict, ws): def mul_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
dtype = str(helper.fetch_tensor(op_def.output[0], ws).dtype) dtype = str(helper.fetch_tensor(op_def.output[0], context.ws).dtype)
node.op_type = 'And' if dtype == 'bool' else 'Mul' node.op_type = 'And' if dtype == 'bool' else 'Mul'
const_tensors = [] # Global scalars const_tensors = [] # Global scalars
for e in op_def.input: for name in op_def.input:
if e.startswith('/share/scalar/'): if name.startswith('/share/scalar/'):
const_tensors.append(helper.from_tensor(e, ws)) const_tensors.append(helper.from_tensor(name, context.ws))
return node, const_tensors return node, const_tensors
@exporter.register('Pow') @export_util.register('Pow')
def pow_exporter(op_def, shape_dict, ws): def pow_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
const_tensors = [] # Global scalars const_tensors = [] # Global scalars
for e in op_def.input: for name in op_def.input:
if e.startswith('/share/scalar/'): if name.startswith('/share/scalar/'):
const_tensors.append(helper.from_tensor(e, ws)) const_tensors.append(helper.from_tensor(name, context.ws))
return node, const_tensors return node, const_tensors
@exporter.register('Sub') @export_util.register('Sub')
def sub_exporter(op_def, shape_dict, ws): def sub_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
dtype = str(helper.fetch_tensor(op_def.output[0], ws).dtype) dtype = str(helper.fetch_tensor(op_def.output[0], context.ws).dtype)
node.op_type = 'Xor' if dtype == 'bool' else 'Sub' node.op_type = 'Xor' if dtype == 'bool' else 'Sub'
const_tensors = [] # Global scalars const_tensors = [] # Global scalars
for e in op_def.input: for name in op_def.input:
if e.startswith('/share/scalar/'): if name.startswith('/share/scalar/'):
const_tensors.append(helper.from_tensor(e, ws)) const_tensors.append(helper.from_tensor(name, context.ws))
return node, const_tensors return node, const_tensors
...@@ -13,13 +13,13 @@ from __future__ import absolute_import ...@@ -13,13 +13,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.vm.onnx.core import exporter
from dragon.vm.onnx.core import helper from dragon.vm.onnx.core import helper
from dragon.vm.onnx.core.exporters import utils as export_util
@exporter.register('BatchNorm') @export_util.register('BatchNorm')
def batch_norm_exporter(op_def, shape_dict, ws): def batch_norm_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'BatchNormalization' node.op_type = 'BatchNormalization'
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'epsilon': if arg.name == 'epsilon':
...@@ -27,13 +27,13 @@ def batch_norm_exporter(op_def, shape_dict, ws): ...@@ -27,13 +27,13 @@ def batch_norm_exporter(op_def, shape_dict, ws):
elif arg.name == 'momentum': elif arg.name == 'momentum':
helper.add_attribute(node, 'momentum', arg.f) helper.add_attribute(node, 'momentum', arg.f)
# Weight, bias, running mean and running variance # Weight, bias, running mean and running variance
const_tensors = [helper.from_tensor(e, ws) for e in op_def.input[1:]] const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
@exporter.register('GroupNorm') @export_util.register('GroupNorm')
def group_norm_exporter(op_def, shape_dict, ws): def group_norm_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'ATen' # Currently not supported in ai.onnx node.op_type = 'ATen' # Currently not supported in ai.onnx
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'epsilon': if arg.name == 'epsilon':
...@@ -46,13 +46,13 @@ def group_norm_exporter(op_def, shape_dict, ws): ...@@ -46,13 +46,13 @@ def group_norm_exporter(op_def, shape_dict, ws):
helper.add_attribute(node, 'op_type', 'GroupNorm') helper.add_attribute(node, 'op_type', 'GroupNorm')
helper.add_attribute(node, 'group', arg.i) helper.add_attribute(node, 'group', arg.i)
# Weight and bias # Weight and bias
const_tensors = [helper.from_tensor(e, ws) for e in op_def.input[1:]] const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
@exporter.register('LpNormalize') @export_util.register('LpNormalize')
def lp_normalize_exporter(op_def, shape_dict, ws): def lp_normalize_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'LpNormalization' node.op_type = 'LpNormalization'
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'axis': if arg.name == 'axis':
......
...@@ -21,13 +21,39 @@ _GLOBAL_REGISTERED_EXPORTERS = _Registry('exporters') ...@@ -21,13 +21,39 @@ _GLOBAL_REGISTERED_EXPORTERS = _Registry('exporters')
register = _GLOBAL_REGISTERED_EXPORTERS.register register = _GLOBAL_REGISTERED_EXPORTERS.register
def translate(op_def, *args, **kwargs): class TranslatorContext(object):
"""Context to pass translator resources."""
def __init__(
self,
workspace,
blob_names,
blob_shapes,
blob_versions,
opset_version,
):
self.ws = workspace
self.blob_names = blob_names
self.blob_shapes = blob_shapes
self.blob_versions = blob_versions
self.opset_version = opset_version
def unique_name(self, name):
self.blob_versions[name] += 1
if self.blob_versions[name] > 1:
return name + '_%d' % (self.blob_versions[name] - 1)
return name
def translate(op_def, context):
"""Translate the OpDef to a NodeProto. """Translate the OpDef to a NodeProto.
Parameters Parameters
---------- ----------
op_def : OperatorDef op_def : OperatorDef
The definition of a operator. The definition of a operator.
context : TranslatorContext
The context of translator.
Returns Returns
------- -------
...@@ -37,9 +63,8 @@ def translate(op_def, *args, **kwargs): ...@@ -37,9 +63,8 @@ def translate(op_def, *args, **kwargs):
The constant tensors. The constant tensors.
""" """
_ = locals()
node = helper.make_node( node = helper.make_node(
op_type=kwargs.get('op_type', op_def.type), op_type=op_def.type,
inputs=op_def.input, inputs=op_def.input,
outputs=op_def.output, outputs=op_def.output,
name=op_def.name if op_def.name != '' else None name=op_def.name if op_def.name != '' else None
...@@ -61,7 +86,6 @@ def registered_exporters(): ...@@ -61,7 +86,6 @@ def registered_exporters():
@register(['PythonPlugin', 'PythonPluginInfer']) @register(['PythonPlugin', 'PythonPluginInfer'])
def _python(op_def, shape_dict, ws): def python_exporter(op_def, context):
"""Export the python operators.""" """Export the python operators."""
_ = locals()
return None, None return None, None
...@@ -16,44 +16,40 @@ from __future__ import print_function ...@@ -16,44 +16,40 @@ from __future__ import print_function
import copy import copy
import numpy import numpy
from dragon.vm.onnx.core import exporter
from dragon.vm.onnx.core import helper from dragon.vm.onnx.core import helper
from dragon.vm.onnx.core.exporters import utils as export_util
@exporter.register([ @export_util.register([
'Conv2d', 'Conv2d',
'ConvTranspose2d', 'ConvTranspose2d',
'DepthwiseConv2d', 'DepthwiseConv2d',
]) ])
def convolution(op_def, shape_dict, ws): def convolution(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'ConvTranspose' if 'Transpose' in op_def.type else 'Conv' node.op_type = 'ConvTranspose' if 'Transpose' in op_def.type else 'Conv'
if 'Depthwise' in op_def.type: if 'Depthwise' in op_def.type:
input_shape = shape_dict[op_def.input[0]] input_shape = context.blob_shapes[op_def.input[0]]
helper.add_attribute(node, 'group', input_shape[1]) helper.add_attribute(node, 'group', input_shape[1])
rank = len(shape_dict[op_def.input[0]]) - 2 rank = len(context.blob_shapes[op_def.input[0]]) - 2
for arg in op_def.arg: for arg in op_def.arg:
_assert_data_format(arg) _assert_data_format(arg)
if arg.name == 'kernel_shape': if arg.name == 'kernel_shape':
helper.add_attribute( helper.add_attribute(
node, 'kernel_shape', node, 'kernel_shape',
_normalize_tuple(arg.ints, rank) _normalize_tuple(arg.ints, rank))
)
elif arg.name == 'dilations': elif arg.name == 'dilations':
helper.add_attribute( helper.add_attribute(
node, 'dilations', node, 'dilations',
_normalize_tuple(arg.ints, rank) _normalize_tuple(arg.ints, rank))
)
elif arg.name == 'strides': elif arg.name == 'strides':
helper.add_attribute( helper.add_attribute(
node, 'strides', node, 'strides',
_normalize_tuple(arg.ints, rank) _normalize_tuple(arg.ints, rank))
)
elif arg.name == 'pads': elif arg.name == 'pads':
helper.add_attribute( helper.add_attribute(
node, 'pads', node, 'pads',
_normalize_pads(arg.ints, rank) _normalize_pads(arg.ints, rank))
)
elif arg.name == 'padding' and arg.s != b'VALID': elif arg.name == 'padding' and arg.s != b'VALID':
helper.add_attribute(node, 'auto_pad', arg.s) helper.add_attribute(node, 'auto_pad', arg.s)
elif arg.name == 'group': elif arg.name == 'group':
...@@ -63,13 +59,13 @@ def convolution(op_def, shape_dict, ws): ...@@ -63,13 +59,13 @@ def convolution(op_def, shape_dict, ws):
elif arg.name == 'output_padding': elif arg.name == 'output_padding':
helper.add_attribute(node, 'output_padding', arg.ints) helper.add_attribute(node, 'output_padding', arg.ints)
# Weights and biases # Weights and biases
const_tensors = [helper.from_tensor(e, ws) for e in op_def.input[1:]] const_tensors = [helper.from_tensor(e, context.ws) for e in op_def.input[1:]]
return node, const_tensors return node, const_tensors
@exporter.register(['DepthToSpace', 'SpaceToDepth']) @export_util.register(['DepthToSpace', 'SpaceToDepth'])
def depth_space_exporter(op_def, shape_dict, ws): def depth_space_exporter(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
for arg in op_def.arg: for arg in op_def.arg:
_assert_data_format(arg) _assert_data_format(arg)
if arg.name == 'block_size': if arg.name == 'block_size':
...@@ -77,28 +73,25 @@ def depth_space_exporter(op_def, shape_dict, ws): ...@@ -77,28 +73,25 @@ def depth_space_exporter(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Pool2d') @export_util.register('Pool2d')
def pool(op_def, shape_dict, ws): def pool(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
rank = len(shape_dict[op_def.input[0]]) - 2 rank = len(context.blob_shapes[op_def.input[0]]) - 2
global_pooling, node_copy = 0, copy.deepcopy(node) global_pooling, node_copy = 0, copy.deepcopy(node)
for arg in op_def.arg: for arg in op_def.arg:
_assert_data_format(arg) _assert_data_format(arg)
if arg.name == 'kernel_shape': if arg.name == 'kernel_shape':
helper.add_attribute( helper.add_attribute(
node, 'kernel_shape', node, 'kernel_shape',
_normalize_tuple(arg.ints, rank) _normalize_tuple(arg.ints, rank))
)
elif arg.name == 'strides': elif arg.name == 'strides':
helper.add_attribute( helper.add_attribute(
node, 'strides', node, 'strides',
_normalize_tuple(arg.ints, rank) _normalize_tuple(arg.ints, rank))
)
elif arg.name == 'pads': elif arg.name == 'pads':
helper.add_attribute( helper.add_attribute(
node, 'pads', node, 'pads',
_normalize_pads(arg.ints, rank) _normalize_pads(arg.ints, rank))
)
elif arg.name == 'padding' and arg.s != b'VALID': elif arg.name == 'padding' and arg.s != b'VALID':
helper.add_attribute(node, 'auto_pad', arg.s) helper.add_attribute(node, 'auto_pad', arg.s)
elif arg.name == 'mode': elif arg.name == 'mode':
...@@ -117,59 +110,53 @@ def pool(op_def, shape_dict, ws): ...@@ -117,59 +110,53 @@ def pool(op_def, shape_dict, ws):
return node, const_tensors return node, const_tensors
@exporter.register('Resize-1') @export_util.register('Resize-1')
def resize_v1(op_def, shape_dict, ws): def resize_v1(op_def, context):
_ = locals()
raise RuntimeError('<Upsample> requires opset version >= 7.') raise RuntimeError('<Upsample> requires opset version >= 7.')
@exporter.register('Resize-7') @export_util.register('Resize-7')
def resize_v7(op_def, shape_dict, ws): def resize_v7(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'Upsample' node.op_type = 'Upsample'
input_shape = shape_dict[op_def.input[0]] input_shape = context.blob_shapes[op_def.input[0]]
output_shape = shape_dict[op_def.output[0]] output_shape = context.blob_shapes[op_def.output[0]]
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'mode': if arg.name == 'mode':
helper.add_attribute(node, 'mode', arg.s.lower()) helper.add_attribute(node, 'mode', arg.s.lower())
helper.add_attribute( helper.add_attribute(
node, 'scales', [ node, 'scales', [float(output_shape[i]) / input_shape[i]
float(output_shape[i]) / input_shape[i] for i in range(len(input_shape))])
for i in range(len(input_shape))
]
)
return node, const_tensors return node, const_tensors
@exporter.register('Resize-9') @export_util.register('Resize-9')
def resize_v9(op_def, shape_dict, ws): def resize_v9(op_def, context):
node, const_tensors = resize_v10(op_def, shape_dict, ws) node, const_tensors = resize_v10(**locals())
node.op_type = 'Upsample' node.op_type = 'Upsample'
return node, const_tensors return node, const_tensors
@exporter.register('Resize-10') @export_util.register('Resize-10')
def resize_v10(op_def, shape_dict, ws): def resize_v10(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
input_shape = shape_dict[op_def.input[0]] input_shape = context.blob_shapes[op_def.input[0]]
output_shape = shape_dict[op_def.output[0]] output_shape = context.blob_shapes[op_def.output[0]]
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'mode': if arg.name == 'mode':
helper.add_attribute(node, 'mode', arg.s.lower()) helper.add_attribute(node, 'mode', arg.s.lower())
scales = helper.from_array( scales = helper.from_array(
numpy.array([ numpy.array([float(output_shape[i]) / input_shape[i]
float(output_shape[i]) / input_shape[i] for i in range(len(input_shape))], 'float32'),
for i in range(len(input_shape)) context.unique_name(op_def.input[0] + '/resize/scales'),
], 'float32'),
op_def.input[0] + '/resize/scales',
) )
node.input.extend([scales.name]) node.input.extend([scales.name])
return node, [scales] return node, [scales]
@exporter.register('Resize-11') @export_util.register('Resize-11')
def resize_v11(op_def, shape_dict, ws): def resize_v11(op_def, context):
node, const_tensors = resize_v10(op_def, shape_dict, ws) node, const_tensors = resize_v10(**locals())
coord_mode = 'half_pixel' coord_mode = 'half_pixel'
for arg in op_def.arg: for arg in op_def.arg:
if arg.name == 'mode': if arg.name == 'mode':
...@@ -179,22 +166,22 @@ def resize_v11(op_def, shape_dict, ws): ...@@ -179,22 +166,22 @@ def resize_v11(op_def, shape_dict, ws):
if arg.i > 0: if arg.i > 0:
coord_mode = 'align_corners' coord_mode = 'align_corners'
helper.add_attribute(node, 'coordinate_transformation_mode', coord_mode) helper.add_attribute(node, 'coordinate_transformation_mode', coord_mode)
rank = len(shape_dict[op_def.input[0]]) rank = len(context.blob_shapes[op_def.input[0]])
roi = helper.from_array( roi = helper.from_array(
numpy.array(([0] * rank + [1] * rank), 'float32'), numpy.array(([0] * rank + [1] * rank), 'float32'),
op_def.input[0] + '/resize/roi', context.unique_name(op_def.input[0] + '/resize/roi'),
) )
node.input[:] = [node.input[0], roi.name, node.input[1]] node.input[:] = [node.input[0], roi.name, node.input[1]]
return node, const_tensors + [roi] return node, const_tensors + [roi]
@exporter.register('RoiAlign') @export_util.register('RoiAlign')
def roi_align(op_def, shape_dict, ws): def roi_align(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
# Make a dummy "batch_indices". # Make a dummy "batch_indices".
batch_indices = helper.from_array( batch_indices = helper.from_array(
numpy.array([1], 'int64'), numpy.array([1], 'int64'),
op_def.input[0] + '/roi_align/batch_indices', context.unique_name(op_def.input[0] + '/roi_align/batch_indices'),
) )
node.input.extend([batch_indices.name]) node.input.extend([batch_indices.name])
for arg in op_def.arg: for arg in op_def.arg:
...@@ -209,9 +196,9 @@ def roi_align(op_def, shape_dict, ws): ...@@ -209,9 +196,9 @@ def roi_align(op_def, shape_dict, ws):
return node, [batch_indices] return node, [batch_indices]
@exporter.register('RoiPool') @export_util.register('RoiPool')
def roi_pool(op_def, shape_dict, ws): def roi_pool(op_def, context):
node, const_tensors = exporter.translate(**locals()) node, const_tensors = export_util.translate(**locals())
node.op_type = 'MaxRoiPool' node.op_type = 'MaxRoiPool'
pooled_shape = [None, None] pooled_shape = [None, None]
for arg in op_def.arg: for arg in op_def.arg:
......
...@@ -8,14 +8,3 @@ ...@@ -8,14 +8,3 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.vm.tensorrt.core.backend import prepare
from dragon.vm.tensorrt.core.backend import run_model
from dragon.vm.tensorrt.core.backend import run_node
from dragon.vm.tensorrt.core.backend import supports_device
__all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -7,11 +7,8 @@ ...@@ -7,11 +7,8 @@
# #
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# Codes are based on:
#
# <https://github.com/pytorch/pytorch/blob/master/caffe2/python/onnx/frontend.py>
#
# ------------------------------------------------------------ # ------------------------------------------------------------
"""Native ONNX frontend."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -20,14 +17,22 @@ from __future__ import print_function ...@@ -20,14 +17,22 @@ from __future__ import print_function
import collections import collections
import itertools import itertools
import numpy
try: try:
import onnx import onnx
except ImportError: except ImportError:
onnx = None onnx = None
from dragon.core.autograph import function_lib
from dragon.core.eager import context as eager_context
from dragon.core.eager import backprop
from dragon.core.framework import types
from dragon.core.framework import workspace as workspace_util
from dragon.core.proto import dragon_pb2
from dragon.core.util import nest from dragon.core.util import nest
from dragon.vm.onnx.core import exporter from dragon.core.util import serialization
from dragon.vm.onnx.core import helper from dragon.vm.onnx.core import helper
from dragon.vm.onnx.core.exporters import utils as export_util
class DragonFrontend(object): class DragonFrontend(object):
...@@ -54,7 +59,7 @@ class DragonFrontend(object): ...@@ -54,7 +59,7 @@ class DragonFrontend(object):
constants=None, constants=None,
value_info=None, value_info=None,
opset_version=None, opset_version=None,
ws=None, workspace=None,
verbose=True, verbose=True,
): ):
input_names = [] if input_names is None else input_names input_names = [] if input_names is None else input_names
...@@ -79,12 +84,12 @@ class DragonFrontend(object): ...@@ -79,12 +84,12 @@ class DragonFrontend(object):
blob_aliases = {} blob_aliases = {}
for i, alias in enumerate(output_names): for i, alias in enumerate(output_names):
blob_aliases[graph_def.output[i]] = alias blob_aliases[graph_def.output[i]] = alias
ws.RegisterAlias(graph_def.output[i], alias) workspace.RegisterAlias(graph_def.output[i], alias)
if graph_def.output[i] in value_info: if graph_def.output[i] in value_info:
value_info[alias] = value_info[graph_def.output[i]] value_info[alias] = value_info[graph_def.output[i]]
for i, alias in enumerate(input_names): for i, alias in enumerate(input_names):
blob_aliases[graph_def.input[i]] = alias blob_aliases[graph_def.input[i]] = alias
ws.RegisterAlias(graph_def.input[i], alias) workspace.RegisterAlias(graph_def.input[i], alias)
if graph_def.input[i] in value_info: if graph_def.input[i] in value_info:
value_info[alias] = value_info[graph_def.input[i]] value_info[alias] = value_info[graph_def.input[i]]
...@@ -100,62 +105,65 @@ class DragonFrontend(object): ...@@ -100,62 +105,65 @@ class DragonFrontend(object):
value_info[k] = (value_info[k][0], v) value_info[k] = (value_info[k][0], v)
# Prepare to make the graph. # Prepare to make the graph.
onnx_graph = onnx.GraphProto( onnx_graph = onnx.GraphProto(name=graph_def.name
name=graph_def.name
if len(graph_def.name) > 0 if len(graph_def.name) > 0
else 'onnx-model' else 'onnx-model')
blob_shapes, blob_names = {}, {}
blob_versions = collections.defaultdict(
int, **dict((blob_aliases.get(k, k), 1)
for k in helper.collect_inputs(graph_def)))
initializers, seen_initializers = [], set()
# Build translator context.
context = export_util.TranslatorContext(
workspace=workspace,
blob_names=blob_names,
blob_shapes=blob_shapes,
blob_versions=blob_versions,
opset_version=opset_version,
) )
graph_inputs = helper.collect_inputs(graph_def)
# Add nodes. # Add nodes.
shapes, blob_names, initializers = {}, {}, []
blob_versions = collections.defaultdict(
int, **dict((blob_aliases.get(k, k), 1) for k in graph_inputs))
for op in graph_def.op: for op in graph_def.op:
# Get the shape of inputs and outputs. # Get the shape of inputs and outputs.
for name in itertools.chain(op.input, op.output): for name in itertools.chain(op.input, op.output):
impl = ws.GetTensor(name) impl = workspace.GetTensor(name)
if impl is not None: if impl is not None:
shapes[name] = impl.dims blob_shapes[name] = impl.dims
else: else:
shapes[name] = value_info[name][1] blob_shapes[name] = value_info[name][1]
# Translate definition. # Translate definition.
nodes, const_tensors = cls._translate(op, opset_version, shapes, ws) nodes, const_tensors = cls._make_node(op, context)
# Rewritten for names. # Rewritten for names.
for node in nodes: for node in nodes:
node.input[:] = [blob_aliases.get(e, e) for e in node.input] node.input[:] = [blob_aliases.get(e, e) for e in node.input]
node.output[:] = [blob_aliases.get(e, e) for e in node.output] node.output[:] = [blob_aliases.get(e, e) for e in node.output]
node, blob_names, blob_versions = \ cls._rewrite_for_ssa(node, context)
cls._ssa_rewrite(node, blob_names, blob_versions)
# Directly convert outputs as const tensors if necessary. # Convert constant outputs if necessary.
if None in nodes: if None in nodes:
const_tensors = [helper.from_tensor(name, ws) for name in op.output] const_tensors = [helper.from_tensor(name, workspace)
for name in op.output]
else: else:
onnx_graph.node.extend(nodes) onnx_graph.node.extend(nodes)
# Merge constant tensors. # Merge constant tensors.
if const_tensors is not None: if const_tensors is not None:
value_info = { value_info = {**value_info,
**value_info, **dict((e.name, (e.data_type, e.dims))
**dict(( for e in const_tensors)}
e.name, (e.data_type, e.dims) for tensor in const_tensors:
) for e in const_tensors) if tensor.name not in seen_initializers:
} initializers.append(tensor)
initializers.extend(const_tensors) seen_initializers.add(tensor.name)
# Add constants. # Add constants.
if constants is not None: if constants is not None:
for k, v in constants.items(): for k, v in constants.items():
initializers.append(helper.from_array(v, name=k)) initializers.append(helper.from_array(v, name=k))
# Add initializers.
onnx_graph.initializer.extend(initializers)
# Add inputs. # Add inputs.
for name in helper.collect_inputs(onnx_graph): for name in helper.collect_inputs(onnx_graph):
try: try:
...@@ -165,18 +173,33 @@ class DragonFrontend(object): ...@@ -165,18 +173,33 @@ class DragonFrontend(object):
elem_type=value_info[name][0], elem_type=value_info[name][0],
shape=value_info[name][1])]) shape=value_info[name][1])])
except KeyError: except KeyError:
impl = workspace.GetTensor(name)
if impl is not None:
initializer = helper.from_tensor(name, workspace)
onnx_graph.input.extend([
helper.make_tensor_value_info(
name=name,
elem_type=initializer.data_type,
shape=initializer.dims)])
if name not in seen_initializers:
initializers.append(initializer)
seen_initializers.add(initializer.name)
else:
raise ValueError( raise ValueError(
'Info of tensor `{}` is missing, ' 'Info of tensor `{}` is missing, '
'specify it in <value_info>.'.format(name)) 'specify it in <value_info>.'.format(name))
# Add initializers.
onnx_graph.initializer.extend(initializers)
# Add outputs. # Add outputs.
onnx_graph.output.extend( onnx_graph.output.extend(
helper.make_tensor_value_info( helper.make_tensor_value_info(
name=blob_names.get(name, name), name=blob_names.get(name_v2, name_v2),
elem_type=value_info[name][0], elem_type=value_info[name_v2][0],
shape=value_info[name][1], shape=value_info[name_v2][1])
) for name in [blob_aliases.get(e, e) for e in set(graph_def.output)] for name_v2 in [blob_aliases.get(name, name)
) for name in set(graph_def.output)])
if verbose: if verbose:
print(helper.printable_graph(onnx_graph)) print(helper.printable_graph(onnx_graph))
...@@ -239,30 +262,31 @@ class DragonFrontend(object): ...@@ -239,30 +262,31 @@ class DragonFrontend(object):
return opset_version return opset_version
@staticmethod @staticmethod
def _ssa_rewrite(op_def, blob_names, blob_versions): def _rewrite_for_ssa(op_def, context):
"""Rewrite a OpDef to satisfy the SSA (Static Single Assignment).""" """Rewrite a OpDef to satisfy the SSA (Static Single Assignment)."""
blob_names = context.blob_names
blob_versions = context.blob_versions
inputs, outputs = [], [] inputs, outputs = [], []
for e in op_def.input: for name in op_def.input:
inputs.append(blob_names[e] if e in blob_names else e) inputs.append(blob_names[name] if name in blob_names else name)
for e in op_def.output: for name in op_def.output:
outputs.append(e + '_%d' % blob_versions[e] outputs.append(name + '_%d' % blob_versions[name]
if blob_versions[e] > 0 else e) if blob_versions[name] > 0 else name)
if e != '': if name != '':
blob_versions[e] += 1 blob_versions[name] += 1
blob_names[e] = outputs[-1] blob_names[name] = outputs[-1]
op_def.ClearField('input') op_def.ClearField('input')
op_def.ClearField('output') op_def.ClearField('output')
op_def.input.extend(inputs) op_def.input.extend(inputs)
op_def.output.extend(outputs) op_def.output.extend(outputs)
return op_def, blob_names, blob_versions
@classmethod @classmethod
def _translate(cls, op_def, opset_version, shape_dict, ws): def _make_node(cls, op_def, context):
"""Return a NodeProto from the OpDef.""" """Return a NodeProto from the OpDef."""
translate_fn = None translate_fn = None
getter = exporter._GLOBAL_REGISTERED_EXPORTERS.try_get getter = export_util._GLOBAL_REGISTERED_EXPORTERS.try_get
# Select the last versioned exporter if necessary. # Select the last versioned exporter if necessary.
for i in range(opset_version, 0, -1): for i in range(context.opset_version, 0, -1):
versioned_op_type = op_def.type + '-%d' % i versioned_op_type = op_def.type + '-%d' % i
if getter(versioned_op_type) is not None: if getter(versioned_op_type) is not None:
translate_fn = getter(versioned_op_type) translate_fn = getter(versioned_op_type)
...@@ -273,10 +297,153 @@ class DragonFrontend(object): ...@@ -273,10 +297,153 @@ class DragonFrontend(object):
translate_fn = getter(op_def.type) translate_fn = getter(op_def.type)
else: else:
# Fallback to the generic exporter. # Fallback to the generic exporter.
translate_fn = exporter.translate translate_fn = export_util.translate
nodes, const_tensors = translate_fn(op_def, shape_dict, ws) nodes, const_tensors = translate_fn(op_def, context)
return nest.flatten(nodes), const_tensors return nest.flatten(nodes), const_tensors
def record():
"""Context-manger to record the graph.
Examples:
```python
with dragon.onnx.record():
...
```
See Also
--------
`dragon.onnx.export(...)`_
"""
tape = backprop.Tape()
tape.retain_graph = True
return backprop._GLOBAL_TAPE_STACK.get_controller(tape)
def export(
inputs,
outputs,
f,
input_names=None,
output_names=None,
input_shapes=None,
opset_version=None,
verbose=False,
enable_onnx_checker=True,
):
"""Export the recorded graph to an onnx model.
Enter into the record mode to export operators into an onnx model:
```python
x = dragon.constant([1, 2, 3])
with dragon.onnx.record():
y = x * x
dragon.onnx.export(inputs=[x], outputs=[y], f='model.onnx')
```
Parameters
----------
inputs : Union[Sequence, Dict]
The model inputs.
outputs : Union[Sequence, Dict]
The model outputs.
f : str
The filename for exporting model.
input_names : Sequence[str], optional
The name to the inputs.
output_names : Sequence[str], optional
The name to the outputs.
input_shapes : Union[Sequence, Dict], optional
The optional rewritten for input shapes.
opset_version : int, optional
The version of operator set.
verbose : bool, optional, default=False
Whether to print the debug string of graph.
enable_onnx_checker : bool, optional, default=True
Whether to check if model is valid.
"""
# Process the inputs.
if isinstance(inputs, dict):
if input_names is not None:
raise ValueError(
'Excepted the input names from <inputs>.\n'
'You should set the <input_names> to None.')
inputs, input_names = list(inputs.values()), list(inputs.keys())
else:
inputs = nest.flatten(inputs)
# Process the outputs.
if isinstance(outputs, dict):
if output_names is not None:
raise ValueError(
'Excepted the output names from <outputs>.\n'
'You should set the <output_names> to None.')
outputs, output_names = list(outputs.values()), list(outputs.keys())
else:
outputs = nest.flatten(outputs)
if eager_context.executing_eagerly():
op_defs = []
tape = backprop.get_default_tape()
if tape is None:
raise RuntimeError('Please enter with ``onnx.frontend.record()``.')
for op_def in tape._defs:
op_defs.append(dragon_pb2.OperatorDef())
op_defs[-1].ParseFromString(op_def.SerializeAs())
graph_def = dragon_pb2.GraphDef(op=op_defs)
else:
symbolic_outputs = []
for output in outputs:
if types.is_symbolic_tensor(output):
symbolic_outputs.append(output)
graph_func = function_lib.create_function(outputs=symbolic_outputs)
graph_func.callback()
graph_def = graph_func.graph_def
graph_def.name = ''
# Add inputs and outputs.
for i, input in enumerate(inputs):
if hasattr(input, 'id'):
graph_def.input.extend([input.id])
elif input_names is not None:
graph_def.input.extend([input_names[i]])
for i, output in enumerate(outputs):
if hasattr(output, 'id'):
graph_def.output.extend([output.id])
elif output_names is not None:
graph_def.output.extend([output_names[i]])
# Make value info from inputs and outputs.
value_names = graph_def.input[:] + graph_def.output[:]
value_info = dict([(k, (helper.tensor_type(v.dtype), v.shape))
for k, v in zip(value_names, inputs + outputs)])
# Extract the constants from inputs and outputs.
constants = collections.OrderedDict()
for k, v in zip(value_names, inputs + outputs):
if isinstance(v, numpy.ndarray):
constants[k] = v
# Export.
model = graph_def_to_onnx_model(
graph_def=graph_def,
input_names=input_names,
output_names=output_names,
input_shapes=input_shapes,
constants=constants,
value_info=value_info,
opset_version=opset_version,
workspace=workspace_util.get_workspace(),
verbose=verbose,
enable_onnx_checker=enable_onnx_checker,
)
serialization.save_bytes(serialization.serialize_proto(model), f)
graph_def_to_onnx_graph = DragonFrontend.graph_def_to_onnx_graph graph_def_to_onnx_graph = DragonFrontend.graph_def_to_onnx_graph
graph_def_to_onnx_model = DragonFrontend.graph_def_to_onnx_model graph_def_to_onnx_model = DragonFrontend.graph_def_to_onnx_model
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause> # <https://opensource.org/licenses/BSD-2-Clause>
# #
# ------------------------------------------------------------ # ------------------------------------------------------------
"""PyTorch ONNX frontend."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -19,7 +20,9 @@ import numpy ...@@ -19,7 +20,9 @@ import numpy
from dragon.core.framework import workspace from dragon.core.framework import workspace
from dragon.core.proto import dragon_pb2 from dragon.core.proto import dragon_pb2
from dragon.core.util import nest from dragon.core.util import nest
from dragon.vm.onnx.core import utils as onnx_util from dragon.core.util import serialization
from dragon.vm.onnx.core import helper
from dragon.vm.onnx.core.frontend.native import graph_def_to_onnx_model
from dragon.vm.torch.core.autograd import backprop from dragon.vm.torch.core.autograd import backprop
...@@ -34,7 +37,7 @@ def export( ...@@ -34,7 +37,7 @@ def export(
verbose=False, verbose=False,
enable_onnx_checker=True, enable_onnx_checker=True,
): ):
"""Export a model into ONNX format. """Export the recorded graph to an onnx model.
The outputs will be obtained by calling ``model(*args)``, The outputs will be obtained by calling ``model(*args)``,
both the tensor or numpy array are allowed: both the tensor or numpy array are allowed:
...@@ -162,10 +165,8 @@ def export( ...@@ -162,10 +165,8 @@ def export(
# Make value info from inputs and outputs. # Make value info from inputs and outputs.
value_names = graph_def.input[:] + graph_def.output[:] value_names = graph_def.input[:] + graph_def.output[:]
value_info = dict([ value_info = dict([(k, (helper.tensor_type(v.dtype), v.shape))
(k, onnx_util.make_value_info(v.shape, v.dtype)) for k, v in zip(value_names, inputs + outputs)])
for k, v in zip(value_names, inputs + outputs)
])
# Extract the constants from inputs and outputs. # Extract the constants from inputs and outputs.
constants = collections.OrderedDict() constants = collections.OrderedDict()
...@@ -175,9 +176,8 @@ def export( ...@@ -175,9 +176,8 @@ def export(
# Export. # Export.
with temporal_ws.as_default(): with temporal_ws.as_default():
onnx_util.export_from_graph( model = graph_def_to_onnx_model(
graph_def=graph_def, graph_def=graph_def,
f=f,
input_names=input_names, input_names=input_names,
output_names=output_names, output_names=output_names,
input_shapes=input_shapes, input_shapes=input_shapes,
...@@ -188,3 +188,4 @@ def export( ...@@ -188,3 +188,4 @@ def export(
verbose=verbose, verbose=verbose,
enable_onnx_checker=enable_onnx_checker, enable_onnx_checker=enable_onnx_checker,
) )
serialization.save_bytes(serialization.serialize_proto(model), f)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Toolkit for manipulating the onnx api."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy
from dragon.core.autograph import function_lib
from dragon.core.eager import context
from dragon.core.eager import backprop
from dragon.core.framework import types
from dragon.core.framework import workspace
from dragon.core.proto import dragon_pb2
from dragon.core.util import nest
from dragon.vm.onnx.core import utils as onnx_util
class Shell(object):
"""Context-manger to export or load onnx models.
Enter a shell to export operators into an onnx model:
```python
x = dragon.constant([1, 2, 3])
with onnx.Shell() as shell, shell.as_default():
y = x * x
shell.export(inputs=[x], outputs=[y], f='model.onnx')
```
The onnx models can also be loaded to execute:
```python
f = shell.load_model('model.onnx', explicit_inputs=True)
print(f(np.array([1, 2, 3]))
```
"""
def __init__(self):
"""Create a ``Shell``."""
self._workspace = None
self._tape = backprop.GradientTape()
def as_default(self):
"""Set as the default shell."""
return self._workspace.as_default()
def export(
self,
inputs,
outputs,
f,
input_names=None,
output_names=None,
input_shapes=None,
opset_version=None,
verbose=False,
enable_onnx_checker=True,
):
"""Export an onnx model.
Parameters
----------
inputs : Union[Sequence, Dict]
The model inputs.
outputs : Union[Sequence, Dict]
The model outputs.
f : str
The filename for exporting model.
input_names : Sequence[str], optional
The name to the inputs.
output_names : Sequence[str], optional
The name to the outputs.
input_shapes : Union[Sequence, Dict], optional
The optional rewritten for input shapes.
opset_version : int, optional
The version of operator set.
verbose : bool, optional, default=False
Whether to print the debug string of graph.
enable_onnx_checker : bool, optional, default=True
Whether to check if model is valid.
"""
# Process the inputs.
if isinstance(inputs, dict):
if input_names is not None:
raise ValueError(
'Excepted the input names from <inputs>.\n'
'You should set the <input_names> to None.')
inputs, input_names = list(inputs.values()), list(inputs.keys())
else:
inputs = nest.flatten(inputs)
# Process the outputs.
if isinstance(outputs, dict):
if output_names is not None:
raise ValueError(
'Excepted the output names from <outputs>.\n'
'You should set the <output_names> to None.')
outputs, output_names = list(outputs.values()), list(outputs.keys())
else:
outputs = nest.flatten(outputs)
if context.executing_eagerly():
# Make graph def.
op_defs = []
for op_def in self._tape._tape._defs:
op_defs.append(dragon_pb2.OperatorDef())
op_defs[-1].ParseFromString(op_def.SerializeAs())
graph_def = dragon_pb2.GraphDef(op=op_defs)
else:
symbolic_outputs = []
for output in outputs:
if types.is_symbolic_tensor(output):
symbolic_outputs.append(output)
with self.as_default():
graph_func = function_lib.create_function(
outputs=symbolic_outputs)
graph_func.callback()
graph_def = graph_func.graph_def
graph_def.name = ''
# Add inputs and outputs.
for i, input in enumerate(inputs):
if hasattr(input, 'id'):
graph_def.input.extend([input.id])
elif input_names is not None:
graph_def.input.extend([input_names[i]])
for i, output in enumerate(outputs):
if hasattr(output, 'id'):
graph_def.output.extend([output.id])
elif output_names is not None:
graph_def.output.extend([output_names[i]])
# Make value info from inputs and outputs.
value_names = graph_def.input[:] + graph_def.output[:]
value_info = dict([
(k, onnx_util.make_value_info(v.shape, v.dtype))
for k, v in zip(value_names, inputs + outputs)
])
# Extract the constants from inputs and outputs.
constants = collections.OrderedDict()
for k, v in zip(value_names, inputs + outputs):
if isinstance(v, numpy.ndarray):
constants[k] = v
# Export.
onnx_util.export_from_graph(
graph_def=graph_def,
f=f,
input_names=input_names,
output_names=output_names,
input_shapes=input_shapes,
constants=constants,
value_info=value_info,
opset_version=opset_version,
workspace=self._workspace,
verbose=verbose,
enable_onnx_checker=enable_onnx_checker,
)
@staticmethod
def load_model(model_path, explicit_inputs=False):
"""Import an onnx model to the function.
Parameters
----------
model_path : str
The path to the onnx model.
explicit_inputs : bool, optional, default=False
**True** to attach model inputs to the function.
Returns
-------
callable
The function to run the model once.
"""
return onnx_util.import_to_function(model_path, explicit_inputs)
def __enter__(self):
self._workspace = workspace.Workspace()
self._workspace.merge_from(workspace.get_workspace())
if context.executing_eagerly():
self._tape._push_tape()
self._tape._tape.retain_graph = True
return self
def __exit__(self, typ, value, traceback):
if self._tape._recording:
self._tape._pop_tape()
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Utilities for exporting and importing models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy
from dragon.core.autograph import function_lib
from dragon.core.framework import workspace
from dragon.core.proto import dragon_pb2
from dragon.core.util import serialization
from dragon.vm.onnx.core.frontend import graph_def_to_onnx_model
from dragon.vm.onnx.core.helper import mapping
def export_from_graph(
graph_def,
f,
input_names=None,
output_names=None,
input_shapes=None,
constants=None,
value_info=None,
opset_version=None,
workspace=None,
verbose=True,
enable_onnx_checker=True,
):
"""Export an onnx model from the graph."""
model = graph_def_to_onnx_model(
graph_def=graph_def,
input_names=input_names,
output_names=output_names,
input_shapes=input_shapes,
constants=constants,
value_info=value_info,
opset_version=opset_version,
workspace=workspace,
verbose=verbose,
enable_onnx_checker=enable_onnx_checker)
serialization.save_bytes(serialization.serialize_proto(model), f)
def import_to_function(model_path, explicit_inputs=False):
"""Import an onnx model to the function."""
return function_lib \
.Function(name='onnx') \
.import_from(
graph_def=import_to_graph(model_path),
explicit_inputs=explicit_inputs,
)
def import_to_graph(model_path):
"""Import an onnx model to the graph."""
if not os.path.exists(model_path):
raise ValueError(
'Model({}) is not existed.'
.format(model_path)
)
graph_def = dragon_pb2.GraphDef()
serialized_proto = workspace \
.get_workspace() \
.ImportONNXModel(model_path)
graph_def.ParseFromString(serialized_proto)
return graph_def
def make_value_info(shape, dtype='float32'):
"""Return a value info from the shape and data type."""
return mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)], shape
...@@ -50,6 +50,7 @@ def constant(value, dtype=None, shape=None, name='Const'): ...@@ -50,6 +50,7 @@ def constant(value, dtype=None, shape=None, name='Const'):
The output tensor. The output tensor.
""" """
dtype = str(dtype) if dtype else None
if dtype is not None: if dtype is not None:
if isinstance(value, numpy.ndarray): if isinstance(value, numpy.ndarray):
value = value.astype(dtype) value = value.astype(dtype)
......
...@@ -22,24 +22,17 @@ from dragon.core.framework import types ...@@ -22,24 +22,17 @@ from dragon.core.framework import types
from dragon.vm.tensorflow.core.framework import constant_op from dragon.vm.tensorflow.core.framework import constant_op
def convert_to_tensor( def convert_to_tensor(value, dtype=None, name=None):
value,
dtype=None,
name=None,
preferred_dtype=None,
):
"""Converts the given value to a Tensor. """Converts the given value to a Tensor.
Parameters Parameters
---------- ----------
value : number, sequence or numpy.ndarray value : Union[number, Sequence, numpy.ndarray]
The value to convert. The value to convert.
dtype : dragon.vm.tensorflow.dtypes.DType, optional dtype : dragon.vm.tensorflow.dtypes.DType, optional
The optional data type. The optional data type.
name : str, optional name : str, optional
The Optional name. The Optional name.
preferred_dtype : dragon.vm.tensorflow.dtypes.DType, optional
The optional type when ``dtype`` is *None*.
Returns Returns
------- -------
...@@ -108,4 +101,4 @@ def device(device_name): ...@@ -108,4 +101,4 @@ def device(device_name):
id = int(id) id = int(id)
except Exception: except Exception:
raise ValueError('The device id should be a integer.') raise ValueError('The device id should be a integer.')
return context.device(device, device_id=id) return context.device(device, device_index=id)
...@@ -348,9 +348,7 @@ class Layer(module.Module): ...@@ -348,9 +348,7 @@ class Layer(module.Module):
else: else:
raise ValueError( raise ValueError(
'Unknown format "%s".\n' 'Unknown format "%s".\n'
'Excepted format in (tf, h5, pkl).' 'Excepted format in (tf, h5, pkl).' % (save_format,))
% (save_format,)
)
if save_format == 'tf': if save_format == 'tf':
raise ValueError('TensorFlow format will never be supported.') raise ValueError('TensorFlow format will never be supported.')
if save_format == 'h5': if save_format == 'h5':
...@@ -379,12 +377,9 @@ class Layer(module.Module): ...@@ -379,12 +377,9 @@ class Layer(module.Module):
'metrics', 'metrics',
} }
if hasattr(self, '_layers'): if hasattr(self, '_layers'):
layers = layer_utils \ layers = layer_utils.filter_empty_layer_containers(self._layers)
.filter_empty_layer_containers(self._layers) return list(itertools.chain.from_iterable(
return list( getattr(layer, attribute) for layer in layers))
itertools.chain.from_iterable(
getattr(layer, attribute) for layer in layers)
)
return [] return []
def _maybe_build(self, inputs): def _maybe_build(self, inputs):
......
...@@ -163,7 +163,7 @@ class Conv2D(Conv): ...@@ -163,7 +163,7 @@ class Conv2D(Conv):
The shape of convolution kernel. The shape of convolution kernel.
strides : Sequence[int], optional, default=1 strides : Sequence[int], optional, default=1
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'valid', 'same'}, Sequence[int]], optional padding : Union[str, Sequence[int]], optional
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'channels_first', 'channels_last'}, optional data_format : {'channels_first', 'channels_last'}, optional
The optional data format. The optional data format.
...@@ -231,7 +231,7 @@ class Conv2DTranspose(Conv2D): ...@@ -231,7 +231,7 @@ class Conv2DTranspose(Conv2D):
The shape of convolution kernel. The shape of convolution kernel.
strides : Sequence[int], optional strides : Sequence[int], optional
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'valid', 'same'}, Sequence[int]], optional padding : Union[str, Sequence[int]], optional
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
output_padding : Sequence[int], optional output_padding : Sequence[int], optional
The sizes of padded to the output. The sizes of padded to the output.
...@@ -396,7 +396,7 @@ class DepthwiseConv2D(Conv2D): ...@@ -396,7 +396,7 @@ class DepthwiseConv2D(Conv2D):
The shape of convolution kernel. The shape of convolution kernel.
strides : Sequence[int], optional, default=1 strides : Sequence[int], optional, default=1
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'valid', 'same'}, Sequence[int]], optional padding : Union[str, Sequence[int]], optional
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'channels_first', 'channels_last'}, optional data_format : {'channels_first', 'channels_last'}, optional
The optional data format. The optional data format.
......
...@@ -106,7 +106,7 @@ class AveragePooling2D(Pooling2D): ...@@ -106,7 +106,7 @@ class AveragePooling2D(Pooling2D):
The size(s) of sliding window. The size(s) of sliding window.
strides : Sequence[int], optional strides : Sequence[int], optional
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'valid', 'same'}, Sequence[int]], optional padding : Union[str, Sequence[int]], optional
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'channels_first', 'channels_last'}, optional data_format : {'channels_first', 'channels_last'}, optional
The optional data format. The optional data format.
...@@ -179,7 +179,7 @@ class MaxPooling2D(Pooling2D): ...@@ -179,7 +179,7 @@ class MaxPooling2D(Pooling2D):
The size(s) of sliding window. The size(s) of sliding window.
strides : Sequence[int], optional strides : Sequence[int], optional
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'valid', 'same'}, Sequence[int]], optional padding : Union[str, Sequence[int]], optional
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'channels_first', 'channels_last'}, optional data_format : {'channels_first', 'channels_last'}, optional
The optional data format. The optional data format.
......
...@@ -44,7 +44,7 @@ def avg_pool( ...@@ -44,7 +44,7 @@ def avg_pool(
The size(s) of sliding window. The size(s) of sliding window.
strides : Sequence[int], optional strides : Sequence[int], optional
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'VALID', 'SAME'}, Sequence[int]] padding : Union[str, Sequence[int]]
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
...@@ -104,7 +104,7 @@ def avg_pool2d( ...@@ -104,7 +104,7 @@ def avg_pool2d(
The size(s) of sliding window. The size(s) of sliding window.
strides : Sequence[int], optional strides : Sequence[int], optional
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'VALID', 'SAME'}, Sequence[int]] padding : Union[str, Sequence[int]]
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
...@@ -157,7 +157,7 @@ def convolution( ...@@ -157,7 +157,7 @@ def convolution(
The weight tensor. The weight tensor.
strides : Sequence[int], optional strides : Sequence[int], optional
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'VALID', 'SAME'}, Sequence[int]], optional padding : Union[str, Sequence[int]], optional
The padding algorithm or padding size(s). The padding algorithm or padding size(s).
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
...@@ -223,7 +223,7 @@ def conv_transpose( ...@@ -223,7 +223,7 @@ def conv_transpose(
The determined shape of output. The determined shape of output.
strides : Sequence[int] strides : Sequence[int]
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'VALID', 'SAME'}, Sequence[int]], optional padding : Union[str, Sequence[int]], optional
The padding algorithm or padding size(s). The padding algorithm or padding size(s).
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
...@@ -292,7 +292,7 @@ def conv2d( ...@@ -292,7 +292,7 @@ def conv2d(
The weight tensor. The weight tensor.
strides : Sequence[int] strides : Sequence[int]
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'VALID', 'SAME'}, Sequence[int]] padding : Union[str, Sequence[int]]
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'NCHW', 'NHWC'}, optional data_format : {'NCHW', 'NHWC'}, optional
The optional data format. The optional data format.
...@@ -332,7 +332,7 @@ def conv2d_transpose( ...@@ -332,7 +332,7 @@ def conv2d_transpose(
The determined shape of output. The determined shape of output.
strides : Sequence[int] strides : Sequence[int]
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'VALID', 'SAME'}, Sequence[int]] padding : Union[str, Sequence[int]]
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'NCHW', 'NHWC'}, optional data_format : {'NCHW', 'NHWC'}, optional
The optional data format. The optional data format.
...@@ -370,7 +370,7 @@ def depthwise_conv2d( ...@@ -370,7 +370,7 @@ def depthwise_conv2d(
The weight tensor. The weight tensor.
strides : Sequence[int] strides : Sequence[int]
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'VALID', 'SAME'}, Sequence[int]] padding : Union[str, Sequence[int]]
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'NCHW', 'NHWC'}, optional data_format : {'NCHW', 'NHWC'}, optional
The optional data format. The optional data format.
...@@ -590,7 +590,7 @@ def max_pool( ...@@ -590,7 +590,7 @@ def max_pool(
The size(s) of sliding window. The size(s) of sliding window.
strides : Sequence[int] strides : Sequence[int]
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'valid', 'same'}, Sequence[int]] padding : Union[str, Sequence[int]]
The padding algorithm or padding sizes. The padding algorithm or padding sizes.
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
...@@ -650,7 +650,7 @@ def max_pool2d( ...@@ -650,7 +650,7 @@ def max_pool2d(
The size(s) of sliding window. The size(s) of sliding window.
strides : Sequence[int], optional strides : Sequence[int], optional
The stride(s) of sliding window. The stride(s) of sliding window.
padding : Union[{'valid', 'same'}, Sequence[int]] padding : Union[str, Sequence[int]]
The padding algorithm or padding size(s). The padding algorithm or padding size(s).
data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional data_format : {'NCHW', 'NCDHW', 'NHWC', 'NDHWC'}, optional
The optional data format. The optional data format.
......
...@@ -33,6 +33,8 @@ class VariableMetaclass(object): ...@@ -33,6 +33,8 @@ class VariableMetaclass(object):
class Variable(VariableMetaclass, EagerTensor): class Variable(VariableMetaclass, EagerTensor):
"""Resource variable."""
def __init__( def __init__(
self, self,
initial_value, initial_value,
...@@ -44,6 +46,7 @@ class Variable(VariableMetaclass, EagerTensor): ...@@ -44,6 +46,7 @@ class Variable(VariableMetaclass, EagerTensor):
"""Create a ``Variable``.""" """Create a ``Variable``."""
super(Variable, self).__init__(trainable=trainable) super(Variable, self).__init__(trainable=trainable)
name = name if name else 'Variable' name = name if name else 'Variable'
dtype = str(dtype) if dtype else None
self._name = context.get_name_scope() + name + ':0' self._name = context.get_name_scope() + name + ':0'
# Determine th value. # Determine th value.
if isinstance(initial_value, EagerTensor): if isinstance(initial_value, EagerTensor):
...@@ -52,7 +55,7 @@ class Variable(VariableMetaclass, EagerTensor): ...@@ -52,7 +55,7 @@ class Variable(VariableMetaclass, EagerTensor):
initial_value = initial_value.get_value() initial_value = initial_value.get_value()
# Determine the data type. # Determine the data type.
if not isinstance(initial_value, numpy.ndarray): if not isinstance(initial_value, numpy.ndarray):
initial_value = numpy.array(initial_value, dtype if dtype else dtype) initial_value = numpy.array(initial_value, dtype)
elif dtype is not None: elif dtype is not None:
initial_value = initial_value.astype(dtype) initial_value = initial_value.astype(dtype)
# Determine the tensor shape. # Determine the tensor shape.
...@@ -103,8 +106,7 @@ def get_default_initializer(name, shape=None, dtype=dtypes.float32): ...@@ -103,8 +106,7 @@ def get_default_initializer(name, shape=None, dtype=dtypes.float32):
else: else:
raise ValueError( raise ValueError(
'An initializer for Variable({}) of {} is required.' 'An initializer for Variable({}) of {} is required.'
.format(name, dtype.base_dtype) .format(name, dtype.base_dtype))
)
return initializer return initializer
......
...@@ -3,6 +3,7 @@ from __future__ import division ...@@ -3,6 +3,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from dragon.core.framework import context from dragon.core.framework import context
from dragon.core.util import string
from dragon.vm.tensorlayer.core.engine import module from dragon.vm.tensorlayer.core.engine import module
...@@ -83,10 +84,10 @@ class Model(module.Module): ...@@ -83,10 +84,10 @@ class Model(module.Module):
return self.forward(inputs, **kwargs) return self.forward(inputs, **kwargs)
def __repr__(self): def __repr__(self):
tmpstr = self.name + '(\n' tmp_str = self.name + '(\n'
for idx, layer in enumerate(self.all_layers): for idx, layer in enumerate(self.all_layers):
modstr = layer.__repr__() mod_str = layer.__repr__()
modstr = self._addindent(modstr, 2) mod_str = string.add_indent(mod_str, 2)
tmpstr = tmpstr + ' (' + layer.name + '): ' + modstr + '\n' tmp_str = tmp_str + ' (' + layer.name + '): ' + mod_str + '\n'
tmpstr = tmpstr + ')' tmp_str = tmp_str + ')'
return tmpstr return tmp_str
...@@ -201,15 +201,13 @@ class Module(object): ...@@ -201,15 +201,13 @@ class Module(object):
else: else:
raise ValueError( raise ValueError(
"Saving format should be in ('hdf5', 'npz', 'pkl', 'npz_dict').\n" "Saving format should be in ('hdf5', 'npz', 'pkl', 'npz_dict').\n"
"Format <%s> is not supported." % format "Format <%s> is not supported." % format)
)
if verbose: if verbose:
for info in matched_info: for info in matched_info:
logging.info( logging.info(
'Weight({}) loaded, Size: ({})' 'Weight({}) loaded, Size: ({})'
.format(info[0], ', '.join([str(d) for d in info[1]])) .format(info[0], ', '.join([str(d) for d in info[1]])))
)
def save_weights(self, filepath, format=None): def save_weights(self, filepath, format=None):
"""Save weights into a binary file. """Save weights into a binary file.
...@@ -241,8 +239,7 @@ class Module(object): ...@@ -241,8 +239,7 @@ class Module(object):
else: else:
raise ValueError( raise ValueError(
"Saving format should be in ('hdf5', 'npz', 'pkl', 'npz_dict').\n" "Saving format should be in ('hdf5', 'npz', 'pkl', 'npz_dict').\n"
"Format <%s> is not supported." % format "Format <%s> is not supported." % format)
)
@staticmethod @staticmethod
def _dedupe_weights(weights): def _dedupe_weights(weights):
......
...@@ -18,15 +18,14 @@ import os as _os ...@@ -18,15 +18,14 @@ import os as _os
import sys as _sys import sys as _sys
# Modules # Modules
from dragon.vm.tensorrt._api import backend from dragon.vm.tensorrt._api import onnx
# Classes # Classes
from dragon.vm.tensorrt.core.backend import ONNXBackendRep
from dragon.vm.tensorrt.core.engine import Binding from dragon.vm.tensorrt.core.engine import Binding
from dragon.vm.tensorrt.core.engine import Engine from dragon.vm.tensorrt.core.engine import Engine
# Attributes # Attributes
_API_MODULE = backend _API_MODULE = onnx
_current_module = _sys.modules[__name__] _current_module = _sys.modules[__name__]
_api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__)) _api_dir = _os.path.dirname(_os.path.dirname(_API_MODULE.__file__))
if not hasattr(_current_module, '__path__'): if not hasattr(_current_module, '__path__'):
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
# Classes
from dragon.vm.onnx.core.backend.tensorrt import BackendRep
# Functions
from dragon.vm.onnx.core.backend.tensorrt import prepare as prepare_backend
from dragon.vm.onnx.core.backend.tensorrt import run_model
from dragon.vm.onnx.core.backend.tensorrt import run_node
from dragon.vm.onnx.core.backend.tensorrt import supports_device
# Attributes
__all__ = [_s for _s in dir() if not _s.startswith('_')]
...@@ -442,12 +442,10 @@ def check_input_validity(index, array, binding): ...@@ -442,12 +442,10 @@ def check_input_validity(index, array, binding):
raise TypeError( raise TypeError(
'Wrong dtype for input %i.\n' 'Wrong dtype for input %i.\n'
'Expected %s, got %s. Cannot safely cast.' % 'Expected %s, got %s. Cannot safely cast.' %
(index, binding.dtype, array.dtype) (index, binding.dtype, array.dtype))
)
else: else:
raise TypeError( raise TypeError(
'Wrong dtype for input %i.\n' 'Wrong dtype for input %i.\n'
'Expected %s, got %s.' % 'Expected %s, got %s.' %
(index, binding.dtype, array.dtype) (index, binding.dtype, array.dtype))
)
return array return array
...@@ -14,6 +14,6 @@ from __future__ import absolute_import as _absolute_import ...@@ -14,6 +14,6 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division from __future__ import division as _division
from __future__ import print_function as _print_function from __future__ import print_function as _print_function
from dragon.vm.torch.core.onnx.utils import export from dragon.vm.onnx.core.frontend.torch import export
__all__ = [_s for _s in dir() if not _s.startswith('_')] __all__ = [_s for _s in dir() if not _s.startswith('_')]
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!