Commit ae4d6834 by Ting PAN

Fix the missing extra arguments in compiled function

Summary:
This commit correctly passes the extra arguments
when executing a compiled graph function.
1 parent aa2ec8c3
......@@ -35,7 +35,7 @@ class Blob(object):
class Net(object):
"""The abstraction ``caffe.Net``.
"""The base net class to connect layers.
This class accepts a network file, and an optional parameter file.
Besides, a phase tag is required to compute gradients or not:
......@@ -193,7 +193,7 @@ class Net(object):
current_ws = workspace.get_workspace()
for name, blob in diffs.items():
current_ws.feed_tensor(self.blobs[name].diff, blob)
self._forward_backward_impl(return_outputs=False, stage='backward')
self._forward_backward_impl(executing_stage='backward')
def copy_from(self, other):
"""Copy layers from the other.
......@@ -228,7 +228,7 @@ class Net(object):
current_ws = workspace.get_workspace()
for name, blob in inputs.items():
current_ws.feed_tensor(self._blobs[name]['data'], blob)
self._forward_backward_impl(return_outputs=False, stage='forward')
self._forward_backward_impl(executing_stage='forward')
return lambda: dict(
(output, current_ws.fetch_tensor(self.blobs[output].data))
for output in self.outputs)
......@@ -250,7 +250,7 @@ class Net(object):
current_ws = workspace.get_workspace()
for name, blob in inputs.items():
current_ws.feed_tensor(self._blobs[name]['data'], blob)
self._forward_backward_impl(return_outputs=False)
self._forward_backward_impl()
return lambda: dict(
(output, current_ws.fetch_tensor(self.blobs[output].data))
for output in self.outputs)
......
......@@ -15,7 +15,7 @@ vm.caffe
`[Sutskever et.al, 2013] <http://www.cs.toronto.edu/~hinton/absps/momentum.pdf>`_.
`class Net <caffe/Net.html>`_
: The abstraction ``caffe.Net``.
: The base net class to connect layers.
`class RMSPropSolver <caffe/RMSPropSolver.html>`_
: The RMSProp solver.
......
......@@ -13,10 +13,13 @@ dragon.distributed
: Broadcast the input from root node in a group.
`is_initialized(...) <distributed/is_initialized.html>`_
: Whether the distributed environment has initialized.
: Return whether the distributed environment is initialized.
`init(...) <distributed/init.html>`_
: Initialize the distributed environment.
`is_mpi_available(...) <distributed/is_mpi_available.html>`_
: Return whether the MPI backend is available.
`is_nccl_available(...) <distributed/is_nccl_available.html>`_
: Return whether the NCCL backend is available.
`get_backend(...) <distributed/get_backend.html>`_
: Return the backend of given process group.
......@@ -38,8 +41,9 @@ dragon.distributed
distributed/all_reduce
distributed/broadcast
distributed/init
distributed/is_initialized
distributed/is_mpi_available
distributed/is_nccl_available
distributed/get_backend
distributed/get_group
distributed/get_rank
......
init
====
is_mpi_available
================
.. autofunction:: dragon.distributed.init
.. autofunction:: dragon.distributed.is_mpi_available
.. raw:: html
......
is_nccl_available
=================
.. autofunction:: dragon.distributed.is_nccl_available
.. raw:: html
<style>
h1:before {
content: "dragon.distributed.";
color: #103d3e;
}
</style>
......@@ -15,9 +15,15 @@ dragon.logging
`fatal(...) <logging/fatal.html>`_
: Log message at the FATAL level.
`get_verbosity(...) <logging/get_verbosity.html>`_
: Return the current logging level.
`info(...) <logging/info.html>`_
: Log message at the INFO level.
`log(...) <logging/log.html>`_
: Log message at the given level.
`set_directory(...) <logging/set_directory.html>`_
: Set the directory for logging files.
......@@ -33,7 +39,9 @@ dragon.logging
logging/debug
logging/error
logging/fatal
logging/get_verbosity
logging/info
logging/log
logging/set_directory
logging/set_verbosity
logging/warning
......
get_verbosity
=============
.. autofunction:: dragon.logging.get_verbosity
.. raw:: html
<style>
h1:before {
content: "dragon.logging.";
color: #103d3e;
}
</style>
log
===
.. autofunction:: dragon.logging.log
.. raw:: html
<style>
h1:before {
content: "dragon.logging.";
color: #103d3e;
}
</style>
......@@ -67,7 +67,7 @@ class CudaStream {
};
void RegisterModule(py::module& m) {
/*! \brief Reporting if CUDA is available */
/*! \brief Return whether CUDA driver is sufficient */
m.def("cudaIsDriverSufficient", []() {
#ifdef USE_CUDA
int count;
......@@ -79,8 +79,8 @@ void RegisterModule(py::module& m) {
#endif
});
/*! \brief Reporting if CUDA is available */
m.def("cudaIsNCCLSufficient", []() {
/*! \brief Return whether NCCL is available */
m.def("ncclIsAvailable", []() {
#ifdef USE_NCCL
#ifdef USE_CUDA
int count;
......
......@@ -32,8 +32,17 @@ namespace mpi {
} while (0)
void RegisterModule(py::module& m) {
/*! \brief Return whether MPI is available */
m.def("mpiIsAvailable", []() {
#ifdef USE_MPI
return true;
#else
return false;
#endif
});
/*! \brief Initialize the MPI environment */
m.def("MPIInit", []() {
m.def("mpiInitialize", []() {
#ifdef USE_MPI
// Enabling the multi-threads for Python is meaningless
// While we will still hold this interface here
......@@ -53,7 +62,7 @@ void RegisterModule(py::module& m) {
});
/*! \brief Return the world rank of current node */
m.def("MPIRank", []() {
m.def("mpiWorldRank", []() {
#ifdef USE_MPI
int world_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
......@@ -64,7 +73,7 @@ void RegisterModule(py::module& m) {
});
/*! \brief Return the world size of current node */
m.def("MPISize", []() {
m.def("mpiWorldSize", []() {
#ifdef USE_MPI
int world_size;
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
......@@ -74,8 +83,8 @@ void RegisterModule(py::module& m) {
#endif
});
/*! \brief Create a MPI group from the nodes */
m.def("MPICreateGroup", [](const vec32_t& ranks, bool verbose = false) {
/*! \brief Create a MPI group from the ranks */
m.def("mpiCreateGroup", [](const vec32_t& ranks, bool verbose = false) {
#ifdef USE_MPI
// Skip the empty ranks to avoid asserting
if (ranks.empty()) return vector<long>();
......@@ -125,7 +134,7 @@ void RegisterModule(py::module& m) {
});
/*! \brief Finalize the MPI environment */
m.def("MPIFinalize", []() {
m.def("mpiFinalize", []() {
#ifdef USE_MPI
MPI_Finalize();
#else
......
......@@ -13,8 +13,9 @@ from __future__ import absolute_import as _absolute_import
from __future__ import division as _division
from __future__ import print_function as _print_function
from dragon.core.distributed.backend import init
from dragon.core.distributed.backend import is_initialized
from dragon.core.distributed.backend import is_mpi_available
from dragon.core.distributed.backend import is_nccl_available
from dragon.core.distributed.backend import get_backend
from dragon.core.distributed.backend import get_group
from dragon.core.distributed.backend import get_rank
......
......@@ -16,7 +16,9 @@ from __future__ import print_function as _print_function
from dragon.core.util.logging import debug
from dragon.core.util.logging import error
from dragon.core.util.logging import fatal
from dragon.core.util.logging import get_verbosity
from dragon.core.util.logging import info
from dragon.core.util.logging import log
from dragon.core.util.logging import set_directory
from dragon.core.util.logging import set_verbosity
from dragon.core.util.logging import warning
......
......@@ -82,10 +82,7 @@ def class_method_to_instance_method(original_function, instance):
return decorator.make_decorator(
original_function.python_function,
type(original_function)(
decorator.make_decorator(
bound_method,
bound_method_wrapper,
),
decorator.make_decorator(bound_method, bound_method_wrapper),
input_signature=original_function.input_signature,
),
)
......@@ -138,17 +135,7 @@ class FunctionSpec(object):
raise ValueError(
'When <input_signature> is provided, '
'only pass arguments covered by it.\n'
'Received %d argument(s).' % len(args)
)
for arg in kwargs.keys():
index = self._args_to_indices.get(arg, None)
if index is not None and \
index >= len(self._input_signature):
raise ValueError(
'When <input_signature> is provided, '
'only pass arguments covered by it.\n'
'Received argument <%s>.' % arg
)
'Received %d argument(s).' % len(args))
# Determine the args from kwargs and default-values.
if not kwargs:
# The simplest case: args only.
......@@ -167,7 +154,7 @@ class FunctionSpec(object):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
elif arg in extra_args:
else:
extra_args[arg] = value
args2 = tuple(arg_indices_to_values[key]
for key in sorted(arg_indices_to_values))
......@@ -250,8 +237,7 @@ class FunctionGuard(object):
'When <input_signature> is provided, '
'only define arguments covered by it.\n'
'Got %d signature(s) and %d argument(s).'
% (len(input_signature), self._function_spec.num_inputs)
)
% (len(input_signature), self._function_spec.num_inputs))
shape = input_signature[i].shape
dtype = input_signature[i].dtype
inputs.append(Tensor(shape, dtype, name).constant())
......@@ -263,22 +249,23 @@ class FunctionGuard(object):
outputs.append(obj)
else:
dummies.append(obj)
executables = [function_lib.create_function(inputs, outputs)]
executables = [function_lib.create_function(outputs=outputs)]
for obj in dummies:
if isinstance(obj, optimizer.Optimizer):
executables.append(function_lib.create_function(optimizer=obj))
self.inputs = inputs
self.outputs = returns
self.executables = executables
# In this case, we have compiled executables.
# Notify the backend to run directly.
executables = self.executables
inputs, kwargs = self.canonicalize_inputs(*args, **kwargs)
executables[0](*inputs, return_outputs=False, **kwargs)
current_ws = workspace.get_workspace()
for input, value in zip(self.inputs, inputs):
current_ws.feed_tensor(input, value)
executables[0](return_outputs=False, **kwargs)
[func(return_outputs=False) for func in executables[1:]]
outputs = []
current_ws = workspace.get_workspace()
for output in self.outputs:
if isinstance(output, Tensor):
impl = current_ws.GetTensor(output.id)
......@@ -286,7 +273,7 @@ class FunctionGuard(object):
outputs.append(EagerTensor(impl=impl, device=device))
else:
outputs.append(output)
return outputs
return outputs[0] if len(outputs) == 1 else outputs
def __get__(self, instance, owner):
"""Override to patch the instance methods."""
......
......@@ -77,8 +77,6 @@ def add_phase(graph_def, targets):
def add_update_defs(graph_def, optimizer):
"""Add the update defs."""
if optimizer is None:
return
grads, update_defs = [], []
extra_arguments = optimizer._extra_kwargs
extra_arguments['handle'] = optimizer._op_handle
......
......@@ -7,10 +7,6 @@
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/caffe2/caffe2/blob/master/caffe2/python/core.py>
#
# ------------------------------------------------------------
"""Simple gradient maker implementation."""
......
......@@ -110,12 +110,7 @@ class OpDef(object):
)
# Return the outputs.
if num_outputs > 1:
return outputs
elif num_outputs == 1:
return outputs[0]
else:
return None
return outputs[0] if num_outputs == 1 else outputs
@staticmethod
def add_spec(op_type, arguments, inputs, outputs):
......
......@@ -13,7 +13,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from dragon.core.distributed.backend import init
from dragon.core.distributed.backend import is_initialized
from dragon.core.distributed.backend import get_backend
from dragon.core.distributed.backend import get_group
......
......@@ -37,14 +37,14 @@ class Backend(object):
raise ValueError('Backend name must be a string, but got: {}'.format(name))
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
if value == 'AUTO':
if _b.cudaIsNCCLSufficient():
if is_nccl_available():
return Backend.NCCL
return Backend.MPI
elif value == 'NCCL':
if not _b.cudaIsNCCLSufficient():
if not is_nccl_available():
raise ValueError('NCCL backend is not available.')
elif value == Backend.UNDEFINED:
raise ValueError("Invalid backend: '{}'".format(name))
raise ValueError('Invalid backend:', name)
return value
......@@ -58,10 +58,12 @@ class ProcessGroup(object):
self._backend = Backend('AUTO')
else:
self._backend = Backend(backend)
# Stored for executing the collective ops.
self._arguments = {
# Stored for executing the collective ops.
'comm': self._comm, 'group': self._handle,
'backend': self._backend, 'ranks': self._ranks,
'comm': self._comm,
'group': self._handle,
'backend': self._backend,
'ranks': self._ranks,
}
@property
......@@ -130,31 +132,45 @@ class ProcessGroup(object):
return _GLOBAL_PROCESS_GROUP_STACK.get_controller(self)
def __repr__(self):
return '%s:%d' % (self._backend, self._handle)
def init():
"""Init the distributed env."""
if is_initialized():
# ATTENTION: MPI env can only be initialized once per process.
return
_b.MPIInit()
global _GLOBAL_MPI_CONTEXT
_GLOBAL_MPI_CONTEXT = _MPIContext()
return '{}:{}'.format(self._backend, self._handle)
def is_initialized():
"""Whether the distributed env has initialized.
"""Return whether the distributed environment is initialized.
Returns
-------
bool
**True** if env has initialized otherwise **False**.
**True** if initialized otherwise **False**.
"""
return _GLOBAL_MPI_CONTEXT is not None
def is_mpi_available():
"""Return whether the MPI backend is available.
Returns
-------
bool
**True** if available otherwise **False**.
"""
return _b.mpiIsAvailable()
def is_nccl_available():
"""Return whether the NCCL backend is available.
Returns
-------
bool
**True** if available otherwise **False**.
"""
return _b.ncclIsAvailable()
def get_backend(group):
"""Return the backend of given process group.
......@@ -201,12 +217,12 @@ def get_rank(group=None):
The rank.
"""
init()
world_rank = _b.MPIRank()
_maybe_initialize()
world_rank = _b.mpiWorldRank()
if group is not None:
for idx, rank in enumerate(group.ranks):
for i, rank in enumerate(group.ranks):
if rank == world_rank:
return idx
return i
return world_rank
......@@ -219,8 +235,8 @@ def get_world_size():
The world size.
"""
init()
return _b.MPISize()
_maybe_initialize()
return _b.mpiWorldSize()
def new_group(ranks=None, backend=None, verbose=False):
......@@ -247,17 +263,27 @@ def new_group(ranks=None, backend=None, verbose=False):
if ranks is None:
return ProcessGroup(None, None, None, backend)
else:
init()
_maybe_initialize()
ranks = nest.flatten(ranks)
comm, handle = _b.MPICreateGroup(ranks, verbose)
comm, handle = _b.mpiCreateGroup(ranks, verbose)
return ProcessGroup(ranks, comm, handle, backend)
def _maybe_initialize():
"""Maybe initialize the distributed environment."""
if is_initialized():
# ATTENTION: MPI env can only be initialized once per process.
return
_b.mpiInitialize()
global _GLOBAL_MPI_CONTEXT
_GLOBAL_MPI_CONTEXT = _MPIContext()
class _MPIContext(object):
"""Context to finalize mpi under destruction."""
def __del__(self):
_b.MPIFinalize()
_b.mpiFinalize()
_GLOBAL_MPI_CONTEXT = None
......
......@@ -80,4 +80,4 @@ def run_operator(
ws.run_operator(op_def)
# Return the outputs.
return outputs if len(outputs) > 1 else outputs[0]
return outputs[0] if len(outputs) == 1 else outputs
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Utilities to fly an operator."""
"""Wrapper and utilities for operator."""
from __future__ import absolute_import
from __future__ import division
......@@ -27,7 +27,7 @@ from dragon.core.framework import workspace
class Operator(object):
"""Wrapper to unify the symbolic and eager operator abstraction."""
"""The operator wrapper."""
def __init__(self, cache_key, device, **kwargs):
self._def = None
......@@ -46,8 +46,14 @@ class Operator(object):
return self.__call__(*args, **kwargs)
def attributes(self):
"""Define the attributes to generate OpDef."""
return {}
"""Define the attributes to generate OpDef.
Returns
-------
dict
The attribute dict.
"""
@classmethod
def blend(cls, op_type=None, **kwargs):
......@@ -97,7 +103,6 @@ class Operator(object):
def forward(self, *inputs, **kwargs):
"""Define the execution."""
pass
def _gen_def(self):
"""Generate the OpDef from attributes."""
......@@ -142,8 +147,7 @@ def scalar_to_tensor(input, dtype):
except (TypeError, ValueError):
raise ValueError(
'<input> should be a python number, got {}.'
.format(type(input).__name__)
)
.format(type(input).__name__))
name = '/share/scalar/{}/{}'.format(dtype, str(input))
ws = workspace.get_workspace()
if not ws.has_tensor(name):
......
......@@ -21,6 +21,7 @@ import dragon
parser = argparse.ArgumentParser(add_help=False)
TEST_CUDA = dragon.cuda.is_available()
TEST_MPI = dragon.distributed.is_mpi_available()
def run_tests(argv=None):
......
......@@ -84,10 +84,7 @@ def not_installed(package=''):
"""Return a dummy function for the package that is not installed."""
def dummy_fn(*args, **kwargs):
_ = locals()
raise ImportError(
'Package <%s> is required but not installed.'
% package
)
raise ImportError('Package <%s> is required but not installed.' % package)
return dummy_fn
......@@ -98,7 +95,4 @@ class NotInstalled(object):
self._package = package
def __getattr__(self, item):
raise ImportError(
'Package <%s> is required but not installed.'
% self._package
)
raise ImportError('Package <%s> is required but not installed.' % self._package)
......@@ -7,11 +7,8 @@
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# Codes are based on:
#
# <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/platform/tf_logging.py>
#
# ------------------------------------------------------------
"""The logging utility."""
from __future__ import absolute_import
from __future__ import division
......@@ -26,28 +23,23 @@ import threading
from dragon import backend
from dragon.core.framework import config
_logger = None
_logger_lock = threading.Lock()
def get_logger():
"""Return the current logger."""
global _logger
# Use double-checked locking to avoid taking lock unnecessarily.
if _logger:
return _logger
_logger_lock.acquire()
try:
if _logger:
return _logger
logger = _logging.getLogger('dragon')
logger.setLevel('INFO')
logger.propagate = False
if True:
# Determine whether we are in an interactive environment.
_interactive = False
......@@ -58,7 +50,6 @@ def get_logger():
except AttributeError:
# Even now, we may be in an interactive shell with `python -i`.
_interactive = _sys.flags.interactive
# If we are in an interactive environment (like Jupyter), set loglevel
# to INFO and pipe the output to stdout.
if _interactive:
......@@ -66,26 +57,28 @@ def get_logger():
_logging_target = _sys.stdout
else:
_logging_target = _sys.stderr
# Add the output handler.
_handler = _logging.StreamHandler(_logging_target)
_handler.setFormatter(_logging.Formatter('%(levelname)s %(message)s'))
logger.addHandler(_handler)
_logger = logger
return _logger
finally:
_logger_lock.release()
def _detailed_msg(msg):
file, lineno = inspect.stack()[:3][2][1:3]
return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
def log(level, msg, *args, **kwargs):
"""Log message at the given level.
Parameters
----------
level : Union[int,str]
The logging level value.
msg: str
The message.
def log(level, msg, *args, **kwargs):
"""Log message at the given level."""
"""
level = _logging._checkLevel(level)
get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
......@@ -126,7 +119,14 @@ def fatal(msg, *args, **kwargs):
def get_verbosity():
"""Return the logging level."""
"""Return the current logging level.
Returns
-------
int
The logging level value.
"""
return get_logger().getEffectiveLevel()
......@@ -147,7 +147,7 @@ def set_directory(path):
Parameters
----------
path : str
path : str, optional
The path of the directory.
"""
......@@ -171,26 +171,14 @@ def set_verbosity(level):
Parameters
----------
level : str
The logging level.
level : Union[int, str]
The logging level value.
"""
get_logger().setLevel(level)
backend.SetLoggingLevel(level)
def warn(msg, *args, **kwargs):
"""Log message at the WARNING level.
Parameters
----------
msg: str
The message.
"""
get_logger().warn(_detailed_msg(msg), *args, **kwargs)
def warning(msg, *args, **kwargs):
"""Log message at the WARNING level.
......@@ -201,3 +189,9 @@ def warning(msg, *args, **kwargs):
"""
get_logger().warning(_detailed_msg(msg), *args, **kwargs)
def _detailed_msg(msg):
"""Return the formatted message with file and lineno."""
file, lineno = inspect.stack()[:3][2][1:3]
return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""The nest utility."""
from __future__ import absolute_import
from __future__ import division
......@@ -16,8 +17,8 @@ from __future__ import print_function
from dragon.core.util import six
def is_sequence(input):
"""Return a bool indicating whether input is a sequence.
def is_nested(input):
"""Return a bool indicating whether input is a sequence or dict.
Parameters
----------
......@@ -27,16 +28,14 @@ def is_sequence(input):
Returns
-------
bool
**True** if input is a sequence otherwise **False**.
**True** if input is a sequence or dict otherwise **False**.
"""
return \
isinstance(input, six.collections_abc.Sequence) and \
not isinstance(input, six.string_types)
return is_sequence(input) or isinstance(input, dict)
def is_nested(input):
"""Return a bool indicating whether input is a sequence or dict.
def is_sequence(input):
"""Return a bool indicating whether input is a sequence.
Parameters
----------
......@@ -46,10 +45,11 @@ def is_nested(input):
Returns
-------
bool
**True** if input is a sequence or dict otherwise **False**.
**True** if input is a sequence otherwise **False**.
"""
return is_sequence(input) or isinstance(input, dict)
return (isinstance(input, six.collections_abc.Sequence) and
not isinstance(input, six.string_types))
def flatten(input):
......@@ -84,7 +84,7 @@ def flatten(input):
return output_list
def flatten_with_tuple_paths(input):
def flatten_with_paths(input):
"""Return a flat list, yield as *(paths, element)*.
Parameters
......@@ -94,14 +94,14 @@ def flatten_with_tuple_paths(input):
Returns
-------
List[Tuple[Tuple, Object]]
List[Tuple[Tuple, object]]
The flat list of input.
"""
return list(zip(yield_flat_paths(input), flatten(input)))
return list(zip(yield_flatten_paths(input), flatten(input)))
def yield_flat_paths(input):
def yield_flatten_paths(input):
"""Yield paths for nested structure.
Parameters
......@@ -115,27 +115,29 @@ def yield_flat_paths(input):
The iterator of paths.
"""
for k, _ in _yield_flat_up_to(input, input, is_nested):
for k, _ in _yield_flatten_up_to(input, input, is_nested):
yield k
def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()):
def _yield_flatten_up_to(shallow_tree, input_tree, is_seq, path=()):
"""Return the tuple of path and element for iterable."""
if not is_seq(shallow_tree):
yield (path, input_tree)
yield path, input_tree
else:
input_tree = dict(_yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
sub_path = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(
for leaf_path, leaf_value in _yield_flatten_up_to(
shallow_subtree,
input_subtree,
is_seq,
path=subpath):
path=sub_path):
yield leaf_path, leaf_value
def _yield_sorted_items(iterable):
"""Return the sorted iterable."""
if isinstance(iterable, six.collections_abc.Mapping):
for key in sorted(iterable):
yield key, iterable[key]
......
......@@ -8,7 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Registry utilities."""
"""The registry utility."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -15,7 +15,6 @@ from __future__ import print_function
from typing import cast
from typing import IO
from typing import Optional
from typing import Text
......@@ -44,8 +43,8 @@ def serialize_proto(proto):
return b''
elif isinstance(proto, bytes):
return proto
elif hasattr(proto, 'SerializeToString') and \
callable(proto.SerializeToString):
elif (hasattr(proto, 'SerializeToString') and
callable(proto.SerializeToString)):
result = proto.SerializeToString()
return result
else:
......@@ -57,16 +56,11 @@ def serialize_proto(proto):
def deserialize_proto(s, proto):
"""Deserialize the protocol buffer object."""
if not isinstance(s, bytes):
raise ValueError(
'Excepted serialized bytes, got type: {}'.format(type(s)))
raise ValueError('Excepted serialized bytes, got: {}'.format(type(s)))
if not (hasattr(proto, 'ParseFromString') and
callable(proto.ParseFromString)):
raise ValueError(
'No <ParseFromString> method. Type is {}'
.format(type(proto)))
decoded = cast(Optional[int], proto.ParseFromString(s))
if decoded is not None and decoded != len(s):
raise RuntimeError(
'Protobuf decoding consumed too few bytes: {} out of {}'
.format(decoded, len(s)))
proto.ParseFromString(s)
return proto
......@@ -62,13 +62,10 @@ class Stack(threading.local):
try:
yield default
finally:
# Stack may be empty if reset() was called.
if self.stack:
if self._enforce_nesting:
if self.stack[-1] is not default:
raise AssertionError(
"Nesting violated for default stack of %s objects" %
type(default))
raise RuntimeError('Nesting violated by the push or pop.')
self.stack.pop()
else:
self.stack.remove(default)
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test the autograph module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import dragon
from dragon.core.testing.unittest.common_utils import run_tests
class TestFunction(unittest.TestCase):
"""Test the graph function."""
@dragon.function(input_signature=[
dragon.Tensor(dtype='int32'),
dragon.Tensor(dtype='int32'),
dragon.Tensor(dtype='int32'),
])
def func1(self, a, b, c=0, **kwargs):
_ = kwargs
return a + b + c
def test_def_function(self):
@dragon.function(input_signature=[dragon.Tensor()])
def func2(a, b):
return a + b
self.assertEqual(self.func1([1, 2], [3, 4]).get_value().tolist(), [4, 6])
self.assertEqual(self.func1([1, 2], b=[3, 4]).get_value().tolist(), [4, 6])
self.assertEqual(self.func1([1, 2], b=[3, 4], c=1).get_value().tolist(), [5, 7])
self.assertEqual(self.func1([1, 2], b=[3, 4], c=1).get_value().tolist(), [5, 7])
self.assertEqual(self.func1([1, 2], [3, 4], executing_stage='forward').get_value().tolist(), [4, 6])
dragon.function(func=lambda: dragon.optimizers.SGD())()
try:
self.func1(1, 2, 3, 4)
except ValueError:
pass
try:
func2(1, 2)
except ValueError:
pass
def test_update_function(self):
optimizer = dragon.optimizers.SGD()
try:
_ = optimizer.op_type
except KeyError:
pass
value = dragon.Tensor(dtype='float32').set_value(1.)
grad = dragon.Tensor(dtype='float32').set_value(1.)
optimizer.apply_gradients([(value, grad)])
dragon.create_function(optimizer=optimizer)()
if __name__ == '__main__':
run_tests()
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test the device module."""
from __future__ import absolute_import
from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test the distributed module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import dragon
from dragon.core.testing.unittest.common_utils import run_tests
from dragon.core.testing.unittest.common_utils import TEST_MPI
class TestBackend(unittest.TestCase):
"""Test the backend components."""
def test_empty_group(self):
for backend in (None, 'AUTO', 'NCCL', 'MPI', 'UNKNOWN', 0):
try:
group = dragon.distributed.new_group(backend=backend)
self.assertEqual(group.ranks, None)
self.assertEqual(group.size, 0)
self.assertEqual(group.arguments['backend'], group.backend)
self.assertEqual(repr(group), '%s:None' % group.backend)
with group.as_default():
self.assertEqual(dragon.distributed.get_group(), None)
self.assertEqual(dragon.distributed.get_backend(group), group.backend)
except ValueError:
pass
@unittest.skipIf(not TEST_MPI, 'MPI unavailable')
def test_mpi_single_process(self):
self.assertEqual(dragon.distributed.get_rank(), 0)
self.assertEqual(dragon.distributed.get_world_size(), 1)
group = dragon.distributed.new_group(ranks=[0], backend='MPI')
with group.as_default():
self.assertEqual(dragon.distributed.get_rank(group), 0)
if __name__ == '__main__':
run_tests()
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test the framework module."""
from __future__ import absolute_import
from __future__ import division
......
......@@ -8,6 +8,7 @@
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test the ops module."""
from __future__ import absolute_import
from __future__ import division
......
# ------------------------------------------------------------
# Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
#
# Licensed under the BSD 2-Clause License.
# You should have received a copy of the BSD 2-Clause License
# along with the software. If not, See,
#
# <https://opensource.org/licenses/BSD-2-Clause>
#
# ------------------------------------------------------------
"""Test the util module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import logging
import os
import threading
import unittest
import dragon
from dragon.core.proto import dragon_pb2
from dragon.core.util import deprecation
from dragon.core.util import nest
from dragon.core.util import registry
from dragon.core.util import serialization
from dragon.core.util import tls
from dragon.core.testing.unittest.common_utils import run_tests
class TestDeprecation(unittest.TestCase):
"""Test the deprecation utility."""
try:
@deprecation.deprecated('2077-01-01', 'Bad deprecated.')
@property
def deprecated_property(self):
return None
except ValueError:
pass
def test_deprecated(self):
@deprecation.deprecated(
'2077-01-01', 'This function is deprecated since 2077.')
def func():
pass
dragon.logging.set_verbosity('FATAL')
func()
dragon.logging.set_verbosity('INFO')
try:
@deprecation.deprecated('2077-01', 'Bad deprecated.')
def func():
pass
func()
except ValueError:
pass
try:
@deprecation.deprecated('2077-01-01', '')
def func():
pass
func()
except ValueError:
pass
def test_not_installed(self):
try:
deprecation.not_installed('!@#$%^&**()')()
except ImportError:
pass
module = deprecation.NotInstalled('!@#$%^&**()')
try:
module.func()
except ImportError:
pass
class TestLogging(unittest.TestCase):
"""Test the logging utility."""
def test_message(self):
dragon.logging.set_verbosity('FATAL')
self.assertEqual(dragon.logging.get_verbosity(), logging.FATAL)
dragon.logging.debug('Test dragon.logging.debug(..)')
dragon.logging.info('Test dragon.logging.info(..)')
dragon.logging.warning('Test dragon.logging.warning(..)')
dragon.logging.error('Test dragon.logging.error(..)')
dragon.logging.log('INFO', 'Test dragon.logging.error(..)')
dragon.logging.set_verbosity('INFO')
self.assertEqual(dragon.logging.get_verbosity(), logging.INFO)
def test_logging_file(self):
dragon.logging.set_directory(None)
class TestNest(unittest.TestCase):
"""Test the nest utility."""
def test_nested(self):
self.assertTrue(nest.is_nested(list()))
self.assertTrue(nest.is_sequence(list()))
self.assertTrue(nest.is_nested(tuple()))
self.assertTrue(nest.is_sequence(tuple()))
self.assertTrue(nest.is_nested(dict()))
self.assertFalse(nest.is_sequence(dict()))
def test_flatten(self):
self.assertEqual(nest.flatten(1), [1])
for a, b in zip(nest.flatten_with_paths([2, 4, [31, 32], 1]),
[((0,), 2), ((1,), 4), ((2, 0), 31), ((2, 1), 32), ((3,), 1)]):
self.assertEqual(a, b)
for a, b in zip(nest.flatten_with_paths({2: 2, 4: 4, 3: {1: 31, 2: 32}, 1: 1}),
[((1,), 1), ((2,), 2), ((3, 1), 31), ((3, 2), 32), ((4,), 4)]):
self.assertEqual(a, b)
class TestRegistry(unittest.TestCase):
"""Test the registry utility."""
def test_register(self):
reg = registry.Registry('test_registry')
reg.register('a+b', lambda a, b: a + b)
self.assertTrue('a+b' in reg.keys)
self.assertTrue(reg.has('a+b'))
self.assertEqual(reg.get('a+b')(1, 2), 3)
try:
reg.get('c+d')
except KeyError:
pass
class TestSerialization(unittest.TestCase):
"""Test the serialization utility."""
def test_bytes(self):
f = io.BytesIO(b'123')
serialization.save_bytes(b'456', f)
f.seek(0)
self.assertEqual(serialization.load_bytes(f), b'456')
save_file = '/tmp/test_dragon_serialization_save_bytes'
try:
serialization.save_bytes(b'789', save_file)
except OSError:
pass
try:
s = serialization.load_bytes(save_file)
self.assertEqual(s, b'789')
except FileNotFoundError:
pass
try:
if os.path.exists(save_file):
os.remove(save_file)
except PermissionError:
pass
def test_proto(self):
self.assertEqual(serialization.serialize_proto(None), b'')
s = serialization.serialize_proto(dragon_pb2.OperatorDef(name='!@#$%^&**()'))
s = serialization.serialize_proto(s)
proto = serialization.deserialize_proto(s, dragon_pb2.OperatorDef())
self.assertEqual(proto, dragon_pb2.OperatorDef(name='!@#$%^&**()'))
try:
serialization.serialize_proto(1)
except ValueError:
pass
try:
serialization.deserialize_proto(2, dragon_pb2.OperatorDef())
except ValueError:
pass
try:
serialization.deserialize_proto(s, 2)
except ValueError:
pass
class TestTLS(unittest.TestCase):
"""Test the tls utility."""
def test_constant(self):
def write(i, q):
c.value = i
q.append(c.value)
c, q = tls.Constant(value=-1), []
threads = [threading.Thread(target=write, args=[i, q]) for i in range(4)]
for t in threads:
t.start()
t.join()
self.assertEqual(c.value, -1)
def test_stack(self):
s = tls.Stack()
s.enforce_nesting = True
self.assertEqual(s.enforce_nesting, True)
try:
with s.get_controller('!@#$%^&**()'):
s.push('123456')
except RuntimeError:
pass
s.enforce_nesting = False
with s.get_controller('!@#$%^&**()'):
s.push('123456')
if __name__ == '__main__':
run_tests()
......@@ -19,9 +19,12 @@ import subprocess
import argparse
TESTS_AND_SOURCES = [
('dragon/core/test_autograph', 'dragon.core'),
('dragon/core/test_device', 'dragon.core'),
('dragon/core/test_distributed', 'dragon.core'),
('dragon/core/test_framework', 'dragon.core'),
('dragon/core/test_ops', 'dragon.core'),
('dragon/core/test_util', 'dragon.core'),
]
TESTS = [t[0] for t in TESTS_AND_SOURCES]
......
......@@ -39,8 +39,14 @@ class Function(object):
return self.__call__(*args, **kwargs)
def attributes(self):
"""Define the attributes to generate OpDef."""
return {}
"""Define the attributes to generate OpDef.
Returns
-------
dict
The attribute dict.
"""
def dispatch(
self,
......@@ -92,7 +98,6 @@ class Function(object):
def forward(self, *inputs, **kwargs):
"""Define the execution."""
raise RuntimeError('The base function can not be called.')
def _gen_def(self):
"""Generate the OpDef from attributes."""
......
......@@ -88,7 +88,7 @@ def run_operator(
ws.run_operator(op_def)
# Return the outputs.
return outputs if len(outputs) > 1 else outputs[0]
return outputs[0] if len(outputs) == 1 else outputs
def run_backward(tensors, grad_tensors=None, retain_graph=False):
......
......@@ -104,8 +104,7 @@ class FunctionGuard(object):
'When <example_inputs> is provided, '
'only define arguments covered by it.\n'
'Got %d inputs(s) and %d argument(s).'
% (len(input_signature),
self._function_spec.num_inputs)
% (len(input_signature), self._function_spec.num_inputs)
)
if input_signature[i] is not None:
inputs.append(Tensor(
......
......@@ -66,6 +66,7 @@ class Tensor(object):
```
"""
def __init__(self, *args, **kwargs):
self._tape = None
self._gc = kwargs.get('gc', None)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!