#include "operators/vision/nn_resize_op.h" #include "core/workspace.h" #include "utils/math_functions.h" #include "utils/op_kernel.h" namespace dragon { template <class Context> template <typename T> void NNResizeOp<Context>::RunWithType() { if (data_format == "NCHW") { n = input(0).dim(0); c = input(0).dim(1); h = input(0).dim(2); w = input(0).dim(3); out_h = output(0)->dim(2); out_w = output(0)->dim(3); } else if (data_format == "NHWC") { n = input(0).dim(0); h = input(0).dim(1); w = input(0).dim(2); c = input(0).dim(3); out_h = output(0)->dim(1); out_w = output(0)->dim(2); } auto* Xdata = input(0).template data<T, Context>(); auto* Ydata = output(0)->template mutable_data<T, Context>(); kernel::NNResize<T, Context>(output(0)->count(), n, c, h, w, out_h, out_w, data_format, Xdata, Ydata); } template <class Context> void NNResizeOp<Context>::RunOnDevice() { vector<TIndex> dims = input(0).dims(); if (dynamic_dsize.size() > 0) { CHECK_EQ(dynamic_dsize.size(), 2) << "\nThe dsize should be a scalar with 2 elements."; for (int i = 0; i < 2; i++) { Tensor* t = ws()->GetTensor(dynamic_dsize[i]); if (t->IsType<int>()) { dims[spatial_axis + i] = t->template data<int, CPUContext>()[0]; } else if (t->IsType<float>()) { dims[spatial_axis + i] = t->template data<float, CPUContext>()[0]; } else { LOG(FATAL) << "Unsupported types of dsize."; } } } else if (static_dsize.size() > 0) { CHECK_EQ(static_dsize.size(), 2) << "\nThe dsize should be a scalar with 2 elements."; for (int i = 0; i < 2; i++) dims[spatial_axis + i] = static_dsize[i]; } else { CHECK(fy != -1.0 && fx != -1.0) << "\nThe fx and fy should be set."; dims[spatial_axis] = int(dims[spatial_axis] * fy); dims[spatial_axis + 1] = int(dims[spatial_axis + 1] * fx); } output(0)->Reshape(dims); if (input(0).template IsType<float>()) return RunWithType<float>(); else LOG(FATAL) << "Unsupported input types."; } DEPLOY_CPU(NNResize); #ifdef WITH_CUDA DEPLOY_CUDA(NNResize); #endif OPERATOR_SCHEMA(NNResize).NumInputs(1).NumOutputs(1); template <class Context> template <typename T> void NNResizeGradientOp<Context>::RunWithType() { if (data_format == "NCHW") { n = input(0).dim(0); c = input(0).dim(1); h = input(0).dim(2); w = input(0).dim(3); out_h = input(-1).dim(2); out_w = input(-1).dim(3); } else if (data_format == "NHWC") { n = input(0).dim(0); h = input(0).dim(1); w = input(0).dim(2); c = input(0).dim(3); out_h = input(-1).dim(1); out_w = input(-1).dim(2); } auto* dYdata = input(-1).template data<T, Context>(); auto* dXdata = output(0)->template mutable_data<T, Context>(); kernel::NNResizeGrad<T, Context>(input(-1).count(), n, c, h, w, out_h, out_w, data_format, dYdata, dXdata); } template <class Context> void NNResizeGradientOp<Context>::RunOnDevice() { output(0)->ReshapeLike(input(0)); if (input(0).template IsType<float>()) return RunWithType<float>(); else LOG(FATAL) << "Unsupported input types."; } DEPLOY_CPU(NNResizeGradient); #ifdef WITH_CUDA DEPLOY_CUDA(NNResizeGradient); #endif OPERATOR_SCHEMA(NNResizeGradient).NumInputs(2).NumOutputs(1); class GetNNResizeGradient final : public GradientMakerBase { public: GRADIENT_MAKER_CTOR(GetNNResizeGradient); vector<OperatorDef> MakeDefs() override { return SingleDef(def.type() + "Gradient", "", vector<string> {I(0), GO(0)}, vector<string> {GI(0)}); } }; REGISTER_GRADIENT(NNResize, GetNNResizeGradient); } // namespace dragon