Source code for pbtools.parser

import os
import re
import textparser
from textparser import Sequence
from textparser import ZeroOrMore
from textparser import ZeroOrMoreDict
from textparser import choice
from textparser import Optional
from textparser import Forward
from textparser import Tag
from textparser import DelimitedList
from textparser import ParseError
from .errors import Error


SCALAR_VALUE_TYPES = [
    'int32',
    'int64',
    'sint32',
    'sint64',
    'uint32',
    'uint64',
    'fixed32',
    'fixed64',
    'sfixed32',
    'sfixed64',
    'float',
    'double',
    'bool',
    'string',
    'bytes'
]


class InternalError(Exception):
    pass


def canonical(value):
    """Replace anything but 'a-z', 'A-Z' and '0-9' with '_'.

    """

    return re.sub(r'[^a-zA-Z0-9]', '_', value)


def camel_to_snake_case(value):
    value = canonical(value)
    value = re.sub(r'(.)([A-Z][a-z]+)', r'\1_\2', value)
    value = re.sub(r'(_+)', '_', value)
    value = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', value).lower()

    return value


def format_parse_error(string, offset):
    return str(ParseError(string, offset))


def ignore_comments(string):
    """Ignore comments in given string by replacing them with spaces. This
    reduces the parsing time by roughly a factor of two.

    """

    comments = [
        (mo.start(), mo.group(0))
        for mo in re.finditer(r'(/\*|\*/|//|\n)', string)
    ]

    comments.sort()

    in_single_line_comment = False
    in_multi_line_comment = False
    start_offset = 0
    non_comment_offset = 0
    chunks = []

    for offset, kind in comments:
        if in_single_line_comment:
            if kind == '\n':
                in_single_line_comment = False
                chunks.append(' ' * (offset - start_offset))
                non_comment_offset = offset
        elif in_multi_line_comment:
            if kind == '/*':
                raise Error(format_parse_error(string, start_offset))
            elif kind == '*/':
                in_multi_line_comment = False
                offset += 2
                chunks.append(' ' * (offset - start_offset))
                non_comment_offset = offset
        elif kind == '//':
            in_single_line_comment = True
            start_offset = offset
            chunks.append(string[non_comment_offset:start_offset])
        elif kind == '/*':
            in_multi_line_comment = True
            start_offset = offset
            chunks.append(string[non_comment_offset:start_offset])

    if in_single_line_comment:
        offset = len(string)
        chunks.append(' ' * (offset - start_offset))
        non_comment_offset = offset

    if in_multi_line_comment:
        raise Error(format_parse_error(string, start_offset))

    chunks.append(string[non_comment_offset:])

    return ''.join(chunks)


