From 66fdb2682647bd349c3cc201e48a05b78d1a1b77 Mon Sep 17 00:00:00 2001 From: Greg Brockman Date: Thu, 23 Feb 2017 12:37:57 -0800 Subject: [PATCH] Support nested definitions and fields with capitals --- lua_protobuf/generator.py | 150 +++++++++++++++++++++++++------------- protoc-gen-lua | 1 - 2 files changed, 99 insertions(+), 52 deletions(-) diff --git a/lua_protobuf/generator.py b/lua_protobuf/generator.py index 7081667..57389c2 100644 --- a/lua_protobuf/generator.py +++ b/lua_protobuf/generator.py @@ -46,6 +46,9 @@ FieldDescriptor.TYPE_SINT64: 'sint64', } +def apply_namespace(namespace, name): + return '_'.join(namespace + [name]) + def lua_protobuf_header(): '''Returns common header included by all produced files''' return ''' @@ -183,10 +186,10 @@ def package_function_prefix(package): def message_function_prefix(package, message): return '%s%s_' % (package_function_prefix(package), message) -def message_open_function_name(package, message): +def message_open_function_name(package, message, namespace=[]): '''Returns function name that registers the Lua library for a message type''' - return '%sopen' % message_function_prefix(package, message) + return '%sopen' % message_function_prefix(package, apply_namespace(namespace, message)) def cpp_class(package, message = None): '''Returns the fully qualified class name for a message type''' @@ -239,7 +242,7 @@ def has_body(package, message, field): lines = [] lines.extend(obtain_message_from_udata(package, message)) - lines.append('lua_pushboolean(L, m->has_%s());' % field) + lines.append('lua_pushboolean(L, m->has_%s());' % field.lower()) lines.append('return 1;') return lines @@ -248,7 +251,7 @@ def clear_body(package, message, field): '''Returns the function body for a clear_ function''' lines = [] lines.extend(obtain_message_from_udata(package, message)) - lines.append('m->clear_%s();' % field) + lines.append('m->clear_%s();' % field.lower()) lines.append('return 0;') return lines @@ -257,7 +260,7 @@ def size_body(package, message, field): '''Returns the function body for a size_ function''' lines = [] lines.extend(obtain_message_from_udata(package, message)) - lines.append('int size = m->%s_size();' % field) + lines.append('int size = m->%s_size();' % field.lower()) lines.append('lua_pushinteger(L, size);') lines.append('return 1;') @@ -268,7 +271,7 @@ def add_body(package, message, field, type_name): lines = [] lines.extend(obtain_message_from_udata(package, message)) lines.extend([ - '%s *msg_new = m->add_%s();' % ( cpp_class(type_name), field ), + '%s *msg_new = m->add_%s();' % ( cpp_class(type_name), field.lower() ), # since the message is allocated out of the containing message, Lua # does not need to do GC @@ -303,9 +306,9 @@ def field_get(package, message, field_descriptor): 'return luaL_error(L, "missing required numeric argument");', '}', 'lua_Integer index = luaL_checkinteger(L, 2);', - 'if (index < 1 || index > m->%s_size()) {' % name, + 'if (index < 1 || index > m->%s_size()) {' % name.lower(), # TODO is returning nil the more Lua way? - 'return luaL_error(L, "index must be between 1 and current size: %%d", m->%s_size());' % name, + 'return luaL_error(L, "index must be between 1 and current size: %%d", m->%s_size());' % name.lower(), '}', ]) @@ -315,30 +318,30 @@ def field_get(package, message, field_descriptor): if repeated: if type in [ FieldDescriptor.TYPE_STRING, FieldDescriptor.TYPE_BYTES ]: lines.extend([ - 'string s = m->%s(index - 1);' % name, + 'string s = m->%s(index - 1);' % name.lower(), 'lua_pushlstring(L, s.c_str(), s.size());', ]) elif type == FieldDescriptor.TYPE_BOOL: - lines.append('lua_pushboolean(L, m->%s(index-1));' % name) + lines.append('lua_pushboolean(L, m->%s(index-1));' % name.lower()) elif type in [FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_UINT32, FieldDescriptor.TYPE_FIXED32, FieldDescriptor.TYPE_SFIXED32, FieldDescriptor.TYPE_SINT32]: - lines.append('lua_pushinteger(L, m->%s(index-1));' % name) + lines.append('lua_pushinteger(L, m->%s(index-1));' % name.lower()) elif type in [ FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_FIXED64, FieldDescriptor.TYPE_SFIXED64, FieldDescriptor.TYPE_SINT64]: - lines.append('lua_pushinteger(L, m->%s(index-1));' % name) + lines.append('lua_pushinteger(L, m->%s(index-1));' % name.lower()) elif type == FieldDescriptor.TYPE_FLOAT or type == FieldDescriptor.TYPE_DOUBLE: - lines.append('lua_pushnumber(L, m->%s(index-1));' % name) + lines.append('lua_pushnumber(L, m->%s(index-1));' % name.lower()) elif type == FieldDescriptor.TYPE_ENUM: - lines.append('lua_pushnumber(L, m->%s(index-1));' % name) + lines.append('lua_pushnumber(L, m->%s(index-1));' % name.lower()) elif type == FieldDescriptor.TYPE_MESSAGE: lines.extend([ - '%s * got_msg = m->mutable_%s(index-1);' % ( type_name.replace('.', '::'), name ), + '%s * got_msg = m->mutable_%s(index-1);' % ( type_name.replace('.', '::'), name.lower() ), 'lua_protobuf%s_pushreference(L, got_msg, NULL, NULL);' % type_name.replace('.', '_'), ]) @@ -348,36 +351,36 @@ def field_get(package, message, field_descriptor): # for scalar fields, we push nil if the value is not defined # this is the Lua way if type == FieldDescriptor.TYPE_STRING or type == FieldDescriptor.TYPE_BYTES: - lines.append('string s = m->%s();' % name) - lines.append('m->has_%s() ? lua_pushlstring(L, s.c_str(), s.size()) : lua_pushnil(L);' % name) + lines.append('string s = m->%s();' % name.lower()) + lines.append('if (m->has_%s()) lua_pushlstring(L, s.c_str(), s.size()); else lua_pushnil(L);' % name.lower()) elif type == FieldDescriptor.TYPE_BOOL: - lines.append('m->has_%s() ? lua_pushboolean(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('m->has_%s() ? lua_pushboolean(L, m->%s()) : lua_pushnil(L);' % ( name.lower(), name.lower() )) elif type in [FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_UINT32, FieldDescriptor.TYPE_FIXED32, FieldDescriptor.TYPE_SFIXED32, FieldDescriptor.TYPE_SINT32]: - lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name.lower(), name.lower() )) elif type in [ FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_FIXED64, FieldDescriptor.TYPE_SFIXED64, FieldDescriptor.TYPE_SINT64]: - lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name.lower(), name.lower() )) elif type == FieldDescriptor.TYPE_FLOAT or type == FieldDescriptor.TYPE_DOUBLE: - lines.append('m->has_%s() ? lua_pushnumber(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('m->has_%s() ? lua_pushnumber(L, m->%s()) : lua_pushnil(L);' % ( name.lower(), name.lower() )) elif type == FieldDescriptor.TYPE_ENUM: - lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name, name )) + lines.append('m->has_%s() ? lua_pushinteger(L, m->%s()) : lua_pushnil(L);' % ( name.lower(), name.lower() )) elif type == FieldDescriptor.TYPE_MESSAGE: lines.extend([ - 'if (!m->has_%s()) {' % name, + 'if (!m->has_%s()) {' % name.lower(), 'lua_pushnil(L);', '}', # we push the message as userdata # since the message is allocated out of the parent message, we # don't need to do garbage collection - '%s * got_msg = m->mutable_%s();' % ( type_name.replace('.', '::'), name ), + '%s * got_msg = m->mutable_%s();' % ( type_name.replace('.', '::'), name.lower() ), 'lua_protobuf%s_pushreference(L, got_msg, NULL, NULL);' % type_name.replace('.', '_'), ]) @@ -393,10 +396,10 @@ def field_get(package, message, field_descriptor): def field_set_assignment(field, args): return [ 'if (index == current_size + 1) {', - 'm->add_%s(%s);' % ( field, args ), + 'm->add_%s(%s);' % ( field.lower(), args ), '}', 'else {', - 'm->set_%s(index-1, %s);' % ( field, args ), + 'm->set_%s(index-1, %s);' % ( field.lower(), args ), '}', ] @@ -422,7 +425,7 @@ def field_set(package, message, field_descriptor): ' return luaL_error(L, "required 2 arguments not passed to function");', '}', 'lua_Integer index = luaL_checkinteger(L, 2);', - 'int current_size = m->%s_size();' % name, + 'int current_size = m->%s_size();' % name.lower(), 'if (index < 1 || index > current_size + 1) {', 'return luaL_error(L, "index must be between 1 and %d", current_size + 1);', '}', @@ -480,7 +483,7 @@ def field_set(package, message, field_descriptor): # this is the Lua way, after all lines.extend([ 'if (lua_isnil(L, 2)) {', - 'm->clear_%s();' % name, + 'm->clear_%s();' % name.lower(), 'return 0;', '}', '', @@ -494,7 +497,7 @@ def field_set(package, message, field_descriptor): 'if (!s) {', 'luaL_error(L, "could not obtain string on stack. weird");', '}', - 'm->set_%s(s, len);' % name, + 'm->set_%s(s, len);' % name.lower(), 'return 0;', ]) @@ -502,7 +505,7 @@ def field_set(package, message, field_descriptor): lines.extend([ 'if (!lua_isnumber(L, 2)) return luaL_error(L, "passed value cannot be converted to a number");', 'lua_Number n = lua_tonumber(L, 2);', - 'm->set_%s(n);' % name, + 'm->set_%s(n);' % name.lower(), 'return 0;', ]) @@ -511,7 +514,7 @@ def field_set(package, message, field_descriptor): lines.extend([ 'lua_Integer v = luaL_checkinteger(L, 2);', - 'm->set_%s(v);' % name, + 'm->set_%s(v);' % name.lower(), 'return 0;', ]) @@ -520,21 +523,21 @@ def field_set(package, message, field_descriptor): lines.extend([ 'lua_Integer i = luaL_checkinteger(L, 2);', - 'm->set_%s(i);' % name, + 'm->set_%s(i);' % name.lower(), 'return 0;', ]) elif type == FieldDescriptor.TYPE_BOOL: lines.extend([ 'bool b = lua_toboolean(L, 2);', - 'm->set_%s(b);' % name, + 'm->set_%s(b);' % name.lower(), 'return 0;', ]) elif type == FieldDescriptor.TYPE_ENUM: lines.extend([ 'lua_Integer i = luaL_checkinteger(L, 2);', - 'm->set_%s((%s)i);' % ( name, type_name.replace('.', '::') ), + 'm->set_%s((%s)i);' % ( name.lower(), type_name.replace('.', '::') ), 'return 0;', ]) @@ -718,14 +721,14 @@ def message_function_array(package, message): '};\n', ] -def message_method_array(package, descriptor): +def message_method_array(package, descriptor, namespace=[]): '''Defines functions for Lua object instances These are functions available to each instance of a message. They take the object userdata as the first parameter. ''' - message = descriptor.name + message = apply_namespace(namespace, descriptor.name) fp = message_function_prefix(package, message) lines = [] @@ -757,10 +760,10 @@ def message_method_array(package, descriptor): return lines -def message_open_function(package, descriptor): +def message_open_function(package, descriptor, namespace): '''Function definition for opening/registering a message type''' - message = descriptor.name + message = apply_namespace(namespace, descriptor.name) lines = [ 'int %s(lua_State *L)' % message_open_function_name(package, message), @@ -785,10 +788,10 @@ def message_open_function(package, descriptor): return lines -def message_header(package, message_descriptor): +def message_header(package, message_descriptor, namespace=[]): '''Returns the lines for a header definition of a message''' - message_name = message_descriptor.name + message_name = apply_namespace(namespace, message_descriptor.name) lines = [] lines.append('// Message %s' % message_name) @@ -877,15 +880,15 @@ def message_header(package, message_descriptor): return lines -def message_source(package, message_descriptor): +def message_source(package, message_descriptor, namespace=[]): '''Returns lines of source code for an individual message type''' lines = [] - message = message_descriptor.name + message = apply_namespace(namespace, message_descriptor.name) lines.extend(message_function_array(package, message)) - lines.extend(message_method_array(package, message_descriptor)) - lines.extend(message_open_function(package, message_descriptor)) + lines.extend(message_method_array(package, message_descriptor, namespace)) + lines.extend(message_open_function(package, message_descriptor, namespace)) lines.extend(message_pushcopy_function(package, message)) lines.extend(message_pushreference_function(package, message)) lines.extend(new_message(package, message)) @@ -925,13 +928,13 @@ def message_source(package, message_descriptor): return lines -def enum_source(descriptor): +def enum_source(descriptor, namespace=[]): '''Returns source code defining an enumeration type''' # this function assumes the module/table the enum should be assigned to # is at the top of the stack when it is called - name = descriptor.name + name = apply_namespace(namespace, descriptor.name) # enums are a little funky # at the core, there is a table whose keys are the enum string names and @@ -989,6 +992,16 @@ def enum_source(descriptor): return lines +def message_header_recursive(package, descriptor, namespace=[]): + lines = [] + + nested_namespace = namespace + [descriptor.name] + for nested in descriptor.nested_type: + lines += message_header_recursive(package, nested, nested_namespace) + + lines += message_header(package, descriptor, namespace) + return lines + def file_header(file_descriptor): filename = file_descriptor.name @@ -999,7 +1012,7 @@ def file_header(file_descriptor): lines.extend(c_header_header(filename, package)) for descriptor in file_descriptor.message_type: - lines.extend(message_header(package, descriptor)) + lines += message_header_recursive(package, descriptor) lines.append('#ifdef __cplusplus') lines.append('}') @@ -1009,6 +1022,40 @@ def file_header(file_descriptor): return '\n'.join(lines) +def enum_source_recursive(descriptor, namespace=[]): + lines = [] + + nested_namespace = namespace + [descriptor.name] + for nested in descriptor.nested_type: + lines += enum_source_recursive(nested, nested_namespace) + + for enum in descriptor.enum_type: + lines += enum_source(enum, nested_namespace) + + return lines + +def message_open_function_recursive(package, descriptor, namespace=[]): + lines = [] + + nested_namespace = namespace + [descriptor.name] + for nested in descriptor.nested_type: + lines += message_open_function_recursive(package, nested, nested_namespace) + + lines.append('%s(L);' % message_open_function_name(package, descriptor.name, namespace)) + + return lines + +def message_source_recursive(package, descriptor, namespace=[]): + lines = [] + + nested_namespace = namespace + [descriptor.name] + for nested in descriptor.nested_type: + lines += message_source_recursive(package, nested, nested_namespace) + + lines += message_source(package, descriptor, namespace) + + return lines + def file_source(file_descriptor): '''Obtains the source code for a FileDescriptor instance''' @@ -1045,8 +1092,10 @@ def file_source(file_descriptor): '}', ]) + for descriptor in file_descriptor.message_type: + enum_source_recursive(descriptor) for descriptor in file_descriptor.enum_type: - lines.extend(enum_source(descriptor)) + lines += enum_source(descriptor) lines.extend([ # don't need main table on stack any more @@ -1058,14 +1107,14 @@ def file_source(file_descriptor): ]) for descriptor in file_descriptor.message_type: - lines.append('%s(L);' % message_open_function_name(package, descriptor.name)) + lines += message_open_function_recursive(package, descriptor) lines.append('return 1;') lines.append('}') lines.append('\n') for descriptor in file_descriptor.message_type: - lines.extend(message_source(package, descriptor)) + lines += message_source_recursive(package, descriptor) # perform some hacky pretty-printing formatted = [] @@ -1085,4 +1134,3 @@ def file_source(file_descriptor): formatted.append((' ' * indent) + line) return '\n'.join(formatted) - diff --git a/protoc-gen-lua b/protoc-gen-lua index 08a4a79..b08277c 100755 --- a/protoc-gen-lua +++ b/protoc-gen-lua @@ -82,4 +82,3 @@ f.content = lua_protobuf_source() stdout.write(response.SerializeToString()) exit(0) -