5 #include "smaug/operators/smv/kernels/params.h"
7 #include "smaug/operators/smv/kernels/activation_functions_simd.h"
54 float16* host_weights,
55 float16* host_results,
65 int inputs_halo_pad[4],
77 int result_rows = results_dims[1];
78 int result_cols = results_dims[2];
79 int result_height = results_dims[3];
80 int results_size = results_dims[0] * result_rows * result_cols *
81 (result_height + results_pad);
83 int k_rows = weights_dims[1];
84 int k_cols = weights_dims[2];
85 int k_height = weights_dims[3];
86 int k_pad = weights_pad;
87 int weights_size = weights_dims[0] * k_rows * k_cols * (k_height + k_pad);
89 int a_rows = inputs_dims[1];
90 int a_cols = inputs_dims[2];
91 int a_height = inputs_dims[3];
92 int a_pad = inputs_align_pad;
93 int inputs_size = inputs_dims[0] * a_rows * a_cols * (a_height + a_pad);
95 int top_pad = inputs_halo_pad[0];
96 int bottom_pad = inputs_halo_pad[1];
97 int left_pad = inputs_halo_pad[2];
98 int right_pad = inputs_halo_pad[3];
99 int end_row = a_rows + top_pad + bottom_pad - k_rows + 1;
100 int end_col = a_cols + left_pad + right_pad - k_cols + 1;
102 int valid_row_end = a_rows - 1;
103 int valid_col_end = a_cols - 1;
107 const v8fp_t zero = { 0, 0, 0, 0, 0, 0, 0, 0 };
110 VEC_ARRAY_4D(
v8fp_t, _kernels, weights, k_rows, k_cols, k_height + k_pad);
112 VEC_ARRAY_3D(
v8fp_t, _a, inputs, a_cols, a_height + a_pad);
115 v8fp_t, _result, results, result_cols, result_height + results_pad);
116 int num_chan_blocks = (k_height - 1) / pe_depth;
121 int num_eff_kernels = min2(weights_dims[0], result_height);
122 int num_kernel_blocks = (num_eff_kernels - 1) / NUM_PE_INSTS;
131 int pe_block_sample = num_kernel_blocks + 1;
132 int kern_row_sample = k_rows;
133 int kern_col_sample = k_cols;
134 int chan_block_sample = num_chan_blocks + 1;
135 int output_row_sample = end_row;
136 int output_col_sample = end_col;
137 int output_row_total_iters =
FRAC_CEIL(end_row, row_stride);
138 int output_col_total_iters =
FRAC_CEIL(end_col, col_stride);
139 int output_row_sample_iters = output_row_total_iters;
140 int output_col_sample_iters = output_col_total_iters;
142 if (sampling->
level >= Low)
143 pe_block_sample = min2(pe_block_sample, sample_num);
144 if (sampling->
level >= Medium) {
145 kern_row_sample = min2(kern_row_sample, sample_num);
146 kern_col_sample = min2(kern_col_sample, sample_num);
148 if (sampling->
level >= High)
149 chan_block_sample = min2(chan_block_sample, sample_num);
150 if (sampling->
level >= VeryHigh) {
151 output_row_sample_iters = min2(output_row_sample_iters, sample_num);
152 output_row_sample = output_row_sample_iters * row_stride;
154 output_col_sample_iters =
155 min2(output_col_sample_iters, max2(2, sample_num));
156 output_col_sample = output_col_sample_iters * col_stride;
158 setSamplingFactor(
"ofmap_block_iteration",
159 (num_kernel_blocks + 1) * 1.0 / pe_block_sample);
160 setSamplingFactor(
"k_row", k_rows * 1.0 / kern_row_sample);
161 setSamplingFactor(
"k_col", k_cols * 1.0 / kern_col_sample);
163 "pe_iteration", (num_chan_blocks + 1) * 1.0 / chan_block_sample);
164 setSamplingFactor(
"conv3d_row",
165 output_row_total_iters * 1.0 / output_row_sample_iters);
166 setSamplingFactor(
"conv3d_col",
167 output_col_total_iters * 1.0 / output_col_sample_iters);
169 ofmap_block_iteration:
170 for (
int ofmap_iters = 0; ofmap_iters < pe_block_sample;
172 int ofmap_offset = ofmap_iters * NUM_PE_INSTS;
174 int kEffNumPeInsts = min2(result_height - ofmap_offset, NUM_PE_INSTS);
177 for (
int kern_row = 0; kern_row < kern_row_sample; kern_row++) {
179 for (
int kern_col = 0; kern_col < kern_col_sample;
184 for (
int ifmap_iters = 0; ifmap_iters < chan_block_sample;
186 bool start_from_zero = (!accumulate && kern_row == 0 &&
187 kern_col == 0 && ifmap_iters == 0);
188 int ifmap_offset = (ifmap_start + ifmap_iters * pe_depth) /
190 int kern_chan_offset =
194 int max_ch_grp = NUM_MACC_INSTS;
197 if (ifmap_iters == num_chan_blocks) {
199 FRAC_CEIL((k_height - ifmap_iters * pe_depth),
205 v8fp_t kernel_reg[NUM_PE_INSTS][NUM_MACC_INSTS] = {
206 { zero }, { zero }, { zero }, { zero },
207 { zero }, { zero }, { zero }, { zero }
211 for (
int pe_id = 0; pe_id < kEffNumPeInsts; pe_id++) {
213 for (
int macc_idx = 0; macc_idx < NUM_MACC_INSTS;
215 kernel_reg[pe_id][macc_idx] =
216 (macc_idx >= max_ch_grp)
218 : _kernels[kern_start +
219 ofmap_offset + pe_id]
227 for (
int out_row = 0; out_row < output_row_sample;
228 out_row += row_stride) {
236 for (
int out_col = 0; out_col < output_col_sample;
237 out_col += col_stride) {
240 v8fp_t smv_conv_product_reg[NUM_PE_INSTS]
242 v8fp_t act_reg[NUM_MACC_INSTS];
243 results_buffer = start_from_zero
245 : _result[out_i][out_j]
247 in_row = out_row - top_pad + kern_row;
248 in_col = out_col - left_pad + kern_col;
249 bool in_padding_row =
250 in_row < 0 || in_row > valid_row_end;
251 bool in_padding_col =
252 in_col < 0 || in_col > valid_col_end;
257 for (
int macc_idx = 0; macc_idx < NUM_MACC_INSTS;
259 bool is_padding = in_padding_row ||
261 macc_idx >= max_ch_grp;
266 [ifmap_offset + macc_idx];
269 v8fp_t accum_vec_reg[NUM_PE_INSTS] = {
270 zero, zero, zero, zero, zero, zero, zero, zero
272 float accum_reg[NUM_PE_INSTS] = { 0, 0, 0, 0,
275 for (
int pe_id = 0; pe_id < kEffNumPeInsts;
278 for (
int macc_idx = 0;
279 macc_idx < NUM_MACC_INSTS;
281 smv_conv_product_reg[pe_id][macc_idx] =
282 kernel_reg[pe_id][macc_idx] *
286 for (
int macc_idx = 0;
287 macc_idx < NUM_MACC_INSTS;
289 accum_vec_reg[pe_id] +=
290 smv_conv_product_reg[pe_id]
297 accum_vec_reg[pe_id][vec_i];
299 results_buffer[pe_id] += accum_reg[pe_id];
303 _result[out_i][out_j][ofmap_iters] = results_buffer;
314 if (act_function != NO_ACTIVATION && send_results) {
316 results, results, results_size, act_function, act_params);