// --------------------------------------------------------
// Dragon
// Copyright(c) 2017 SeetaTech
// Written by Ting Pan
// --------------------------------------------------------

#ifndef DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_
#define DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_

#include "core/operator.h"

namespace dragon {

template <class Context>
class ReluOp : public Operator<Context> {
 public:
    ReluOp(const OperatorDef& op_def, Workspace* ws)
        : Operator<Context>(op_def, ws),
          slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {}

    void RunOnDevice() override;
    template <typename T> void RunWithType();

 protected:
    float slope;
};

template <class Context>
class ReluGradientOp : public Operator<Context> {
 public:
    ReluGradientOp(const OperatorDef& op_def, Workspace* ws)
        : Operator<Context>(op_def, ws),
          slope(OperatorBase::GetSingleArg<float>("slope", 0.0)) {
        DISABLE_SHARE_GRADIENT;
    }

    void RunOnDevice() override;
    template <typename T> void RunWithType();

 protected:
    float slope;
};

#ifdef WITH_CUDNN

template <class Context>
class CuDNNReluOp final : public ReluOp<Context> {
public:
    CuDNNReluOp(const OperatorDef& op_def, Workspace* ws) 
        : ReluOp<Context>(op_def, ws) {
        CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
        CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
        CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
        CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc, 
            CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
    }
    void RunOnDevice() override;
    template <typename T> void RunWithType(); 

 protected:
    cudnnTensorDescriptor_t input_desc, output_desc;
    cudnnActivationDescriptor_t act_desc;
};

template <class Context>
class CuDNNReluGradientOp final : public ReluGradientOp<Context> {
 public:
    CuDNNReluGradientOp(const OperatorDef& op_def, Workspace* ws)
        : ReluGradientOp<Context>(op_def, ws) {
        CUDNN_CHECK(cudnnCreateTensorDescriptor(&input_desc));
        CUDNN_CHECK(cudnnCreateTensorDescriptor(&output_desc));
        CUDNN_CHECK(cudnnCreateActivationDescriptor(&act_desc));
        CUDNN_CHECK(cudnnSetActivationDescriptor(act_desc,
            CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0));
    }
    void RunOnDevice() override;
    template <typename T> void RunWithType();

 protected:
    cudnnTensorDescriptor_t input_desc, output_desc;
    cudnnActivationDescriptor_t act_desc;
};

#endif // WITH_CUDNN

}    // namespace dragon

#endif    // DRAGON_OPERATORS_ACTIVATION_RELU_OP_H_