5 #include "smaug/operators/smv/kernels/params.h"
7 #include "smaug/operators/smv/kernels/activation_functions_simd.h"
61 float16* host_results,
79 int a_width = a_dims[1];
80 int a_height = a_dims[0];
81 int b_width = b_dims[1];
82 int b_height = b_dims[0];
83 int results_width = results_dims[1];
84 int results_height = results_dims[0];
86 "Width of b must be a multiple of VECTOR_SIZE!");
89 int a_size = a_height * (a_width + a_pad);
90 int b_size = b_height * (b_width + b_pad);
91 int results_size = results_height * (results_width + results_pad);
94 VEC_ARRAY_2D(
v8fp_t, _a, a, a_width + a_pad);
95 VEC_ARRAY_2D(
v8fp_t, _b, b, b_width + b_pad);
96 VEC_ARRAY_2D(
v8fp_t, _results, results, results_width + results_pad);
105 int b_col_sample = b_width_vec;
106 int b_col_total_iters =
FRAC_CEIL(b_width_vec, NUM_MACC_INSTS);
107 int b_col_sample_iters = b_col_total_iters;
109 if (sampling->
level >= VeryHigh) {
111 b_col_sample_iters = min2(b_col_sample_iters, max2(2, sample_num));
112 b_col_sample = b_col_sample_iters * NUM_MACC_INSTS;
114 setSamplingFactor(
"b_col", b_col_total_iters * 1.0 / b_col_sample_iters);
117 for (
int a_act = 0; a_act < a_height; a_act++) {
119 for (
int b_row = 0; b_row < b_height; b_row += NUM_PE_INSTS) {
122 partial_sums = _results[a_act][(result_start + b_row) /
130 for (
int b_col = 0; b_col < b_col_sample; b_col += NUM_MACC_INSTS) {
138 0, 0, 0, 0, 0, 0, 0, 0
141 v8fp_t a_reg[NUM_MACC_INSTS];
143 for (
int a_vec = 0; a_vec < NUM_MACC_INSTS; a_vec++) {
146 a_col >= a_width_vec ? zero : _a[a_act][a_col];
150 for (
int pe_id = 0; pe_id < NUM_PE_INSTS; pe_id++) {
151 v8fp_t b_reg[NUM_MACC_INSTS];
153 for (
int macc_idx = 0; macc_idx < NUM_MACC_INSTS;
155 int pe_row = b_row + pe_id;
156 int this_b_col = b_col + macc_idx;
158 (pe_row >= b_height ||
159 this_b_col >= b_width_vec)
161 : _b[pe_row][b_col + macc_idx];
164 v8fp_t product_reg[NUM_MACC_INSTS];
166 for (
int macc_idx = 0; macc_idx < NUM_MACC_INSTS;
168 product_reg[macc_idx] =
169 a_reg[macc_idx] * b_reg[macc_idx];
172 v8fp_t accum_vec_reg = zero;
174 for (
int macc_idx = 0; macc_idx < NUM_MACC_INSTS;
176 accum_vec_reg += product_reg[macc_idx];
181 for (
int vec_i = 0; vec_i <
VECTOR_SIZE; vec_i++) {
182 accum_reg += accum_vec_reg[vec_i];
184 partial_sums_inner[pe_id] += accum_reg;
187 for (
int i = 0; i < NUM_PE_INSTS; i++) {
188 partial_sums[i] += partial_sums_inner[i];
192 int next_b_row = b_row + NUM_PE_INSTS;
193 if (next_b_row %
VECTOR_SIZE == 0 || next_b_row >= b_height) {
194 _results[a_act][(result_start + b_row) /
VECTOR_SIZE] =
200 if (act_function != NO_ACTIVATION && send_results) {
202 results, results, results_size, act_function, act_params);