1 #ifndef _OPERATORS_REF_ACTIVATION_FUN_OP_H_
2 #define _OPERATORS_REF_ACTIVATION_FUN_OP_H_
22 static inline void relu(
float* inputs,
float* results,
int input_size) {
24 for (
int i = 0; i < input_size; i++) {
25 float value = inputs[i];
35 static inline void lrelu(
float* inputs,
40 for (
int i = 0; i < input_size; i++) {
41 float value = inputs[i];
43 results[i] = slope * value;
51 static inline void elu(
float* inputs,
56 for (
int i = 0; i < input_size; i++) {
57 float value = inputs[i];
59 results[i] = alpha * (exp(value) - 1);
67 static inline void selu(
float* inputs,
72 elu(inputs, results, input_size, alpha);
74 for (
int i = 0; i < input_size; i++) {
75 results[i] = lambda * results[i];
80 static inline void sigmoid(
float* inputs,
float* results,
int input_size) {
82 for (
int i = 0; i < input_size; i++) {
83 results[i] = 1.0 / (1.0 + exp(-inputs[i]));
88 static inline void tanh_act(
float* inputs,
float* results,
int input_size) {
91 for (i = 0; i < input_size; i++) {
92 results[i] = 2 * inputs[i];
95 sigmoid(results, results, input_size);
98 for (i = 0; i < input_size; i++) {
99 results[i] = 2 * results[i] - 1;
104 static inline void hard_tanh_act(
105 float* inputs,
float* results,
int input_size,
float min,
float max) {
107 for (
int i = 0; i < input_size; i++) {
108 float value = inputs[i];
109 results[i] = (value < min) ? min : (value > max) ? max : value;
114 static inline void activation_fun(
float* inputs,
119 if (
function == RELU) {
120 relu(inputs, results, inputs_size);
121 }
else if (
function == LRELU) {
122 lrelu(inputs, results, inputs_size, params.slope);
123 }
else if (
function == ELU) {
124 elu(inputs, results, inputs_size, params.alpha);
125 }
else if (
function == SELU) {
126 selu(inputs, results, inputs_size, params.alpha, params.lambda);
127 }
else if (
function == TANH) {
128 tanh_act(inputs, results, inputs_size);
129 }
else if (
function == HARD_TANH) {
130 hard_tanh_act(inputs, results, inputs_size, params.min, params.max);
131 }
else if (
function == SIGMOID) {
132 sigmoid(inputs, results, inputs_size);
133 }
else if (
function == SOFTMAX) {
134 assert(
false &&
"Softmax not added yet!");