28 if (indices_[d] < 0)
return false;
31 if (indices_[d] >
MaxIndexOfDim(static_cast<FlexDimensions>(d)))
39 return MaxIndexOfDim(dimension) == indices_[dimension];
45 int max_index = stride_map_->shape_[dim] - 1;
46 if (dim ==
FD_BATCH)
return max_index;
48 const size_t batch = indices_[
FD_BATCH];
50 if (batch >= stride_map_->heights_.size() ||
51 stride_map_->heights_[batch] > max_index)
53 return stride_map_->heights_[batch] - 1;
55 if (batch >= stride_map_->widths_.size() ||
56 stride_map_->widths_[batch] > max_index)
58 return stride_map_->widths_[batch] - 1;
64 indices_[dimension] += offset;
73 if (!IsLast(static_cast<FlexDimensions>(d))) {
74 t_ += stride_map_->t_increments_[d];
78 t_ -= stride_map_->t_increments_[d] * indices_[d];
90 if (indices_[d] > 0) {
95 InitToLastOfBatch(indices_[
FD_BATCH]);
97 t_ -= stride_map_->t_increments_[d];
101 indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
102 t_ += stride_map_->t_increments_[d] * indices_[d];
110 void StrideMap::Index::InitToLastOfBatch(
int batch) {
113 indices_[d] = MaxIndexOfDim(static_cast<FlexDimensions>(d));
119 void StrideMap::Index::SetTFromIndices() {
122 t_ += stride_map_->t_increments_[d] * indices_[d];
130 for (
const std::pair<int, int>& hw : h_w_pairs) {
131 int height = hw.first;
132 int width = hw.second;
133 heights_.push_back(height);
134 widths_.push_back(width);
135 if (height > max_height) max_height = height;
136 if (width > max_width) max_width = width;
141 ComputeTIncrements();
146 for (
int& height : heights_) height /= y_factor;
147 for (
int& width : widths_) width /= x_factor;
150 ComputeTIncrements();
155 widths_.assign(widths_.size(), 1);
157 ComputeTIncrements();
163 std::swap(heights_, widths_);
164 ComputeTIncrements();
168 void StrideMap::ComputeTIncrements() {
171 t_increments_[d] = t_increments_[d + 1] * shape_[d + 1];
void SetStride(const std::vector< std::pair< int, int >> &h_w_pairs)
bool AddOffset(int offset, FlexDimensions dimension)
bool IsLast(FlexDimensions dimension) const
int MaxIndexOfDim(FlexDimensions dim) const
void ScaleXY(int x_factor, int y_factor)