# Copyright 2019 gRPC authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Invocation-side implementation of gRPC Asyncio Python.""" import asyncio import sys from typing import Any, Iterable, Optional, Sequence, List import grpc from grpc import _common, _compression, _grpcio_metadata from grpc._cython import cygrpc from . import _base_call, _base_channel from ._call import (StreamStreamCall, StreamUnaryCall, UnaryStreamCall, UnaryUnaryCall) from ._interceptor import ( InterceptedUnaryUnaryCall, InterceptedUnaryStreamCall, InterceptedStreamUnaryCall, InterceptedStreamStreamCall, ClientInterceptor, UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor, StreamStreamClientInterceptor) from ._metadata import Metadata from ._typing import (ChannelArgumentType, DeserializingFunction, SerializingFunction, RequestIterableType) from ._utils import _timeout_to_deadline _USER_AGENT = 'grpc-python-asyncio/{}'.format(_grpcio_metadata.__version__) if sys.version_info[1] < 7: def _all_tasks() -> Iterable[asyncio.Task]: return asyncio.Task.all_tasks() else: def _all_tasks() -> Iterable[asyncio.Task]: return asyncio.all_tasks() def _augment_channel_arguments(base_options: ChannelArgumentType, compression: Optional[grpc.Compression]): compression_channel_argument = _compression.create_channel_option( compression) user_agent_channel_argument = (( cygrpc.ChannelArgKey.primary_user_agent_string, _USER_AGENT, ),) return tuple(base_options ) + compression_channel_argument + user_agent_channel_argument class _BaseMultiCallable: """Base class of all multi callable objects. Handles the initialization logic and stores common attributes. """ _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _method: bytes _request_serializer: SerializingFunction _response_deserializer: DeserializingFunction _interceptors: Optional[Sequence[ClientInterceptor]] _loop: asyncio.AbstractEventLoop # pylint: disable=too-many-arguments def __init__( self, channel: cygrpc.AioChannel, method: bytes, request_serializer: SerializingFunction, response_deserializer: DeserializingFunction, interceptors: Optional[Sequence[ClientInterceptor]], loop: asyncio.AbstractEventLoop, ) -> None: self._loop = loop self._channel = channel self._method = method self._request_serializer = request_serializer self._response_deserializer = response_deserializer self._interceptors = interceptors @staticmethod def _init_metadata(metadata: Optional[Metadata] = None, compression: Optional[grpc.Compression] = None ) -> Metadata: """Based on the provided values for or initialise the final metadata, as it should be used for the current call. """ metadata = metadata or Metadata() if compression: metadata = Metadata( *_compression.augment_metadata(metadata, compression)) return metadata class UnaryUnaryMultiCallable(_BaseMultiCallable, _base_channel.UnaryUnaryMultiCallable): def __call__(self, request: Any, *, timeout: Optional[float] = None, metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryUnaryCall: metadata = self._init_metadata(metadata, compression) if not self._interceptors: call = UnaryUnaryCall(request, _timeout_to_deadline(timeout), metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) else: call = InterceptedUnaryUnaryCall( self._interceptors, request, timeout, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) return call class UnaryStreamMultiCallable(_BaseMultiCallable, _base_channel.UnaryStreamMultiCallable): def __call__(self, request: Any, *, timeout: Optional[float] = None, metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.UnaryStreamCall: metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: call = UnaryStreamCall(request, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) else: call = InterceptedUnaryStreamCall( self._interceptors, request, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) return call class StreamUnaryMultiCallable(_BaseMultiCallable, _base_channel.StreamUnaryMultiCallable): def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamUnaryCall: metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: call = StreamUnaryCall(request_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) else: call = InterceptedStreamUnaryCall( self._interceptors, request_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) return call class StreamStreamMultiCallable(_BaseMultiCallable, _base_channel.StreamStreamMultiCallable): def __call__(self, request_iterator: Optional[RequestIterableType] = None, timeout: Optional[float] = None, metadata: Optional[Metadata] = None, credentials: Optional[grpc.CallCredentials] = None, wait_for_ready: Optional[bool] = None, compression: Optional[grpc.Compression] = None ) -> _base_call.StreamStreamCall: metadata = self._init_metadata(metadata, compression) deadline = _timeout_to_deadline(timeout) if not self._interceptors: call = StreamStreamCall(request_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) else: call = InterceptedStreamStreamCall( self._interceptors, request_iterator, deadline, metadata, credentials, wait_for_ready, self._channel, self._method, self._request_serializer, self._response_deserializer, self._loop) return call class Channel(_base_channel.Channel): _loop: asyncio.AbstractEventLoop _channel: cygrpc.AioChannel _unary_unary_interceptors: List[UnaryUnaryClientInterceptor] _unary_stream_interceptors: List[UnaryStreamClientInterceptor] _stream_unary_interceptors: List[StreamUnaryClientInterceptor] _stream_stream_interceptors: List[StreamStreamClientInterceptor] def __init__(self, target: str, options: ChannelArgumentType, credentials: Optional[grpc.ChannelCredentials], compression: Optional[grpc.Compression], interceptors: Optional[Sequence[ClientInterceptor]]): """Constructor. Args: target: The target to which to connect. options: Configuration options for the channel. credentials: A cygrpc.ChannelCredentials or None. compression: An optional value indicating the compression method to be used over the lifetime of the channel. interceptors: An optional list of interceptors that would be used for intercepting any RPC executed with that channel. """ self._unary_unary_interceptors = [] self._unary_stream_interceptors = [] self._stream_unary_interceptors = [] self._stream_stream_interceptors = [] if interceptors is not None: for interceptor in interceptors: if isinstance(interceptor, UnaryUnaryClientInterceptor): self._unary_unary_interceptors.append(interceptor) elif isinstance(interceptor, UnaryStreamClientInterceptor): self._unary_stream_interceptors.append(interceptor) elif isinstance(interceptor, StreamUnaryClientInterceptor): self._stream_unary_interceptors.append(interceptor) elif isinstance(interceptor, StreamStreamClientInterceptor): self._stream_stream_interceptors.append(interceptor) else: raise ValueError( "Interceptor {} must be ".format(interceptor) + "{} or ".format(UnaryUnaryClientInterceptor.__name__) + "{} or ".format(UnaryStreamClientInterceptor.__name__) + "{} or ".format(StreamUnaryClientInterceptor.__name__) + "{}. ".format(StreamStreamClientInterceptor.__name__)) self._loop = cygrpc.get_working_loop() self._channel = cygrpc.AioChannel( _common.encode(target), _augment_channel_arguments(options, compression), credentials, self._loop) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self._close(None) async def _close(self, grace): # pylint: disable=too-many-branches if self._channel.closed(): return # No new calls will be accepted by the Cython channel. self._channel.closing() # Iterate through running tasks tasks = _all_tasks() calls = [] call_tasks = [] for task in tasks: try: stack = task.get_stack(limit=1) except AttributeError as attribute_error: # NOTE(lidiz) tl;dr: If the Task is created with a CPython # object, it will trigger AttributeError. # # In the global finalizer, the event loop schedules # a CPython PyAsyncGenAThrow object. # https://github.com/python/cpython/blob/00e45877e33d32bb61aa13a2033e3bba370bda4d/Lib/asyncio/base_events.py#L484 # # However, the PyAsyncGenAThrow object is written in C and # failed to include the normal Python frame objects. Hence, # this exception is a false negative, and it is safe to ignore # the failure. It is fixed by https://github.com/python/cpython/pull/18669, # but not available until 3.9 or 3.8.3. So, we have to keep it # for a while. # TODO(lidiz) drop this hack after 3.8 deprecation if 'frame' in str(attribute_error): continue else: raise # If the Task is created by a C-extension, the stack will be empty. if not stack: continue # Locate ones created by `aio.Call`. frame = stack[0] candidate = frame.f_locals.get('self') if candidate: if isinstance(candidate, _base_call.Call): if hasattr(candidate, '_channel'): # For intercepted Call object if candidate._channel is not self._channel: continue elif hasattr(candidate, '_cython_call'): # For normal Call object if candidate._cython_call._channel is not self._channel: continue else: # Unidentified Call object raise cygrpc.InternalError( f'Unrecognized call object: {candidate}') calls.append(candidate) call_tasks.append(task) # If needed, try to wait for them to finish. # Call objects are not always awaitables. if grace and call_tasks: await asyncio.wait(call_tasks, timeout=grace, loop=self._loop) # Time to cancel existing calls. for call in calls: call.cancel() # Destroy the channel self._channel.close() async def close(self, grace: Optional[float] = None): await self._close(grace) def get_state(self, try_to_connect: bool = False) -> grpc.ChannelConnectivity: result = self._channel.check_connectivity_state(try_to_connect) return _common.CYGRPC_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY[result] async def wait_for_state_change( self, last_observed_state: grpc.ChannelConnectivity, ) -> None: assert await self._channel.watch_connectivity_state( last_observed_state.value[0], None) async def channel_ready(self) -> None: state = self.get_state(try_to_connect=True) while state != grpc.ChannelConnectivity.READY: await self.wait_for_state_change(state) state = self.get_state(try_to_connect=True) def unary_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryUnaryMultiCallable: return UnaryUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, self._unary_unary_interceptors, self._loop) def unary_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> UnaryStreamMultiCallable: return UnaryStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, self._unary_stream_interceptors, self._loop) def stream_unary( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamUnaryMultiCallable: return StreamUnaryMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, self._stream_unary_interceptors, self._loop) def stream_stream( self, method: str, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None ) -> StreamStreamMultiCallable: return StreamStreamMultiCallable(self._channel, _common.encode(method), request_serializer, response_deserializer, self._stream_stream_interceptors, self._loop) def insecure_channel( target: str, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, interceptors: Optional[Sequence[ClientInterceptor]] = None): """Creates an insecure asynchronous Channel to a server. Args: target: The server address options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core runtime) to configure the channel. compression: An optional value indicating the compression method to be used over the lifetime of the channel. This is an EXPERIMENTAL option. interceptors: An optional sequence of interceptors that will be executed for any call executed with this channel. Returns: A Channel. """ return Channel(target, () if options is None else options, None, compression, interceptors) def secure_channel(target: str, credentials: grpc.ChannelCredentials, options: Optional[ChannelArgumentType] = None, compression: Optional[grpc.Compression] = None, interceptors: Optional[Sequence[ClientInterceptor]] = None): """Creates a secure asynchronous Channel to a server. Args: target: The server address. credentials: A ChannelCredentials instance. options: An optional list of key-value pairs (:term:`channel_arguments` in gRPC Core runtime) to configure the channel. compression: An optional value indicating the compression method to be used over the lifetime of the channel. This is an EXPERIMENTAL option. interceptors: An optional sequence of interceptors that will be executed for any call executed with this channel. Returns: An aio.Channel. """ return Channel(target, () if options is None else options, credentials._credentials, compression, interceptors)