@@ -52,7 +52,7 @@ Result<Tag> get_tag(
52
52
}
53
53
}
54
54
55
- size_t calculate_nbytes (
55
+ Result< size_t > calculate_nbytes (
56
56
Span<const int32_t > sizes,
57
57
executorch::aten::ScalarType scalar_type) {
58
58
size_t n = 1 ;
@@ -61,7 +61,13 @@ size_t calculate_nbytes(
61
61
prev_n = n;
62
62
n *= sizes[i];
63
63
// Check for overflow
64
- ET_CHECK (sizes[i] == 0 || n / sizes[i] == prev_n);
64
+ ET_CHECK_OR_RETURN_ERROR (
65
+ sizes[i] == 0 || n / sizes[i] == prev_n,
66
+ InvalidArgument,
67
+ " Invalid size[%zu]: %d. Potentially overflowed, expect to be 0 or prev_n: %zu" ,
68
+ i,
69
+ sizes[i],
70
+ prev_n);
65
71
}
66
72
67
73
size_t elem_size = executorch::runtime::elementSize (scalar_type);
@@ -70,25 +76,47 @@ size_t calculate_nbytes(
70
76
n = n * elem_size;
71
77
72
78
// Check for overflow
73
- ET_CHECK (elem_size == 0 || n / elem_size == prev_n);
79
+ ET_CHECK_OR_RETURN_ERROR (
80
+ elem_size == 0 || n / elem_size == prev_n,
81
+ InvalidArgument,
82
+ " Invalid elem_size: %zu. Potentially overflowed, expect to be 0 or prev_n: %zu" ,
83
+ elem_size,
84
+ prev_n);
74
85
75
86
return n;
76
87
}
77
88
78
89
} // namespace
79
90
91
+ /* static*/ Result<TensorInfo> TensorInfo::create (
92
+ Span<const int32_t > sizes,
93
+ Span<const uint8_t > dim_order,
94
+ executorch::aten::ScalarType scalar_type,
95
+ const bool is_memory_planned,
96
+ std::string_view name) {
97
+ auto nbytes = calculate_nbytes (sizes, scalar_type);
98
+ ET_CHECK_OR_RETURN_ERROR (
99
+ nbytes.ok (),
100
+ InvalidArgument,
101
+ " Failed to calculate nbytes for TensorInfo" );
102
+
103
+ return TensorInfo (
104
+ sizes, dim_order, scalar_type, is_memory_planned, name, nbytes.get ());
105
+ }
106
+
80
107
TensorInfo::TensorInfo (
81
108
Span<const int32_t > sizes,
82
109
Span<const uint8_t > dim_order,
83
110
executorch::aten::ScalarType scalar_type,
84
111
const bool is_memory_planned,
85
- std::string_view name)
112
+ std::string_view name,
113
+ size_t nbytes)
86
114
: sizes_(sizes),
87
115
dim_order_ (dim_order),
88
116
name_(name),
89
117
scalar_type_(scalar_type),
90
118
is_memory_planned_(is_memory_planned),
91
- nbytes_(calculate_nbytes(sizes_, scalar_type_) ) {}
119
+ nbytes_(nbytes ) {}
92
120
93
121
Span<const int32_t > TensorInfo::sizes () const {
94
122
return sizes_;
@@ -160,7 +188,7 @@ Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const {
160
188
auto input_index = s_plan_->inputs ()->Get (index);
161
189
// input_index was already validated by input_tag().
162
190
auto tensor_value = s_plan_->values ()->Get (input_index)->val_as_Tensor ();
163
- return TensorInfo (
191
+ return TensorInfo::create (
164
192
Span<const int32_t >(
165
193
tensor_value->sizes ()->data (), tensor_value->sizes ()->size ()),
166
194
Span<const uint8_t >(
@@ -212,7 +240,7 @@ Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const {
212
240
// output_index was already validated by output_tag().
213
241
auto tensor_value = s_plan_->values ()->Get (output_index)->val_as_Tensor ();
214
242
215
- return TensorInfo (
243
+ return TensorInfo::create (
216
244
Span<const int32_t >(
217
245
tensor_value->sizes ()->data (), tensor_value->sizes ()->size ()),
218
246
Span<const uint8_t >(
@@ -255,7 +283,7 @@ Result<TensorInfo> MethodMeta::attribute_tensor_meta(size_t index) const {
255
283
auto t_name =
256
284
tensor_value->extra_tensor_info ()->fully_qualified_name ();
257
285
// Count constant returns as memory planned
258
- return TensorInfo (
286
+ return TensorInfo::create (
259
287
Span<const int32_t >(
260
288
tensor_value->sizes ()->data (), tensor_value->sizes ()->size ()),
261
289
Span<const uint8_t >(
0 commit comments