# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Internal network layer helper methods.""" import datetime import errno import select import struct import threading _HAS_POLL = True _EVENT_MASK = 0 try: from select import poll _EVENT_MASK = ( select.POLLIN | select.POLLPRI | select.POLLERR | select.POLLHUP) except ImportError: _HAS_POLL = False try: from select import error as _SELECT_ERROR except ImportError: _SELECT_ERROR = OSError from bson.py3compat import PY3 from pymongo import helpers, message from pymongo.common import MAX_MESSAGE_SIZE from pymongo.compression_support import decompress, _NO_COMPRESSION from pymongo.errors import (AutoReconnect, NotMasterError, OperationFailure, ProtocolError) from pymongo.message import _UNPACK_REPLY _UNPACK_HEADER = struct.Struct(" max_bson_size): message._raise_document_too_large(name, size, max_bson_size) else: request_id, msg, size = message.query( flags, ns, 0, -1, spec, None, codec_options, check_keys, compression_ctx) if (max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD): message._raise_document_too_large( name, size, max_bson_size + message._COMMAND_OVERHEAD) if publish: encoding_duration = datetime.datetime.now() - start listeners.publish_command_start(orig, dbname, request_id, address) start = datetime.datetime.now() try: sock.sendall(msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. response_doc = {"ok": 1} else: reply = receive_message(sock, request_id) unpacked_docs = reply.unpack_response(codec_options=codec_options) response_doc = unpacked_docs[0] if client: client._receive_cluster_time(response_doc, session) if check: helpers._check_command_response( response_doc, None, allowable_errors, parse_write_concern_error=parse_write_concern_error) except Exception as exc: if publish: duration = (datetime.datetime.now() - start) + encoding_duration if isinstance(exc, (NotMasterError, OperationFailure)): failure = exc.details else: failure = message._convert_exception(exc) listeners.publish_command_failure( duration, failure, name, request_id, address) raise if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( duration, response_doc, name, request_id, address) return response_doc _UNPACK_COMPRESSION_HEADER = struct.Struct(" max_message_size: raise ProtocolError("Message length (%r) is larger than server max " "message size (%r)" % (length, max_message_size)) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( _receive_data_on_socket(sock, 9)) data = decompress( _receive_data_on_socket(sock, length - 25), compressor_id) else: data = _receive_data_on_socket(sock, length - 16) try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: raise ProtocolError("Got opcode %r but expected " "%r" % (op_code, _UNPACK_REPLY.keys())) return unpack_reply(data) # memoryview was introduced in Python 2.7 but we only use it on Python 3 # because before 2.7.4 the struct module did not support memoryview: # https://bugs.python.org/issue10212. # In Jython, using slice assignment on a memoryview results in a # NullPointerException. if not PY3: def _receive_data_on_socket(sock, length): buf = bytearray(length) i = 0 while length: try: chunk = sock.recv(length) except (IOError, OSError) as exc: if _errno_from_exception(exc) == errno.EINTR: continue raise if chunk == b"": raise AutoReconnect("connection closed") buf[i:i + len(chunk)] = chunk i += len(chunk) length -= len(chunk) return bytes(buf) else: def _receive_data_on_socket(sock, length): buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 while bytes_read < length: try: chunk_length = sock.recv_into(mv[bytes_read:]) except (IOError, OSError) as exc: if _errno_from_exception(exc) == errno.EINTR: continue raise if chunk_length == 0: raise AutoReconnect("connection closed") bytes_read += chunk_length return mv def _errno_from_exception(exc): if hasattr(exc, 'errno'): return exc.errno elif exc.args: return exc.args[0] else: return None class SocketChecker(object): def __init__(self): if _HAS_POLL: self._lock = threading.Lock() self._poller = poll() else: self._lock = None self._poller = None def socket_closed(self, sock): """Return True if we know socket has been closed, False otherwise. """ while True: try: if self._poller: with self._lock: self._poller.register(sock, _EVENT_MASK) try: rd = self._poller.poll(0) finally: self._poller.unregister(sock) else: rd, _, _ = select.select([sock], [], [], 0) except (RuntimeError, KeyError): # RuntimeError is raised during a concurrent poll. KeyError # is raised by unregister if the socket is not in the poller. # These errors should not be possible since we protect the # poller with a mutex. raise except ValueError: # ValueError is raised by register/unregister/select if the # socket file descriptor is negative or outside the range for # select (> 1023). return True except (_SELECT_ERROR, IOError) as exc: if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN): continue return True except Exception: # Any other exceptions should be attributed to a closed # or invalid socket. return True return len(rd) > 0