Files
safeline/extract/descpb_to_proto.py
2025-03-11 19:01:30 +08:00

241 lines
8.9 KiB
Python

#!/usr/bin/python3
#-*- encoding: Utf-8 -*-
from google.protobuf.descriptor_pb2 import DescriptorProto, FieldDescriptorProto
from collections import OrderedDict
from itertools import groupby
"""
This script converts back a FileDescriptor structure to a readable .proto file.
There is already a function in the standard C++ library that does this [1], but
- It is not accessible through the Python binding
- This implementation has a few output readability improvements, i.e
-- Declaring enums/messages after first use rather than at top of block
-- Not always using full names when referencing messages types
-- Smaller aesthetic differences (number of tabs, line jumps)
For reference of the FileDescriptor structure, see [2].
Other (less complete) implementations of this are [3] or [4].
[1] https://github.com/google/protobuf/blob/5a76e/src/google/protobuf/descriptor.cc#L2242
[2] https://github.com/google/protobuf/blob/bb77c/src/google/protobuf/descriptor.proto#L59
[3] https://github.com/fry/d3/blob/master/decompile_protobins.py
[4] https://github.com/sarum9in/bunsan_binlogs_python/blob/master/src/python/source.py
"""
INDENT = ' ' * 4
def descpb_to_proto(desc):
out = 'syntax = "%s";\n\n' % (desc.syntax or 'proto2')
scopes = ['']
if desc.package:
out += 'package %s;\n\n' % desc.package
scopes[0] += '.' + desc.package
for index, dep in enumerate(desc.dependency):
prefix = ' public' * (index in desc.public_dependency)
prefix += ' weak' * (index in desc.weak_dependency)
out += 'import%s "%s";\n' % (prefix, dep)
scopes.append('.' + ('/' + dep.rsplit('/', 1)[0])[1:].replace('/', '.'))
out += '\n' * (out[-2] != '\n')
out += parse_msg(desc, scopes, desc.syntax).strip('\n')
name = desc.name.replace('..', '').strip('.\\/')
return name, out + '\n'
def parse_msg(desc, scopes, syntax):
out = ''
is_msg = isinstance(desc, DescriptorProto)
if is_msg:
scopes = list(scopes)
scopes[0] += '.' + desc.name
blocks = OrderedDict()
for nested_msg in (desc.nested_type if is_msg else desc.message_type):
blocks[nested_msg.name] = parse_msg(nested_msg, scopes, syntax)
for enum in desc.enum_type:
out2 = ''
for val in enum.value:
out2 += '%s = %s;\n' % (val.name, fmt_value(val.number, val.options))
if len(set(i.number for i in enum.value)) == len(enum.value):
enum.options.ClearField('allow_alias')
blocks[enum.name] = wrap_block('enum', out2, enum)
if is_msg and desc.options.map_entry:
return ' map<%s>' % ', '.join(min_name(i.type_name, scopes) \
if i.type_name else types[i.type] \
for i in desc.field)
if is_msg:
for field in desc.field:
out += fmt_field(field, scopes, blocks, syntax)
for index, oneof in enumerate(desc.oneof_decl):
out += wrap_block('oneof', blocks.pop('_oneof_%d' % index), oneof)
out += fmt_ranges('extensions', desc.extension_range)
out += fmt_ranges('reserved', [*desc.reserved_range, *desc.reserved_name])
else:
for service in desc.service:
out2 = ''
for method in service.method:
out2 += 'rpc %s(%s%s) returns (%s%s);\n' % (method.name,
'stream ' * method.client_streaming,
min_name(method.input_type, scopes),
'stream ' * method.server_streaming,
min_name(method.output_type, scopes))
out += wrap_block('service', out2, service)
extendees = OrderedDict()
for ext in desc.extension:
extendees.setdefault(ext.extendee, '')
extendees[ext.extendee] += fmt_field(ext, scopes, blocks, syntax, True)
for name, value in blocks.items():
out += value[:-1]
for name, fields in extendees.items():
out += wrap_block('extend', fields, name=min_name(name, scopes))
out = wrap_block('message' * is_msg, out, desc)
return out
def fmt_value(val, options=None, desc=None, optarr=[]):
if type(val) != str:
if type(val) == bool:
val = str(val).lower()
elif desc and desc.enum_type:
val = desc.enum_type.values_by_number[val].name
val = str(val)
else:
val = '"%s"' % val.encode('unicode_escape').decode('utf8')
if options:
opts = [*optarr]
for (option, value) in options.ListFields():
opts.append('%s = %s' % (option.name, fmt_value(value, desc=option)))
if opts:
val += ' [%s]' % ', '.join(opts)
return val
types = {v: k.split('_')[1].lower() for k, v in FieldDescriptorProto.Type.items()}
labels = {v: k.split('_')[1].lower() for k, v in FieldDescriptorProto.Label.items()}
def fmt_field(field, scopes, blocks, syntax, extend=False):
type_ = types[field.type]
default = ''
if field.default_value:
if field.type == field.TYPE_STRING:
default = ['default = %s' % fmt_value(field.default_value)]
elif field.type == field.TYPE_BYTES:
default = ['default = "%s"' % field.default_value]
else:
# Guess whether it ought to be more readable as base 10 or 16,
# based on the presence of repeated digits:
if ('int' in type_ or 'fixed' in type_) and \
int(field.default_value) >= 0x10000 and \
not any(len(list(i)) > 3 for _, i in groupby(str(field.default_value))):
field.default_value = hex(int(field.default_value))
default = ['default = %s' % field.default_value]
out = ''
if field.type_name:
type_ = min_name(field.type_name, scopes)
short_type = type_.split('.')[-1]
if short_type in blocks and ((not extend and not field.HasField('oneof_index')) or \
blocks[short_type].startswith(' map<')):
out += blocks.pop(short_type)[1:]
if out.startswith('map<'):
line = out + ' %s = %s;\n' % (field.name, fmt_value(field.number, field.options, optarr=default))
out = ''
elif field.type != field.TYPE_GROUP:
line = '%s %s %s = %s;\n' % (labels[field.label], type_, field.name, fmt_value(field.number, field.options, optarr=default))
else:
line = '%s group %s = %d ' % (labels[field.label], type_, field.number)
out = out.split(' ', 2)[-1]
if field.HasField('oneof_index') or (syntax == 'proto3' and line.startswith('optional')):
line = line.split(' ', 1)[-1]
if out:
line = '\n' + line
if field.HasField('oneof_index'):
blocks.setdefault('_oneof_%d' % field.oneof_index, '')
blocks['_oneof_%d' % field.oneof_index] += line + out
return ''
else:
return line + out
"""
Find the smallest name to refer to another message from our scopes.
For this, we take the final part of its name, and expand it until
the path both scopes don't have in common (if any) is specified; and
expand it again if there are multiple outer packages/messages in the
scopes sharing the same name, and that the first part of the obtained
partial name is one of them, leading to ambiguity.
"""
def min_name(name, scopes):
name, cur_scope = name.split('.'), scopes[0].split('.')
short_name = [name.pop()]
while name and (cur_scope[:len(name)] != name or \
any(list_rfind(scope.split('.'), short_name[0]) > len(name) \
for scope in scopes)):
short_name.insert(0, name.pop())
return '.'.join(short_name)
def wrap_block(type_, value, desc=None, name=None):
out = ''
if type_:
out = '\n%s %s {\n' % (type_, name or desc.name)
if desc:
for (option, optval) in desc.options.ListFields():
value = 'option %s = %s;\n' % (option.name, fmt_value(optval, desc=option)) + value
value = value.replace('\n\n\n', '\n\n')
if type_:
out += '\n'.join(INDENT + line for line in value.strip('\n').split('\n'))
out += '\n}\n\n'
else:
out += value
return out
def fmt_ranges(name, ranges):
text = []
for range_ in ranges:
if type(range_) != str and range_.end - 1 > range_.start:
if range_.end < 0x20000000:
text.append('%d to %d' % (range_.start, range_.end - 1))
else:
text.append('%d to max' % range_.start)
elif type(range_) != str:
text.append(fmt_value(range_.start))
else:
text.append(fmt_value(range_))
if text:
return '\n%s %s;\n' % (name, ', '.join(text))
return ''
# Fulfilling a blatant lack of the Python language.
list_rfind = lambda x, i: len(x) - 1 - x[::-1].index(i) if i in x else -1