Skip to content

Commit f32a7dc

Browse files
authored
Bugfix __eq__ for numpy data types (#8646)
* [Python] Sync PythonTest.sh flags with generate_code.py * [Python] Update generated code to latest flatc version for tests * [Python] Fix test support for numpy newer than 2.0.0 * [Python] Remove unused variable * [Python] Fix __eq__ for numpy arrays * [Python] Run clang-format over the entire file
1 parent 860d645 commit f32a7dc

File tree

5 files changed

+108
-88
lines changed

5 files changed

+108
-88
lines changed

src/idl_gen_python.cpp

Lines changed: 54 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ class PythonStubGenerator {
5252
public:
5353
PythonStubGenerator(const Parser &parser, const std::string &path,
5454
const Version &version)
55-
: parser_{parser},
56-
namer_{WithFlagOptions(kStubConfig, parser.opts, path),
57-
Keywords(version)},
55+
: parser_{ parser },
56+
namer_{ WithFlagOptions(kStubConfig, parser.opts, path),
57+
Keywords(version) },
5858
version_(version) {}
5959

6060
bool Generate() {
@@ -140,8 +140,7 @@ class PythonStubGenerator {
140140
return module;
141141
}
142142

143-
template <typename T>
144-
std::string ModuleFor(const T *def) const {
143+
template<typename T> std::string ModuleFor(const T *def) const {
145144
if (parser_.opts.one_file) return ModuleForFile(def->file);
146145
return namer_.NamespacedType(*def);
147146
}
@@ -165,7 +164,7 @@ class PythonStubGenerator {
165164
return "None";
166165
}
167166

168-
template <typename F>
167+
template<typename F>
169168
std::string UnionType(const EnumDef &enum_def, Imports *imports,
170169
F type) const {
171170
imports->Import("typing");
@@ -181,14 +180,9 @@ class PythonStubGenerator {
181180
result += import.name;
182181
break;
183182
}
184-
case BASE_TYPE_STRING:
185-
result += "str";
186-
break;
187-
case BASE_TYPE_NONE:
188-
result += "None";
189-
break;
190-
default:
191-
break;
183+
case BASE_TYPE_STRING: result += "str"; break;
184+
case BASE_TYPE_NONE: result += "None"; break;
185+
default: break;
192186
}
193187
}
194188
return "typing.Union[" + result + "]";
@@ -229,18 +223,14 @@ class PythonStubGenerator {
229223
namer_.Type(*type.struct_def));
230224
return import.name;
231225
}
232-
case BASE_TYPE_STRING:
233-
return "str";
226+
case BASE_TYPE_STRING: return "str";
234227
case BASE_TYPE_ARRAY:
235228
case BASE_TYPE_VECTOR: {
236229
imports->Import("typing");
237230
return "typing.List[" + TypeOf(type.VectorType(), imports) + "]";
238231
}
239-
case BASE_TYPE_UNION:
240-
return UnionType(*type.enum_def, imports);
241-
default:
242-
FLATBUFFERS_ASSERT(0);
243-
return "";
232+
case BASE_TYPE_UNION: return UnionType(*type.enum_def, imports);
233+
default: FLATBUFFERS_ASSERT(0); return "";
244234
}
245235
}
246236

@@ -262,8 +252,7 @@ class PythonStubGenerator {
262252
namer_.ObjectType(*field_type.struct_def));
263253
return field_name + ": " + import.name + " | None";
264254
}
265-
case BASE_TYPE_STRING:
266-
return field_name + ": str | None";
255+
case BASE_TYPE_STRING: return field_name + ": str | None";
267256
case BASE_TYPE_ARRAY:
268257
case BASE_TYPE_VECTOR: {
269258
imports->Import("typing");
@@ -282,8 +271,7 @@ class PythonStubGenerator {
282271
case BASE_TYPE_UNION:
283272
return field_name + ": " +
284273
UnionObjectType(*field->value.type.enum_def, imports);
285-
default:
286-
return field_name;
274+
default: return field_name;
287275
}
288276
}
289277

