645 lines
25 KiB
Plaintext
645 lines
25 KiB
Plaintext
|
|
#!/usr/bin/env python3
|
||
|
|
|
||
|
|
# Copyright 2022 Google LLC
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
|
||
|
|
"""Custom mmi2grpc gRPC compiler."""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
|
||
|
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||
|
|
|
||
|
|
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest, CodeGeneratorResponse
|
||
|
|
from google.protobuf.descriptor import (
|
||
|
|
FieldDescriptor
|
||
|
|
)
|
||
|
|
from google.protobuf.descriptor_pb2 import (
|
||
|
|
FileDescriptorProto,
|
||
|
|
EnumDescriptorProto,
|
||
|
|
DescriptorProto,
|
||
|
|
ServiceDescriptorProto,
|
||
|
|
MethodDescriptorProto,
|
||
|
|
FieldDescriptorProto,
|
||
|
|
)
|
||
|
|
|
||
|
|
_REQUEST = CodeGeneratorRequest.FromString(sys.stdin.buffer.read())
|
||
|
|
|
||
|
|
|
||
|
|
def find_type_in_file(proto_file: FileDescriptorProto, type_name: str) -> Optional[Union[DescriptorProto, EnumDescriptorProto]]:
|
||
|
|
for enum in proto_file.enum_type:
|
||
|
|
if enum.name == type_name:
|
||
|
|
return enum
|
||
|
|
for message in proto_file.message_type:
|
||
|
|
if message.name == type_name:
|
||
|
|
return message
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def find_type(package: str, type_name: str) -> Tuple[FileDescriptorProto, Union[DescriptorProto, EnumDescriptorProto]]:
|
||
|
|
for file in _REQUEST.proto_file:
|
||
|
|
if file.package == package and (type := find_type_in_file(file, type_name)):
|
||
|
|
return file, type
|
||
|
|
raise Exception(f'Type {package}.{type_name} not found')
|
||
|
|
|
||
|
|
|
||
|
|
def add_import(imports: List[str], import_str: str) -> None:
|
||
|
|
if not import_str in imports:
|
||
|
|
imports.append(import_str)
|
||
|
|
|
||
|
|
|
||
|
|
def import_type(imports: List[str], type: str, local: Optional[FileDescriptorProto]) -> Tuple[str, Union[DescriptorProto, EnumDescriptorProto], str]:
|
||
|
|
package = type[1:type.rindex('.')]
|
||
|
|
type_name = type[type.rindex('.')+1:]
|
||
|
|
file, desc = find_type(package, type_name)
|
||
|
|
if file == local:
|
||
|
|
return f'{type_name}', desc, ''
|
||
|
|
python_path = file.name.replace('.proto', '').replace('/', '.')
|
||
|
|
module_path = python_path[:python_path.rindex('.')]
|
||
|
|
module_name = python_path[python_path.rindex('.')+1:] + '_pb2'
|
||
|
|
add_import(imports, f'from {module_path} import {module_name}')
|
||
|
|
dft_import = ''
|
||
|
|
if isinstance(desc, EnumDescriptorProto):
|
||
|
|
dft_import = f'from {module_path}.{module_name} import {desc.value[0].name}'
|
||
|
|
return f'{module_name}.{type_name}', desc, dft_import
|
||
|
|
|
||
|
|
|
||
|
|
def collect_type(imports: List[str], parent: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[str, str, str]:
|
||
|
|
dft: str
|
||
|
|
dft_import: str = ''
|
||
|
|
if field.type == FieldDescriptor.TYPE_BYTES:
|
||
|
|
type = 'bytes'
|
||
|
|
dft = 'b\'\''
|
||
|
|
elif field.type == FieldDescriptor.TYPE_STRING:
|
||
|
|
type = 'str'
|
||
|
|
dft = '\'\''
|
||
|
|
elif field.type == FieldDescriptor.TYPE_BOOL:
|
||
|
|
type = 'bool'
|
||
|
|
dft = 'False'
|
||
|
|
elif field.type in [
|
||
|
|
FieldDescriptor.TYPE_FLOAT,
|
||
|
|
FieldDescriptor.TYPE_DOUBLE
|
||
|
|
]:
|
||
|
|
type = 'float'
|
||
|
|
dft = '0.0'
|
||
|
|
elif field.type in [
|
||
|
|
FieldDescriptor.TYPE_INT64,
|
||
|
|
FieldDescriptor.TYPE_UINT64,
|
||
|
|
FieldDescriptor.TYPE_INT32,
|
||
|
|
FieldDescriptor.TYPE_FIXED64,
|
||
|
|
FieldDescriptor.TYPE_FIXED32,
|
||
|
|
FieldDescriptor.TYPE_UINT32,
|
||
|
|
FieldDescriptor.TYPE_SFIXED32,
|
||
|
|
FieldDescriptor.TYPE_SFIXED64,
|
||
|
|
FieldDescriptor.TYPE_SINT32,
|
||
|
|
FieldDescriptor.TYPE_SINT64
|
||
|
|
]:
|
||
|
|
type = 'int'
|
||
|
|
dft = '0'
|
||
|
|
elif field.type in [FieldDescriptor.TYPE_ENUM, FieldDescriptor.TYPE_MESSAGE]:
|
||
|
|
parts = field.type_name.split(f".{parent.name}.", 2)
|
||
|
|
if len(parts) == 2:
|
||
|
|
type = parts[1]
|
||
|
|
for nested_type in parent.nested_type:
|
||
|
|
if nested_type.name == type:
|
||
|
|
assert nested_type.options.map_entry
|
||
|
|
assert field.label == FieldDescriptor.LABEL_REPEATED
|
||
|
|
key_type, _, _ = collect_type(imports, nested_type, nested_type.field[0], local)
|
||
|
|
val_type, _, _ = collect_type(imports, nested_type, nested_type.field[1], local)
|
||
|
|
add_import(imports, 'from typing import Dict')
|
||
|
|
return f'Dict[{key_type}, {val_type}]', '{}', ''
|
||
|
|
type, desc, enum_dft = import_type(imports, field.type_name, local)
|
||
|
|
if isinstance(desc, EnumDescriptorProto):
|
||
|
|
dft_import = enum_dft
|
||
|
|
dft = desc.value[0].name
|
||
|
|
else:
|
||
|
|
dft = f'{type}()'
|
||
|
|
else:
|
||
|
|
raise Exception(f'TODO: {field}')
|
||
|
|
|
||
|
|
if field.label == FieldDescriptor.LABEL_REPEATED:
|
||
|
|
add_import(imports, 'from typing import List')
|
||
|
|
type = f'List[{type}]'
|
||
|
|
dft = '[]'
|
||
|
|
|
||
|
|
return type, dft, dft_import
|
||
|
|
|
||
|
|
|
||
|
|
def collect_field(imports: List[str], message: DescriptorProto, field: FieldDescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[Optional[int], str, str, str, str]:
|
||
|
|
type, dft, dft_import = collect_type(imports, message, field, local)
|
||
|
|
oneof_index = field.oneof_index if 'oneof_index' in f'{field}' else None
|
||
|
|
return oneof_index, field.name, type, dft, dft_import
|
||
|
|
|
||
|
|
|
||
|
|
def collect_message(imports: List[str], message: DescriptorProto, local: Optional[FileDescriptorProto]) -> Tuple[
|
||
|
|
List[Tuple[str, str, str]],
|
||
|
|
Dict[str, list[Tuple[str, str]]],
|
||
|
|
]:
|
||
|
|
fields: List[Tuple[str, str, str]] = []
|
||
|
|
oneof: Dict[str, list[Tuple[str, str]]] = {}
|
||
|
|
|
||
|
|
for field in message.field:
|
||
|
|
idx, name, type, dft, dft_import = collect_field(imports, message, field, local)
|
||
|
|
if idx is not None:
|
||
|
|
oneof_name = message.oneof_decl[idx].name
|
||
|
|
oneof.setdefault(oneof_name, [])
|
||
|
|
oneof[oneof_name].append((name, type))
|
||
|
|
else:
|
||
|
|
add_import(imports, dft_import)
|
||
|
|
fields.append((name, type, dft))
|
||
|
|
|
||
|
|
for oneof_name, oneof_fields in oneof.items():
|
||
|
|
for name, type in oneof_fields:
|
||
|
|
add_import(imports, 'from typing import Optional')
|
||
|
|
fields.append((name, f'Optional[{type}]', 'None'))
|
||
|
|
|
||
|
|
return fields, oneof
|
||
|
|
|
||
|
|
|
||
|
|
def generate_enum(imports: List[str], file: FileDescriptorProto, enum: EnumDescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]:
|
||
|
|
res.append(CodeGeneratorResponse.File(
|
||
|
|
name=file.name.replace('.proto', '_pb2.py'),
|
||
|
|
insertion_point=f'module_scope',
|
||
|
|
content=f'class {enum.name}: ...\n\n'
|
||
|
|
))
|
||
|
|
add_import(imports, 'from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper')
|
||
|
|
return [
|
||
|
|
f'class {enum.name}(int, EnumTypeWrapper):',
|
||
|
|
f' pass',
|
||
|
|
f'',
|
||
|
|
*[f'{value.name}: {enum.name}' for value in enum.value],
|
||
|
|
''
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def generate_message(imports: List[str], file: FileDescriptorProto, message: DescriptorProto, res: List[CodeGeneratorResponse.File]) -> List[str]:
|
||
|
|
nested_message_lines: List[str] = []
|
||
|
|
message_lines: List[str] = [f'class {message.name}(Message):']
|
||
|
|
|
||
|
|
add_import(imports, 'from google.protobuf.message import Message')
|
||
|
|
fields, oneof = collect_message(imports, message, file)
|
||
|
|
|
||
|
|
for (name, type, _) in fields:
|
||
|
|
message_lines.append(f' {name}: {type}')
|
||
|
|
|
||
|
|
args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in fields])
|
||
|
|
if args: args = ', ' + args
|
||
|
|
message_lines.extend([
|
||
|
|
f'',
|
||
|
|
f' def __init__(self{args}) -> None: ...',
|
||
|
|
f''
|
||
|
|
])
|
||
|
|
|
||
|
|
for oneof_name, oneof_fields in oneof.items():
|
||
|
|
literals: str = ', '.join((f'Literal[\'{name}\']' for name, _ in oneof_fields))
|
||
|
|
types: Set[str] = set((type for _, type in oneof_fields))
|
||
|
|
if len(types) == 1:
|
||
|
|
type = 'Optional[' + types.pop() + ']'
|
||
|
|
else:
|
||
|
|
types.add('None')
|
||
|
|
type = 'Union[' + ', '.join(types) + ']'
|
||
|
|
|
||
|
|
nested_message_lines.extend([
|
||
|
|
f'class {message.name}_{oneof_name}_dict(TypedDict, total=False):',
|
||
|
|
'\n'.join([f' {name}: {type}' for name, type in oneof_fields]),
|
||
|
|
f'',
|
||
|
|
])
|
||
|
|
|
||
|
|
add_import(imports, 'from typing import Union')
|
||
|
|
add_import(imports, 'from typing_extensions import TypedDict')
|
||
|
|
add_import(imports, 'from typing_extensions import Literal')
|
||
|
|
message_lines.extend([
|
||
|
|
f' @property',
|
||
|
|
f' def {oneof_name}(self) -> {type}: ...'
|
||
|
|
f'',
|
||
|
|
f' def {oneof_name}_variant(self) -> Union[{literals}, None]: ...'
|
||
|
|
f'',
|
||
|
|
f' def {oneof_name}_asdict(self) -> {message.name}_{oneof_name}_dict: ...',
|
||
|
|
f'',
|
||
|
|
])
|
||
|
|
|
||
|
|
return_variant = '\n '.join([f'if variant == \'{name}\': return unwrap(self.{name})' for name, _ in oneof_fields])
|
||
|
|
return_asdict = '\n '.join([f'if variant == \'{name}\': return {{\'{name}\': unwrap(self.{name})}} # type: ignore' for name, _ in oneof_fields])
|
||
|
|
if return_variant: return_variant += '\n '
|
||
|
|
if return_asdict: return_asdict += '\n '
|
||
|
|
|
||
|
|
res.append(CodeGeneratorResponse.File(
|
||
|
|
name=file.name.replace('.proto', '_pb2.py'),
|
||
|
|
insertion_point=f'module_scope',
|
||
|
|
content=f"""
|
||
|
|
def _{message.name}_{oneof_name}(self: {message.name}):
|
||
|
|
variant = self.{oneof_name}_variant()
|
||
|
|
if variant is None: return None
|
||
|
|
{return_variant}raise Exception('Field `{oneof_name}` not found.')
|
||
|
|
|
||
|
|
def _{message.name}_{oneof_name}_variant(self: {message.name}):
|
||
|
|
return self.WhichOneof('{oneof_name}') # type: ignore
|
||
|
|
|
||
|
|
def _{message.name}_{oneof_name}_asdict(self: {message.name}):
|
||
|
|
variant = self.{oneof_name}_variant()
|
||
|
|
if variant is None: return {{}}
|
||
|
|
{return_asdict}raise Exception('Field `{oneof_name}` not found.')
|
||
|
|
|
||
|
|
setattr({message.name}, '{oneof_name}', property(_{message.name}_{oneof_name}))
|
||
|
|
setattr({message.name}, '{oneof_name}_variant', _{message.name}_{oneof_name}_variant)
|
||
|
|
setattr({message.name}, '{oneof_name}_asdict', _{message.name}_{oneof_name}_asdict)
|
||
|
|
"""))
|
||
|
|
|
||
|
|
return message_lines + nested_message_lines
|
||
|
|
|
||
|
|
|
||
|
|
def generate_service_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, method: MethodDescriptorProto, sync: bool = True) -> List[str]:
|
||
|
|
input_mode = 'stream' if method.client_streaming else 'unary'
|
||
|
|
output_mode = 'stream' if method.server_streaming else 'unary'
|
||
|
|
|
||
|
|
input_type, input_msg, _ = import_type(imports, method.input_type, None)
|
||
|
|
output_type, _, _ = import_type(imports, method.output_type, None)
|
||
|
|
|
||
|
|
input_type_pb2, _, _ = import_type(imports, method.input_type, None)
|
||
|
|
output_type_pb2, _, _ = import_type(imports, method.output_type, None)
|
||
|
|
|
||
|
|
if output_mode == 'stream':
|
||
|
|
if input_mode == 'stream':
|
||
|
|
output_type_hint = f'StreamStream[{input_type}, {output_type}]'
|
||
|
|
if sync:
|
||
|
|
add_import(imports, f'from ._utils import Sender')
|
||
|
|
add_import(imports, f'from ._utils import Stream')
|
||
|
|
add_import(imports, f'from ._utils import StreamStream')
|
||
|
|
else:
|
||
|
|
add_import(imports, f'from ._utils import AioSender as Sender')
|
||
|
|
add_import(imports, f'from ._utils import AioStream as Stream')
|
||
|
|
add_import(imports, f'from ._utils import AioStreamStream as StreamStream')
|
||
|
|
else:
|
||
|
|
output_type_hint = f'Stream[{output_type}]'
|
||
|
|
if sync:
|
||
|
|
add_import(imports, f'from ._utils import Stream')
|
||
|
|
else:
|
||
|
|
add_import(imports, f'from ._utils import AioStream as Stream')
|
||
|
|
else:
|
||
|
|
output_type_hint = output_type if sync else f'Awaitable[{output_type}]'
|
||
|
|
if not sync: add_import(imports, f'from typing import Awaitable')
|
||
|
|
|
||
|
|
if input_mode == 'stream' and output_mode == 'stream':
|
||
|
|
add_import(imports, f'from typing import Optional')
|
||
|
|
return (
|
||
|
|
f'def {method.name}(self, timeout: Optional[float] = None) -> {output_type_hint}:\n'
|
||
|
|
f' tx: Sender[{input_type}] = Sender()\n'
|
||
|
|
f' rx: Stream[{output_type}] = self.channel.{input_mode}_{output_mode}( # type: ignore\n'
|
||
|
|
f" '/{file.package}.{service.name}/{method.name}',\n"
|
||
|
|
f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n'
|
||
|
|
f' response_deserializer={output_type_pb2}.FromString # type: ignore\n'
|
||
|
|
f' )(tx)\n'
|
||
|
|
f' return StreamStream(tx, rx)'
|
||
|
|
).split('\n')
|
||
|
|
if input_mode == 'stream':
|
||
|
|
iterator_type = 'Iterator' if sync else 'AsyncIterator'
|
||
|
|
add_import(imports, f'from typing import {iterator_type}')
|
||
|
|
add_import(imports, f'from typing import Optional')
|
||
|
|
return (
|
||
|
|
f'def {method.name}(self, iterator: {iterator_type}[{input_type}], timeout: Optional[float] = None) -> {output_type_hint}:\n'
|
||
|
|
f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n'
|
||
|
|
f" '/{file.package}.{service.name}/{method.name}',\n"
|
||
|
|
f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n'
|
||
|
|
f' response_deserializer={output_type_pb2}.FromString # type: ignore\n'
|
||
|
|
f' )(iterator)'
|
||
|
|
).split('\n')
|
||
|
|
else:
|
||
|
|
add_import(imports, f'from typing import Optional')
|
||
|
|
assert isinstance(input_msg, DescriptorProto)
|
||
|
|
input_fields, _ = collect_message(imports, input_msg, None)
|
||
|
|
args = ', '.join([f'{name}: {type} = {dft}' for name, type, dft in input_fields])
|
||
|
|
args_name = ', '.join([f'{name}={name}' for name, _, _ in input_fields])
|
||
|
|
if args: args = ', ' + args
|
||
|
|
return (
|
||
|
|
f'def {method.name}(self{args}, wait_for_ready: Optional[bool] = None, timeout: Optional[float] = None) -> {output_type_hint}:\n'
|
||
|
|
f' return self.channel.{input_mode}_{output_mode}( # type: ignore\n'
|
||
|
|
f" '/{file.package}.{service.name}/{method.name}',\n"
|
||
|
|
f' request_serializer={input_type_pb2}.SerializeToString, # type: ignore\n'
|
||
|
|
f' response_deserializer={output_type_pb2}.FromString # type: ignore\n'
|
||
|
|
f' )({input_type_pb2}({args_name}), wait_for_ready=wait_for_ready, timeout=timeout) # type: ignore'
|
||
|
|
).split('\n')
|
||
|
|
|
||
|
|
|
||
|
|
def generate_service(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
|
||
|
|
methods = '\n\n '.join([
|
||
|
|
'\n '.join(
|
||
|
|
generate_service_method(imports, file, service, method, sync)
|
||
|
|
) for method in service.method
|
||
|
|
])
|
||
|
|
channel_type = 'grpc.Channel' if sync else 'grpc.aio.Channel'
|
||
|
|
return (
|
||
|
|
f'class {service.name}:\n'
|
||
|
|
f' channel: {channel_type}\n'
|
||
|
|
f'\n'
|
||
|
|
f' def __init__(self, channel: {channel_type}) -> None:\n'
|
||
|
|
f' self.channel = channel\n'
|
||
|
|
f'\n'
|
||
|
|
f' {methods}\n'
|
||
|
|
).split('\n')
|
||
|
|
|
||
|
|
|
||
|
|
def generate_servicer_method(imports: List[str], method: MethodDescriptorProto, sync: bool = True) -> List[str]:
|
||
|
|
input_mode = 'stream' if method.client_streaming else 'unary'
|
||
|
|
output_mode = 'stream' if method.server_streaming else 'unary'
|
||
|
|
|
||
|
|
input_type, _, _ = import_type(imports, method.input_type, None)
|
||
|
|
output_type, _, _ = import_type(imports, method.output_type, None)
|
||
|
|
|
||
|
|
output_type_hint = output_type
|
||
|
|
if output_mode == 'stream':
|
||
|
|
if sync:
|
||
|
|
output_type_hint = f'Generator[{output_type}, None, None]'
|
||
|
|
add_import(imports, f'from typing import Generator')
|
||
|
|
else:
|
||
|
|
output_type_hint = f'AsyncGenerator[{output_type}, None]'
|
||
|
|
add_import(imports, f'from typing import AsyncGenerator')
|
||
|
|
|
||
|
|
iterator_type = 'Iterator' if sync else 'AsyncIterator'
|
||
|
|
|
||
|
|
if input_mode == 'stream':
|
||
|
|
iterator_type = 'Iterator' if sync else 'AsyncIterator'
|
||
|
|
add_import(imports, f'from typing import {iterator_type}')
|
||
|
|
lines = (('' if sync else 'async ') + (
|
||
|
|
f'def {method.name}(self, request: {iterator_type}[{input_type}], context: grpc.ServicerContext) -> {output_type_hint}:\n'
|
||
|
|
f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n'
|
||
|
|
f' context.set_details("Method not implemented!") # type: ignore\n'
|
||
|
|
f' raise NotImplementedError("Method not implemented!")'
|
||
|
|
)).split('\n')
|
||
|
|
else:
|
||
|
|
lines = (('' if sync else 'async ') + (
|
||
|
|
f'def {method.name}(self, request: {input_type}, context: grpc.ServicerContext) -> {output_type_hint}:\n'
|
||
|
|
f' context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore\n'
|
||
|
|
f' context.set_details("Method not implemented!") # type: ignore\n'
|
||
|
|
f' raise NotImplementedError("Method not implemented!")'
|
||
|
|
)).split('\n')
|
||
|
|
if output_mode == 'stream':
|
||
|
|
lines.append(f' yield {output_type}() # no-op: to make the linter happy')
|
||
|
|
return lines
|
||
|
|
|
||
|
|
|
||
|
|
def generate_servicer(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
|
||
|
|
methods = '\n\n '.join([
|
||
|
|
'\n '.join(
|
||
|
|
generate_servicer_method(imports, method, sync)
|
||
|
|
) for method in service.method
|
||
|
|
])
|
||
|
|
if not methods:
|
||
|
|
methods = 'pass'
|
||
|
|
return (
|
||
|
|
f'class {service.name}Servicer:\n'
|
||
|
|
f' {methods}\n'
|
||
|
|
).split('\n')
|
||
|
|
|
||
|
|
|
||
|
|
def generate_rpc_method_handler(imports: List[str], method: MethodDescriptorProto) -> List[str]:
|
||
|
|
input_mode = 'stream' if method.client_streaming else 'unary'
|
||
|
|
output_mode = 'stream' if method.server_streaming else 'unary'
|
||
|
|
|
||
|
|
input_type, _, _ = import_type(imports, method.input_type, None)
|
||
|
|
output_type, _, _ = import_type(imports, method.output_type, None)
|
||
|
|
|
||
|
|
return (
|
||
|
|
f"'{method.name}': grpc.{input_mode}_{output_mode}_rpc_method_handler( # type: ignore\n"
|
||
|
|
f' servicer.{method.name},\n'
|
||
|
|
f' request_deserializer={input_type}.FromString, # type: ignore\n'
|
||
|
|
f' response_serializer={output_type}.SerializeToString, # type: ignore\n'
|
||
|
|
f' ),\n'
|
||
|
|
).split('\n')
|
||
|
|
|
||
|
|
|
||
|
|
def generate_add_servicer_to_server_method(imports: List[str], file: FileDescriptorProto, service: ServiceDescriptorProto, sync: bool = True) -> List[str]:
|
||
|
|
method_handlers = ' '.join([
|
||
|
|
'\n '.join(
|
||
|
|
generate_rpc_method_handler(imports, method)
|
||
|
|
) for method in service.method
|
||
|
|
])
|
||
|
|
server_type = 'grpc.Server' if sync else 'grpc.aio.Server'
|
||
|
|
return (
|
||
|
|
f'def add_{service.name}Servicer_to_server(servicer: {service.name}Servicer, server: {server_type}) -> None:\n'
|
||
|
|
f' rpc_method_handlers = {{\n'
|
||
|
|
f' {method_handlers}\n'
|
||
|
|
f' }}\n'
|
||
|
|
f' generic_handler = grpc.method_handlers_generic_handler( # type: ignore\n'
|
||
|
|
f" '{file.package}.{service.name}', rpc_method_handlers)\n"
|
||
|
|
f' server.add_generic_rpc_handlers((generic_handler,)) # type: ignore\n'
|
||
|
|
).split('\n')
|
||
|
|
|
||
|
|
|
||
|
|
_HEADER = '''# Copyright 2022 Google LLC
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
|
||
|
|
"""Generated python gRPC interfaces."""
|
||
|
|
'''
|
||
|
|
|
||
|
|
_UTILS_PY = f'''{_HEADER}
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import queue
|
||
|
|
import grpc
|
||
|
|
import sys
|
||
|
|
|
||
|
|
from typing import Any, AsyncIterable, AsyncIterator, Generic, Iterator, TypeVar
|
||
|
|
|
||
|
|
|
||
|
|
_T_co = TypeVar('_T_co', covariant=True)
|
||
|
|
_T = TypeVar('_T')
|
||
|
|
|
||
|
|
|
||
|
|
class Stream(Iterator[_T_co], grpc.RpcContext): ...
|
||
|
|
|
||
|
|
|
||
|
|
class AioStream(AsyncIterable[_T_co], grpc.RpcContext): ...
|
||
|
|
|
||
|
|
|
||
|
|
class Sender(Iterator[_T]):
|
||
|
|
if sys.version_info >= (3, 8):
|
||
|
|
_inner: queue.Queue[_T]
|
||
|
|
else:
|
||
|
|
_inner: queue.Queue
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._inner = queue.Queue()
|
||
|
|
|
||
|
|
def __iter__(self) -> Iterator[_T]:
|
||
|
|
return self
|
||
|
|
|
||
|
|
def __next__(self) -> _T:
|
||
|
|
return self._inner.get()
|
||
|
|
|
||
|
|
def send(self, item: _T) -> None:
|
||
|
|
self._inner.put(item)
|
||
|
|
|
||
|
|
|
||
|
|
class AioSender(AsyncIterator[_T]):
|
||
|
|
if sys.version_info >= (3, 8):
|
||
|
|
_inner: asyncio.Queue[_T]
|
||
|
|
else:
|
||
|
|
_inner: asyncio.Queue
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._inner = asyncio.Queue()
|
||
|
|
|
||
|
|
def __iter__(self) -> AsyncIterator[_T]:
|
||
|
|
return self
|
||
|
|
|
||
|
|
async def __anext__(self) -> _T:
|
||
|
|
return await self._inner.get()
|
||
|
|
|
||
|
|
async def send(self, item: _T) -> None:
|
||
|
|
await self._inner.put(item)
|
||
|
|
|
||
|
|
def send_nowait(self, item: _T) -> None:
|
||
|
|
self._inner.put_nowait(item)
|
||
|
|
|
||
|
|
|
||
|
|
class StreamStream(Generic[_T, _T_co], Iterator[_T_co], grpc.RpcContext):
|
||
|
|
_sender: Sender[_T]
|
||
|
|
_receiver: Stream[_T_co]
|
||
|
|
|
||
|
|
def __init__(self, sender: Sender[_T], receiver: Stream[_T_co]) -> None:
|
||
|
|
self._sender = sender
|
||
|
|
self._receiver = receiver
|
||
|
|
|
||
|
|
def send(self, item: _T) -> None:
|
||
|
|
self._sender.send(item)
|
||
|
|
|
||
|
|
def __iter__(self) -> Iterator[_T_co]:
|
||
|
|
return self._receiver.__iter__()
|
||
|
|
|
||
|
|
def __next__(self) -> _T_co:
|
||
|
|
return self._receiver.__next__()
|
||
|
|
|
||
|
|
def is_active(self) -> bool:
|
||
|
|
return self._receiver.is_active() # type: ignore
|
||
|
|
|
||
|
|
def time_remaining(self) -> float:
|
||
|
|
return self._receiver.time_remaining() # type: ignore
|
||
|
|
|
||
|
|
def cancel(self) -> None:
|
||
|
|
self._receiver.cancel() # type: ignore
|
||
|
|
|
||
|
|
def add_callback(self, callback: Any) -> None:
|
||
|
|
self._receiver.add_callback(callback) # type: ignore
|
||
|
|
|
||
|
|
|
||
|
|
class AioStreamStream(Generic[_T, _T_co], AsyncIterator[_T_co], grpc.RpcContext):
|
||
|
|
_sender: AioSender[_T]
|
||
|
|
_receiver: AioStream[_T_co]
|
||
|
|
|
||
|
|
def __init__(self, sender: AioSender[_T], receiver: AioStream[_T_co]) -> None:
|
||
|
|
self._sender = sender
|
||
|
|
self._receiver = receiver
|
||
|
|
|
||
|
|
def __aiter__(self) -> AsyncIterator[_T_co]:
|
||
|
|
return self._receiver.__aiter__()
|
||
|
|
|
||
|
|
async def __anext__(self) -> _T_co:
|
||
|
|
return await self._receiver.__aiter__().__anext__()
|
||
|
|
|
||
|
|
async def send(self, item: _T) -> None:
|
||
|
|
await self._sender.send(item)
|
||
|
|
|
||
|
|
def send_nowait(self, item: _T) -> None:
|
||
|
|
self._sender.send_nowait(item)
|
||
|
|
|
||
|
|
def is_active(self) -> bool:
|
||
|
|
return self._receiver.is_active() # type: ignore
|
||
|
|
|
||
|
|
def time_remaining(self) -> float:
|
||
|
|
return self._receiver.time_remaining() # type: ignore
|
||
|
|
|
||
|
|
def cancel(self) -> None:
|
||
|
|
self._receiver.cancel() # type: ignore
|
||
|
|
|
||
|
|
def add_callback(self, callback: Any) -> None:
|
||
|
|
self._receiver.add_callback(callback) # type: ignore
|
||
|
|
'''
|
||
|
|
|
||
|
|
|
||
|
|
_FILES: List[CodeGeneratorResponse.File] = []
|
||
|
|
_UTILS_FILES: Set[str] = set()
|
||
|
|
|
||
|
|
|
||
|
|
for file_name in _REQUEST.file_to_generate:
|
||
|
|
file: FileDescriptorProto = next(filter(lambda x: x.name == file_name, _REQUEST.proto_file))
|
||
|
|
|
||
|
|
_FILES.append(CodeGeneratorResponse.File(
|
||
|
|
name=file.name.replace('.proto', '_pb2.py'),
|
||
|
|
insertion_point=f'module_scope',
|
||
|
|
content='def unwrap(x):\n assert x\n return x\n'
|
||
|
|
))
|
||
|
|
|
||
|
|
pyi_imports: List[str] = []
|
||
|
|
grpc_imports: List[str] = ['import grpc']
|
||
|
|
grpc_aio_imports: List[str] = ['import grpc', 'import grpc.aio']
|
||
|
|
|
||
|
|
enums = '\n'.join(sum([generate_enum(pyi_imports, file, enum, _FILES) for enum in file.enum_type], []))
|
||
|
|
messages = '\n'.join(sum([generate_message(pyi_imports, file, message, _FILES) for message in file.message_type], []))
|
||
|
|
|
||
|
|
services = '\n'.join(sum([generate_service(grpc_imports, file, service) for service in file.service], []))
|
||
|
|
aio_services = '\n'.join(sum([generate_service(grpc_aio_imports, file, service, False) for service in file.service], []))
|
||
|
|
|
||
|
|
servicers = '\n'.join(sum([generate_servicer(grpc_imports, file, service) for service in file.service], []))
|
||
|
|
aio_servicers = '\n'.join(sum([generate_servicer(grpc_aio_imports, file, service, False) for service in file.service], []))
|
||
|
|
|
||
|
|
add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_imports, file, service) for service in file.service], []))
|
||
|
|
aio_add_servicer_methods = '\n'.join(sum([generate_add_servicer_to_server_method(grpc_aio_imports, file, service, False) for service in file.service], []))
|
||
|
|
|
||
|
|
pyi_imports.sort()
|
||
|
|
grpc_imports.sort()
|
||
|
|
grpc_aio_imports.sort()
|
||
|
|
|
||
|
|
pyi_imports_str: str = '\n'.join(pyi_imports)
|
||
|
|
grpc_imports_str: str = '\n'.join(grpc_imports)
|
||
|
|
grpc_aio_imports_str: str = '\n'.join(grpc_aio_imports)
|
||
|
|
|
||
|
|
utils_filename = file_name.replace(os.path.basename(file_name), '_utils.py')
|
||
|
|
if utils_filename not in _UTILS_FILES:
|
||
|
|
_UTILS_FILES.add(utils_filename)
|
||
|
|
_FILES.extend([
|
||
|
|
CodeGeneratorResponse.File(
|
||
|
|
name=utils_filename,
|
||
|
|
content=_UTILS_PY,
|
||
|
|
)
|
||
|
|
])
|
||
|
|
|
||
|
|
_FILES.extend([
|
||
|
|
CodeGeneratorResponse.File(
|
||
|
|
name=file.name.replace('.proto', '_pb2.pyi'),
|
||
|
|
content=f'{_HEADER}\n\n{pyi_imports_str}\n\n{enums}\n\n{messages}\n'
|
||
|
|
),
|
||
|
|
CodeGeneratorResponse.File(
|
||
|
|
name=file_name.replace('.proto', '_grpc.py'),
|
||
|
|
content=f'{_HEADER}\n\n{grpc_imports_str}\n\n{services}\n\n{servicers}\n\n{add_servicer_methods}'
|
||
|
|
),
|
||
|
|
CodeGeneratorResponse.File(
|
||
|
|
name=file_name.replace('.proto', '_grpc_aio.py'),
|
||
|
|
content=f'{_HEADER}\n\n{grpc_aio_imports_str}\n\n{aio_services}\n\n{aio_servicers}\n\n{aio_add_servicer_methods}'
|
||
|
|
)
|
||
|
|
])
|
||
|
|
|
||
|
|
|
||
|
|
sys.stdout.buffer.write(CodeGeneratorResponse(file=_FILES).SerializeToString())
|