79 Device* device_ptr =
nullptr;
82 layer* layer_ptr_ =
nullptr;
85 Params* params_ptr_ =
nullptr;
88 bool parallelize =
false;
90 backend_t engine = default_engine();
96 op_params_ = std::unique_ptr<OpParams>(
new OpParams());
100 const std::vector<tensor_t*>&
out_data,
102 std::vector<tensor_t*>&
in_grad)
107 op_params_ = std::unique_ptr<OpParams>(
new OpParams());
110 tensor_t& input(
const int idx)
const {
111 return *in_data_[
idx];
114 tensor_t& output(
const int idx)
const {
115 return *out_data_[idx];
118 tensor_t& input_grad(
const int idx)
const {
119 return *in_grad_[idx];
122 tensor_t& output_grad(
const int idx)
const {
123 return *out_grad_[idx];
126 void setParams(Params* params) {
127 op_params_->params_ptr_ = params;
130 Params* params()
const {
131 return op_params_->params_ptr_;
134 void setParallelize(
const bool parallelize) {
135 op_params_->parallelize = parallelize;
138 bool parallelize()
const {
139 return op_params_->parallelize;
142 void setDevice(Device* device) {
143 op_params_->device_ptr = device;
146 Device* device()
const {
147 return op_params_->device_ptr;
150 void setLayer(layer* layer) {
151 op_params_->layer_ptr_ = layer;
154 layer* Layer()
const {
155 return op_params_->layer_ptr_;
158 backend_t engine()
const {
159 return op_params_->engine;
162 void setEngine(
const backend_t engine) {
163 op_params_->engine = engine;
167 std::vector<tensor_t*> in_data_;
168 std::vector<tensor_t*> out_data_;
169 std::vector<tensor_t*> out_grad_;
170 std::vector<tensor_t*> in_grad_;
172 std::unique_ptr<OpParams> op_params_;
Definition op_kernel.h:72
Definition op_kernel.h:74