1 #include "smaug/core/backend.h"
3 #include "smaug/operators/greater_op.h"
13 void ref_greater(
float* input0,
float* input1,
bool* results,
int input_size) {
14 dmaLoad(input0, input0, input_size *
sizeof(
float));
15 dmaLoad(input1, input1, input_size *
sizeof(
float));
17 for (
int i = 0; i < input_size; i++) {
18 results[i] = input0[i] > input1[i];
20 dmaStore(results, results, input_size *
sizeof(
bool));
31 dmaLoad(input0, input0, input_size *
sizeof(
float));
32 dmaLoad(input1, input1, input_size *
sizeof(
float));
34 for (
int i = 0; i < input_size; i++) {
35 results[i] = input0[i] >= input1[i];
37 dmaStore(results, results, input_size *
sizeof(
bool));
47 void GreaterOp<ReferenceBackend>::run() {
48 auto input0 = getInput(Input0);
49 auto input1 = getInput(Input1);
50 auto output = getOutput(Outputs);
51 const TensorShape& input0Shape = input0->getShape();
52 const TensorShape& input1Shape = input1->getShape();
53 const TensorShape& outputShape = output->getShape();
54 assert(input0Shape == input1Shape && input0Shape == outputShape);
56 float* input0Data = input0->data<
float>();
57 float* input1Data = input1->data<
float>();
58 bool* outputData = output->data<
bool>();
60 input0Shape.storageSize() *
sizeof(
float));
62 input1Shape.storageSize() *
sizeof(
float));
64 outputShape.storageSize() *
sizeof(
bool));
66 outputData, input0Shape.size());
70 void GreaterEqualOp<ReferenceBackend>::run() {
71 auto input0 = getInput(Input0);
72 auto input1 = getInput(Input1);
73 auto output = getOutput(Outputs);
74 const TensorShape& input0Shape = input0->getShape();
75 const TensorShape& input1Shape = input1->getShape();
76 const TensorShape& outputShape = output->getShape();
77 assert(input0Shape == input1Shape && input0Shape == outputShape);
79 float* input0Data = input0->data<
float>();
80 float* input1Data = input1->data<
float>();
81 bool* outputData = output->data<
bool>();
83 input0Shape.storageSize() *
sizeof(
float));
85 input1Shape.storageSize() *
sizeof(
float));
87 outputShape.storageSize() *
sizeof(
bool));
89 outputData, input0Shape.size());