#include "operators/misc/python_op.h" namespace dragon { template <class Context> RunOp<Context>::RunOp(const OperatorDef& op_def, Workspace* ws) : Operator<Context>(op_def, ws), module(OperatorBase::GetSingleArg<string>("module", "")), op(OperatorBase::GetSingleArg<string>("op", "")), param_str((OperatorBase::GetSingleArg<string>("param_str", ""))) { // init interpreter & load module Py_Initialize(); PyObject* py_module = PyImport_ImportModule(module.c_str()); CHECK(py_module) << "\nFail to import py module: " << module; PyObject* py_dict = PyModule_GetDict(py_module); PyObject* py_op = PyDict_GetItemString(py_dict, op.c_str()); CHECK(py_op) << "\nFail not import operator: " << op << " from module: " << module; self = PyObject_CallObject(py_op, NULL); // pass param string PyObject_SetAttr(self, String("param_str"), String(param_str.c_str())); PyObject_SetAttr(self, String("param_str_"), String(param_str.c_str())); // build inputs and outputs for Python inputs = PyList_New(InputSize()); for (int i = 0; i < InputSize(); i++) PyList_SetItem(inputs, i, String(input(i).name().c_str())); outputs = PyList_New(OutputSize()); for (int i = 0; i < OutputSize(); i++) PyList_SetItem(outputs, i, String(output(i)->name().c_str())); if (!this->allow_run()) return; // setup if (PyObject_HasAttr(self, String("setup"))) PyObject_CallMethod(self, "setup", "OO", inputs, outputs); } template <class Context> void RunOp<Context>::RunOnDevice() { // init phase PyObject_SetAttr(self, String("phase"), String(this->phase().c_str())); // reshape if (PyObject_HasAttr(self, String("reshape"))) PyObject_CallMethod(self, "reshape", "OO", inputs, outputs); // run if (PyObject_HasAttr(self, String("forward"))) { PyObject_CallMethod(self, "forward", "OO", inputs, outputs); } else if (PyObject_HasAttr(self, String("run"))) { PyObject_CallMethod(self, "run", "OO", inputs, outputs); } } DEPLOY_CPU(Run); #ifdef WITH_CUDA DEPLOY_CUDA(Run); #endif OPERATOR_SCHEMA(Run); NO_GRADIENT(Run); template <class Context> void TemplateGradientOp<Context>::RunOnDevice() { // init phase PyObject_SetAttr(this->self, String("phase"), String(this->phase().c_str())); // reshape if (PyObject_HasAttr(this->self, String("reshape"))) PyObject_CallMethod(this->self, "reshape", "OO", this->inputs, this->outputs); // run if (PyObject_HasAttr(this->self, String("backward"))) { PyObject_CallMethod(this->self, "forward", "OO", this->inputs, this->outputs); } else if (PyObject_HasAttr(this->self, String("grad"))) { PyObject_CallMethod(this->self, "grad", "OO", this->inputs, this->outputs); } } DEPLOY_CPU(Template); #ifdef WITH_CUDA DEPLOY_CUDA(Template); #endif OPERATOR_SCHEMA(Template); DEPLOY_CPU(TemplateGradient); #ifdef WITH_CUDA DEPLOY_CUDA(TemplateGradient); #endif OPERATOR_SCHEMA(TemplateGradient); class GetTemplateGradient final : public GradientMakerBase { public: GRADIENT_MAKER_CTOR(GetTemplateGradient); vector<OperatorDef> MakeDefs() override { vector<string> inputs, outputs; for (auto input : def.input()) inputs.push_back(input); for (int i = 0; i < def.output_size(); i++) inputs.push_back(GO(i)); for (int i = 0; i < def.input_size(); i++) outputs.push_back(GI(i)); return SingleDef(def.type() + "Gradient", "", inputs, outputs); } }; REGISTER_GRADIENT(Template, GetTemplateGradient); } // namespace dragon