class Parser(textparser.Parser):

    KEYWORDS = [
        'syntax',
        'package',
        'message',
        'repeated',
        'enum',
        'service',
        'rpc',
        'returns',
        'stream',
        'import',
        'oneof',
        'option',
        'true',
        'false',
        'map',
        'reserved',
        'max',
        'to',
        'weak',
        'public',
        'optional'
    ]

    def token_specs(self):
        return [
            ('SKIP',          r'[ \r\n\t]+|//.*?\n'),
            ('IDENT',         r'[a-zA-Z]\w*'),
            ('INT',           r'-?(0[xX][a-fA-F0-9]+|[0-9]+)'),
            ('PROTO3',        r'"proto3"'),
            ('STRING',        r'"(\\"|[^"])*?"'),
            ('DOT',      '.', r'\.'),
            ('COMMA',    ',', r','),
            ('EQ',       '=', r'='),
            ('SCOLON',   ';', r';'),
            ('LBRACE',   '{', r'{'),
            ('RBRACE',   '}', r'}'),
            ('LPAREN',   '(', r'\('),
            ('RPAREN',   ')', r'\)'),
            ('LBRACK',   '[', r'\['),
            ('RBRACK',   ']', r'\]'),
            ('LT',       '<', r'<'),
            ('GT',       '>', r'>'),
            ('MISMATCH',      r'.')
        ]

    def keywords(self):
        return set(self.KEYWORDS)

    def grammar(self):
        ident = choice('IDENT', *self.KEYWORDS)
        full_ident = DelimitedList(ident, '.')
        empty_statement = ';'
        message_type = Sequence(Optional('.'), full_ident)
        constant = choice(Tag('bool', choice('true', 'false')),
                          Tag('ident', full_ident),
                          Tag('string', 'STRING'),
                          Tag('integer', 'INT'))
        option_name = Sequence(choice(ident,
                                      Sequence('(', full_ident, ')')),
                               ZeroOrMore(Sequence('.', ident)))
        options = Optional(Sequence('[',
                                    DelimitedList(
                                        Sequence(option_name, '=', constant)),
                                    ']'))

        # Import.
        import_ = Sequence('import',
                           Optional(choice('weak', 'public')),
                           'STRING')

        # Package.
        package = Sequence('package', full_ident, ';')

        # Option.
        option = Sequence('option', option_name, '=', constant, ';')

        # Enum.
        enum_field = Sequence(ident, '=', 'INT', options, ';')
        enum = Sequence('enum',
                        ident,
                        '{',
                        ZeroOrMore(choice(enum_field, empty_statement)),
                        '}')

        # Oneof.
        oneof_field = Sequence(message_type, ident, '=', 'INT', ';')
        oneof = Sequence('oneof',
                         ident,
                         '{',
                         ZeroOrMore(choice(oneof_field, empty_statement)),
                         '}')

        # Map.
        map_field = Sequence('map', '<', ident, ',', message_type, '>',
                             ident, '=', 'INT', options, ';')

        # Reserved.
        field_number_range = Sequence('INT',
                                      Optional(Sequence('to',
                                                        choice('INT', 'max'))))
        reserved = Sequence('reserved', choice(DelimitedList(field_number_range),
                                               DelimitedList('STRING')))

        # Message.
        field = Sequence(Optional(choice('repeated', 'optional')),
                         message_type, ident, '=', 'INT', options, ';')
        message = Forward()
        message <<= Sequence('message',
                             ident,
                             '{',
                             ZeroOrMore(choice(Tag('field', field),
                                               enum,
                                               message,
                                               oneof,
                                               map_field,
                                               reserved,
                                               empty_statement)),
                             '}')

        # Service.
        rpc_name = ident
        rpc = Sequence('rpc',
                       rpc_name,
                       '(', Optional('stream'), message_type, ')',
                       'returns',
                       '(', Optional('stream'), message_type, ')',
                       choice(Sequence('{',
                                       ZeroOrMore(option),
                                       '}'),
                              ';'))
        service = Sequence('service',
                           ident,
                           '{',
                           ZeroOrMore(choice(option, rpc, empty_statement)),
                           '}')

        # Proto3-file.
        top_level_def = choice(message, enum, service)
        syntax = Sequence('syntax', '=', 'PROTO3', ';')
        proto = Sequence(syntax,
                         ZeroOrMoreDict(choice(import_,
                                               package,
                                               option,
                                               top_level_def,
                                               empty_statement)))

        return proto


class Field:

    def __init__(self, type, name, field_number):
        self.type = type[-1]
        self.name = name
        self.field_number = field_number
        self.namespace = type[:-1]
        self.type_kind = None
        self.package = '.'.join(self.namespace)

    @property
    def full_type(self):
        return '.'.join(self.namespace + [self.type])

    @property
    def full_type_snake_case(self):
        return camel_to_snake_case(self.full_type)

    @property
    def name_snake_case(self):
        return camel_to_snake_case(self.name)


class Option:

    def __init__(self, tokens):
        if tokens[1][0][0] == '(':
            full_ident = '.'.join(tokens[1][0][1])
            self.name = f'({full_ident})'
        else:
            self.name = tokens[1][0]

        self.name += ''.join([f'.{ident}' for _, ident in tokens[1][1]])

        kind, value = tokens[3]

        if kind == 'string':
            value = value[1:-1]
        elif kind == 'bool':
            value = (value == 'true')
        elif kind == 'ident':
            value = value[0]
        elif kind == 'integer':
            value = int(value[0])

        self.kind = kind
        self.value = value


class EnumField:

    def __init__(self, tokens):
        self.name = tokens[0]
        self.field_number = int(tokens[2])

    @property
    def name_snake_case(self):
        return camel_to_snake_case(self.name)


class Enum:

    def __init__(self, tokens, namespace):
        self.name = tokens[1]
        self.fields = []
        self.namespace = namespace

        for item in tokens[3]:
            self.fields.append(EnumField(item))

    @property
    def namespace_snake_case(self):
        return camel_to_snake_case('.'.join(self.namespace))

    @property
    def full_name(self):
        return '.'.join(self.namespace + [self.name])

    @property
    def full_name_snake_case(self):
        return camel_to_snake_case(self.full_name)


class OneofField(Field):

    def __init__(self, tokens):
        super().__init__(load_message_type(tokens[0]),
                         tokens[1],
                         int(tokens[3]))


