26 for (
int index : indices_) {
27 if (
index < 0)
return false;
30 if (indices_[d] >
MaxIndexOfDim(static_cast<FlexDimensions>(d)))
38 return MaxIndexOfDim(dimension) == indices_[dimension];
44 int max_index = stride_map_->shape_[dim] - 1;
45 if (dim ==
FD_BATCH)
return max_index;
47 const size_t batch = indices_[
FD_BATCH];
49 if (batch >= stride_map_->heights_.size() ||
50 stride_map_->heights_[batch] > max_index)
52 return stride_map_->heights_[batch] - 1;
54 if (batch >= stride_map_->widths_.size() ||
55 stride_map_->widths_[batch] > max_index)
57 return stride_map_->widths_[batch] - 1;
63 indices_[dimension] += offset;
72 if (!IsLast(static_cast<FlexDimensions>(d))) {
73 t_ += stride_map_->t_increments_[d];
77 t_ -= stride_map_->t_increments_[d] * indices_[d];
89 if (indices_[d] > 0) {
94 InitToLastOfBatch(indices_[
FD_BATCH]);
96 t_ -= stride_map_->t_increments_[d];
100 indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
101 t_ += stride_map_->t_increments_[d] * indices_[d];
109 void StrideMap::Index::InitToLastOfBatch(
int batch) {
112 indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
118 void StrideMap::Index::SetTFromIndices() {
121 t_ += stride_map_->t_increments_[d] * indices_[d];
129 for (
const std::pair<int, int>& hw : h_w_pairs) {
130 int height = hw.first;
131 int width = hw.second;
132 heights_.push_back(height);
133 widths_.push_back(width);
134 if (height > max_height) max_height = height;
135 if (width > max_width) max_width = width;
140 ComputeTIncrements();
145 for (
int& height : heights_) height /= y_factor;
146 for (
int& width : widths_) width /= x_factor;
149 ComputeTIncrements();
154 widths_.assign(widths_.size(), 1);
156 ComputeTIncrements();
162 std::swap(heights_, widths_);
163 ComputeTIncrements();
167 void StrideMap::ComputeTIncrements() {
170 t_increments_[d] = t_increments_[d + 1] * shape_[d + 1];