1 #include "smaug/core/backend.h"
3 #include "smaug/operators/batch_norm_op.h"
4 #include "smaug/operators/ref/ref_activation_fun_op.h"
5 #include "smaug/utility/debug_stream.h"
28 float scale = recip_sqrt_var * gamma;
29 float shift = input - mean;
30 return shift * scale + beta;
51 int inputs_size = input_nums * (input_size + input_pad);
52 int kernel_size = inputs_size;
53 int result_size = inputs_size;
54 dmaLoad(inputs, inputs, inputs_size *
sizeof(
float));
55 dmaLoad(mean, mean, kernel_size *
sizeof(
float));
56 dmaLoad(variance, variance, kernel_size *
sizeof(
float));
57 dmaLoad(gamma, gamma, kernel_size *
sizeof(
float));
58 dmaLoad(beta, beta, kernel_size *
sizeof(
float));
60 ARRAY_2D(
float, _inputs, inputs, input_size + input_pad);
61 ARRAY_2D(
float, _result, result, input_size + input_pad);
64 for (
int i = 0; i < input_nums; i++) {
66 for (
int j = 0; j < input_size; j++) {
68 _inputs[i][j], mean[j], variance[j], gamma[j], beta[j]);
71 if (act_function != NO_ACTIVATION) {
72 activation_fun(result, result, result_size, act_function, act_params);
74 dmaStore(result, result, result_size *
sizeof(
float));
99 int input_size = img_nums * img_chans * img_rows * (img_cols + img_pad);
100 int kernel_size = img_chans;
101 int result_size = input_size;
102 dmaLoad(inputs, inputs, input_size *
sizeof(
float));
103 dmaLoad(mean, mean, kernel_size *
sizeof(
float));
104 dmaLoad(variance, variance, kernel_size *
sizeof(
float));
105 dmaLoad(gamma, gamma, kernel_size *
sizeof(
float));
106 dmaLoad(beta, beta, kernel_size *
sizeof(
float));
108 ARRAY_4D(
float, _inputs, inputs, img_chans, img_rows, img_cols + img_pad);
109 ARRAY_4D(
float, _result, result, img_chans, img_rows, img_cols + img_pad);
112 for (
int i = 0; i < img_nums; i++) {
114 for (
int h = 0; h < img_chans; h++) {
115 float mean_val = mean[h];
116 float recip_sqrt_var_val = variance[h];
117 float gamma_val = gamma[h];
118 float beta_val = beta[h];
121 for (
int r = 0; r < img_rows; r++) {
123 for (
int c = 0; c < img_cols; c++) {
133 if (act_function != NO_ACTIVATION) {
134 activation_fun(result, result, result_size, act_function, act_params);
136 dmaStore(result, result, result_size *
sizeof(
float));
161 int input_size = img_nums * img_rows * img_cols * (img_chans + img_pad);
162 int kernel_size = img_chans;
163 int result_size = input_size;
164 dmaLoad(inputs, inputs, input_size *
sizeof(
float));
165 dmaLoad(mean, mean, kernel_size *
sizeof(
float));
166 dmaLoad(variance, variance, kernel_size *
sizeof(
float));
167 dmaLoad(gamma, gamma, kernel_size *
sizeof(
float));
168 dmaLoad(beta, beta, kernel_size *
sizeof(
float));
170 ARRAY_4D(
float, _inputs, inputs, img_rows, img_cols, img_chans + img_pad);
171 ARRAY_4D(
float, _result, result, img_rows, img_cols, img_chans + img_pad);
174 for (
int i = 0; i < img_nums; i++) {
176 for (
int h = 0; h < img_chans; h++) {
177 float mean_val = mean[h];
178 float recip_sqrt_var_val = variance[h];
179 float gamma_val = gamma[h];
180 float beta_val = beta[h];
182 for (
int r = 0; r < img_rows; r++) {
184 for (
int c = 0; c < img_cols; c++) {
194 if (act_function != NO_ACTIVATION) {
195 activation_fun(result, result, result_size, act_function, act_params);
197 dmaStore(result, result, result_size *
sizeof(
float));
207 void BatchNormOp<ReferenceBackend>::run() {
208 auto input = getInput(Inputs);
209 auto mean = getInput(Mean);
210 auto variance = getInput(Variance);
211 auto gamma = getInput(Gamma);
212 auto beta = getInput(Beta);
213 auto output = getOutput(Outputs);
214 const TensorShape& inputShape = input->getShape();
215 const TensorShape& kernelShape = mean->getShape();
216 const TensorShape& outputShape = output->getShape();
217 bool isPostConv = (input->ndims() == 4);
218 dout(2) << *mean <<
"\n";
219 dout(2) << *variance<<
"\n";
220 dout(2) << *gamma <<
"\n";
221 dout(2) << *beta <<
"\n";
223 float* inputData = input->data<
float>();
224 float* meanData = mean->data<
float>();
225 float* varianceData = variance->data<
float>();
226 float* gammaData = gamma->data<
float>();
227 float* betaData = beta->data<
float>();
228 float* outputData = output->data<
float>();
230 inputShape.storageSize() *
sizeof(
float));
232 kernelShape.storageSize() *
sizeof(
float));
234 kernelShape.storageSize() *
sizeof(
float));
236 kernelShape.storageSize() *
sizeof(
float));
238 kernelShape.storageSize() *
sizeof(
float));
240 outputShape.storageSize() *
sizeof(
float));
242 bool isNCHW = input->getShape().getLayout() == NCHW;
245 invokeKernel(ref::kBatchNormHw, func, inputData, meanData, varianceData,
246 gammaData, betaData, outputData, inputShape[0],
247 inputShape[1], inputShape[2], inputShape[3],
248 inputShape.getPadding(3), kernelShape.getPadding(3),
249 actInfo.function, actInfo.params);
251 assert(inputShape.getLayout() == DataLayout::NC);
252 assert(outputShape.getLayout() == DataLayout::NC);
254 meanData, varianceData, gammaData, betaData, outputData,
255 inputShape[0], inputShape[1], inputShape.getPadding(1),
256 actInfo.function, actInfo.params);