5 #include "smaug/operators/smv/kernels/params.h"
7 #include "smaug/operators/smv/kernels/activation_functions_simd.h"
30 v8fp_t scale = recip_sqrt_var * gamma;
31 v8fp_t shift = input - mean;
32 return shift * scale + beta;
42 float16* host_weights,
43 float16* host_results,
54 int inputs_nums = inputs_dims[0];
55 int inputs_acts = inputs_dims[1];
56 int inputs_size = inputs_nums * (inputs_acts + inputs_pad);
57 int weights_size = 4 * (weights_acts + inputs_pad);
58 int results_size = inputs_size;
62 if (inputs_start == 0)
66 VEC_ARRAY_2D(
v8fp_t, _inputs, inputs, inputs_size + inputs_pad);
67 VEC_ARRAY_2D(
v8fp_t, _weights, weights, weights_acts + inputs_pad);
68 VEC_ARRAY_2D(
v8fp_t, _results, results, inputs_size + inputs_pad);
71 for (
int i = 0; i < inputs_nums; i++) {
73 for (
int j = 0; j < weights_acts /
VECTOR_SIZE; j++) {
74 _results[i][j + inputs_start_vec] =
83 if (act_function != NO_ACTIVATION && send_results) {
85 results, results, results_size, act_function, act_params);
101 float16* host_weights,
102 float16* host_results,
113 int inputs_nums = inputs_dims[0];
114 int inputs_chans = inputs_dims[1];
115 int inputs_rows = inputs_dims[2];
116 int inputs_cols = inputs_dims[3];
117 int inputs_size = inputs_nums * inputs_chans * inputs_rows *
118 (inputs_cols + inputs_pad);
119 int weights_size = 4 * (weights_chans + weights_pad);
120 int results_size = inputs_size;
121 int weights_start_vec = weights_start /
VECTOR_SIZE;
125 if (weights_start == 0)
133 inputs_cols + inputs_pad);
134 VEC_ARRAY_2D(
v8fp_t, _weights, weights, weights_chans + weights_pad);
140 inputs_cols + inputs_pad);
143 for (
int i = 0; i < inputs_nums; i++) {
148 float mean = _weights[0][h + weights_start_vec][v];
149 float recip_sqrt_var = _weights[1][h + weights_start_vec][v];
150 float gamma = _weights[2][h + weights_start_vec][v];
151 float beta = _weights[3][h + weights_start_vec][v];
152 v8fp_t mean_vec = { mean, mean, mean, mean,
153 mean, mean, mean, mean };
154 v8fp_t recip_sqrt_var_vec = { recip_sqrt_var, recip_sqrt_var,
155 recip_sqrt_var, recip_sqrt_var,
156 recip_sqrt_var, recip_sqrt_var,
157 recip_sqrt_var, recip_sqrt_var };
158 v8fp_t gamma_vec = { gamma, gamma, gamma, gamma,
159 gamma, gamma, gamma, gamma };
160 v8fp_t beta_vec = { beta, beta, beta, beta,
161 beta, beta, beta, beta };
165 for (
int r = 0; r < inputs_rows; r++) {
169 _results[i][ofmap][r][c] =
180 if (act_function != NO_ACTIVATION) {
182 results, results, results_size, act_function, act_params);
197 float16* host_weights,
198 float16* host_results,
210 int inputs_nums = inputs_dims[0];
211 int inputs_rows = inputs_dims[1];
212 int inputs_cols = inputs_dims[2];
213 int inputs_chans = inputs_dims[3];
214 int inputs_size = inputs_nums * inputs_rows * inputs_cols *
215 (inputs_chans + inputs_pad);
216 int weights_size = 4 * (weights_chans + weights_pad);
217 int results_size = inputs_size;
218 int weights_start_vec = weights_start /
VECTOR_SIZE;
223 if (weights_start == 0)
226 VEC_ARRAY_4D(
v8fp_t, _inputs, inputs, inputs_rows, inputs_cols,
227 inputs_chans + inputs_pad);
228 VEC_ARRAY_2D(
v8fp_t, _weights, weights, weights_chans + weights_pad);
229 VEC_ARRAY_4D(
v8fp_t, _results, results, inputs_rows, inputs_cols,
230 inputs_chans + inputs_pad);
234 int batch_sample = inputs_nums;
235 int chan_sample = inputs_chans_vec;
236 int row_sample = inputs_rows;
237 int col_sample = inputs_cols;
239 if (sampling->
level >= VeryHigh) {
240 batch_sample = min2(batch_sample, sample_num);
241 chan_sample = min2(chan_sample, sample_num);
242 row_sample = min2(row_sample, sample_num);
243 col_sample = min2(col_sample, sample_num);
245 setSamplingFactor(
"bn_batch", inputs_nums * 1.0 / batch_sample);
246 setSamplingFactor(
"bn_chan", inputs_chans_vec * 1.0 / chan_sample);
247 setSamplingFactor(
"bn_row", inputs_rows * 1.0 / row_sample);
248 setSamplingFactor(
"bn_col", inputs_cols * 1.0 / col_sample);
251 for (
int i = 0; i < batch_sample; i++) {
253 for (
int h = 0; h < chan_sample; h++) {
254 v8fp_t mean = _weights[0][h + weights_start_vec];
255 v8fp_t recip_sqrt_var = _weights[1][h + weights_start_vec];
256 v8fp_t gamma = _weights[2][h + weights_start_vec];
257 v8fp_t beta = _weights[3][h + weights_start_vec];
259 for (
int r = 0; r < row_sample; r++) {
261 for (
int c = 0; c < col_sample; c++) {
262 _results[i][r][c][h] =
272 if (act_function != NO_ACTIVATION) {
274 results, results, results_size, act_function, act_params);