class Oneof:

    def __init__(self, tokens, namespace):
        self.name = tokens[1]
        self.fields = []
        self.namespace = namespace

        for item in tokens[3]:
            self.fields.append(OneofField(item))

    @property
    def full_name(self):
        return '.'.join(self.namespace + [self.name])

    @property
    def full_name_snake_case(self):
        return camel_to_snake_case(self.full_name)

    @property
    def name_snake_case(self):
        return camel_to_snake_case(self.name)


class MessageField(Field):

    def __init__(self, type, name, field_number, repeated, optional):
        super().__init__(type, name, field_number)
        self.repeated = repeated
        self.optional = optional

    @classmethod
    def from_field_tokens(cls, tokens):
        if tokens[0]:
            repeated = 'repeated' in tokens[0]
            optional = 'optional' in tokens[0]
        else:
            repeated = False
            optional = False

        return cls(load_message_type(tokens[1]),
                   tokens[2],
                   int(tokens[4]),
                   repeated,
                   optional)


[docs]class Message: """A message. """ def __init__(self, tokens, namespace): self.name = tokens[1] self.fields = [] self.enums = [] self.messages = [] self.oneofs = [] self.namespace = namespace for item in tokens[3]: kind = item[0] sub_namespace = namespace + [self.name] if kind == 'field': self.fields.append(MessageField.from_field_tokens(item[1])) elif kind == 'enum': self.enums.append(Enum(item, sub_namespace)) elif kind == 'message': self.messages.append(Message(item, sub_namespace)) elif kind == 'oneof': self.oneofs.append(Oneof(item, sub_namespace)) elif kind == 'map': self._load_map(item, sub_namespace) elif kind in ['reserved', ';']: pass else: raise InternalError(kind) def _load_map(self, tokens, sub_namespace): key_type = tokens[2] value_type = load_message_type(tokens[4])[-1] name = tokens[6] field_number = int(tokens[8]) map_type = f'PbtoolsMap{name.title()}' map_message_tokens = [ 'message', map_type, '{', [ ('field', [[], [[], [key_type]], 'key', '=', '1', [], ';']), ('field', [[], [[], [value_type]], 'value', '=', '2', [], ';']) ], '}' ] self.messages.append(Message(map_message_tokens, sub_namespace)) self.fields.append(MessageField([map_type], name, field_number, True, False)) @property def repeated_fields(self): return [field for field in self.fields if field.repeated] @property def type_names(self): type_names = [enum.name for enum in self.enums] type_names += [message.name for message in self.messages] return type_names @property def full_name(self): return '.'.join(self.namespace + [self.name]) @property def full_name_snake_case(self): return camel_to_snake_case(self.full_name) def used_types(self): used_types = [] for field in self.fields: if field.full_type in SCALAR_VALUE_TYPES: continue used_types.append(field.full_type) for oneof in self.oneofs: for field in oneof.fields: if field.full_type in SCALAR_VALUE_TYPES: continue used_types.append(field.full_type) for message in self.messages: used_types += message.used_types() return list(set(used_types))
class Rpc: def __init__(self, tokens): self.name = tokens[1] self.request_type = tokens[4][1][0] self.request_stream = bool(tokens[3]) self.response_type = tokens[9][1][0] self.response_stream = bool(tokens[8]) class Service: def __init__(self, tokens): self.name = tokens[1] self.rpcs = [] for item in tokens[3]: kind = item[0] if kind == 'rpc': self.rpcs.append(Rpc(item)) elif kind == ';': pass else: raise InternalError(kind) class ImportedProto: def __init__(self, tokens, import_paths): self.path = tokens[2].strip('"') self.abspath = find_file(self.path, import_paths) with open(self.abspath, encoding='utf-8') as fin: tree = Parser().parse(ignore_comments(fin.read())) # ToDo: Public imports should be found as well. self.package = load_package(tree) self.enums = [ tokens[1] for tokens in tree[1].get('enum', []) ] self.messages = [ tokens[1] for tokens in tree[1].get('message', []) ] @property def type_names(self): return self.enums + self.messages
[docs]class Proto: """A proto3-file. :func:`~pbtools.parse_file()` returns an instance of this class. """ def __init__(self, tree, abspath, import_paths): self.abspath = abspath self._package = load_package(tree) self._imports = load_imports(tree, import_paths) namespace = self.namespace_base() self._options = load_options(tree) self._messages = load_messages(tree, namespace) self._services = load_services(tree) self._enums = load_enums(tree, namespace) self.messages_stack = [] self.resolve_messages_fields_types() self.resolve_messages_fields_type_kinds() self._messages = self.sort_messages_by_usage(self._messages) @property def package(self): """Package name, or ``None`` if missing. """ return self._package @property def imports(self): """A list of all imports. """ return self._imports @property def options(self): """A list of all options. """ return self._options @property def services(self): """A list of all services. """ return self._services @property def messages(self): """A list of all messages. """ return self._messages @property def enums(self): """A list of all enums. """ return self._enums @property def type_names(self): type_names = [enum.name for enum in self._enums] type_names += [message.name for message in self._messages] return type_names def namespace_base(self): if self._package is None: return [] else: return [self._package] def resolve_messages_fields_types(self): for message in self._messages: self.resolve_message_fields_types(message) def resolve_message_fields_types(self, message): self.messages_stack.append(message) for field in message.fields: if field.type in SCALAR_VALUE_TYPES: continue self.resolve_field_type(field) for sub_message in message.messages: self.resolve_message_fields_types(sub_message) for oneof in message.oneofs: for field in oneof.fields: if field.type in SCALAR_VALUE_TYPES: continue self.resolve_field_type(field) self.messages_stack.pop() def resolve_field_type(self, field): if field.namespace: return for message in reversed(self.messages_stack): if field.type in message.type_names: namespace = message.namespace + [message.name] package = self._package break else: if field.type in self.type_names: namespace = self.namespace_base() package = self._package else: for imported in self._imports: if imported.package == self._package: if field.type in imported.type_names: namespace = self.namespace_base() package = self._package break else: namespace = [] package = None field.namespace = namespace field.package = package def resolve_messages_fields_type_kinds(self): for message in self._messages: self.resolve_message_fields_type_kinds(message) def resolve_message_fields_type_kinds(self, message): self.messages_stack.append(message) for field in message.fields: self.resolve_field_type_kind(field) for sub_message in message.messages: self.resolve_message_fields_type_kinds(sub_message) for oneof in message.oneofs: for field in oneof.fields: self.resolve_field_type_kind(field) self.messages_stack.pop() def resolve_field_type_kind(self, field): if field.type in SCALAR_VALUE_TYPES: field.type_kind = 'scalar-value-type' elif self.is_field_enum(field): field.type_kind = 'enum' else: field.type_kind = 'message' def is_field_enum(self, field): for imported in self._imports: if field.package == imported.package: if field.type in imported.enums: return True elif field.type in imported.messages: return False if field.package == self._package: offset = len(self.namespace_base()) type = self.lookup_type(field.namespace[offset:] + [field.type], self._enums, self._messages) return isinstance(type, Enum) raise Error(f"'{field.type}' is not defined.") def lookup_type(self, path, enums, messages): name = path[0] rest = path[1:] if rest: for message in messages: if message.name == name: return self.lookup_type(rest, message.enums, message.messages) else: for enum in enums: if enum.name == name: return enum for message in messages: if message.name == name: return message def sort_messages_by_usage(self, messages): reversed_sorted_messages = [] for message in messages: message.messages = self.sort_messages_by_usage(message.messages) # Insert first in the reversed list if there are no types # using this type. insert_index = 0 for i, sorted_message in enumerate(reversed_sorted_messages, 1): if message.full_name in sorted_message.used_types(): if i > insert_index: insert_index = i reversed_sorted_messages.insert(insert_index, message) return list(reversed(reversed_sorted_messages))
def load_message_type(tokens): return tokens[1] def load_package(tokens): try: return '.'.join(tokens[1]['package'][0][1]) except KeyError: return None def load_imports(tokens, import_paths): return [ ImportedProto(imported, import_paths) for imported in tokens[1].get('import', []) ] def load_options(tokens): return [ Option(option) for option in tokens[1].get('option', []) ] def load_messages(tokens, namespace): return [ Message(message, namespace) for message in tokens[1].get('message', []) ] def load_services(tokens): return [ Service(service) for service in tokens[1].get('service', []) ] def load_enums(tokens, namespace): return [ Enum(enum, namespace) for enum in tokens[1].get('enum', []) ] def find_file(filename, import_paths): for import_path in import_paths: filepath = os.path.join(import_path, filename) if os.path.exists(filepath): break else: filepath = filename filepath = os.path.abspath(filepath) if not os.path.exists(filepath): raise Error(f"'{filename}' not found in import path.") return filepath
[docs]def parse_file(filename, import_paths=None): """Parse given proto3-file `filename` and its imports. Returns a :class:`~pbtools.parser.Proto` object. `import_paths` is a list of paths where to search for imported files. """ if import_paths is None: import_paths = [] with open(filename, 'r', encoding='utf-8') as fin: return Proto(Parser().parse(ignore_comments(fin.read())), find_file(filename, import_paths), import_paths)