@@ -312,9 +300,7 @@ class PythonStubGenerator {
312300
field_type = "'" + import_.name + "' | None";
313301
break;
314302
}
315-
case BASE_TYPE_STRING:
316-
field_type = "str | None";
317-
break;
303+
case BASE_TYPE_STRING: field_type = "str | None"; break;
318304
case BASE_TYPE_ARRAY:
319305
case BASE_TYPE_VECTOR: {
320306
imports->Import("typing");
@@ -334,9 +320,7 @@ class PythonStubGenerator {
334320
case BASE_TYPE_UNION:
335321
field_type = UnionObjectType(*type.enum_def, imports);
336322
break;
337-
default:
338-
field_type = "typing.Any";
339-
break;
323+
default: field_type = "typing.Any"; break;
340324
}
341325
}
342326
stub << " " << field_name << ": " << field_type << " = ...,\n";
@@ -485,8 +469,7 @@ class PythonStubGenerator {
485469
stub << " def " << name << "(self) -> table.Table | None: ...\n";
486470
break;
487471
}
488-
default:
489-
break;
472+
default: break;
490473
}
491474
}
492475
}
@@ -530,9 +513,7 @@ class PythonStubGenerator {
530513
stub << '\n';
531514
stub << "def Create" + namer_.Type(*struct_def)
532515
<< "(builder: flatbuffers.Builder";
533-
for (const std::string &arg : args) {
534-
stub << ", " << arg;
535-
}
516+
for (const std::string &arg : args) { stub << ", " << arg; }
536517
stub << ") -> uoffset: ...\n";
537518
}
538519

