12 #include <tensorflow/compiler/xla/array2d.h>
24 class StridemapTest :
public ::testing::Test {
27 std::locale::global(std::locale(
""));
32 std::unique_ptr<xla::Array2D<int>> SetupArray(
int ysize,
int xsize,
int start) {
33 std::unique_ptr<xla::Array2D<int>> a(
new xla::Array2D<int>(ysize, xsize));
35 for (
int y = 0; y < ysize; ++y) {
36 for (
int x = 0; x < xsize; ++x) {
44 TEST_F(StridemapTest, Indexing) {
47 std::vector<std::unique_ptr<xla::Array2D<int>>> arrays;
48 arrays.push_back(SetupArray(3, 4, 0));
49 arrays.push_back(SetupArray(4, 5, 12));
50 arrays.push_back(SetupArray(4, 4, 32));
51 arrays.push_back(SetupArray(3, 5, 48));
52 std::vector<std::pair<int, int>> h_w_sizes;
53 for (
size_t i = 0; i < arrays.size(); ++i) {
54 h_w_sizes.emplace_back(arrays[i].get()->height(), arrays[i].get()->width());
58 StrideMap::Index index(stride_map);
61 EXPECT_GE(index.t(), pos);
66 index.index(
FD_BATCH) == arrays.size() - 1);
73 EXPECT_TRUE(index.IsValid());
75 }
while (index.Increment());
80 EXPECT_GE(index.t(), pos);
84 StrideMap::Index copy(index);
89 EXPECT_FALSE(copy.AddOffset(1,
FD_BATCH));
97 EXPECT_FALSE(copy.AddOffset(-1,
FD_BATCH));
104 EXPECT_FALSE(copy.AddOffset(10,
FD_WIDTH));
106 EXPECT_FALSE(copy.AddOffset(-10,
FD_HEIGHT));
107 EXPECT_TRUE(index.IsValid());
108 }
while (index.Decrement());
111 TEST_F(StridemapTest, Scaling) {
114 std::vector<std::unique_ptr<xla::Array2D<int>>> arrays;
115 arrays.push_back(SetupArray(3, 4, 0));
116 arrays.push_back(SetupArray(4, 5, 12));
117 arrays.push_back(SetupArray(4, 4, 32));
118 arrays.push_back(SetupArray(3, 5, 48));
119 std::vector<std::pair<int, int>> h_w_sizes;
120 for (
size_t i = 0; i < arrays.size(); ++i) {
121 h_w_sizes.emplace_back(arrays[i].get()->height(), arrays[i].get()->width());
127 std::vector<int> values_x2 = {0, 1, 4, 5, 8, 9, 12, 13, 17, 18,
128 22, 23, 27, 28, 32, 33, 36, 37, 40, 41,
129 44, 45, 48, 49, 53, 54, 58, 59};
131 test_map.ScaleXY(2, 1);
132 StrideMap::Index index(test_map);
135 int expected_value = values_x2[pos++];
139 }
while (index.Increment());
140 EXPECT_EQ(pos, values_x2.size());
142 test_map = stride_map;
144 std::vector<int> values_y2 = {0, 1, 2, 3, 12, 13, 14, 15, 16,
145 17, 18, 19, 20, 21, 32, 33, 34, 35,
146 36, 37, 38, 39, 48, 49, 50, 51, 52};
151 int expected_value = values_y2[pos++];
155 }
while (index.Increment());
156 EXPECT_EQ(pos, values_y2.size());
158 test_map = stride_map;
160 std::vector<int> values_xy2 = {0, 1, 12, 13, 17, 18, 32, 33, 36, 37, 48, 49};
165 int expected_value = values_xy2[pos++];
169 }
while (index.Increment());
170 EXPECT_EQ(pos, values_xy2.size());
172 test_map = stride_map;
174 std::vector<int> values_x_to_1 = {0, 4, 8, 12, 17, 22, 27,
175 32, 36, 40, 44, 48, 53, 58};
180 int expected_value = values_x_to_1[pos++];
184 }
while (index.Increment());
185 EXPECT_EQ(pos, values_x_to_1.size());