SMAUG
Simulating Machine Learning Applications on gem5-Aladdin
load_store_fp16_data.c
2 
3 #ifdef __cplusplus
4 extern "C" {
5 #endif
6 
7 void host_load_fp16(float* local_data,
8  float16* remote_data,
9  int num_elems,
10  int local_offset,
11  int remote_offset) {
12  VEC_ARRAY_1D(v8ph_t, _local_data_hp, local_data);
13  VEC_ARRAY_1D(v8fp_t, _local_data_sp, local_data);
14  const int page_size = (1 << LOG_PAGE_SIZE);
15  const int max_transfer_size = page_size;
16  const int total_bytes =
17  next_multiple(num_elems * sizeof(float16), CACHELINE_SIZE);
18  int num_xfers = FRAC_CEIL(total_bytes, max_transfer_size);
19  int num_bytes_remaining = total_bytes;
20  host_fp16_to_fp32:
21  for (int i = 0; i < num_xfers; i++) {
22  int transfer_size = min2(num_bytes_remaining, max_transfer_size);
23  int curr_offset = (i * page_size * 2) / sizeof(float);
24  hostLoad(local_data + local_offset + curr_offset,
25  remote_data + remote_offset + curr_offset,
26  transfer_size);
27 
28  // This loads N bytes of FP16 data into local_data. We now expand
29  // N bytes of half precision to 2*N bytes of single precision, in
30  // place, 32 bytes at a time. In order to do this without overwriting
31  // the data we're trying to unpack, we need to start from the back.
32  int num_vectors =
33  FRAC_CEIL(transfer_size * 2, VECTOR_SIZE * sizeof(float));
34  int page_offset_vec = (local_offset + curr_offset) / VECTOR_SIZE;
35  vector_fp16_to_fp32:
36  for (int v = num_vectors - 1; v >= 0; v--) {
37  v8ph_t fp16_data = _local_data_hp[page_offset_vec * 2 + v];
38  v8fp_t fp32_data = _CVT_PH_PS_256(fp16_data);
39  _local_data_sp[page_offset_vec + v] = fp32_data;
40  }
41  num_bytes_remaining -= transfer_size;
42  }
43 }
44 
45 void host_store_fp16(float* local_data,
46  float16* remote_data,
47  int num_elems,
48  int local_offset,
49  int remote_offset) {
50  VEC_ARRAY_1D(v8ph_t, _local_data_hp, local_data);
51  VEC_ARRAY_1D(v8fp_t, _local_data_sp, local_data);
52  const int page_size = (1 << LOG_PAGE_SIZE);
53  const int max_transfer_size = page_size;
54  const int total_bytes =
55  next_multiple(num_elems * sizeof(float16), CACHELINE_SIZE);
56  int num_xfers = FRAC_CEIL(total_bytes, max_transfer_size);
57  int num_bytes_remaining = total_bytes;
58  host_fp32_to_fp16:
59  for (int i = 0; i < num_xfers; i++) {
60  int transfer_size = min2(num_bytes_remaining, max_transfer_size);
61  // The effective transfer size is the size in terms of FP32.
62  int eff_transfer_size = transfer_size * 2;
63  int curr_offset = (i * 2 * page_size) / sizeof(float);
64 
65  int num_vectors =
66  FRAC_CEIL(eff_transfer_size, VECTOR_SIZE * sizeof(float));
67  int page_offset_vec = (local_offset + curr_offset) / VECTOR_SIZE;
68  vector_fp32_to_fp16:
69  for (int v = 0; v < num_vectors; v++){
70  v8fp_t fp32_data = _local_data_sp[page_offset_vec + v];
71  v8ph_t fp16_data = _CVT_PS_PH_256(fp32_data, 0);
72  _local_data_hp[page_offset_vec * 2 + v] = fp16_data;
73  }
74 
75  hostStore(remote_data + remote_offset + curr_offset,
76  local_data + local_offset + curr_offset,
77  transfer_size);
78 
79  num_bytes_remaining -= transfer_size;
80  }
81 }
82 
83 #ifdef __cplusplus
84 } // extern "C"
85 #endif
host_store_fp16
void host_store_fp16(float *local_data, float16 *remote_data, int num_elems, int local_offset, int remote_offset)
Definition: load_store_fp16_data.c:45
host_load_fp16
void host_load_fp16(float *local_data, float16 *remote_data, int num_elems, int local_offset, int remote_offset)
Definition: load_store_fp16_data.c:7
FRAC_CEIL
#define FRAC_CEIL(A, B)
Implements the ceiling function of A/B.
Definition: common.h:505
v8ph_t
fp16_t v8ph_t
8 packed 16-bit floating point values.
Definition: common.h:311
v8fp_t
fp_t v8fp_t
8 packed 32-bit floating point values.
Definition: common.h:301
load_store_fp16_data.h
Aladdin kernels to load/store FP16 data to/from host memory.
VECTOR_SIZE
#define VECTOR_SIZE
Vector size used in SMV backends.
Definition: common.h:293
next_multiple
size_t next_multiple(size_t request, size_t align)
Returns the smallest multiple of align that is >= request.
Definition: common.cpp:36