@@ -610,11 +591,10 @@ class PythonStubGenerator {
610591

611592
imports->Import("typing", "cast");
612593

613-
if (version_.major == 3){
594+
if (version_.major == 3) {
614595
imports->Import("enum", "IntEnum");
615596
stub << "(IntEnum)";
616-
}
617-
else {
597+
} else {
618598
stub << "(object)";
619599
}
620600

@@ -637,16 +617,15 @@ class PythonStubGenerator {
637617
ss << "from __future__ import annotations\n";
638618
ss << '\n';
639619
ss << "import flatbuffers\n";
640-
if (parser_.opts.python_gen_numpy) {
641-
ss << "import numpy as np\n";
642-
}
620+
if (parser_.opts.python_gen_numpy) { ss << "import numpy as np\n"; }
643621
ss << '\n';
644622

645623
std::set<std::string> modules;
646624
std::map<std::string, std::set<std::string>> names_by_module;
647625
for (const Import &import : imports.imports) {
648626
if (import.IsLocal()) continue; // skip all local imports
649-
if (import.module == "flatbuffers" && import.name == "") continue; // skip double include hardcoded flatbuffers
627+
if (import.module == "flatbuffers" && import.name == "")
628+
continue; // skip double include hardcoded flatbuffers
650629
if (import.name == "") {
651630
modules.insert(import.module);
652631
} else {
@@ -686,7 +665,8 @@ class PythonStubGenerator {
686665
const Parser &parser_;
687666
const IdlNamer namer_;
688667
const Version version_;
689-
};} // namespace
668+
};
669+
} // namespace
690670

691671
class PythonGenerator : public BaseGenerator {
692672
public:
@@ -695,8 +675,8 @@ class PythonGenerator : public BaseGenerator {
695675
: BaseGenerator(parser, path, file_name, "" /* not used */,
696676
"" /* not used */, "py"),
697677
float_const_gen_("float('nan')", "float('inf')", "float('-inf')"),
698-
namer_(WithFlagOptions(kConfig, parser.opts, path),
699-
Keywords(version)) {}
678+
namer_(WithFlagOptions(kConfig, parser.opts, path), Keywords(version)) {
679+
}
700680

701681
// Most field accessors need to retrieve and test the field offset first,
702682
// this is the prefix code for that.
@@ -886,9 +866,8 @@ class PythonGenerator : public BaseGenerator {
886866
GenReceiver(struct_def, code_ptr);
887867
code += namer_.Method(field);
888868

889-
const ImportMapEntry import_entry = {
890-
GenPackageReference(field.value.type), TypeName(field)
891-
};
869+
const ImportMapEntry import_entry = { GenPackageReference(field.value.type),
870+
TypeName(field) };
892871

893872
if (parser_.opts.python_typing) {
894873
const std::string return_type = ReturnType(struct_def, field);
@@ -948,9 +927,8 @@ class PythonGenerator : public BaseGenerator {
948927
GenReceiver(struct_def, code_ptr);
949928
code += namer_.Method(field) + "(self)";
950929

951-
const ImportMapEntry import_entry = {
952-
GenPackageReference(field.value.type), TypeName(field)
953-
};
930+
const ImportMapEntry import_entry = { GenPackageReference(field.value.type),
931+
TypeName(field) };
954932

955933
if (parser_.opts.python_typing) {
956934
const std::string return_type = ReturnType(struct_def, field);
@@ -1036,11 +1014,8 @@ class PythonGenerator : public BaseGenerator {
10361014
code += Indent + Indent + "return None\n\n";
10371015
}
10381016

1039-
template <typename T>
1040-
std::string ModuleFor(const T *def) const {
1041-
if (!parser_.opts.one_file) {
1042-
return namer_.NamespacedType(*def);
1043-
}
1017+
template<typename T> std::string ModuleFor(const T *def) const {
1018+
if (!parser_.opts.one_file) { return namer_.NamespacedType(*def); }
10441019

10451020
std::string filename =
10461021
StripExtension(def->file) + parser_.opts.filename_suffix;
@@ -1070,9 +1045,8 @@ class PythonGenerator : public BaseGenerator {
10701045

10711046
GenReceiver(struct_def, code_ptr);
10721047
code += namer_.Method(field);
1073-
const ImportMapEntry import_entry = {
1074-
GenPackageReference(field.value.type), TypeName(field)
1075-
};
1048+
const ImportMapEntry import_entry = { GenPackageReference(field.value.type),
1049+
TypeName(field) };
10761050

10771051
if (parser_.opts.python_typing) {
10781052
const std::string return_type = ReturnType(struct_def, field);
@@ -1195,8 +1169,7 @@ class PythonGenerator : public BaseGenerator {
11951169
std::string qualified_name = NestedFlatbufferType(unqualified_name);
11961170
if (qualified_name.empty()) { qualified_name = nested->constant; }
11971171

1198-
const ImportMapEntry import_entry = { qualified_name,
1199-
unqualified_name };
1172+
const ImportMapEntry import_entry = { qualified_name, unqualified_name };
12001173

12011174
auto &code = *code_ptr;
12021175
GenReceiver(struct_def, code_ptr);
@@ -1808,8 +1781,8 @@ class PythonGenerator : public BaseGenerator {
18081781
}
18091782
field_type = "Optional[List[" + field_type + "]";
18101783
} else {
1811-
field_type =
1812-
"Optional[List[" + GetBasePythonTypeForScalarAndString(base_type) + "]]";
1784+
field_type = "Optional[List[" +
1785+
GetBasePythonTypeForScalarAndString(base_type) + "]]";
18131786
}
18141787
}
18151788

@@ -1858,11 +1831,12 @@ class PythonGenerator : public BaseGenerator {
18581831
const auto field_field = namer_.Field(field);
18591832

18601833
// Build signature with keyword arguments, type hints, and default values.
1861-
signature_params += GenIndents(2) + field_field + " = " + default_value + ",";
1834+
signature_params +=
1835+
GenIndents(2) + field_field + " = " + default_value + ",";
18621836

18631837
// Build the body of the __init__ method.
18641838
init_body += GenIndents(2) + "self." + field_field + " = " + field_field +
1865-
" # type: " + field_type;
1839+
" # type: " + field_type;
18661840
}
18671841

18681842
// Writes __init__ method.
@@ -1954,10 +1928,16 @@ class PythonGenerator : public BaseGenerator {
19541928
auto &field = **it;
19551929
if (field.deprecated) continue;
19561930

1957-
// Wrties the comparison statement for this field.
1958-
const auto field_field = namer_.Field(field);
1959-
code += " and \\" + GenIndents(3) + "self." + field_field +
1960-
" == " + "other." + field_field;
1931+
// Writes the comparison statement for this field.
1932+
const auto field_name = namer_.Field(field);
1933+
if (parser_.opts.python_gen_numpy &&
1934+
field.value.type.base_type == BASE_TYPE_VECTOR) {
1935+
code += " and \\" + GenIndents(3) + "np.array_equal(self." +
1936+
field_name + ", " + "other." + field_name + ")";
1937+
} else {
1938+
code += " and \\" + GenIndents(3) + "self." + field_name +
1939+
" == " + "other." + field_name;
1940+
}
19611941
}
19621942
code += "\n";
19631943
}
@@ -2154,7 +2134,6 @@ class PythonGenerator : public BaseGenerator {
21542134
auto &field = **it;
21552135
if (field.deprecated) continue;
21562136

2157-
auto field_type = TypeName(field);
21582137
switch (field.value.type.base_type) {
21592138
case BASE_TYPE_STRUCT: {
21602139
GenUnPackForStruct(struct_def, field, &code);
@@ -2338,9 +2317,9 @@ class PythonGenerator : public BaseGenerator {
23382317

23392318
if (parser_.opts.python_gen_numpy) {
23402319
code_prefix += GenIndents(3) + "if np is not None and type(self." +
2341-
field_field + ") is np.ndarray:";
2320+
field_field + ") is np.ndarray:";
23422321
code_prefix += GenIndents(4) + field_field +
2343-
" = builder.CreateNumpyVector(self." + field_field + ")";
2322+
" = builder.CreateNumpyVector(self." + field_field + ")";
23442323
code_prefix += GenIndents(3) + "else:";
23452324
GenPackForScalarVectorFieldHelper(struct_def, field, code_prefix_ptr, 4);
23462325
code_prefix += "(self." + field_field + "[i])";
@@ -2788,9 +2767,7 @@ class PythonGenerator : public BaseGenerator {
27882767
}
27892768
}
27902769
}
2791-
if (parser_.opts.python_gen_numpy) {
2792-
code += "np = import_numpy()\n\n";
2793-
}
2770+
if (parser_.opts.python_gen_numpy) { code += "np = import_numpy()\n\n"; }
27942771
}
27952772
}
27962773

@@ -2828,7 +2805,7 @@ class PythonGenerator : public BaseGenerator {
28282805

28292806
static bool GeneratePython(const Parser &parser, const std::string &path,
28302807
const std::string &file_name) {
2831-
python::Version version{parser.opts.python_version};
2808+
python::Version version{ parser.opts.python_version };
28322809
if (!version.IsValid()) return false;
28332810

28342811
python::PythonGenerator generator(parser, path, file_name, version);

tests/MyGame/MonsterExtra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def __eq__(self, other):
282282
self.f1 == other.f1 and \
283283
self.f2 == other.f2 and \
284284
self.f3 == other.f3 and \
285-
self.dvec == other.dvec and \
286-
self.fvec == other.fvec
285+
np.array_equal(self.dvec, other.dvec) and \
286+
np.array_equal(self.fvec, other.fvec)
287287

288288
# MonsterExtraT
289289
def _UnPack(self, monsterExtra):

tests/PythonTest.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ ${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_test.fbs --g
2626
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_test.fbs --gen-object-api --gen-onefile
2727
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test monster_extra.fbs --gen-object-api --python-typing --gen-compare
2828
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test arrays_test.fbs --gen-object-api --python-typing
29-
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test nested_union_test.fbs --gen-object-api --python-typing
29+
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test nested_union_test.fbs --gen-object-api --python-typing --python-decode-obj-api-strings
3030
${test_dir}/../flatc -p -o ${gen_code_path} -I include_test service_test.fbs --grpc --grpc-python-typed-handlers --python-typing --no-python-gen-numpy --gen-onefile
3131

3232
# Syntax: run_tests <interpreter> <benchmark vtable dedupes>

0 commit comments

Comments
 (0)