# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for creating `messages <http://www.mongodb.org/display/DOCS/Mongo+Wire+Protocol>`_ to be sent to MongoDB. .. note:: This module is for internal use and is generally not needed by application developers. """ import datetime import random import struct import bson from bson import (CodecOptions, _bson_to_dict, _dict_to_bson, _make_c_string) from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.py3compat import b, StringIO from bson.son import SON try: from pymongo import _cmessage _use_c = True except ImportError: _use_c = False from pymongo.errors import (ConfigurationError, CursorNotFound, DocumentTooLarge, ExecutionTimeout, InvalidOperation, NotMasterError, OperationFailure, ProtocolError) from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 # Overhead allowed for encoded command documents. _COMMAND_OVERHEAD = 16382 _INSERT = 0 _UPDATE = 1 _DELETE = 2 _EMPTY = b'' _BSONOBJ = b'\x03' _ZERO_8 = b'\x00' _ZERO_16 = b'\x00\x00' _ZERO_32 = b'\x00\x00\x00\x00' _ZERO_64 = b'\x00\x00\x00\x00\x00\x00\x00\x00' _SKIPLIM = b'\x00\x00\x00\x00\xff\xff\xff\xff' _OP_MAP = { _INSERT: b'\x04documents\x00\x00\x00\x00\x00', _UPDATE: b'\x04updates\x00\x00\x00\x00\x00', _DELETE: b'\x04deletes\x00\x00\x00\x00\x00', } _FIELD_MAP = { 'insert': 'documents', 'update': 'updates', 'delete': 'deletes' } _UJOIN = u"%s.%s" _UNICODE_REPLACE_CODEC_OPTIONS = CodecOptions( unicode_decode_error_handler='replace') def _randint(): """Generate a pseudo random 32 bit integer.""" return random.randint(MIN_INT32, MAX_INT32) def _maybe_add_read_preference(spec, read_preference): """Add $readPreference to spec when appropriate.""" mode = read_preference.mode tag_sets = read_preference.tag_sets max_staleness = read_preference.max_staleness # Only add $readPreference if it's something other than primary to avoid # problems with mongos versions that don't support read preferences. Also, # for maximum backwards compatibility, don't add $readPreference for # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting # the slaveOkay bit has the same effect). if mode and ( mode != ReadPreference.SECONDARY_PREFERRED.mode or tag_sets != [{}] or max_staleness != -1): if "$query" not in spec: spec = SON([("$query", spec)]) spec["$readPreference"] = read_preference.document return spec def _convert_exception(exception): """Convert an Exception into a failure document for publishing.""" return {'errmsg': str(exception), 'errtype': exception.__class__.__name__} def _convert_write_result(operation, command, result): """Convert a legacy write result to write commmand format.""" # Based on _merge_legacy from bulk.py affected = result.get("n", 0) res = {"ok": 1, "n": affected} errmsg = result.get("errmsg", result.get("err", "")) if errmsg: # The write was successful on at least the primary so don't return. if result.get("wtimeout"): res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} else: # The write failed. error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} if "errInfo" in result: error["errInfo"] = result["errInfo"] res["writeErrors"] = [error] return res if operation == "insert": # GLE result for insert is always 0 in most MongoDB versions. res["n"] = len(command['documents']) elif operation == "update": if "upserted" in result: res["upserted"] = [{"index": 0, "_id": result["upserted"]}] # Versions of MongoDB before 2.6 don't return the _id for an # upsert if _id is not an ObjectId. elif result.get("updatedExisting") is False and affected == 1: # If _id is in both the update document *and* the query spec # the update document _id takes precedence. update = command['updates'][0] _id = update["u"].get("_id", update["q"].get("_id")) res["upserted"] = [{"index": 0, "_id": _id}] return res _OPTIONS = SON([ ('tailable', 2), ('oplogReplay', 8), ('noCursorTimeout', 16), ('awaitData', 32), ('allowPartialResults', 128)]) _MODIFIERS = SON([ ('$query', 'filter'), ('$orderby', 'sort'), ('$hint', 'hint'), ('$comment', 'comment'), ('$maxScan', 'maxScan'), ('$maxTimeMS', 'maxTimeMS'), ('$max', 'max'), ('$min', 'min'), ('$returnKey', 'returnKey'), ('$showRecordId', 'showRecordId'), ('$showDiskLoc', 'showRecordId'), # <= MongoDb 3.0 ('$snapshot', 'snapshot')]) def _gen_find_command(coll, spec, projection, skip, limit, batch_size, options, read_concern, collation=None, session=None): """Generate a find command document.""" cmd = SON([('find', coll)]) if '$query' in spec: cmd.update([(_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) for key, val in spec.items()]) if '$explain' in cmd: cmd.pop('$explain') if '$readPreference' in cmd: cmd.pop('$readPreference') else: cmd['filter'] = spec if projection: cmd['projection'] = projection if skip: cmd['skip'] = skip if limit: cmd['limit'] = abs(limit) if limit < 0: cmd['singleBatch'] = True if batch_size: cmd['batchSize'] = batch_size if read_concern.level and not (session and session._in_transaction): cmd['readConcern'] = read_concern.document if collation: cmd['collation'] = collation if options: cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) return cmd def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms): """Generate a getMore command document.""" cmd = SON([('getMore', cursor_id), ('collection', coll)]) if batch_size: cmd['batchSize'] = batch_size if max_await_time_ms is not None: cmd['maxTimeMS'] = max_await_time_ms return cmd class _Query(object): """A query operation.""" __slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec', 'fields', 'codec_options', 'read_preference', 'limit', 'batch_size', 'name', 'read_concern', 'collation', 'session', 'client', '_as_command') def __init__(self, flags, db, coll, ntoskip, spec, fields, codec_options, read_preference, limit, batch_size, read_concern, collation, session, client): self.flags = flags self.db = db self.coll = coll self.ntoskip = ntoskip self.spec = spec self.fields = fields self.codec_options = codec_options self.read_preference = read_preference self.read_concern = read_concern self.limit = limit self.batch_size = batch_size self.collation = collation self.session = session self.client = client self.name = 'find' self._as_command = None def use_command(self, sock_info, exhaust): use_find_cmd = False if sock_info.max_wire_version >= 4: if not exhaust: use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( 'read concern level of %s is not valid ' 'with a max wire version of %d.' % (self.read_concern.level, sock_info.max_wire_version)) if sock_info.max_wire_version < 5 and self.collation is not None: raise ConfigurationError( 'Specifying a collation is unsupported with a max wire ' 'version of %d.' % (sock_info.max_wire_version,)) sock_info.validate_session(self.client, self.session) return use_find_cmd def as_command(self, sock_info): """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. if self._as_command is not None: return self._as_command explain = '$explain' in self.spec cmd = _gen_find_command( self.coll, self.spec, self.fields, self.ntoskip, self.limit, self.batch_size, self.flags, self.read_concern, self.collation, self.session) if explain: self.name = 'explain' cmd = SON([('explain', cmd)]) session = self.session if session: session._apply_to(cmd, False, self.read_preference) # Explain does not support readConcern. if (not explain and session.options.causal_consistency and session.operation_time is not None and not session._in_transaction): cmd.setdefault( 'readConcern', {})[ 'afterClusterTime'] = session.operation_time sock_info.send_cluster_time(cmd, session, self.client) self._as_command = cmd, self.db return self._as_command def get_message(self, set_slave_ok, sock_info, use_cmd=False): """Get a query message, possibly setting the slaveOk bit.""" if set_slave_ok: # Set the slaveOk bit. flags = self.flags | 4 else: flags = self.flags ns = _UJOIN % (self.db, self.coll) spec = self.spec if use_cmd: spec = self.as_command(sock_info)[0] if sock_info.op_msg_enabled: request_id, msg, size, _ = _op_msg( 0, spec, self.db, self.read_preference, set_slave_ok, False, self.codec_options, ctx=sock_info.compression_context) return request_id, msg, size ns = _UJOIN % (self.db, "$cmd") ntoreturn = -1 # All DB commands return 1 document else: # OP_QUERY treats ntoreturn of -1 and 1 the same, return # one document and close the cursor. We have to use 2 for # batch size if 1 is specified. ntoreturn = self.batch_size == 1 and 2 or self.batch_size if self.limit: if ntoreturn: ntoreturn = min(self.limit, ntoreturn) else: ntoreturn = self.limit if sock_info.is_mongos: spec = _maybe_add_read_preference(spec, self.read_preference) return query(flags, ns, self.ntoskip, ntoreturn, spec, None if use_cmd else self.fields, self.codec_options, ctx=sock_info.compression_context) class _GetMore(object): """A getmore operation.""" __slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms', 'codec_options', 'read_preference', 'session', 'client', '_as_command') name = 'getMore' def __init__(self, db, coll, ntoreturn, cursor_id, codec_options, read_preference, session, client, max_await_time_ms=None): self.db = db self.coll = coll self.ntoreturn = ntoreturn self.cursor_id = cursor_id self.codec_options = codec_options self.read_preference = read_preference self.session = session self.client = client self.max_await_time_ms = max_await_time_ms self._as_command = None def use_command(self, sock_info, exhaust): sock_info.validate_session(self.client, self.session) return sock_info.max_wire_version >= 4 and not exhaust def as_command(self, sock_info): """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: return self._as_command cmd = _gen_get_more_command(self.cursor_id, self.coll, self.ntoreturn, self.max_await_time_ms) if self.session: self.session._apply_to(cmd, False, self.read_preference) sock_info.send_cluster_time(cmd, self.session, self.client) self._as_command = cmd, self.db return self._as_command def get_message(self, dummy0, sock_info, use_cmd=False): """Get a getmore message.""" ns = _UJOIN % (self.db, self.coll) ctx = sock_info.compression_context if use_cmd: spec = self.as_command(sock_info)[0] if sock_info.op_msg_enabled: request_id, msg, size, _ = _op_msg( 0, spec, self.db, ReadPreference.PRIMARY, False, False, self.codec_options, ctx=sock_info.compression_context) return request_id, msg, size ns = _UJOIN % (self.db, "$cmd") return query(0, ns, 0, -1, spec, None, self.codec_options, ctx=ctx) return get_more(ns, self.ntoreturn, self.cursor_id, ctx) # TODO: Use OP_MSG once the server is able to respond with document streams. class _RawBatchQuery(_Query): def use_command(self, socket_info, exhaust): # Compatibility checks. super(_RawBatchQuery, self).use_command(socket_info, exhaust) return False def get_message(self, set_slave_ok, sock_info, use_cmd=False): # Always pass False for use_cmd. return super(_RawBatchQuery, self).get_message( set_slave_ok, sock_info, False) class _RawBatchGetMore(_GetMore): def use_command(self, socket_info, exhaust): return False def get_message(self, set_slave_ok, sock_info, use_cmd=False): # Always pass False for use_cmd. return super(_RawBatchGetMore, self).get_message( set_slave_ok, sock_info, False) class _CursorAddress(tuple): """The server address (host, port) of a cursor, with namespace property.""" def __new__(cls, address, namespace): self = tuple.__new__(cls, address) self.__namespace = namespace return self @property def namespace(self): """The namespace this cursor.""" return self.__namespace def __hash__(self): # Two _CursorAddress instances with different namespaces # must not hash the same. return (self + (self.__namespace,)).__hash__() def __eq__(self, other): if isinstance(other, _CursorAddress): return (tuple(self) == tuple(other) and self.namespace == other.namespace) return NotImplemented def __ne__(self, other): return not self == other _pack_compression_header = struct.Struct("<iiiiiiB").pack _COMPRESSION_HEADER_SIZE = 25 def _compress(operation, data, ctx): """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" compressed = ctx.compress(data) request_id = _randint() header = _pack_compression_header( _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length request_id, # Request id 0, # responseTo 2012, # operation id operation, # original operation id len(data), # uncompressed message length ctx.compressor_id) # compressor id return request_id, header + compressed def __last_error(namespace, args): """Data to send to do a lastError. """ cmd = SON([("getlasterror", 1)]) cmd.update(args) splitns = namespace.split('.', 1) return query(0, splitns[0] + '.$cmd', 0, -1, cmd, None, DEFAULT_CODEC_OPTIONS) _pack_header = struct.Struct("<iiii").pack def __pack_message(operation, data): """Takes message data and adds a message header based on the operation. Returns the resultant message string. """ rid = _randint() message = _pack_header(16 + len(data), rid, 0, operation) return rid, message + data _pack_int = struct.Struct("<i").pack def _insert(collection_name, docs, check_keys, flags, opts): """Get an OP_INSERT message""" encode = _dict_to_bson # Make local. Uses extensions. if len(docs) == 1: encoded = encode(docs[0], check_keys, opts) return b"".join([ b"\x00\x00\x00\x00", # Flags don't matter for one doc. _make_c_string(collection_name), encoded]), len(encoded) encoded = [encode(doc, check_keys, opts) for doc in docs] if not encoded: raise InvalidOperation("cannot do an empty bulk insert") return b"".join([ _pack_int(flags), _make_c_string(collection_name), b"".join(encoded)]), max(map(len, encoded)) def _insert_compressed( collection_name, docs, check_keys, continue_on_error, opts, ctx): """Internal compressed unacknowledged insert message helper.""" op_insert, max_bson_size = _insert( collection_name, docs, check_keys, continue_on_error, opts) rid, msg = _compress(2002, op_insert, ctx) return rid, msg, max_bson_size def _insert_uncompressed(collection_name, docs, check_keys, safe, last_error_args, continue_on_error, opts): """Internal insert message helper.""" op_insert, max_bson_size = _insert( collection_name, docs, check_keys, continue_on_error, opts) rid, msg = __pack_message(2002, op_insert) if safe: rid, gle, _ = __last_error(collection_name, last_error_args) return rid, msg + gle, max_bson_size return rid, msg, max_bson_size if _use_c: _insert_uncompressed = _cmessage._insert_message def insert(collection_name, docs, check_keys, safe, last_error_args, continue_on_error, opts, ctx=None): """Get an **insert** message.""" if ctx: return _insert_compressed( collection_name, docs, check_keys, continue_on_error, opts, ctx) return _insert_uncompressed(collection_name, docs, check_keys, safe, last_error_args, continue_on_error, opts) def _update(collection_name, upsert, multi, spec, doc, check_keys, opts): """Get an OP_UPDATE message.""" flags = 0 if upsert: flags += 1 if multi: flags += 2 encode = _dict_to_bson # Make local. Uses extensions. encoded_update = encode(doc, check_keys, opts) return b"".join([ _ZERO_32, _make_c_string(collection_name), _pack_int(flags), encode(spec, False, opts), encoded_update]), len(encoded_update) def _update_compressed( collection_name, upsert, multi, spec, doc, check_keys, opts, ctx): """Internal compressed unacknowledged update message helper.""" op_update, max_bson_size = _update( collection_name, upsert, multi, spec, doc, check_keys, opts) rid, msg = _compress(2001, op_update, ctx) return rid, msg, max_bson_size def _update_uncompressed(collection_name, upsert, multi, spec, doc, safe, last_error_args, check_keys, opts): """Internal update message helper.""" op_update, max_bson_size = _update( collection_name, upsert, multi, spec, doc, check_keys, opts) rid, msg = __pack_message(2001, op_update) if safe: rid, gle, _ = __last_error(collection_name, last_error_args) return rid, msg + gle, max_bson_size return rid, msg, max_bson_size if _use_c: _update_uncompressed = _cmessage._update_message def update(collection_name, upsert, multi, spec, doc, safe, last_error_args, check_keys, opts, ctx=None): """Get an **update** message.""" if ctx: return _update_compressed( collection_name, upsert, multi, spec, doc, check_keys, opts, ctx) return _update_uncompressed(collection_name, upsert, multi, spec, doc, safe, last_error_args, check_keys, opts) _pack_op_msg_flags_type = struct.Struct("<IB").pack _pack_byte = struct.Struct("<B").pack def _op_msg_no_header(flags, command, identifier, docs, check_keys, opts): """Get a OP_MSG message. Note: this method handles multiple documents in a type one payload but it does not perform batch splitting and the total message size is only checked *after* generating the entire message. """ # Encode the command document in payload 0 without checking keys. encoded = _dict_to_bson(command, False, opts) flags_type = _pack_op_msg_flags_type(flags, 0) total_size = len(encoded) max_doc_size = 0 if identifier: type_one = _pack_byte(1) cstring = _make_c_string(identifier) encoded_docs = [_dict_to_bson(doc, check_keys, opts) for doc in docs] size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 encoded_size = _pack_int(size) total_size += size max_doc_size = max(len(doc) for doc in encoded_docs) data = ([flags_type, encoded, type_one, encoded_size, cstring] + encoded_docs) else: data = [flags_type, encoded] return b''.join(data), total_size, max_doc_size def _op_msg_compressed(flags, command, identifier, docs, check_keys, opts, ctx): """Internal OP_MSG message helper.""" msg, total_size, max_bson_size = _op_msg_no_header( flags, command, identifier, docs, check_keys, opts) rid, msg = _compress(2013, msg, ctx) return rid, msg, total_size, max_bson_size def _op_msg_uncompressed(flags, command, identifier, docs, check_keys, opts): """Internal compressed OP_MSG message helper.""" data, total_size, max_bson_size = _op_msg_no_header( flags, command, identifier, docs, check_keys, opts) request_id, op_message = __pack_message(2013, data) return request_id, op_message, total_size, max_bson_size if _use_c: _op_msg_uncompressed = _cmessage._op_msg def _op_msg(flags, command, dbname, read_preference, slave_ok, check_keys, opts, ctx=None): """Get a OP_MSG message.""" command['$db'] = dbname if "$readPreference" not in command: if slave_ok and not read_preference.mode: command["$readPreference"] = ( ReadPreference.PRIMARY_PREFERRED.document) else: command["$readPreference"] = read_preference.document name = next(iter(command)) try: identifier = _FIELD_MAP.get(name) docs = command.pop(identifier) except KeyError: identifier = "" docs = None try: if ctx: return _op_msg_compressed( flags, command, identifier, docs, check_keys, opts, ctx) return _op_msg_uncompressed( flags, command, identifier, docs, check_keys, opts) finally: # Add the field back to the command. if identifier: command[identifier] = docs def _query(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys): """Get an OP_QUERY message.""" encoded = _dict_to_bson(query, check_keys, opts) if field_selector: efs = _dict_to_bson(field_selector, False, opts) else: efs = b"" max_bson_size = max(len(encoded), len(efs)) return b"".join([ _pack_int(options), _make_c_string(collection_name), _pack_int(num_to_skip), _pack_int(num_to_return), encoded, efs]), max_bson_size def _query_compressed(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys=False, ctx=None): """Internal compressed query message helper.""" op_query, max_bson_size = _query( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys) rid, msg = _compress(2004, op_query, ctx) return rid, msg, max_bson_size def _query_uncompressed(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys=False): """Internal query message helper.""" op_query, max_bson_size = _query( options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys) rid, msg = __pack_message(2004, op_query) return rid, msg, max_bson_size if _use_c: _query_uncompressed = _cmessage._query_message def query(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys=False, ctx=None): """Get a **query** message.""" if ctx: return _query_compressed(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys, ctx) return _query_uncompressed(options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, check_keys) _pack_long_long = struct.Struct("<q").pack def _get_more(collection_name, num_to_return, cursor_id): """Get an OP_GET_MORE message.""" return b"".join([ _ZERO_32, _make_c_string(collection_name), _pack_int(num_to_return), _pack_long_long(cursor_id)]) def _get_more_compressed(collection_name, num_to_return, cursor_id, ctx): """Internal compressed getMore message helper.""" return _compress( 2005, _get_more(collection_name, num_to_return, cursor_id), ctx) def _get_more_uncompressed(collection_name, num_to_return, cursor_id): """Internal getMore message helper.""" return __pack_message( 2005, _get_more(collection_name, num_to_return, cursor_id)) if _use_c: _get_more_uncompressed = _cmessage._get_more_message def get_more(collection_name, num_to_return, cursor_id, ctx=None): """Get a **getMore** message.""" if ctx: return _get_more_compressed( collection_name, num_to_return, cursor_id, ctx) return _get_more_uncompressed(collection_name, num_to_return, cursor_id) def _delete(collection_name, spec, opts, flags): """Get an OP_DELETE message.""" encoded = _dict_to_bson(spec, False, opts) # Uses extensions. return b"".join([ _ZERO_32, _make_c_string(collection_name), _pack_int(flags), encoded]), len(encoded) def _delete_compressed(collection_name, spec, opts, flags, ctx): """Internal compressed unacknowledged delete message helper.""" op_delete, max_bson_size = _delete(collection_name, spec, opts, flags) rid, msg = _compress(2006, op_delete, ctx) return rid, msg, max_bson_size def _delete_uncompressed( collection_name, spec, safe, last_error_args, opts, flags=0): """Internal delete message helper.""" op_delete, max_bson_size = _delete(collection_name, spec, opts, flags) rid, msg = __pack_message(2006, op_delete) if safe: rid, gle, _ = __last_error(collection_name, last_error_args) return rid, msg + gle, max_bson_size return rid, msg, max_bson_size def delete( collection_name, spec, safe, last_error_args, opts, flags=0, ctx=None): """Get a **delete** message. `opts` is a CodecOptions. `flags` is a bit vector that may contain the SingleRemove flag or not: http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#op-delete """ if ctx: return _delete_compressed(collection_name, spec, opts, flags, ctx) return _delete_uncompressed( collection_name, spec, safe, last_error_args, opts, flags) def kill_cursors(cursor_ids): """Get a **killCursors** message. """ num_cursors = len(cursor_ids) pack = struct.Struct("<ii" + ("q" * num_cursors)).pack op_kill_cursors = pack(0, num_cursors, *cursor_ids) return __pack_message(2007, op_kill_cursors) class _BulkWriteContext(object): """A wrapper around SocketInfo for use with write splitting functions.""" __slots__ = ('db_name', 'command', 'sock_info', 'op_id', 'name', 'field', 'publish', 'start_time', 'listeners', 'session', 'compress') def __init__(self, database_name, command, sock_info, operation_id, listeners, session): self.db_name = database_name self.command = command self.sock_info = sock_info self.op_id = operation_id self.listeners = listeners self.publish = listeners.enabled_for_commands self.name = next(iter(command)) self.field = _FIELD_MAP[self.name] self.start_time = datetime.datetime.now() if self.publish else None self.session = session self.compress = True if sock_info.compression_context else False @property def max_bson_size(self): """A proxy for SockInfo.max_bson_size.""" return self.sock_info.max_bson_size @property def max_message_size(self): """A proxy for SockInfo.max_message_size.""" return self.sock_info.max_message_size @property def max_write_batch_size(self): """A proxy for SockInfo.max_write_batch_size.""" return self.sock_info.max_write_batch_size def legacy_bulk_insert( self, request_id, msg, max_doc_size, acknowledged, docs, compress): if compress: request_id, msg = _compress( 2002, msg, self.sock_info.compression_context) return self.legacy_write( request_id, msg, max_doc_size, acknowledged, docs) def legacy_write(self, request_id, msg, max_doc_size, acknowledged, docs): """A proxy for SocketInfo.legacy_write that handles event publishing. """ if self.publish: duration = datetime.datetime.now() - self.start_time cmd = self._start(request_id, docs) start = datetime.datetime.now() try: result = self.sock_info.legacy_write( request_id, msg, max_doc_size, acknowledged) if self.publish: duration = (datetime.datetime.now() - start) + duration if result is not None: reply = _convert_write_result(self.name, cmd, result) else: # Comply with APM spec. reply = {'ok': 1} self._succeed(request_id, reply, duration) except OperationFailure as exc: if self.publish: duration = (datetime.datetime.now() - start) + duration self._fail( request_id, _convert_write_result( self.name, cmd, exc.details), duration) raise finally: self.start_time = datetime.datetime.now() return result def write_command(self, request_id, msg, docs): """A proxy for SocketInfo.write_command that handles event publishing. """ if self.publish: duration = datetime.datetime.now() - self.start_time self._start(request_id, docs) start = datetime.datetime.now() try: reply = self.sock_info.write_command(request_id, msg) if self.publish: duration = (datetime.datetime.now() - start) + duration self._succeed(request_id, reply, duration) except OperationFailure as exc: if self.publish: duration = (datetime.datetime.now() - start) + duration self._fail(request_id, exc.details, duration) raise finally: self.start_time = datetime.datetime.now() return reply def _start(self, request_id, docs): """Publish a CommandStartedEvent.""" cmd = self.command.copy() cmd[self.field] = docs self.listeners.publish_command_start( cmd, self.db_name, request_id, self.sock_info.address, self.op_id) return cmd def _succeed(self, request_id, reply, duration): """Publish a CommandSucceededEvent.""" self.listeners.publish_command_success( duration, reply, self.name, request_id, self.sock_info.address, self.op_id) def _fail(self, request_id, failure, duration): """Publish a CommandFailedEvent.""" self.listeners.publish_command_failure( duration, failure, self.name, request_id, self.sock_info.address, self.op_id) def _raise_document_too_large(operation, doc_size, max_size): """Internal helper for raising DocumentTooLarge.""" if operation == "insert": raise DocumentTooLarge("BSON document too large (%d bytes)" " - the connected server supports" " BSON document sizes up to %d" " bytes." % (doc_size, max_size)) else: # There's nothing intelligent we can say # about size for update and delete raise DocumentTooLarge("%r command document too large" % (operation,)) def _do_batched_insert(collection_name, docs, check_keys, safe, last_error_args, continue_on_error, opts, ctx): """Insert `docs` using multiple batches. """ def _insert_message(insert_message, send_safe): """Build the insert message with header and GLE. """ request_id, final_message = __pack_message(2002, insert_message) if send_safe: request_id, error_message, _ = __last_error(collection_name, last_error_args) final_message += error_message return request_id, final_message send_safe = safe or not continue_on_error last_error = None data = StringIO() data.write(struct.pack("<i", int(continue_on_error))) data.write(_make_c_string(collection_name)) message_length = begin_loc = data.tell() has_docs = False to_send = [] encode = _dict_to_bson # Make local compress = ctx.compress and not (safe or send_safe) for doc in docs: encoded = encode(doc, check_keys, opts) encoded_length = len(encoded) too_large = (encoded_length > ctx.max_bson_size) message_length += encoded_length if message_length < ctx.max_message_size and not too_large: data.write(encoded) to_send.append(doc) has_docs = True continue if has_docs: # We have enough data, send this message. try: if compress: rid, msg = None, data.getvalue() else: rid, msg = _insert_message(data.getvalue(), send_safe) ctx.legacy_bulk_insert( rid, msg, 0, send_safe, to_send, compress) # Exception type could be OperationFailure or a subtype # (e.g. DuplicateKeyError) except OperationFailure as exc: # Like it says, continue on error... if continue_on_error: # Store exception details to re-raise after the final batch. last_error = exc # With unacknowledged writes just return at the first error. elif not safe: return # With acknowledged writes raise immediately. else: raise if too_large: _raise_document_too_large( "insert", encoded_length, ctx.max_bson_size) message_length = begin_loc + encoded_length data.seek(begin_loc) data.truncate() data.write(encoded) to_send = [doc] if not has_docs: raise InvalidOperation("cannot do an empty bulk insert") if compress: request_id, msg = None, data.getvalue() else: request_id, msg = _insert_message(data.getvalue(), safe) ctx.legacy_bulk_insert(request_id, msg, 0, safe, to_send, compress) # Re-raise any exception stored due to continue_on_error if last_error is not None: raise last_error if _use_c: _do_batched_insert = _cmessage._do_batched_insert # OP_MSG ------------------------------------------------------------- _OP_MSG_MAP = { _INSERT: b'documents\x00', _UPDATE: b'updates\x00', _DELETE: b'deletes\x00', } def _batched_op_msg_impl( operation, command, docs, check_keys, ack, opts, ctx, buf): """Create a batched OP_MSG write.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size max_message_size = ctx.max_message_size flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" # Flags buf.write(flags) # Type 0 Section buf.write(b"\x00") buf.write(_dict_to_bson(command, False, opts)) # Type 1 Section buf.write(b"\x01") size_location = buf.tell() # Save space for size buf.write(b"\x00\x00\x00\x00") try: buf.write(_OP_MSG_MAP[operation]) except KeyError: raise InvalidOperation('Unknown command') if operation in (_UPDATE, _DELETE): check_keys = False to_send = [] idx = 0 for doc in docs: # Encode the current operation value = _dict_to_bson(doc, check_keys, opts) doc_length = len(value) new_message_size = buf.tell() + doc_length # Does first document exceed max_message_size? doc_too_large = (idx == 0 and (new_message_size > max_message_size)) # When OP_MSG is used unacknowleged we have to check # document size client side or applications won't be notified. # Otherwise we let the server deal with documents that are too large # since ordered=False causes those documents to be skipped instead of # halting the bulk write operation. unacked_doc_too_large = (not ack and (doc_length > max_bson_size)) if doc_too_large or unacked_doc_too_large: write_op = list(_FIELD_MAP.keys())[operation] _raise_document_too_large( write_op, len(value), max_bson_size) # We have enough data, return this batch. if new_message_size > max_message_size: break buf.write(value) to_send.append(doc) idx += 1 # We have enough documents, return this batch. if idx == max_write_batch_size: break # Write type 1 section size length = buf.tell() buf.seek(size_location) buf.write(_pack_int(length - size_location)) return to_send, length def _encode_batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx): """Encode the next batched insert, update, or delete operation as OP_MSG. """ buf = StringIO() to_send, _ = _batched_op_msg_impl( operation, command, docs, check_keys, ack, opts, ctx, buf) return buf.getvalue(), to_send if _use_c: _encode_batched_op_msg = _cmessage._encode_batched_op_msg def _batched_op_msg_compressed( operation, command, docs, check_keys, ack, opts, ctx): """Create the next batched insert, update, or delete operation with OP_MSG, compressed. """ data, to_send = _encode_batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx) request_id, msg = _compress( 2013, data, ctx.sock_info.compression_context) return request_id, msg, to_send def _batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx): """OP_MSG implementation entry point.""" buf = StringIO() # Save space for message length and request id buf.write(_ZERO_64) # responseTo, opCode buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") to_send, length = _batched_op_msg_impl( operation, command, docs, check_keys, ack, opts, ctx, buf) # Header - request id and message length buf.seek(4) request_id = _randint() buf.write(_pack_int(request_id)) buf.seek(0) buf.write(_pack_int(length)) return request_id, buf.getvalue(), to_send if _use_c: _batched_op_msg = _cmessage._batched_op_msg def _do_batched_op_msg( namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete operation using OP_MSG. """ command['$db'] = namespace.split('.', 1)[0] if 'writeConcern' in command: ack = bool(command['writeConcern'].get('w', 1)) else: ack = True if ctx.sock_info.compression_context: return _batched_op_msg_compressed( operation, command, docs, check_keys, ack, opts, ctx) return _batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx) # End OP_MSG ----------------------------------------------------- def _batched_write_command_compressed( namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete command, compressed. """ data, to_send = _encode_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx) request_id, msg = _compress( 2004, data, ctx.sock_info.compression_context) return request_id, msg, to_send def _encode_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Encode the next batched insert, update, or delete command. """ buf = StringIO() to_send, _ = _batched_write_command_impl( namespace, operation, command, docs, check_keys, opts, ctx, buf) return buf.getvalue(), to_send if _use_c: _encode_batched_write_command = _cmessage._encode_batched_write_command def _batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete command. """ buf = StringIO() # Save space for message length and request id buf.write(_ZERO_64) # responseTo, opCode buf.write(b"\x00\x00\x00\x00\xd4\x07\x00\x00") # Write OP_QUERY write command to_send, length = _batched_write_command_impl( namespace, operation, command, docs, check_keys, opts, ctx, buf) # Header - request id and message length buf.seek(4) request_id = _randint() buf.write(_pack_int(request_id)) buf.seek(0) buf.write(_pack_int(length)) return request_id, buf.getvalue(), to_send if _use_c: _batched_write_command = _cmessage._batched_write_command def _do_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Batched write commands entry point.""" if ctx.sock_info.compression_context: return _batched_write_command_compressed( namespace, operation, command, docs, check_keys, opts, ctx) return _batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx) def _do_bulk_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Bulk write commands entry point.""" if ctx.sock_info.max_wire_version > 5: return _do_batched_op_msg( namespace, operation, command, docs, check_keys, opts, ctx) return _do_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx) def _batched_write_command_impl( namespace, operation, command, docs, check_keys, opts, ctx, buf): """Create a batched OP_QUERY write command.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size # Max BSON object size + 16k - 2 bytes for ending NUL bytes. # Server guarantees there is enough room: SERVER-10643. max_cmd_size = max_bson_size + _COMMAND_OVERHEAD # No options buf.write(_ZERO_32) # Namespace as C string buf.write(b(namespace)) buf.write(_ZERO_8) # Skip: 0, Limit: -1 buf.write(_SKIPLIM) # Where to write command document length command_start = buf.tell() buf.write(bson.BSON.encode(command)) # Start of payload buf.seek(-1, 2) # Work around some Jython weirdness. buf.truncate() try: buf.write(_OP_MAP[operation]) except KeyError: raise InvalidOperation('Unknown command') if operation in (_UPDATE, _DELETE): check_keys = False # Where to write list document length list_start = buf.tell() - 4 to_send = [] idx = 0 for doc in docs: # Encode the current operation key = b(str(idx)) value = bson.BSON.encode(doc, check_keys, opts) # Is there enough room to add this document? max_cmd_size accounts for # the two trailing null bytes. enough_data = (buf.tell() + len(key) + len(value)) >= max_cmd_size enough_documents = (idx >= max_write_batch_size) if enough_data or enough_documents: if not idx: write_op = list(_FIELD_MAP.keys())[operation] _raise_document_too_large( write_op, len(value), max_bson_size) break buf.write(_BSONOBJ) buf.write(key) buf.write(_ZERO_8) buf.write(value) to_send.append(doc) idx += 1 # Finalize the current OP_QUERY message. # Close list and command documents buf.write(_ZERO_16) # Write document lengths and request id length = buf.tell() buf.seek(list_start) buf.write(_pack_int(length - list_start - 1)) buf.seek(command_start) buf.write(_pack_int(length - command_start)) return to_send, length class _OpReply(object): """A MongoDB OP_REPLY response message.""" __slots__ = ("flags", "cursor_id", "number_returned", "documents") UNPACK_FROM = struct.Struct("<iqii").unpack_from OP_CODE = 1 def __init__(self, flags, cursor_id, number_returned, documents): self.flags = flags self.cursor_id = cursor_id self.number_returned = number_returned self.documents = documents def raw_response(self, cursor_id=None): """Check the response header from the database, without decoding BSON. Check the response for errors and unpack. Can raise CursorNotFound, NotMasterError, ExecutionTimeout, or OperationFailure. :Parameters: - `cursor_id` (optional): cursor_id we sent to get this response - used for raising an informative exception when we get cursor id not valid at server response. """ if self.flags & 1: # Shouldn't get this response if we aren't doing a getMore if cursor_id is None: raise ProtocolError("No cursor id for getMore operation") # Fake a getMore command response. OP_GET_MORE provides no # document. msg = "Cursor not found, cursor id: %d" % (cursor_id,) errobj = {"ok": 0, "errmsg": msg, "code": 43} raise CursorNotFound(msg, 43, errobj) elif self.flags & 2: error_object = bson.BSON(self.documents).decode() # Fake the ok field if it doesn't exist. error_object.setdefault("ok", 0) if error_object["$err"].startswith("not master"): raise NotMasterError(error_object["$err"], error_object) elif error_object.get("code") == 50: raise ExecutionTimeout(error_object.get("$err"), error_object.get("code"), error_object) raise OperationFailure("database error: %s" % error_object.get("$err"), error_object.get("code"), error_object) return [self.documents] def unpack_response(self, cursor_id=None, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS): """Unpack a response from the database and decode the BSON document(s). Check the response for errors and unpack, returning a dictionary containing the response data. Can raise CursorNotFound, NotMasterError, ExecutionTimeout, or OperationFailure. :Parameters: - `cursor_id` (optional): cursor_id we sent to get this response - used for raising an informative exception when we get cursor id not valid at server response - `codec_options` (optional): an instance of :class:`~bson.codec_options.CodecOptions` """ self.raw_response(cursor_id) return bson.decode_all(self.documents, codec_options) def command_response(self): """Unpack a command response.""" docs = self.unpack_response() assert self.number_returned == 1 return docs[0] @classmethod def unpack(cls, msg): """Construct an _OpReply from raw bytes.""" # PYTHON-945: ignore starting_from field. flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) # Convert Python 3 memoryview to bytes. Note we should call # memoryview.tobytes() if we start using memoryview in Python 2.7. documents = bytes(msg[20:]) return cls(flags, cursor_id, number_returned, documents) class _OpMsg(object): """A MongoDB OP_MSG response message.""" __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") UNPACK_FROM = struct.Struct("<IBi").unpack_from OP_CODE = 2013 def __init__(self, flags, payload_document): self.flags = flags self.payload_document = payload_document def raw_response(self, cursor_id=None): raise NotImplementedError def unpack_response(self, cursor_id=None, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS): """Unpack a OP_MSG command response. :Parameters: - `cursor_id` (optional): Ignored, for compatibility with _OpReply. - `codec_options` (optional): an instance of :class:`~bson.codec_options.CodecOptions` """ return bson.decode_all(self.payload_document, codec_options) def command_response(self): """Unpack a command response.""" return self.unpack_response()[0] @classmethod def unpack(cls, msg): """Construct an _OpMsg from raw bytes.""" flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) if flags != 0: raise ProtocolError("Unsupported OP_MSG flags (%r)" % (flags,)) if first_payload_type != 0: raise ProtocolError( "Unsupported OP_MSG payload type (%r)" % (first_payload_type,)) if len(msg) != first_payload_size + 5: raise ProtocolError("Unsupported OP_MSG reply: >1 section") # Convert Python 3 memoryview to bytes. Note we should call # memoryview.tobytes() if we start using memoryview in Python 2.7. payload_document = bytes(msg[5:]) return cls(flags, payload_document) _UNPACK_REPLY = { _OpReply.OP_CODE: _OpReply.unpack, _OpMsg.OP_CODE: _OpMsg.unpack, } def _first_batch(sock_info, db, coll, query, ntoreturn, slave_ok, codec_options, read_preference, cmd, listeners): """Simple query helper for retrieving a first (and possibly only) batch.""" query = _Query( 0, db, coll, 0, query, None, codec_options, read_preference, ntoreturn, 0, DEFAULT_READ_CONCERN, None, None, None) name = next(iter(cmd)) publish = listeners.enabled_for_commands if publish: start = datetime.datetime.now() request_id, msg, max_doc_size = query.get_message(slave_ok, sock_info) if publish: encoding_duration = datetime.datetime.now() - start listeners.publish_command_start( cmd, db, request_id, sock_info.address) start = datetime.datetime.now() sock_info.send_message(msg, max_doc_size) reply = sock_info.receive_message(request_id) try: docs = reply.unpack_response(None, codec_options) except Exception as exc: if publish: duration = (datetime.datetime.now() - start) + encoding_duration if isinstance(exc, (NotMasterError, OperationFailure)): failure = exc.details else: failure = _convert_exception(exc) listeners.publish_command_failure( duration, failure, name, request_id, sock_info.address) raise # listIndexes if 'cursor' in cmd: result = { u'cursor': { u'firstBatch': docs, u'id': reply.cursor_id, u'ns': u'%s.%s' % (db, coll) }, u'ok': 1.0 } # fsyncUnlock, currentOp else: result = docs[0] if docs else {} result[u'ok'] = 1.0 if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( duration, result, name, request_id, sock_info.address) return result