Skip to content
Toggle navigation
P
Projects
G
Groups
S
Snippets
Help
SeetaResearch
/
Dragon
This project
Loading...
Sign in
Toggle navigation
Go to a project
Project
Repository
Issues
0
Merge Requests
0
Pipelines
Wiki
Snippets
Settings
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Commit bd84b7fd
authored
May 14, 2019
by
Ting PAN
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add masked AssignOp
1 parent
e90a8f1a
Show whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
630 additions
and
219 deletions
Docs/api/python/contents/ops.rst
Dragon/include/core/context_cuda.h
Dragon/include/operators/control_flow/masked_assign_op.h
Dragon/include/utils/cuda_device.h
Dragon/include/utils/cudnn_device.h
Dragon/include/utils/op_kernel.h
Dragon/python/dragon/core/tensor.py
Dragon/python/dragon/operators/control_flow.py
Dragon/python/dragon/ops.py
Dragon/python/dragon/utils/vision/data_transformer.py
Dragon/python/dragon/vm/torch/module.py
Dragon/python/dragon/vm/torch/ops/builtin.py
Dragon/python/dragon/vm/torch/ops/modules/array.py
Dragon/python/dragon/vm/torch/ops/modules/base.py
Dragon/python/dragon/vm/torch/ops/modules/control_flow.py
Dragon/python/dragon/vm/torch/ops/modules/init.py
Dragon/python/dragon/vm/torch/ops/modules/update.py
Dragon/python/dragon/vm/torch/ops/modules/vision.py
Dragon/python/dragon/vm/torch/ops/tensor.py
Dragon/python/dragon/vm/torch/tensor.py
Dragon/src/core/mixedmem.cc
Dragon/src/core/operator_schema.cc
Dragon/src/kernels/control_flow/assign_op_kernel.cu
Dragon/src/kernels/control_flow/masked_assign_op_kernel.cc
Dragon/src/kernels/control_flow/masked_assign_op_kernel.cu
Dragon/src/operators/control_flow/assign_op.cc
Dragon/src/operators/control_flow/masked_assign_op.cc
Docs/api/python/contents/ops.rst
View file @
bd84b7f
...
@@ -168,6 +168,7 @@ List Brief
...
@@ -168,6 +168,7 @@ List Brief
=============== ======================================================================
=============== ======================================================================
`Copy`_ Copy the *value* to *ref*.
`Copy`_ Copy the *value* to *ref*.
`Assign`_ Assign the *value* to *ref*.
`Assign`_ Assign the *value* to *ref*.
`MaskedAssign`_ Assign the *value* to *ref* where mask is *1*.
`Equal`_ *Equal* Comparing between A and B.
`Equal`_ *Equal* Comparing between A and B.
`Less`_ *Less* Comparing between A and B.
`Less`_ *Less* Comparing between A and B.
`LessEqual`_ *LessEqual* Comparing between A and B.
`LessEqual`_ *LessEqual* Comparing between A and B.
...
@@ -308,8 +309,9 @@ List Brief
...
@@ -308,8 +309,9 @@ List Brief
.. _Arange: operators/array.html#dragon.operators.array.Arange
.. _Arange: operators/array.html#dragon.operators.array.Arange
.. _Multinomial: operators/array.html#dragon.operators.array.Multinomial
.. _Multinomial: operators/array.html#dragon.operators.array.Multinomial
.. _Copy: operators/control_flow.html#dAragon.operators.control_flow.Copy
.. _Copy: operators/control_flow.html#dragon.operators.control_flow.Copy
.. _Assign: operators/control_flow.html#dAragon.operators.control_flow.Assign
.. _Assign: operators/control_flow.html#dragon.operators.control_flow.Assign
.. _MaskedAssign: operators/control_flow.html#dragon.operators.control_flow.MaskedAssign
.. _Equal: operators/control_flow.html#dragon.operators.control_flow.Equal
.. _Equal: operators/control_flow.html#dragon.operators.control_flow.Equal
.. _Less: operators/control_flow.html#dragon.operators.control_flow.Less
.. _Less: operators/control_flow.html#dragon.operators.control_flow.Less
.. _LessEqual: operators/control_flow.html#dragon.operators.control_flow.LessEqual
.. _LessEqual: operators/control_flow.html#dragon.operators.control_flow.LessEqual
...
...
Dragon/include/core/context_cuda.h
View file @
bd84b7f
...
@@ -72,7 +72,7 @@ class CUDAObject {
...
@@ -72,7 +72,7 @@ class CUDAObject {
if
(
streams
.
size
()
<=
(
unsigned
)
stream_id
)
if
(
streams
.
size
()
<=
(
unsigned
)
stream_id
)
streams
.
resize
(
stream_id
+
1
,
nullptr
);
streams
.
resize
(
stream_id
+
1
,
nullptr
);
if
(
!
streams
[
stream_id
])
{
if
(
!
streams
[
stream_id
])
{
DeviceGuard
guard
(
device_id
);
CUDA
DeviceGuard
guard
(
device_id
);
unsigned
int
flags
=
!
stream_id
?
unsigned
int
flags
=
!
stream_id
?
cudaStreamDefault
:
cudaStreamDefault
:
cudaStreamNonBlocking
;
cudaStreamNonBlocking
;
...
@@ -97,7 +97,7 @@ class CUDAObject {
...
@@ -97,7 +97,7 @@ class CUDAObject {
if
(
handles
.
size
()
<=
(
unsigned
)
stream_id
)
if
(
handles
.
size
()
<=
(
unsigned
)
stream_id
)
handles
.
resize
(
stream_id
+
1
,
nullptr
);
handles
.
resize
(
stream_id
+
1
,
nullptr
);
if
(
!
handles
[
stream_id
])
{
if
(
!
handles
[
stream_id
])
{
DeviceGuard
guard
(
device_id
);
CUDA
DeviceGuard
guard
(
device_id
);
CUBLAS_CHECK
(
cublasCreate_v2
(
&
handles
[
stream_id
]));
CUBLAS_CHECK
(
cublasCreate_v2
(
&
handles
[
stream_id
]));
CUBLAS_CHECK
(
cublasSetStream_v2
(
CUBLAS_CHECK
(
cublasSetStream_v2
(
handles
[
stream_id
],
handles
[
stream_id
],
...
@@ -120,7 +120,7 @@ class CUDAObject {
...
@@ -120,7 +120,7 @@ class CUDAObject {
if
(
handles
.
size
()
<=
(
unsigned
)
stream_id
)
if
(
handles
.
size
()
<=
(
unsigned
)
stream_id
)
handles
.
resize
(
stream_id
+
1
,
nullptr
);
handles
.
resize
(
stream_id
+
1
,
nullptr
);
if
(
!
handles
[
stream_id
])
{
if
(
!
handles
[
stream_id
])
{
DeviceGuard
guard
(
device_id
);
CUDA
DeviceGuard
guard
(
device_id
);
CUDNN_CHECK
(
cudnnCreate
(
&
handles
[
stream_id
]));
CUDNN_CHECK
(
cudnnCreate
(
&
handles
[
stream_id
]));
CUDNN_CHECK
(
cudnnSetStream
(
CUDNN_CHECK
(
cudnnSetStream
(
handles
[
stream_id
],
handles
[
stream_id
],
...
@@ -292,7 +292,7 @@ class CUDAContext {
...
@@ -292,7 +292,7 @@ class CUDAContext {
/*! \brief Return the internal cuda random generator */
/*! \brief Return the internal cuda random generator */
curandGenerator_t
&
curand_generator
()
{
curandGenerator_t
&
curand_generator
()
{
if
(
!
curand_generator_
)
{
if
(
!
curand_generator_
)
{
DeviceGuard
guard
(
device_id_
);
CUDA
DeviceGuard
guard
(
device_id_
);
CURAND_CHECK
(
curandCreateGenerator
(
CURAND_CHECK
(
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
CURAND_CHECK
(
curandSetPseudoRandomGeneratorSeed
(
CURAND_CHECK
(
curandSetPseudoRandomGeneratorSeed
(
...
...
Dragon/include/operators/control_flow/masked_assign_op.h
0 → 100644
View file @
bd84b7f
/*!
* Copyright (c) 2017-present, SeetaTech, Co.,Ltd.
*
* Licensed under the BSD 2-Clause License.
* You should have received a copy of the BSD 2-Clause License
* along with the software. If not, See,
*
* <https://opensource.org/licenses/BSD-2-Clause>
*
* ------------------------------------------------------------
*/
#ifndef DRAGON_OPERATORS_CONTROL_FLOW_MASKED_ASSIGN_OP_H_
#define DRAGON_OPERATORS_CONTROL_FLOW_MASKED_ASSIGN_OP_H_
#include "core/operator.h"
namespace
dragon
{
template
<
class
Context
>
class
MaskedAssignOp
final
:
public
Operator
<
Context
>
{
public
:
MaskedAssignOp
(
const
OperatorDef
&
def
,
Workspace
*
ws
)
:
Operator
<
Context
>
(
def
,
ws
)
{}
USE_OPERATOR_FUNCTIONS
;
void
RunOnDevice
()
override
;
template
<
typename
T
>
void
RunImpl
();
};
}
// namespace dragon
#endif // DRAGON_OPERATORS_CONTROL_FLOW_MASKED_ASSIGN_OP_H_
\ No newline at end of file
Dragon/include/utils/cuda_device.h
View file @
bd84b7f
...
@@ -135,19 +135,19 @@ struct CUDADeviceProps {
...
@@ -135,19 +135,19 @@ struct CUDADeviceProps {
vector
<
cudaDeviceProp
>
props
;
vector
<
cudaDeviceProp
>
props
;
};
};
inline
const
cudaDeviceProp
&
Get
DeviceProperty
(
inline
const
cudaDeviceProp
&
Get
CUDADeviceProp
(
const
int
device_id
)
{
int
device_id
)
{
static
CUDADeviceProps
props
;
static
CUDADeviceProps
props
;
CHECK_LT
(
device_id
,
(
int
)
props
.
props
.
size
())
CHECK_LT
(
device_id
,
(
int
)
props
.
props
.
size
())
<<
"Invalid device id: "
<<
device_id
<<
"
\n
Invalid device id: "
<<
device_id
<<
"
\n
Detected "
<<
props
.
props
.
size
()
<<
"
\n
Detected "
<<
props
.
props
.
size
()
<<
"
eligible cuda
devices."
;
<<
" devices."
;
return
props
.
props
[
device_id
];
return
props
.
props
[
device_id
];
}
}
inline
bool
CUDA_TRUE_FP16_AVAILABLE
()
{
inline
bool
CUDA_TRUE_FP16_AVAILABLE
()
{
int
device
=
CUDA_GET_DEVICE
();
int
device
=
CUDA_GET_DEVICE
();
auto
&
prop
=
Get
DeviceProperty
(
device
);
auto
&
prop
=
Get
CUDADeviceProp
(
device
);
return
prop
.
major
>=
6
;
return
prop
.
major
>=
6
;
}
}
...
@@ -156,21 +156,26 @@ inline bool TENSOR_CORE_AVAILABLE() {
...
@@ -156,21 +156,26 @@ inline bool TENSOR_CORE_AVAILABLE() {
return
false
;
return
false
;
#else
#else
int
device
=
CUDA_GET_DEVICE
();
int
device
=
CUDA_GET_DEVICE
();
auto
&
prop
=
Get
DeviceProperty
(
device
);
auto
&
prop
=
Get
CUDADeviceProp
(
device
);
return
prop
.
major
>=
7
;
return
prop
.
major
>=
7
;
#endif
#endif
}
}
class
DeviceGuard
{
class
CUDA
DeviceGuard
{
public
:
public
:
DeviceGuard
(
int
new_id
)
:
prev_id
(
CUDA_GET_DEVICE
())
{
CUDADeviceGuard
(
int
new_id
)
if
(
prev_id
!=
new_id
)
CUDA_CHECK
(
cudaSetDevice
(
new_id
));
:
prev_id_
(
CUDA_GET_DEVICE
())
{
if
(
prev_id_
!=
new_id
)
{
CUDA_CHECK
(
cudaSetDevice
(
new_id
));
}
}
}
~
DeviceGuard
()
{
CUDA_CHECK
(
cudaSetDevice
(
prev_id
));
}
~
CUDADeviceGuard
()
{
CUDA_CHECK
(
cudaSetDevice
(
prev_id_
));
}
private
:
private
:
int
prev_id
;
int
prev_id
_
;
};
};
#else
#else
...
...
Dragon/include/utils/cudnn_device.h
View file @
bd84b7f
Dragon/include/utils/op_kernel.h
View file @
bd84b7f
...
@@ -657,6 +657,16 @@ void GreaterEqual(
...
@@ -657,6 +657,16 @@ void GreaterEqual(
bool
*
y
,
bool
*
y
,
Context
*
ctx
);
Context
*
ctx
);
/*! control_flow.masked_assign */
template
<
typename
T
,
class
Context
>
void
MaskedAssign
(
const
int
count
,
const
uint8_t
*
mask
,
const
T
*
x
,
T
*
y
,
Context
*
ctx
);
/*! loss.l1_loss */
/*! loss.l1_loss */
template
<
typename
T
,
class
Context
>
template
<
typename
T
,
class
Context
>
...
...
Dragon/python/dragon/core/tensor.py
View file @
bd84b7f
...
@@ -488,9 +488,9 @@ class Tensor(object):
...
@@ -488,9 +488,9 @@ class Tensor(object):
Parameters
Parameters
----------
----------
key : int
or slice
key : int
, slice or Tensor
The indices.
The indices.
value :
Tensor, number or sequence
value :
number, sequence or Tensor
The value.
The value.
Returns
Returns
...
@@ -498,11 +498,20 @@ class Tensor(object):
...
@@ -498,11 +498,20 @@ class Tensor(object):
None
None
"""
"""
starts
,
sizes
=
self
.
_process_indices
(
key
)
if
not
isinstance
(
value
,
Tensor
):
if
not
isinstance
(
value
,
Tensor
):
value
=
self
.
_from_constant
(
value
)
value
=
self
.
_from_constant
(
value
)
return
self
.
CreateOperator
(
'Assign'
,
[
value
],
if
isinstance
(
key
,
Tensor
):
existing_outputs
=
[
self
],
starts
=
starts
,
sizes
=
sizes
)
return
self
.
CreateOperator
(
'MaskedAssign'
,
[
value
,
key
],
existing_outputs
=
[
self
],
)
else
:
starts
,
sizes
=
self
.
_process_indices
(
key
)
return
self
.
CreateOperator
(
'Assign'
,
[
value
],
starts
=
starts
,
sizes
=
sizes
,
existing_outputs
=
[
self
],
)
def
_from_constant
(
self
,
value
,
name
=
None
):
def
_from_constant
(
self
,
value
,
name
=
None
):
if
not
isinstance
(
value
,
numpy
.
ndarray
):
if
not
isinstance
(
value
,
numpy
.
ndarray
):
...
...
Dragon/python/dragon/operators/control_flow.py
View file @
bd84b7f
...
@@ -75,10 +75,36 @@ def Assign(inputs, starts=None, sizes=None, **kwargs):
...
@@ -75,10 +75,36 @@ def Assign(inputs, starts=None, sizes=None, **kwargs):
@OpSchema.ConvertConstantInputs
()
@OpSchema.ConvertConstantInputs
()
@OpSchema.Inputs
(
2
)
@OpSchema.Inputs
(
2
)
def
MaskedAssign
(
inputs
,
mask
,
**
kwargs
):
"""Assign the ``value`` to ``ref`` where ``mask`` is *1*.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
Parameters
----------
inputs : sequence of Tensor
The ``ref`` and ``value`` respectively.
mask : Tensor
The mask, with the same size as ``ref``.
Returns
-------
Tensor
The ``ref``.
"""
arguments
=
ParseArgs
(
locals
())
arguments
[
'existing_outputs'
]
=
[
arguments
[
'inputs'
][
0
]]
arguments
[
'inputs'
]
=
[
arguments
[
'inputs'
][
1
],
mask
]
return
Tensor
.
CreateOperator
(
'Assign'
,
**
arguments
)
@OpSchema.ConvertConstantInputs
()
@OpSchema.Inputs
(
2
)
def
Equal
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
def
Equal
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
"""
``Equal``
comparing between A and B.
"""
*Equal*
comparing between A and B.
Set ``to_uint8`` if you expect the
``uint8`` results instead of ``bool``
.
Set ``to_uint8`` if you expect the
*uint8* results instead of *bool*
.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...
@@ -87,7 +113,7 @@ def Equal(inputs, to_uint8=False, **kwargs):
...
@@ -87,7 +113,7 @@ def Equal(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor
inputs : sequence of Tensor
The inputs, represent A and B respectively.
The inputs, represent A and B respectively.
to_uint8 : bool
to_uint8 : bool
``True`` to convert to ``uint8``
results.
*True* to convert to *uint8*
results.
Returns
Returns
-------
-------
...
@@ -102,9 +128,9 @@ def Equal(inputs, to_uint8=False, **kwargs):
...
@@ -102,9 +128,9 @@ def Equal(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs
()
@OpSchema.ConvertConstantInputs
()
@OpSchema.Inputs
(
2
)
@OpSchema.Inputs
(
2
)
def
Less
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
def
Less
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
"""
``Less``
comparing between A and B.
"""
*Less*
comparing between A and B.
Set ``to_uint8`` if you expect the
``uint8`` results instead of ``bool``
.
Set ``to_uint8`` if you expect the
*uint8* results instead of *bool*
.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...
@@ -113,7 +139,7 @@ def Less(inputs, to_uint8=False, **kwargs):
...
@@ -113,7 +139,7 @@ def Less(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor
inputs : sequence of Tensor
The inputs, represent A and B respectively.
The inputs, represent A and B respectively.
to_uint8 : bool
to_uint8 : bool
``True`` to convert to ``uint8``
results.
*True* to convert to *uint8*
results.
Returns
Returns
-------
-------
...
@@ -128,9 +154,9 @@ def Less(inputs, to_uint8=False, **kwargs):
...
@@ -128,9 +154,9 @@ def Less(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs
()
@OpSchema.ConvertConstantInputs
()
@OpSchema.Inputs
(
2
)
@OpSchema.Inputs
(
2
)
def
LessEqual
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
def
LessEqual
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
"""
``LessEqual``
comparing between A and B.
"""
*LessEqual*
comparing between A and B.
Set ``to_uint8`` if you expect the
``uint8`` results instead of ``bool``
.
Set ``to_uint8`` if you expect the
*uint8* results instead of *bool*
.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...
@@ -139,7 +165,7 @@ def LessEqual(inputs, to_uint8=False, **kwargs):
...
@@ -139,7 +165,7 @@ def LessEqual(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor
inputs : sequence of Tensor
The inputs, represent A and B respectively.
The inputs, represent A and B respectively.
to_uint8 : bool
to_uint8 : bool
``True`` to convert to ``uint8``
results.
*True* to convert to *uint8*
results.
Returns
Returns
-------
-------
...
@@ -154,9 +180,9 @@ def LessEqual(inputs, to_uint8=False, **kwargs):
...
@@ -154,9 +180,9 @@ def LessEqual(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs
()
@OpSchema.ConvertConstantInputs
()
@OpSchema.Inputs
(
2
)
@OpSchema.Inputs
(
2
)
def
Greater
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
def
Greater
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
"""
``Greater``
comparing between A and B.
"""
*Greater*
comparing between A and B.
Set ``to_uint8`` if you expect the
``uint8`` results instead of ``bool``
.
Set ``to_uint8`` if you expect the
*uint8* results instead of *bool*
.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...
@@ -165,7 +191,7 @@ def Greater(inputs, to_uint8=False, **kwargs):
...
@@ -165,7 +191,7 @@ def Greater(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor
inputs : sequence of Tensor
The inputs, represent A and B respectively.
The inputs, represent A and B respectively.
to_uint8 : bool
to_uint8 : bool
``True`` to convert to ``uint8``
results.
*True* to convert to *uint8*
results.
Returns
Returns
-------
-------
...
@@ -180,9 +206,9 @@ def Greater(inputs, to_uint8=False, **kwargs):
...
@@ -180,9 +206,9 @@ def Greater(inputs, to_uint8=False, **kwargs):
@OpSchema.ConvertConstantInputs
()
@OpSchema.ConvertConstantInputs
()
@OpSchema.Inputs
(
2
)
@OpSchema.Inputs
(
2
)
def
GreaterEqual
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
def
GreaterEqual
(
inputs
,
to_uint8
=
False
,
**
kwargs
):
"""
``GreaterEqual``
comparing between A and B.
"""
*GreaterEqual*
comparing between A and B.
Set ``to_uint8`` if you expect the
``uint8`` results instead of ``bool``
.
Set ``to_uint8`` if you expect the
*uint8* results instead of *bool*
.
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
**Type Constraints**: (*bool*, *int8*, *uint8*, *int32*, *int64*, *float16*, *float32*, *float64*)
...
@@ -191,7 +217,7 @@ def GreaterEqual(inputs, to_uint8=False, **kwargs):
...
@@ -191,7 +217,7 @@ def GreaterEqual(inputs, to_uint8=False, **kwargs):
inputs : sequence of Tensor
inputs : sequence of Tensor
The inputs, represent A and B respectively.
The inputs, represent A and B respectively.
to_uint8 : bool
to_uint8 : bool
``True`` to convert to ``uint8``
results.
*True* to convert to *uint8*
results.
Returns
Returns
-------
-------
...
...
Dragon/python/dragon/ops.py
View file @
bd84b7f
...
@@ -143,6 +143,7 @@ Multinomial = _array_ops.Multinomial
...
@@ -143,6 +143,7 @@ Multinomial = _array_ops.Multinomial
# Control Flow
# Control Flow
Copy
=
_control_flow_ops
.
Copy
Copy
=
_control_flow_ops
.
Copy
Assign
=
_control_flow_ops
.
Assign
Assign
=
_control_flow_ops
.
Assign
MaskedAssign
=
_control_flow_ops
.
MaskedAssign
Equal
=
_control_flow_ops
.
Equal
Equal
=
_control_flow_ops
.
Equal
Less
=
_control_flow_ops
.
Less
Less
=
_control_flow_ops
.
Less
LessEqual
=
_control_flow_ops
.
LessEqual
LessEqual
=
_control_flow_ops
.
LessEqual
...
...
Dragon/python/dragon/utils/vision/data_transformer.py
View file @
bd84b7f
...
@@ -70,8 +70,8 @@ class DataTransformer(multiprocessing.Process):
...
@@ -70,8 +70,8 @@ class DataTransformer(multiprocessing.Process):
self
.
_cutout_size
=
kwargs
.
get
(
'cutout_size'
,
0
)
self
.
_cutout_size
=
kwargs
.
get
(
'cutout_size'
,
0
)
self
.
_mirror
=
kwargs
.
get
(
'mirror'
,
False
)
self
.
_mirror
=
kwargs
.
get
(
'mirror'
,
False
)
self
.
_color_aug
=
kwargs
.
get
(
'color_augmentation'
,
False
)
self
.
_color_aug
=
kwargs
.
get
(
'color_augmentation'
,
False
)
self
.
_min_rand
om
_scale
=
kwargs
.
get
(
'min_random_scale'
,
1.0
)
self
.
_min_rand_scale
=
kwargs
.
get
(
'min_random_scale'
,
1.0
)
self
.
_max_rand
om
_scale
=
kwargs
.
get
(
'max_random_scale'
,
1.0
)
self
.
_max_rand_scale
=
kwargs
.
get
(
'max_random_scale'
,
1.0
)
self
.
_force_color
=
kwargs
.
get
(
'force_color'
,
False
)
self
.
_force_color
=
kwargs
.
get
(
'force_color'
,
False
)
self
.
_phase
=
kwargs
.
get
(
'phase'
,
'TRAIN'
)
self
.
_phase
=
kwargs
.
get
(
'phase'
,
'TRAIN'
)
self
.
_random_seed
=
_cfg
.
GetRandomSeed
()
self
.
_random_seed
=
_cfg
.
GetRandomSeed
()
...
@@ -102,12 +102,16 @@ class DataTransformer(multiprocessing.Process):
...
@@ -102,12 +102,16 @@ class DataTransformer(multiprocessing.Process):
im
=
im
.
reshape
((
datum
.
height
,
datum
.
width
,
datum
.
channels
))
im
=
im
.
reshape
((
datum
.
height
,
datum
.
width
,
datum
.
channels
))
# Random scale
# Random scale
random_scale
=
numpy
.
random
.
uniform
()
*
(
rand_scale
=
numpy
.
random
.
uniform
()
*
(
self
.
_max_random_scale
-
self
.
_min_random_scale
)
\
self
.
_max_rand_scale
-
self
.
_min_rand_scale
+
self
.
_min_random_scale
)
+
self
.
_min_rand_scale
if
random_scale
!=
1.0
:
if
rand_scale
!=
1.0
:
im
=
cv2
.
resize
(
im
,
None
,
fx
=
random_scale
,
im
=
cv2
.
resize
(
fy
=
random_scale
,
interpolation
=
cv2
.
INTER_LINEAR
)
im
,
None
,
fx
=
rand_scale
,
fy
=
rand_scale
,
interpolation
=
cv2
.
INTER_LINEAR
,
)
# Padding
# Padding
if
self
.
_padding
>
0
:
if
self
.
_padding
>
0
:
...
@@ -149,7 +153,7 @@ class DataTransformer(multiprocessing.Process):
...
@@ -149,7 +153,7 @@ class DataTransformer(multiprocessing.Process):
# Gray Transformation
# Gray Transformation
if
self
.
_force_color
:
if
self
.
_force_color
:
if
im
.
shape
[
2
]
==
1
:
if
im
.
shape
[
2
]
==
1
:
#
d
uplicate to 3 channels
#
D
uplicate to 3 channels
im
=
numpy
.
concatenate
([
im
,
im
,
im
],
axis
=
2
)
im
=
numpy
.
concatenate
([
im
,
im
,
im
],
axis
=
2
)
# Color Augmentation
# Color Augmentation
...
...
Dragon/python/dragon/vm/torch/module.py
View file @
bd84b7f
...
@@ -338,11 +338,13 @@ class Module(object):
...
@@ -338,11 +338,13 @@ class Module(object):
def
run
(
self
,
inputs
,
outputs
,
auto_grad
=
True
,
callback
=
None
):
def
run
(
self
,
inputs
,
outputs
,
auto_grad
=
True
,
callback
=
None
):
if
self
.
_module_def
is
None
:
self
.
_gen_module_def
()
if
self
.
_module_def
is
None
:
self
.
_gen_module_def
()
meta
=
(
self
.
module_key
,
self
.
_module_def
)
return
RunOperator
(
return
RunOperator
(
inputs
,
outputs
,
meta
,
inputs
=
inputs
,
outputs
=
outputs
,
meta
=
(
self
.
module_key
,
self
.
_module_def
),
auto_grad
=
auto_grad
,
auto_grad
=
auto_grad
,
callback_on_run
=
callback
)
callback_on_run
=
callback
,
)
def
train
(
self
,
mode
=
True
):
def
train
(
self
,
mode
=
True
):
self
.
training
=
mode
self
.
training
=
mode
...
...
Dragon/python/dragon/vm/torch/ops/builtin.py
View file @
bd84b7f
...
@@ -17,7 +17,10 @@ from dragon.core import mpi
...
@@ -17,7 +17,10 @@ from dragon.core import mpi
from
dragon.vm.torch.tensor
import
Tensor
,
_LeafTensor
,
_Device
from
dragon.vm.torch.tensor
import
Tensor
,
_LeafTensor
,
_Device
from
dragon.vm.torch.ops.primitive
import
MakeDevice
,
WrapScalar
from
dragon.vm.torch.ops.primitive
import
MakeDevice
,
WrapScalar
from
dragon.vm.torch.ops.factory
import
get_module
from
dragon.vm.torch.ops.factory
import
get_module
from
dragon.vm.torch.ops.modules.control_flow
import
Compare
from
dragon.vm.torch.ops.modules.control_flow
import
(
Assign
,
MaskedAssign
,
Compare
)
from
dragon.vm.torch.ops.modules.arithmetic
import
(
from
dragon.vm.torch.ops.modules.arithmetic
import
(
Fundamental
,
Log
,
Exp
,
Sqrt
,
Fundamental
,
Log
,
Exp
,
Sqrt
,
...
@@ -32,9 +35,8 @@ from dragon.vm.torch.ops.modules.init import (
...
@@ -32,9 +35,8 @@ from dragon.vm.torch.ops.modules.init import (
from
dragon.vm.torch.ops.modules.array
import
(
from
dragon.vm.torch.ops.modules.array
import
(
Reshape
,
Squeeze
,
UnSqueeze
,
Permute
,
Reshape
,
Squeeze
,
UnSqueeze
,
Permute
,
Indexing
,
Assigning
,
Indexing
,
IndexSelect
,
Repeat
,
Concat
,
Stack
,
Repeat
,
Concat
,
Stack
,
IndexSelect
,
Reduce
,
ArgReduce
,
OneHot
,
Multinomial
,
Reduce
,
ArgReduce
,
OneHot
,
Multinomial
,
)
)
...
@@ -48,8 +50,8 @@ from dragon.vm.torch.ops.modules.vision import (
...
@@ -48,8 +50,8 @@ from dragon.vm.torch.ops.modules.vision import (
__all__
=
[
__all__
=
[
'add'
,
'sub'
,
'mul'
,
'div'
,
'accumulate'
,
'accumulate'
,
'add'
,
'sub'
,
'mul'
,
'div'
,
'maximum'
,
'minimum'
,
'clamp'
,
'maximum'
,
'minimum'
,
'clamp'
,
'log'
,
'exp'
,
'sqrt'
,
'log'
,
'exp'
,
'sqrt'
,
'mm'
,
'xw_plus_b'
,
'mm'
,
'xw_plus_b'
,
...
@@ -59,9 +61,12 @@ __all__ = [
...
@@ -59,9 +61,12 @@ __all__ = [
'gt'
,
'lt'
,
'eq'
,
'ge'
,
'le'
,
'gt'
,
'lt'
,
'eq'
,
'ge'
,
'le'
,
'cat'
,
'stack'
,
'narrow'
,
'cat'
,
'stack'
,
'narrow'
,
'index_select'
,
'index_select'
,
'one_hot'
,
'multinomial'
,
'rand'
,
'randn'
,
'one_hot'
,
'multinomial'
,
'zeros'
,
'zeros_like'
,
'ones'
,
'ones_like'
,
'rand'
,
'randn'
,
'nn_resize'
,
'bilinear_resize'
,
'roi_pool'
,
'roi_align'
,
'ones'
,
'ones_like'
,
'zeros'
,
'zeros_like'
,
'nn_resize'
,
'bilinear_resize'
,
'roi_pool'
,
'roi_align'
,
]
]
...
@@ -409,52 +414,64 @@ def xw_plus_b(x, w, bias=None, transW=True, out=None):
...
@@ -409,52 +414,64 @@ def xw_plus_b(x, w, bias=None, transW=True, out=None):
def
_reshape
(
input
,
shape
,
shape_like
=
None
):
def
_reshape
(
input
,
shape
,
shape_like
=
None
):
if
shape_like
is
not
None
:
shape
=
shape_like
.
shape
if
shape_like
is
not
None
:
shape
=
shape_like
.
shape
dev
=
MakeDevice
(
inputs
=
[
input
]);
n
_
dim
=
len
(
shape
)
dev
=
MakeDevice
(
inputs
=
[
input
]);
ndim
=
len
(
shape
)
key
=
'Reshape/{}/n
_dim:{}'
.
format
(
dev
,
n_
dim
)
key
=
'Reshape/{}/n
dim:{}'
.
format
(
dev
,
n
dim
)
module
=
get_module
(
Reshape
,
key
,
dev
,
n
_dim
=
n_
dim
)
module
=
get_module
(
Reshape
,
key
,
dev
,
n
dim
=
n
dim
)
return
module
.
forward
(
input
,
shape
)
return
module
.
forward
(
input
,
shape
)
def
_permute
(
input
,
perm
):
def
_permute
(
input
,
perm
):
dev
=
MakeDevice
(
inputs
=
[
input
]);
n
_
perm
=
len
(
perm
)
dev
=
MakeDevice
(
inputs
=
[
input
]);
nperm
=
len
(
perm
)
key
=
'Permute/{}/n
_perm:{}'
.
format
(
dev
,
n_
perm
)
key
=
'Permute/{}/n
perm:{}'
.
format
(
dev
,
n
perm
)
module
=
get_module
(
Permute
,
key
,
dev
,
n
_perm
=
n_
perm
)
module
=
get_module
(
Permute
,
key
,
dev
,
n
perm
=
n
perm
)
return
module
.
forward
(
input
,
perm
)
return
module
.
forward
(
input
,
perm
)
def
_repeat
(
input
,
times
):
def
_repeat
(
input
,
times
):
dev
=
MakeDevice
(
inputs
=
[
input
]);
n
_
times
=
len
(
times
)
dev
=
MakeDevice
(
inputs
=
[
input
]);
ntimes
=
len
(
times
)
key
=
'Repeat/{}/n
_times:{}'
.
format
(
dev
,
n_
times
)
key
=
'Repeat/{}/n
times:{}'
.
format
(
dev
,
n
times
)
module
=
get_module
(
Repeat
,
key
,
dev
,
n
_times
=
n_
times
)
module
=
get_module
(
Repeat
,
key
,
dev
,
n
times
=
n
times
)
return
module
.
forward
(
input
,
times
)
return
module
.
forward
(
input
,
times
)
def
_fill
(
input
,
shape
,
value
):
def
_fill
(
input
,
shape
,
value
):
dev
=
MakeDevice
(
inputs
=
[
input
]);
n_dim
=
len
(
shape
)
dev
=
MakeDevice
(
inputs
=
[
input
]);
ndim
=
len
(
shape
)
key
=
'Fill/{}/dtype:{}/n_dim:{}/value:{}'
.
format
(
key
=
'Fill/{}/dtype:{}/ndim:{}/value:{}'
\
dev
,
input
.
dtype
,
n_dim
,
value
)
.
format
(
dev
,
input
.
dtype
,
ndim
,
value
)
module
=
get_module
(
Fill
,
key
,
dev
,
n_dim
=
n_dim
,
module
=
get_module
(
value
=
value
,
dtype
=
input
.
dtype
)
Fill
,
key
,
dev
,
ndim
=
ndim
,
value
=
value
,
dtype
=
input
.
dtype
,
)
return
module
.
forward
(
input
,
shape
)
return
module
.
forward
(
input
,
shape
)
def
_uniform
(
input
,
shape
,
low
,
high
):
def
_uniform
(
input
,
shape
,
low
,
high
):
dev
=
MakeDevice
(
inputs
=
[
input
]);
n
_
dim
=
len
(
shape
)
dev
=
MakeDevice
(
inputs
=
[
input
]);
ndim
=
len
(
shape
)
key
=
'Uniform/{}/dtype:{}/n
_
dim:{}/low:{}/high:{}'
.
format
(
key
=
'Uniform/{}/dtype:{}/ndim:{}/low:{}/high:{}'
.
format
(
dev
,
input
.
dtype
,
n
_
dim
,
float
(
low
),
float
(
high
))
dev
,
input
.
dtype
,
ndim
,
float
(
low
),
float
(
high
))
module
=
get_module
(
module
=
get_module
(
RandomUniform
,
key
,
dev
,
n_dim
=
n_dim
,
RandomUniform
,
key
,
dev
,
low
=
low
,
high
=
high
,
dtype
=
input
.
dtype
)
ndim
=
ndim
,
low
=
low
,
high
=
high
,
dtype
=
input
.
dtype
,
)
return
module
.
forward
(
input
,
shape
)
return
module
.
forward
(
input
,
shape
)
def
_normal
(
input
,
shape
,
mean
,
std
):
def
_normal
(
input
,
shape
,
mean
,
std
):
dev
=
MakeDevice
(
inputs
=
[
input
]);
n
_
dim
=
len
(
shape
)
dev
=
MakeDevice
(
inputs
=
[
input
]);
ndim
=
len
(
shape
)
key
=
'Normal/{}/dtype:{}/n
_
dim:{}/mean:{}/std:{}'
.
format
(
key
=
'Normal/{}/dtype:{}/ndim:{}/mean:{}/std:{}'
.
format
(
dev
,
input
.
dtype
,
n
_
dim
,
float
(
mean
),
float
(
std
))
dev
,
input
.
dtype
,
ndim
,
float
(
mean
),
float
(
std
))
module
=
get_module
(
module
=
get_module
(
RandomNormal
,
key
,
dev
,
n_dim
=
n_dim
,
RandomNormal
,
key
,
dev
,
mean
=
mean
,
std
=
std
,
dtype
=
input
.
dtype
)
ndim
=
ndim
,
mean
=
mean
,
std
=
std
,
dtype
=
input
.
dtype
,
)
return
module
.
forward
(
input
,
shape
)
return
module
.
forward
(
input
,
shape
)
...
@@ -464,44 +481,62 @@ def _reduce(input, operation, dim=None, keepdim=False, out=None):
...
@@ -464,44 +481,62 @@ def _reduce(input, operation, dim=None, keepdim=False, out=None):
key
=
'{}/{}/dim:{}/keepdim:{}'
.
format
(
key
=
'{}/{}/dim:{}/keepdim:{}'
.
format
(
operation
,
dev
,
dim
,
int
(
keepdim
))
operation
,
dev
,
dim
,
int
(
keepdim
))
module
=
get_module
(
module
=
get_module
(
Reduce
,
key
,
dev
,
operation
=
operation
,
Reduce
,
key
,
dev
,
dim
=
dim
,
keepdim
=
keepdim
)
dim
=
dim
,
keepdim
=
keepdim
,
operation
=
operation
,
)
return
module
.
forward
(
input
,
out
)
return
module
.
forward
(
input
,
out
)
def
_arg_reduce
(
input
,
operation
,
dim
=
None
,
keepdim
=
False
,
top
_
k
=
1
,
out
=
None
):
def
_arg_reduce
(
input
,
operation
,
dim
=
None
,
keepdim
=
False
,
topk
=
1
,
out
=
None
):
if
dim
is
None
:
keepdim
=
False
if
dim
is
None
:
keepdim
=
False
dev
=
MakeDevice
(
inputs
=
[
input
])
dev
=
MakeDevice
(
inputs
=
[
input
])
key
=
'{}/{}/dim:{}/keepdim:{}/top
_
k:{}'
.
format
(
key
=
'{}/{}/dim:{}/keepdim:{}/topk:{}'
.
format
(
operation
,
dev
,
dim
,
int
(
keepdim
),
top
_
k
)
operation
,
dev
,
dim
,
int
(
keepdim
),
topk
)
module
=
get_module
(
module
=
get_module
(
ArgReduce
,
key
,
dev
,
ArgReduce
,
key
,
dev
,
operation
=
operation
,
axis
=
dim
,
axis
=
dim
,
keepdim
=
keepdim
,
top_k
=
top_k
)
topk
=
topk
,
keepdim
=
keepdim
,
operation
=
operation
,
)
return
module
.
forward
(
input
,
out
)
return
module
.
forward
(
input
,
out
)
def
_index
ing
(
input
,
starts
,
sizes
):
def
_index
(
input
,
starts
,
sizes
):
n
_starts
,
n_
sizes
=
len
(
starts
),
len
(
sizes
)
n
starts
,
n
sizes
=
len
(
starts
),
len
(
sizes
)
dev
=
MakeDevice
(
inputs
=
[
input
])
dev
=
MakeDevice
(
inputs
=
[
input
])
key
=
'Index/{}/n
_starts:{}/n_sizes:{}'
.
format
(
dev
,
n_starts
,
n_
sizes
)
key
=
'Index/{}/n
starts:{}/nsizes:{}'
.
format
(
dev
,
nstarts
,
n
sizes
)
module
=
get_module
(
Indexing
,
key
,
dev
,
n
_starts
=
n_starts
,
n_sizes
=
n_
sizes
)
module
=
get_module
(
Indexing
,
key
,
dev
,
n
starts
=
nstarts
,
nsizes
=
n
sizes
)
return
module
.
forward
(
input
,
starts
,
sizes
)
return
module
.
forward
(
input
,
starts
,
sizes
)
def
_assign
ing
(
output
,
input
,
starts
,
sizes
):
def
_assign
(
output
,
starts
,
sizes
,
input
):
if
not
isinstance
(
input
,
Tensor
):
if
not
isinstance
(
input
,
Tensor
):
if
isinstance
(
input
,
(
tuple
,
list
)):
if
isinstance
(
input
,
(
tuple
,
list
)):
input
=
Tensor
(
input
,
dtype
=
output
.
dtype
,
device
=
output
.
device
)
input
=
Tensor
(
input
,
dtype
=
output
.
dtype
,
device
=
output
.
device
)
else
:
else
:
input
=
WrapScalar
(
input
,
output
.
dtype
,
output
.
device
)
input
=
WrapScalar
(
input
,
output
.
dtype
,
output
.
device
)
n
_starts
,
n_
sizes
=
len
(
starts
),
len
(
sizes
)
n
starts
,
n
sizes
=
len
(
starts
),
len
(
sizes
)
dev
=
MakeDevice
(
inputs
=
[
input
])
dev
=
MakeDevice
(
inputs
=
[
input
])
key
=
'Assign/{}/n
_starts:{}/n_sizes:{}'
.
format
(
dev
,
n_starts
,
n_
sizes
)
key
=
'Assign/{}/n
starts:{}/nsizes:{}'
.
format
(
dev
,
nstarts
,
n
sizes
)
module
=
get_module
(
Assign
ing
,
key
,
dev
,
n_starts
=
n_starts
,
n_sizes
=
n_
sizes
)
module
=
get_module
(
Assign
,
key
,
dev
,
nstarts
=
nstarts
,
nsizes
=
n
sizes
)
return
module
.
forward
(
input
,
output
,
starts
,
sizes
)
return
module
.
forward
(
input
,
output
,
starts
,
sizes
)
def
_masked_assign
(
output
,
mask
,
input
):
if
not
isinstance
(
input
,
Tensor
):
if
isinstance
(
input
,
(
tuple
,
list
)):
input
=
Tensor
(
input
,
dtype
=
output
.
dtype
,
device
=
output
.
device
)
else
:
input
=
WrapScalar
(
input
,
output
.
dtype
,
output
.
device
)
dev
=
MakeDevice
(
inputs
=
[
input
])
key
=
'MaskedAssign/{}'
.
format
(
dev
)
module
=
get_module
(
MaskedAssign
,
key
,
dev
)
return
module
.
forward
(
input
,
output
,
mask
)
def
_compare
(
input
,
other
,
operation
,
out
=
None
):
def
_compare
(
input
,
other
,
operation
,
out
=
None
):
if
not
isinstance
(
other
,
Tensor
):
if
not
isinstance
(
other
,
Tensor
):
other
=
WrapScalar
(
other
,
input
.
dtype
,
input
.
device
)
other
=
WrapScalar
(
other
,
input
.
dtype
,
input
.
device
)
...
@@ -927,7 +962,7 @@ def narrow(input, dimension, start, length):
...
@@ -927,7 +962,7 @@ def narrow(input, dimension, start, length):
"""
"""
sizes
=
list
(
input
.
shape
[:]);
starts
=
[
0
]
*
len
(
sizes
)
sizes
=
list
(
input
.
shape
[:]);
starts
=
[
0
]
*
len
(
sizes
)
starts
[
dimension
],
sizes
[
dimension
]
=
start
,
length
starts
[
dimension
],
sizes
[
dimension
]
=
start
,
length
return
_index
ing
(
input
,
starts
,
sizes
)
return
_index
(
input
,
starts
,
sizes
)
def
one_hot
(
input
,
depth
):
def
one_hot
(
input
,
depth
):
...
@@ -1159,8 +1194,13 @@ def _update(
...
@@ -1159,8 +1194,13 @@ def _update(
):
):
dev
=
MakeDevice
(
inputs
=
[
param
])
dev
=
MakeDevice
(
inputs
=
[
param
])
key
=
'{}/{}/{}/{}'
.
format
(
op_type
,
dev
,
slot
,
param
.
name
)
key
=
'{}/{}/{}/{}'
.
format
(
op_type
,
dev
,
slot
,
param
.
name
)
module
=
get_module
(
Update
,
key
,
dev
,
op_type
=
op_type
,
module
=
get_module
(
lr_mult
=
lr_mult
,
decay_mult
=
decay_mult
,
slot
=
slot
)
Update
,
key
,
dev
,
op_type
=
op_type
,
lr_mult
=
lr_mult
,
decay_mult
=
decay_mult
,
slot
=
slot
,
)
return
module
.
forward
(
param
,
grad
)
return
module
.
forward
(
param
,
grad
)
...
@@ -1183,8 +1223,12 @@ def _resize_2d(input, op_type, dsize, fx, fy):
...
@@ -1183,8 +1223,12 @@ def _resize_2d(input, op_type, dsize, fx, fy):
dev
=
MakeDevice
(
inputs
=
[
input
])
dev
=
MakeDevice
(
inputs
=
[
input
])
key
=
'{}/{}/dsize:{}/fx:{}/fy:{}'
.
format
(
key
=
'{}/{}/dsize:{}/fx:{}/fy:{}'
.
format
(
op_type
,
dev
,
'2'
if
dsize
else
'none'
,
fx
,
fy
)
op_type
,
dev
,
'2'
if
dsize
else
'none'
,
fx
,
fy
)
module
=
get_module
(
Resize2d
,
key
,
dev
,
module
=
get_module
(
op_type
=
op_type
,
dsize
=
dsize
,
fx
=
fx
,
fy
=
fy
)
Resize2d
,
key
,
dev
,
dsize
=
dsize
,
fx
=
fx
,
fy
=
fy
,
op_type
=
op_type
,
)
return
module
.
forward
(
input
,
dsize
)
return
module
.
forward
(
input
,
dsize
)
...
...
Dragon/python/dragon/vm/torch/ops/modules/array.py
View file @
bd84b7f
...
@@ -27,8 +27,8 @@ class Indexing(BaseModule):
...
@@ -27,8 +27,8 @@ class Indexing(BaseModule):
"""
"""
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Indexing
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Indexing
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
n
_starts
=
kwargs
.
get
(
'n_
starts'
,
0
)
self
.
n
starts
=
kwargs
.
get
(
'n
starts'
,
0
)
self
.
n
_sizes
=
kwargs
.
get
(
'n_
sizes'
,
0
)
self
.
n
sizes
=
kwargs
.
get
(
'n
sizes'
,
0
)
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
...
@@ -37,61 +37,25 @@ class Indexing(BaseModule):
...
@@ -37,61 +37,25 @@ class Indexing(BaseModule):
'arguments'
:
{
'arguments'
:
{
'starts_desc'
:
[
'starts_desc'
:
[
'${{ANCHOR}}/starts[{}]'
.
format
(
n
)
'${{ANCHOR}}/starts[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
starts
)],
for
n
in
range
(
self
.
nstarts
)],
'sizes_desc'
:
[
'sizes_desc'
:
[
'${{ANCHOR}}/sizes[{}]'
.
format
(
n
)
'${{ANCHOR}}/sizes[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
sizes
)],
for
n
in
range
(
self
.
nsizes
)],
},
},
}
}
def
update_arg
ument
s
(
self
,
A
,
starts
,
sizes
):
def
update_args
(
self
,
A
,
starts
,
sizes
):
for
i
,
e
in
enumerate
(
starts
):
for
i
,
e
in
enumerate
(
starts
):
self
.
set_arg
ument
_i64
(
'{}/starts[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_arg_i64
(
'{}/starts[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_arg
ument
_i64
(
'{}/sizes[{}]'
.
format
(
A
,
i
),
sizes
[
i
])
self
.
set_arg_i64
(
'{}/sizes[{}]'
.
format
(
A
,
i
),
sizes
[
i
])
def
forward
(
self
,
x
,
starts
,
sizes
):
def
forward
(
self
,
x
,
starts
,
sizes
):
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
outputs
=
[
self
.
register_output
()]
callback
=
lambda
A
:
self
.
update_arg
ument
s
(
A
,
starts
,
sizes
)
callback
=
lambda
A
:
self
.
update_args
(
A
,
starts
,
sizes
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
class
Assigning
(
BaseModule
):
"""This module imports the *AssignOp* from backend.
Arbitrary length of starts and sizes will be take.
"""
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Assigning
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
n_starts
=
kwargs
.
get
(
'n_starts'
,
0
)
self
.
n_sizes
=
kwargs
.
get
(
'n_sizes'
,
0
)
self
.
register_op
()
def
register_op
(
self
):
self
.
op_meta
=
{
'op_type'
:
'Assign'
,
'arguments'
:
{
'starts_desc'
:
[
'${{ANCHOR}}/starts[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n_starts
)],
'sizes_desc'
:
[
'${{ANCHOR}}/sizes[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n_sizes
)],
},
}
def
update_arguments
(
self
,
A
,
starts
,
sizes
):
for
i
,
e
in
enumerate
(
starts
):
self
.
set_argument_i64
(
'{}/starts[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_argument_i64
(
'{}/sizes[{}]'
.
format
(
A
,
i
),
sizes
[
i
])
def
forward
(
self
,
x
,
y
,
starts
,
sizes
):
self
.
unify_devices
([
x
,
y
])
callback
=
lambda
A
:
self
.
update_arguments
(
A
,
starts
,
sizes
)
return
self
.
run
([
x
],
[
y
],
callback
=
callback
,
auto_grad
=
False
)
class
Concat
(
BaseModule
):
class
Concat
(
BaseModule
):
"""This module imports the *ConcatOp* from backend.
"""This module imports the *ConcatOp* from backend.
...
@@ -200,18 +164,19 @@ class ArgReduce(BaseModule):
...
@@ -200,18 +164,19 @@ class ArgReduce(BaseModule):
self
.
operation
=
kwargs
.
get
(
'operation'
,
'ARGMAX'
)
self
.
operation
=
kwargs
.
get
(
'operation'
,
'ARGMAX'
)
self
.
axis
=
kwargs
.
get
(
'axis'
,
None
)
self
.
axis
=
kwargs
.
get
(
'axis'
,
None
)
self
.
keepdim
=
kwargs
.
get
(
'keepdim'
,
True
)
self
.
keepdim
=
kwargs
.
get
(
'keepdim'
,
True
)
self
.
top
_k
=
kwargs
.
get
(
'top_
k'
,
1
)
self
.
top
k
=
kwargs
.
get
(
'top
k'
,
1
)
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'ArgReduce'
,
'op_type'
:
'ArgReduce'
,
'arguments'
:
{
'arguments'
:
{
'operation'
:
self
.
operation
if
'ARG'
in
self
.
operation
\
'operation'
:
self
.
operation
if
'ARG'
in
self
.
operation
\
else
'ARG'
+
self
.
operation
,
else
'ARG'
+
self
.
operation
,
'axis'
:
self
.
axis
if
self
.
axis
else
2147483647
,
'axis'
:
self
.
axis
if
self
.
axis
else
2147483647
,
'keep_dims'
:
self
.
keepdim
,
'keep_dims'
:
self
.
keepdim
,
'top_k'
:
self
.
top
_
k
,
'top_k'
:
self
.
topk
,
},
},
}
}
...
@@ -241,7 +206,7 @@ class ArgReduce(BaseModule):
...
@@ -241,7 +206,7 @@ class ArgReduce(BaseModule):
class
Reshape
(
BaseModule
):
class
Reshape
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Reshape
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Reshape
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
n
_dim
=
kwargs
.
get
(
'n_
dim'
,
0
)
self
.
n
dim
=
kwargs
.
get
(
'n
dim'
,
0
)
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
...
@@ -250,19 +215,19 @@ class Reshape(BaseModule):
...
@@ -250,19 +215,19 @@ class Reshape(BaseModule):
'arguments'
:
{
'arguments'
:
{
'dims_desc'
:
[
'dims_desc'
:
[
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
dim
)
for
n
in
range
(
self
.
ndim
)
],
],
},
},
}
}
def
update_arg
ument
s
(
self
,
A
,
shape
):
def
update_args
(
self
,
A
,
shape
):
for
i
,
e
in
enumerate
(
shape
):
for
i
,
e
in
enumerate
(
shape
):
self
.
set_arg
ument
_i64
(
'{}/dims[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_arg_i64
(
'{}/dims[{}]'
.
format
(
A
,
i
),
e
)
def
forward
(
self
,
x
,
shape
):
def
forward
(
self
,
x
,
shape
):
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
outputs
=
[
_ReferenceTensor
(
x
)]
outputs
=
[
_ReferenceTensor
(
x
)]
callback
=
lambda
A
:
self
.
update_arg
ument
s
(
A
,
shape
)
callback
=
lambda
A
:
self
.
update_args
(
A
,
shape
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
...
@@ -275,7 +240,9 @@ class Squeeze(BaseModule):
...
@@ -275,7 +240,9 @@ class Squeeze(BaseModule):
def
register_op
(
self
):
def
register_op
(
self
):
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'Squeeze'
,
'op_type'
:
'Squeeze'
,
'arguments'
:
{
'axis'
:
self
.
dim
},
'arguments'
:
{
'axis'
:
self
.
dim
,
},
}
}
def
forward
(
self
,
x
,
out
=
None
):
def
forward
(
self
,
x
,
out
=
None
):
...
@@ -293,7 +260,9 @@ class UnSqueeze(BaseModule):
...
@@ -293,7 +260,9 @@ class UnSqueeze(BaseModule):
def
register_op
(
self
):
def
register_op
(
self
):
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'ExpandDims'
,
'op_type'
:
'ExpandDims'
,
'arguments'
:
{
'axis'
:
self
.
dim
},
'arguments'
:
{
'axis'
:
self
.
dim
,
},
}
}
def
forward
(
self
,
x
,
out
=
None
):
def
forward
(
self
,
x
,
out
=
None
):
...
@@ -305,7 +274,7 @@ class UnSqueeze(BaseModule):
...
@@ -305,7 +274,7 @@ class UnSqueeze(BaseModule):
class
Permute
(
BaseModule
):
class
Permute
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Permute
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Permute
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
n
_perm
=
kwargs
.
get
(
'n_
perm'
,
0
)
self
.
n
perm
=
kwargs
.
get
(
'n
perm'
,
0
)
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
...
@@ -313,26 +282,26 @@ class Permute(BaseModule):
...
@@ -313,26 +282,26 @@ class Permute(BaseModule):
'op_type'
:
'Transpose'
,
'op_type'
:
'Transpose'
,
'arguments'
:
{
'arguments'
:
{
'perm_desc'
:
[
'${{ANCHOR}}/perm[{}]'
.
format
(
n
)
'perm_desc'
:
[
'${{ANCHOR}}/perm[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
perm
)],
for
n
in
range
(
self
.
nperm
)],
},
},
}
}
def
update_arg
ument
s
(
self
,
A
,
perm
):
def
update_args
(
self
,
A
,
perm
):
if
perm
:
if
perm
:
for
i
,
e
in
enumerate
(
perm
):
for
i
,
e
in
enumerate
(
perm
):
self
.
set_arg
ument
_i64
(
'{}/perm[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_arg_i64
(
'{}/perm[{}]'
.
format
(
A
,
i
),
e
)
def
forward
(
self
,
x
,
perm
):
def
forward
(
self
,
x
,
perm
):
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
outputs
=
[
self
.
register_output
()]
callback
=
lambda
A
:
self
.
update_arg
ument
s
(
A
,
perm
)
callback
=
lambda
A
:
self
.
update_args
(
A
,
perm
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
class
Repeat
(
BaseModule
):
class
Repeat
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Repeat
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Repeat
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
n
_times
=
kwargs
.
get
(
'n_
times'
,
0
)
self
.
n
times
=
kwargs
.
get
(
'n
times'
,
0
)
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
...
@@ -341,19 +310,19 @@ class Repeat(BaseModule):
...
@@ -341,19 +310,19 @@ class Repeat(BaseModule):
'arguments'
:
{
'arguments'
:
{
'multiples_desc'
:
[
'multiples_desc'
:
[
'${{ANCHOR}}/multiples[{}]'
.
format
(
n
)
'${{ANCHOR}}/multiples[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
times
)
for
n
in
range
(
self
.
ntimes
)
],
],
},
},
}
}
def
update_arg
ument
s
(
self
,
A
,
times
):
def
update_args
(
self
,
A
,
times
):
for
i
,
d
in
enumerate
(
times
):
for
i
,
d
in
enumerate
(
times
):
self
.
set_arg
ument
_i64
(
'{}/multiples[{}]'
.
format
(
A
,
i
),
d
)
self
.
set_arg_i64
(
'{}/multiples[{}]'
.
format
(
A
,
i
),
d
)
def
forward
(
self
,
x
,
times
):
def
forward
(
self
,
x
,
times
):
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
inputs
=
[
x
];
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
outputs
=
[
self
.
register_output
()]
callback
=
lambda
A
:
self
.
update_arg
ument
s
(
A
,
times
)
callback
=
lambda
A
:
self
.
update_args
(
A
,
times
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
...
@@ -409,7 +378,6 @@ class Multinomial(BaseModule):
...
@@ -409,7 +378,6 @@ class Multinomial(BaseModule):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Multinomial
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Multinomial
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
num_samples
=
kwargs
.
get
(
'num_samples'
,
1
)
self
.
num_samples
=
kwargs
.
get
(
'num_samples'
,
1
)
self
.
normalize
=
kwargs
.
get
(
'normalize'
,
False
)
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
...
@@ -417,7 +385,7 @@ class Multinomial(BaseModule):
...
@@ -417,7 +385,7 @@ class Multinomial(BaseModule):
'op_type'
:
'Multinomial'
,
'op_type'
:
'Multinomial'
,
'arguments'
:
{
'arguments'
:
{
'num_samples'
:
self
.
num_samples
,
'num_samples'
:
self
.
num_samples
,
'normalize'
:
self
.
normaliz
e
,
'normalize'
:
Fals
e
,
},
},
}
}
...
...
Dragon/python/dragon/vm/torch/ops/modules/base.py
View file @
bd84b7f
...
@@ -14,9 +14,9 @@ from __future__ import division
...
@@ -14,9 +14,9 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
numpy
import
numpy
from
dragon.core
import
proto_utils
as
_proto_utils
from
dragon.core
import
workspace
as
_workspace
from
dragon.core
import
workspace
as
_workspace
from
dragon.core
import
proto_utils
as
_proto_utils
from
dragon.vm.torch.module
import
Module
from
dragon.vm.torch.module
import
Module
...
@@ -25,10 +25,14 @@ class BaseModule(Module):
...
@@ -25,10 +25,14 @@ class BaseModule(Module):
super
(
BaseModule
,
self
)
.
__init__
()
super
(
BaseModule
,
self
)
.
__init__
()
self
.
_module_key
=
key
self
.
_module_key
=
key
self
.
_device
=
dev
self
.
_device
=
dev
self
.
_args_dev
=
_proto_utils
.
\
self
.
_arg_dev
=
_proto_utils
\
GetDeviceOption
(
'cpu'
)
.
SerializeToString
()
.
GetDeviceOption
(
'cpu'
)
\
.
SerializeToString
()
def
set_argument_i64
(
self
,
name
,
value
):
def
set_arg_i64
(
self
,
name
,
value
):
_workspace
.
get_default_workspace
()
\
_workspace
.
get_default_workspace
()
\
.
FeedTensor
(
name
,
numpy
.
array
(
.
FeedTensor
(
value
,
dtype
=
numpy
.
int64
),
self
.
_args_dev
)
name
,
\ No newline at end of file
numpy
.
array
(
value
,
'int64'
),
self
.
_arg_dev
,
)
\ No newline at end of file
Dragon/python/dragon/vm/torch/ops/modules/control_flow.py
View file @
bd84b7f
...
@@ -47,3 +47,52 @@ class Compare(BaseModule):
...
@@ -47,3 +47,52 @@ class Compare(BaseModule):
inputs
=
[
x1
,
x2
];
self
.
unify_devices
(
inputs
)
inputs
=
[
x1
,
x2
];
self
.
unify_devices
(
inputs
)
outputs
=
[
y
]
if
y
else
[
self
.
register_output
()]
outputs
=
[
y
]
if
y
else
[
self
.
register_output
()]
return
self
.
run
(
inputs
,
outputs
)
return
self
.
run
(
inputs
,
outputs
)
class
Assign
(
BaseModule
):
"""This module imports the *AssignOp* from backend.
Arbitrary length of starts and sizes will be take.
"""
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Assign
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
nstarts
=
kwargs
.
get
(
'nstarts'
,
0
)
self
.
nsizes
=
kwargs
.
get
(
'nsizes'
,
0
)
self
.
register_op
()
def
register_op
(
self
):
self
.
op_meta
=
{
'op_type'
:
'Assign'
,
'arguments'
:
{
'starts_desc'
:
[
'${{ANCHOR}}/starts[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
nstarts
)],
'sizes_desc'
:
[
'${{ANCHOR}}/sizes[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
nsizes
)],
},
}
def
update_args
(
self
,
A
,
starts
,
sizes
):
for
i
,
e
in
enumerate
(
starts
):
self
.
set_arg_i64
(
'{}/starts[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_arg_i64
(
'{}/sizes[{}]'
.
format
(
A
,
i
),
sizes
[
i
])
def
forward
(
self
,
x
,
y
,
starts
,
sizes
):
self
.
unify_devices
([
x
,
y
])
callback
=
lambda
A
:
self
.
update_args
(
A
,
starts
,
sizes
)
return
self
.
run
([
x
],
[
y
],
callback
=
callback
,
auto_grad
=
False
)
class
MaskedAssign
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
MaskedAssign
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
register_op
()
def
register_op
(
self
):
self
.
op_meta
=
{
'op_type'
:
'MaskedAssign'
,
'arguments'
:
{}}
def
forward
(
self
,
x
,
y
,
mask
):
self
.
unify_devices
([
x
,
y
])
return
self
.
run
([
x
,
mask
],
[
y
])
\ No newline at end of file
Dragon/python/dragon/vm/torch/ops/modules/init.py
View file @
bd84b7f
...
@@ -19,16 +19,16 @@ from dragon.vm.torch.ops.modules.base import BaseModule
...
@@ -19,16 +19,16 @@ from dragon.vm.torch.ops.modules.base import BaseModule
class
_InitModule
(
BaseModule
):
class
_InitModule
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
_InitModule
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
_InitModule
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
n
_dim
=
kwargs
.
get
(
'n_
dim'
,
0
)
self
.
n
dim
=
kwargs
.
get
(
'n
dim'
,
0
)
self
.
dtype
=
kwargs
.
get
(
'dtype'
,
'float32'
)
self
.
dtype
=
kwargs
.
get
(
'dtype'
,
'float32'
)
def
update_arg
ument
s
(
self
,
A
,
shape
):
def
update_args
(
self
,
A
,
shape
):
for
i
,
e
in
enumerate
(
shape
):
for
i
,
e
in
enumerate
(
shape
):
self
.
set_arg
ument
_i64
(
'{}/dims[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_arg_i64
(
'{}/dims[{}]'
.
format
(
A
,
i
),
e
)
def
forward
(
self
,
x
,
shape
):
def
forward
(
self
,
x
,
shape
):
outputs
=
[
x
];
self
.
unify_devices
(
outputs
)
outputs
=
[
x
];
self
.
unify_devices
(
outputs
)
callback
=
lambda
A
:
self
.
update_arg
ument
s
(
A
,
shape
)
callback
=
lambda
A
:
self
.
update_args
(
A
,
shape
)
return
self
.
run
([],
outputs
,
callback
=
callback
)
return
self
.
run
([],
outputs
,
callback
=
callback
)
...
@@ -46,7 +46,7 @@ class Fill(_InitModule):
...
@@ -46,7 +46,7 @@ class Fill(_InitModule):
'value'
:
float
(
self
.
value
),
'value'
:
float
(
self
.
value
),
'dims_desc'
:
[
'dims_desc'
:
[
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
dim
)
for
n
in
range
(
self
.
ndim
)
],
],
},
},
}
}
...
@@ -68,7 +68,7 @@ class RandomNormal(_InitModule):
...
@@ -68,7 +68,7 @@ class RandomNormal(_InitModule):
'std'
:
float
(
self
.
std
),
'std'
:
float
(
self
.
std
),
'dims_desc'
:
[
'dims_desc'
:
[
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
dim
)
for
n
in
range
(
self
.
ndim
)
],
],
},
},
}
}
...
@@ -90,7 +90,7 @@ class RandomUniform(_InitModule):
...
@@ -90,7 +90,7 @@ class RandomUniform(_InitModule):
'high'
:
float
(
self
.
high
),
'high'
:
float
(
self
.
high
),
'dims_desc'
:
[
'dims_desc'
:
[
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
'${{ANCHOR}}/dims[{}]'
.
format
(
n
)
for
n
in
range
(
self
.
n
_
dim
)
for
n
in
range
(
self
.
ndim
)
],
],
},
},
}
}
\ No newline at end of file
Dragon/python/dragon/vm/torch/ops/modules/update.py
View file @
bd84b7f
...
@@ -13,7 +13,7 @@ from __future__ import absolute_import
...
@@ -13,7 +13,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
dragon.core.mpi
as
mpi
from
dragon.core
import
mpi
as
_
mpi
from
dragon.vm.torch.ops.modules.base
import
BaseModule
from
dragon.vm.torch.ops.modules.base
import
BaseModule
...
@@ -50,11 +50,13 @@ class Collective(BaseModule):
...
@@ -50,11 +50,13 @@ class Collective(BaseModule):
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
idx
,
group
=
mpi
.
AllowParallel
()
idx
,
group
=
_
mpi
.
AllowParallel
()
if
idx
==
-
1
:
if
idx
==
-
1
:
raise
RuntimeError
(
'The mpi node({}) dost not in '
raise
RuntimeError
(
'parallel groups.
\n
Set it using mpi.Parallel([..]).'
.
format
(
mpi
.
Rank
()))
'The mpi node({}) dost not in groups.
\n
'
mpi_comm
,
mpi_group
=
mpi
.
CreateGroup
(
root
=
group
[
0
],
incl
=
group
)
'Set it using mpi.Parallel([..]).'
.
format
(
_mpi
.
Rank
())
)
mpi_comm
,
mpi_group
=
_mpi
.
CreateGroup
(
root
=
group
[
0
],
incl
=
group
)
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'CollectiveUpdate'
,
'op_type'
:
'CollectiveUpdate'
,
'arguments'
:
{
'arguments'
:
{
...
@@ -78,7 +80,10 @@ class Accumulate(BaseModule):
...
@@ -78,7 +80,10 @@ class Accumulate(BaseModule):
def
register_op
(
self
):
def
register_op
(
self
):
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'Accumulate'
,
'op_type'
:
'Accumulate'
,
'arguments'
:
{
'alpha'
:
1.
,
'beta'
:
1.
},
'arguments'
:
{
'alpha'
:
1.
,
'beta'
:
1.
,
},
}
}
def
forward
(
self
,
grads
):
def
forward
(
self
,
grads
):
...
...
Dragon/python/dragon/vm/torch/ops/modules/vision.py
View file @
bd84b7f
...
@@ -19,10 +19,10 @@ from dragon.vm.torch.ops.modules.base import BaseModule
...
@@ -19,10 +19,10 @@ from dragon.vm.torch.ops.modules.base import BaseModule
class
Resize2d
(
BaseModule
):
class
Resize2d
(
BaseModule
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
def
__init__
(
self
,
key
,
dev
,
**
kwargs
):
super
(
Resize2d
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
super
(
Resize2d
,
self
)
.
__init__
(
key
,
dev
,
**
kwargs
)
self
.
op_type
=
kwargs
.
get
(
'op_type'
,
'NNResize'
)
self
.
dsize
=
kwargs
.
get
(
'dsize'
,
None
)
self
.
dsize
=
kwargs
.
get
(
'dsize'
,
None
)
self
.
fx
=
kwargs
.
get
(
'fx'
,
None
)
self
.
fx
=
kwargs
.
get
(
'fx'
,
None
)
self
.
fy
=
kwargs
.
get
(
'fy'
,
None
)
self
.
fy
=
kwargs
.
get
(
'fy'
,
None
)
self
.
op_type
=
kwargs
.
get
(
'op_type'
,
'NNResize'
)
self
.
register_op
()
self
.
register_op
()
def
register_op
(
self
):
def
register_op
(
self
):
...
@@ -38,15 +38,15 @@ class Resize2d(BaseModule):
...
@@ -38,15 +38,15 @@ class Resize2d(BaseModule):
},
},
}
}
def
update_arg
ument
s
(
self
,
A
,
dsize
):
def
update_args
(
self
,
A
,
dsize
):
if
self
.
dsize
:
if
self
.
dsize
:
for
i
,
e
in
enumerate
(
dsize
):
for
i
,
e
in
enumerate
(
dsize
):
self
.
set_arg
ument
_i64
(
'{}/dsize[{}]'
.
format
(
A
,
i
),
e
)
self
.
set_arg_i64
(
'{}/dsize[{}]'
.
format
(
A
,
i
),
e
)
def
forward
(
self
,
input
,
dsize
=
None
):
def
forward
(
self
,
input
,
dsize
=
None
):
inputs
=
[
input
];
self
.
unify_devices
(
inputs
)
inputs
=
[
input
];
self
.
unify_devices
(
inputs
)
outputs
=
[
self
.
register_output
()]
outputs
=
[
self
.
register_output
()]
callback
=
lambda
A
:
self
.
update_arg
ument
s
(
A
,
dsize
)
callback
=
lambda
A
:
self
.
update_args
(
A
,
dsize
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
return
self
.
run
(
inputs
,
outputs
,
callback
=
callback
)
...
@@ -62,7 +62,8 @@ class RoIPool(BaseModule):
...
@@ -62,7 +62,8 @@ class RoIPool(BaseModule):
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'ROIPool'
,
'op_type'
:
'ROIPool'
,
'arguments'
:
{
'arguments'
:
{
'pool_h'
:
self
.
pool_h
,
'pool_w'
:
self
.
pool_w
,
'pool_h'
:
self
.
pool_h
,
'pool_w'
:
self
.
pool_w
,
'spatial_scale'
:
self
.
spatial_scale
,
'spatial_scale'
:
self
.
spatial_scale
,
},
},
}
}
...
@@ -86,7 +87,8 @@ class RoIAlign(BaseModule):
...
@@ -86,7 +87,8 @@ class RoIAlign(BaseModule):
self
.
op_meta
=
{
self
.
op_meta
=
{
'op_type'
:
'ROIAlign'
,
'op_type'
:
'ROIAlign'
,
'arguments'
:
{
'arguments'
:
{
'pool_h'
:
self
.
pool_h
,
'pool_w'
:
self
.
pool_w
,
'pool_h'
:
self
.
pool_h
,
'pool_w'
:
self
.
pool_w
,
'spatial_scale'
:
self
.
spatial_scale
,
'spatial_scale'
:
self
.
spatial_scale
,
'sampling_ratio'
:
self
.
sampling_ratio
,
'sampling_ratio'
:
self
.
sampling_ratio
,
},
},
...
...
Dragon/python/dragon/vm/torch/ops/tensor.py
View file @
bd84b7f
...
@@ -23,9 +23,9 @@ from dragon.vm.torch.ops.builtin import (
...
@@ -23,9 +23,9 @@ from dragon.vm.torch.ops.builtin import (
_fundamental
,
_rfundamental
,
_fundamental
,
_rfundamental
,
log
,
exp
,
sqrt
,
clamp
,
log
,
exp
,
sqrt
,
clamp
,
_reshape
,
squeeze
,
unsqueeze
,
_reshape
,
squeeze
,
unsqueeze
,
_permute
,
_repeat
,
_permute
,
_repeat
,
narrow
,
_index
ing
,
_assigning
,
_index
,
index_select
,
narrow
,
index_select
,
_assign
,
_masked_assign
,
mean
,
sum
,
max
,
min
,
mean
,
sum
,
max
,
min
,
gt
,
lt
,
eq
,
ge
,
le
,
gt
,
lt
,
eq
,
ge
,
le
,
)
)
...
@@ -41,6 +41,7 @@ def _type_to(input, dtype='float32', inplace=False):
...
@@ -41,6 +41,7 @@ def _type_to(input, dtype='float32', inplace=False):
Tensor
.
fill_
=
lambda
self
,
value
:
_fill
(
self
,
self
.
shape
,
value
)
Tensor
.
fill_
=
lambda
self
,
value
:
_fill
(
self
,
self
.
shape
,
value
)
Tensor
.
masked_fill_
=
lambda
*
args
,
**
kwargs
:
_masked_assign
(
*
args
,
**
kwargs
)
Tensor
.
uniform_
=
lambda
self
,
low
=
0
,
high
=
1
:
_uniform
(
self
,
self
.
shape
,
low
,
high
)
Tensor
.
uniform_
=
lambda
self
,
low
=
0
,
high
=
1
:
_uniform
(
self
,
self
.
shape
,
low
,
high
)
Tensor
.
normal_
=
lambda
self
,
mean
=
0
,
std
=
1
:
_normal
(
self
,
self
.
shape
,
mean
,
std
)
Tensor
.
normal_
=
lambda
self
,
mean
=
0
,
std
=
1
:
_normal
(
self
,
self
.
shape
,
mean
,
std
)
Tensor
.
multinomial
=
lambda
*
args
,
**
kwargs
:
multinomial
(
*
args
,
**
kwargs
)
Tensor
.
multinomial
=
lambda
*
args
,
**
kwargs
:
multinomial
(
*
args
,
**
kwargs
)
...
@@ -85,8 +86,8 @@ Tensor.le = lambda *args, **kwargs: le(*args, **kwargs)
...
@@ -85,8 +86,8 @@ Tensor.le = lambda *args, **kwargs: le(*args, **kwargs)
Tensor
.
eq
=
lambda
*
args
,
**
kwargs
:
eq
(
*
args
,
**
kwargs
)
Tensor
.
eq
=
lambda
*
args
,
**
kwargs
:
eq
(
*
args
,
**
kwargs
)
Tensor
.
index_select
=
lambda
*
args
,
**
kwargs
:
index_select
(
*
args
,
**
kwargs
)
Tensor
.
index_select
=
lambda
*
args
,
**
kwargs
:
index_select
(
*
args
,
**
kwargs
)
Tensor
.
narrow
=
lambda
*
args
,
**
kwargs
:
narrow
(
*
args
,
**
kwargs
)
Tensor
.
narrow
=
lambda
*
args
,
**
kwargs
:
narrow
(
*
args
,
**
kwargs
)
Tensor
.
_index
ing
=
lambda
*
args
,
**
kwargs
:
_indexing
(
*
args
,
**
kwargs
)
Tensor
.
_index
=
lambda
*
args
,
**
kwargs
:
_index
(
*
args
,
**
kwargs
)
Tensor
.
_assign
ing
=
lambda
*
args
,
**
kwargs
:
_assigning
(
*
args
,
**
kwargs
)
Tensor
.
_assign
=
lambda
*
args
,
**
kwargs
:
_assign
(
*
args
,
**
kwargs
)
Tensor
.
half
=
lambda
self
:
_type_to
(
self
,
dtype
=
'float16'
,
inplace
=
False
)
Tensor
.
half
=
lambda
self
:
_type_to
(
self
,
dtype
=
'float16'
,
inplace
=
False
)
...
...
Dragon/python/dragon/vm/torch/tensor.py
View file @
bd84b7f
...
@@ -533,16 +533,16 @@ class Tensor(object):
...
@@ -533,16 +533,16 @@ class Tensor(object):
"""
"""
starts
,
sizes
=
self
.
_process_indices
(
item
)
starts
,
sizes
=
self
.
_process_indices
(
item
)
return
self
.
_index
ing
(
starts
,
sizes
)
return
self
.
_index
(
starts
,
sizes
)
def
__setitem__
(
self
,
key
,
value
):
def
__setitem__
(
self
,
key
,
value
):
"""Set the value at the specific indices.
"""Set the value at the specific indices.
Parameters
Parameters
----------
----------
key : int, slice
key : int, slice
or dragon.vm.torch.Tensor
The indices.
The indices.
value :
dragon.vm.torch.Tensor, number or sequence
value :
number, sequence or dragon.vm.torch.Tensor
The value.
The value.
Returns
Returns
...
@@ -550,8 +550,11 @@ class Tensor(object):
...
@@ -550,8 +550,11 @@ class Tensor(object):
None
None
"""
"""
if
isinstance
(
key
,
Tensor
):
return
self
.
masked_fill_
(
key
,
value
)
else
:
starts
,
sizes
=
self
.
_process_indices
(
key
)
starts
,
sizes
=
self
.
_process_indices
(
key
)
return
self
.
_assigning
(
value
,
starts
,
sizes
)
return
self
.
_assign
(
starts
,
sizes
,
value
)
def
__hash__
(
self
):
def
__hash__
(
self
):
return
id
(
self
)
return
id
(
self
)
...
@@ -886,7 +889,7 @@ class Tensor(object):
...
@@ -886,7 +889,7 @@ class Tensor(object):
return
self
return
self
def
fill_
(
self
,
value
):
def
fill_
(
self
,
value
):
"""Fill
s self tensor with the specified
value.
"""Fill
self with the given
value.
Parameters
Parameters
----------
----------
...
@@ -901,6 +904,24 @@ class Tensor(object):
...
@@ -901,6 +904,24 @@ class Tensor(object):
"""
"""
raise
NotImplementedError
(
'Refer torch.ops.tensor.fill_'
)
raise
NotImplementedError
(
'Refer torch.ops.tensor.fill_'
)
def
masked_fill_
(
self
,
mask
,
value
):
"""Fill self with the given value where ``mask`` is *1*.
Parameters
----------
mask : dragon.vm.torch.Tensor
The mask.
value : number
The value to fill.
Returns
-------
dragon.vm.torch.Tensor
The self.
"""
raise
NotImplementedError
(
'Refer torch.ops.tensor.masked_fill_'
)
def
zero_
(
self
):
def
zero_
(
self
):
"""Fills self tensor with zeros.
"""Fills self tensor with zeros.
...
...
Dragon/src/core/mixedmem.cc
View file @
bd84b7f
...
@@ -123,7 +123,7 @@ void MixedMemory::SwitchToCUDADevice(int device_id) {
...
@@ -123,7 +123,7 @@ void MixedMemory::SwitchToCUDADevice(int device_id) {
if
(
device_id
!=
ptr_device_
)
{
if
(
device_id
!=
ptr_device_
)
{
// Move the memory to another device
// Move the memory to another device
void
*
new_ptr_
=
nullptr
;
void
*
new_ptr_
=
nullptr
;
DeviceGuard
gurad
(
device_id
);
CUDA
DeviceGuard
gurad
(
device_id
);
new_ptr_
=
CUDAContext
::
New
(
nbytes_
);
new_ptr_
=
CUDAContext
::
New
(
nbytes_
);
CUDAContext
::
MemcpyEx
<
CUDAContext
,
CUDAContext
>
(
CUDAContext
::
MemcpyEx
<
CUDAContext
,
CUDAContext
>
(
nbytes_
,
new_ptr_
,
cuda_ptr_
,
ptr_device_
);
nbytes_
,
new_ptr_
,
cuda_ptr_
,
ptr_device_
);
...
...
Dragon/src/core/operator_schema.cc
View file @
bd84b7f
...
@@ -4,24 +4,31 @@ namespace dragon {
...
@@ -4,24 +4,31 @@ namespace dragon {
bool
OpSchema
::
Verify
(
const
OperatorDef
&
def
)
const
{
bool
OpSchema
::
Verify
(
const
OperatorDef
&
def
)
const
{
if
(
ignore_verify_
)
return
true
;
if
(
ignore_verify_
)
return
true
;
string
indicator
=
"["
+
def
.
name
()
+
", "
+
def
.
type
()
+
"]
\n
"
;
auto
header
=
"["
+
def
.
name
()
+
", "
+
def
.
type
()
+
"]
\n
"
;
if
(
def
.
input_size
()
<
min_input_
||
def
.
input_size
()
>
max_input_
)
{
if
(
def
.
input_size
()
<
min_input_
||
LOG
(
FATAL
)
<<
indicator
<<
"Input size: "
<<
def
.
input_size
()
def
.
input_size
()
>
max_input_
)
{
LOG
(
FATAL
)
<<
header
<<
"Input size: "
<<
def
.
input_size
()
<<
" is not in range [min="
<<
min_input_
<<
" is not in range [min="
<<
min_input_
<<
", max="
<<
max_input_
<<
"]"
;
<<
", max="
<<
max_input_
<<
"]"
;
}
}
if
(
def
.
output_size
()
<
min_output_
||
def
.
output_size
()
>
max_output_
)
{
if
(
def
.
output_size
()
<
min_output_
||
LOG
(
FATAL
)
<<
indicator
<<
"Output size: "
<<
def
.
output_size
()
def
.
output_size
()
>
max_output_
)
{
LOG
(
FATAL
)
<<
header
<<
"Output size: "
<<
def
.
output_size
()
<<
" is not in range [min="
<<
min_output_
<<
" is not in range [min="
<<
min_output_
<<
", max="
<<
max_output_
<<
"]"
;
<<
", max="
<<
max_output_
<<
"]"
;
}
}
for
(
int
in
=
0
;
in
<
def
.
input_size
();
in
++
)
{
for
(
int
i
=
0
;
i
<
def
.
input_size
();
++
i
)
{
if
(
def
.
input
(
in
)
==
"NULL"
)
continue
;
if
(
def
.
input
(
i
)
==
"NULL"
)
continue
;
for
(
int
out
=
0
;
out
<
def
.
output_size
();
out
++
)
{
for
(
int
j
=
0
;
j
<
def
.
output_size
();
++
j
)
{
if
(
def
.
output
(
out
)
==
"NULL"
)
continue
;
if
(
def
.
output
(
j
)
==
"NULL"
)
continue
;
if
(
def
.
input
(
in
)
==
def
.
output
(
out
)
&&
(
!
CheckInplace
(
in
,
out
)))
if
(
def
.
input
(
i
)
==
def
.
output
(
j
)
&&
LOG
(
FATAL
)
<<
indicator
<<
"Input("
<<
in
<<
") and "
!
CheckInplace
(
i
,
j
))
<<
"Output("
<<
out
<<
") can not be set to inplace."
;
LOG
(
FATAL
)
<<
header
<<
"Input("
<<
i
<<
") and Output("
<<
j
<<
") "
<<
"can not be set to inplace."
;
}
}
}
}
return
true
;
return
true
;
...
...
Dragon/src/kernels/control_flow/assign_op_kernel.cu
View file @
bd84b7f
...
@@ -54,7 +54,7 @@ __global__ void _Assign(
...
@@ -54,7 +54,7 @@ __global__ void _Assign(
const T* x, \
const T* x, \
T* y, \
T* y, \
CUDAContext* ctx) { \
CUDAContext* ctx) { \
_Assign
<T>
\
_Assign \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \
0, ctx->cuda_stream() >> >( \
count, \
count, \
...
...
Dragon/src/kernels/control_flow/masked_assign_op_kernel.cc
0 → 100644
View file @
bd84b7f
#include "utils/op_kernel.h"
#include "utils/math_utils.h"
#include "utils/omp_alternative.h"
namespace
dragon
{
namespace
kernel
{
/* <T = ?, Device = CPU> */
template
<
typename
T
>
void
_MaskedAssign
(
const
int
count
,
const
uint8_t
*
mask
,
const
T
*
x
,
T
*
y
)
{
#ifdef WITH_OMP
#pragma omp parallel for num_threads(OMP_THREADS(count))
#endif
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
y
[
i
]
=
mask
[
i
]
?
x
[
i
]
:
y
[
i
];
}
}
/* Kernel Launchers */
#define DEFINE_ASSIGN_KERNEL_LAUNCHER(T) \
template<> void MaskedAssign<T, CPUContext>( \
const int count, \
const uint8_t* mask, \
const T* x, \
T* y, \
CPUContext* ctx) { \
_MaskedAssign(count, mask, x, y); \
}
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
bool
);
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
int8_t
);
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
uint8_t
);
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
int
);
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
int64_t
);
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
float16
);
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
float
);
DEFINE_ASSIGN_KERNEL_LAUNCHER
(
double
);
#undef DEFINE_ASSIGN_KERNEL_LAUNCHER
}
// namespace kernel
}
//
namepsace
dragon
\ No newline at end of file
Dragon/src/kernels/control_flow/masked_assign_op_kernel.cu
0 → 100644
View file @
bd84b7f
#ifdef WITH_CUDA
#include "core/context_cuda.h"
#include "utils/op_kernel.h"
namespace dragon {
namespace kernel {
/* <T = ?, Device = CUDA> */
template<typename T>
__global__ void _MaskedAssign(
const int nthreads,
const uint8_t* mask,
const T* x,
T* y) {
CUDA_1D_KERNEL_LOOP(i, nthreads) {
y[i] = mask[i] ? x[i] : y[i];
}
}
/* Kernel Launchers */
#define DEFINE_ASSIGN_KERNEL_LAUNCHER(T) \
template<> void MaskedAssign<T, CUDAContext>( \
const int count, \
const uint8_t* mask, \
const T* x, \
T* y, \
CUDAContext* ctx) { \
_MaskedAssign \
<< < CUDA_BLOCKS(count), CUDA_THREADS, \
0, ctx->cuda_stream() >> >( \
count, mask, x, y \
); \
}
DEFINE_ASSIGN_KERNEL_LAUNCHER(bool);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(uint8_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int);
DEFINE_ASSIGN_KERNEL_LAUNCHER(int64_t);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float16);
DEFINE_ASSIGN_KERNEL_LAUNCHER(float);
DEFINE_ASSIGN_KERNEL_LAUNCHER(double);
#undef DEFINE_ASSIGN_KERNEL_LAUNCHER
} // namespace kernel
} // namepsace dragon
#endif // WITH_CUDA
\ No newline at end of file
Dragon/src/operators/control_flow/assign_op.cc
View file @
bd84b7f
...
@@ -146,7 +146,10 @@ DEPLOY_CUDA(Assign);
...
@@ -146,7 +146,10 @@ DEPLOY_CUDA(Assign);
#endif
#endif
OPERATOR_SCHEMA
(
Assign
)
OPERATOR_SCHEMA
(
Assign
)
.
NumInputs
(
1
).
NumOutputs
(
1
);
/* V */
.
NumInputs
(
1
)
/* X */
.
NumOutputs
(
1
);
NO_GRADIENT
(
Assign
);
NO_GRADIENT
(
Assign
);
...
...
Dragon/src/operators/control_flow/masked_assign_op.cc
0 → 100644
View file @
bd84b7f
#include "core/workspace.h"
#include "utils/op_kernel.h"
#include "utils/math_utils.h"
#include "utils/math_functions.h"
#include "operators/control_flow/masked_assign_op.h"
namespace
dragon
{
template
<
class
Context
>
template
<
typename
T
>
void
MaskedAssignOp
<
Context
>::
RunImpl
()
{
const
T
*
x
=
nullptr
;
auto
*
mask
=
X
(
1
).
template
raw_data
<
Context
>
();
auto
*
y
=
Y
(
0
)
->
template
mutable_data
<
T
,
Context
>
();
if
(
X
(
0
).
count
()
<
Y
(
0
)
->
count
())
{
int
rows
,
cols
;
auto
*
scratch
=
ws
()
->
template
data
<
T
,
Context
>
({
Y
(
0
)
->
count
()
})[
0
];
auto
*
rx
=
X
(
0
).
template
data
<
T
,
Context
>
();
if
(
utils
::
IsRowwiseBroadcast
(
Y
(
0
)
->
dims
(),
X
(
0
).
dims
(),
&
rows
,
&
cols
))
{
math
::
BroadcastSet
(
rows
,
cols
,
0
,
rx
,
scratch
,
ctx
()
);
}
else
if
(
utils
::
IsColwiseBroadcast
(
Y
(
0
)
->
dims
(),
X
(
0
).
dims
(),
&
rows
,
&
cols
))
{
math
::
BroadcastSet
(
rows
,
cols
,
1
,
rx
,
scratch
,
ctx
()
);
}
else
{
LOG
(
FATAL
)
<<
"Could not broadcast "
<<
X
(
0
).
DimString
()
<<
" to "
<<
Y
(
0
)
->
DimString
();
}
x
=
scratch
;
}
else
if
(
X
(
0
).
count
()
==
Y
(
0
)
->
count
())
{
x
=
X
(
0
).
template
data
<
T
,
Context
>
();
}
else
{
LOG
(
FATAL
)
<<
"Could not assign "
<<
X
(
0
).
DimString
()
<<
" to "
<<
Y
(
0
)
->
DimString
();
}
kernel
::
MaskedAssign
(
Y
(
0
)
->
count
(),
(
const
uint8_t
*
)
mask
,
x
,
y
,
ctx
()
);
}
template
<
class
Context
>
void
MaskedAssignOp
<
Context
>::
RunOnDevice
()
{
CHECK_EQ
(
X
(
1
).
count
(),
Y
(
0
)
->
count
())
<<
"
\n
Size of mask and input should be equal."
;
CHECK
(
XIsType
(
X
(
1
),
bool
)
||
XIsType
(
X
(
1
),
uint8_t
))
<<
"
\n
Excepted bool or uint8 mask."
;
if
(
XIsType
(
X
(
0
),
bool
))
{
RunImpl
<
bool
>
();
}
else
if
(
XIsType
(
X
(
0
),
int8_t
))
{
RunImpl
<
int8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
uint8_t
))
{
RunImpl
<
uint8_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
int
))
{
RunImpl
<
int
>
();
}
else
if
(
XIsType
(
X
(
0
),
int64_t
))
{
RunImpl
<
int64_t
>
();
}
else
if
(
XIsType
(
X
(
0
),
float16
))
{
RunImpl
<
float16
>
();
}
else
if
(
XIsType
(
X
(
0
),
float
))
{
RunImpl
<
float
>
();
}
else
if
(
XIsType
(
X
(
0
),
double
))
{
RunImpl
<
double
>
();
}
else
{
LOG
(
FATAL
)
<<
DTypeString
(
X
(
0
),
{
"bool"
,
"int8"
,
"uint8"
,
"int32"
,
"int64"
,
"float16"
,
"float32"
,
"float64"
,
});
}
}
DEPLOY_CPU
(
MaskedAssign
);
#ifdef WITH_CUDA
DEPLOY_CUDA
(
MaskedAssign
);
#endif
OPERATOR_SCHEMA
(
MaskedAssign
)
/* V, M */
.
NumInputs
(
2
)
/* X */
.
NumOutputs
(
1
);
NO_GRADIENT
(
MaskedAssign
);
}
//
namespace
dragon
\ No newline at end of file
Write
Preview
Markdown
is supported
Attach a file
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to post a comment