# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for creating `messages `_ to be sent to MongoDB. .. note:: This module is for internal use and is generally not needed by application developers. """ import datetime import random import struct import bson from bson import CodecOptions from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.py3compat import b, StringIO from bson.son import SON try: from pymongo import _cmessage _use_c = True except ImportError: _use_c = False from pymongo.errors import (ConfigurationError, CursorNotFound, DocumentTooLarge, ExecutionTimeout, InvalidOperation, NotMasterError, OperationFailure, ProtocolError) from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 # Overhead allowed for encoded command documents. _COMMAND_OVERHEAD = 16382 _INSERT = 0 _UPDATE = 1 _DELETE = 2 _EMPTY = b'' _BSONOBJ = b'\x03' _ZERO_8 = b'\x00' _ZERO_16 = b'\x00\x00' _ZERO_32 = b'\x00\x00\x00\x00' _ZERO_64 = b'\x00\x00\x00\x00\x00\x00\x00\x00' _SKIPLIM = b'\x00\x00\x00\x00\xff\xff\xff\xff' _OP_MAP = { _INSERT: b'\x04documents\x00\x00\x00\x00\x00', _UPDATE: b'\x04updates\x00\x00\x00\x00\x00', _DELETE: b'\x04deletes\x00\x00\x00\x00\x00', } _UJOIN = u"%s.%s" _UNICODE_REPLACE_CODEC_OPTIONS = CodecOptions( unicode_decode_error_handler='replace') def _randint(): """Generate a pseudo random 32 bit integer.""" return random.randint(MIN_INT32, MAX_INT32) def _maybe_add_read_preference(spec, read_preference): """Add $readPreference to spec when appropriate.""" mode = read_preference.mode tag_sets = read_preference.tag_sets max_staleness = read_preference.max_staleness # Only add $readPreference if it's something other than primary to avoid # problems with mongos versions that don't support read preferences. Also, # for maximum backwards compatibility, don't add $readPreference for # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting # the slaveOkay bit has the same effect). if mode and ( mode != ReadPreference.SECONDARY_PREFERRED.mode or tag_sets != [{}] or max_staleness != -1): if "$query" not in spec: spec = SON([("$query", spec)]) spec["$readPreference"] = read_preference.document return spec def _convert_exception(exception): """Convert an Exception into a failure document for publishing.""" return {'errmsg': str(exception), 'errtype': exception.__class__.__name__} def _convert_write_result(operation, command, result): """Convert a legacy write result to write commmand format.""" # Based on _merge_legacy from bulk.py affected = result.get("n", 0) res = {"ok": 1, "n": affected} errmsg = result.get("errmsg", result.get("err", "")) if errmsg: # The write was successful on at least the primary so don't return. if result.get("wtimeout"): res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} else: # The write failed. error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} if "errInfo" in result: error["errInfo"] = result["errInfo"] res["writeErrors"] = [error] return res if operation == "insert": # GLE result for insert is always 0 in most MongoDB versions. res["n"] = len(command['documents']) elif operation == "update": if "upserted" in result: res["upserted"] = [{"index": 0, "_id": result["upserted"]}] # Versions of MongoDB before 2.6 don't return the _id for an # upsert if _id is not an ObjectId. elif result.get("updatedExisting") is False and affected == 1: # If _id is in both the update document *and* the query spec # the update document _id takes precedence. update = command['updates'][0] _id = update["u"].get("_id", update["q"].get("_id")) res["upserted"] = [{"index": 0, "_id": _id}] return res _OPTIONS = SON([ ('tailable', 2), ('oplogReplay', 8), ('noCursorTimeout', 16), ('awaitData', 32), ('allowPartialResults', 128)]) _MODIFIERS = SON([ ('$query', 'filter'), ('$orderby', 'sort'), ('$hint', 'hint'), ('$comment', 'comment'), ('$maxScan', 'maxScan'), ('$maxTimeMS', 'maxTimeMS'), ('$max', 'max'), ('$min', 'min'), ('$returnKey', 'returnKey'), ('$showRecordId', 'showRecordId'), ('$showDiskLoc', 'showRecordId'), # <= MongoDb 3.0 ('$snapshot', 'snapshot')]) def _gen_find_command(coll, spec, projection, skip, limit, batch_size, options, read_concern, collation=None): """Generate a find command document.""" cmd = SON([('find', coll)]) if '$query' in spec: cmd.update([(_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) for key, val in spec.items()]) if '$explain' in cmd: cmd.pop('$explain') if '$readPreference' in cmd: cmd.pop('$readPreference') else: cmd['filter'] = spec if projection: cmd['projection'] = projection if skip: cmd['skip'] = skip if limit: cmd['limit'] = abs(limit) if limit < 0: cmd['singleBatch'] = True if batch_size: cmd['batchSize'] = batch_size if read_concern.level: cmd['readConcern'] = read_concern.document if collation: cmd['collation'] = collation if options: cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) return cmd def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms): """Generate a getMore command document.""" cmd = SON([('getMore', cursor_id), ('collection', coll)]) if batch_size: cmd['batchSize'] = batch_size if max_await_time_ms is not None: cmd['maxTimeMS'] = max_await_time_ms return cmd class _Query(object): """A query operation.""" __slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec', 'fields', 'codec_options', 'read_preference', 'limit', 'batch_size', 'name', 'read_concern', 'collation', 'session', 'client') def __init__(self, flags, db, coll, ntoskip, spec, fields, codec_options, read_preference, limit, batch_size, read_concern, collation, session, client): self.flags = flags self.db = db self.coll = coll self.ntoskip = ntoskip self.spec = spec self.fields = fields self.codec_options = codec_options self.read_preference = read_preference self.read_concern = read_concern self.limit = limit self.batch_size = batch_size self.collation = collation self.session = session self.client = client self.name = 'find' def use_command(self, sock_info, exhaust): use_find_cmd = False if sock_info.max_wire_version >= 4: if not exhaust: use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( 'read concern level of %s is not valid ' 'with a max wire version of %d.' % (self.read_concern.level, sock_info.max_wire_version)) if sock_info.max_wire_version < 5 and self.collation is not None: raise ConfigurationError( 'Specifying a collation is unsupported with a max wire ' 'version of %d.' % (sock_info.max_wire_version,)) sock_info.validate_session(self.client, self.session) return use_find_cmd def as_command(self, sock_info): """Return a find command document for this query. Should be called *after* get_message. """ explain = '$explain' in self.spec cmd = _gen_find_command( self.coll, self.spec, self.fields, self.ntoskip, self.limit, self.batch_size, self.flags, self.read_concern, self.collation) if explain: self.name = 'explain' cmd = SON([('explain', cmd)]) session = self.session if session: cmd['lsid'] = session._use_lsid() # Explain does not support readConcern. if (not explain and session.options.causal_consistency and session.operation_time is not None): cmd.setdefault( 'readConcern', {})[ 'afterClusterTime'] = session.operation_time sock_info.send_cluster_time(cmd, session, self.client) return cmd, self.db def get_message(self, set_slave_ok, sock_info, use_cmd=False): """Get a query message, possibly setting the slaveOk bit.""" if set_slave_ok: # Set the slaveOk bit. flags = self.flags | 4 else: flags = self.flags ns = _UJOIN % (self.db, self.coll) spec = self.spec if use_cmd: ns = _UJOIN % (self.db, "$cmd") spec = self.as_command(sock_info)[0] ntoreturn = -1 # All DB commands return 1 document else: # OP_QUERY treats ntoreturn of -1 and 1 the same, return # one document and close the cursor. We have to use 2 for # batch size if 1 is specified. ntoreturn = self.batch_size == 1 and 2 or self.batch_size if self.limit: if ntoreturn: ntoreturn = min(self.limit, ntoreturn) else: ntoreturn = self.limit if sock_info.is_mongos: spec = _maybe_add_read_preference(spec, self.read_preference) return query(flags, ns, self.ntoskip, ntoreturn, spec, self.fields, self.codec_options) class _GetMore(object): """A getmore operation.""" __slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms', 'codec_options', 'session', 'client') name = 'getMore' def __init__(self, db, coll, ntoreturn, cursor_id, codec_options, session, client, max_await_time_ms=None): self.db = db self.coll = coll self.ntoreturn = ntoreturn self.cursor_id = cursor_id self.codec_options = codec_options self.session = session self.client = client self.max_await_time_ms = max_await_time_ms def use_command(self, sock_info, exhaust): sock_info.validate_session(self.client, self.session) return sock_info.max_wire_version >= 4 and not exhaust def as_command(self, sock_info): """Return a getMore command document for this query.""" cmd = _gen_get_more_command(self.cursor_id, self.coll, self.ntoreturn, self.max_await_time_ms) if self.session: cmd['lsid'] = self.session._use_lsid() sock_info.send_cluster_time(cmd, self.session, self.client) return cmd, self.db def get_message(self, dummy0, sock_info, use_cmd=False): """Get a getmore message.""" ns = _UJOIN % (self.db, self.coll) if use_cmd: ns = _UJOIN % (self.db, "$cmd") spec = self.as_command(sock_info)[0] return query(0, ns, 0, -1, spec, None, self.codec_options) return get_more(ns, self.ntoreturn, self.cursor_id) class _RawBatchQuery(_Query): def use_command(self, socket_info, exhaust): # Compatibility checks. super(_RawBatchQuery, self).use_command(socket_info, exhaust) return False def get_message(self, set_slave_ok, sock_info, use_cmd=False): # Always pass False for use_cmd. return super(_RawBatchQuery, self).get_message( set_slave_ok, sock_info, False) class _RawBatchGetMore(_GetMore): def use_command(self, socket_info, exhaust): return False def get_message(self, set_slave_ok, sock_info, use_cmd=False): # Always pass False for use_cmd. return super(_RawBatchGetMore, self).get_message( set_slave_ok, sock_info, False) class _CursorAddress(tuple): """The server address (host, port) of a cursor, with namespace property.""" def __new__(cls, address, namespace): self = tuple.__new__(cls, address) self.__namespace = namespace return self @property def namespace(self): """The namespace this cursor.""" return self.__namespace def __hash__(self): # Two _CursorAddress instances with different namespaces # must not hash the same. return (self + (self.__namespace,)).__hash__() def __eq__(self, other): if isinstance(other, _CursorAddress): return (tuple(self) == tuple(other) and self.namespace == other.namespace) return NotImplemented def __ne__(self, other): return not self == other def __last_error(namespace, args): """Data to send to do a lastError. """ cmd = SON([("getlasterror", 1)]) cmd.update(args) splitns = namespace.split('.', 1) return query(0, splitns[0] + '.$cmd', 0, -1, cmd, None, DEFAULT_CODEC_OPTIONS) def __pack_message(operation, data): """Takes message data and adds a message header based on the operation. Returns the resultant message string. """ request_id = _randint() message = struct.pack(" ctx.max_bson_size) message_length += encoded_length if message_length < ctx.max_message_size and not too_large: data.write(encoded) to_send.append(doc) has_docs = True continue if has_docs: # We have enough data, send this message. try: request_id, msg = _insert_message(data.getvalue(), send_safe) ctx.legacy_write(request_id, msg, 0, send_safe, to_send) # Exception type could be OperationFailure or a subtype # (e.g. DuplicateKeyError) except OperationFailure as exc: # Like it says, continue on error... if continue_on_error: # Store exception details to re-raise after the final batch. last_error = exc # With unacknowledged writes just return at the first error. elif not safe: return # With acknowledged writes raise immediately. else: raise if too_large: _raise_document_too_large( "insert", encoded_length, ctx.max_bson_size) message_length = begin_loc + encoded_length data.seek(begin_loc) data.truncate() data.write(encoded) to_send = [doc] if not has_docs: raise InvalidOperation("cannot do an empty bulk insert") request_id, msg = _insert_message(data.getvalue(), safe) ctx.legacy_write(request_id, msg, 0, safe, to_send) # Re-raise any exception stored due to continue_on_error if last_error is not None: raise last_error if _use_c: _do_batched_insert = _cmessage._do_batched_insert def _do_batched_write_command(namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete command. """ max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size # Max BSON object size + 16k - 2 bytes for ending NUL bytes. # Server guarantees there is enough room: SERVER-10643. max_cmd_size = max_bson_size + _COMMAND_OVERHEAD buf = StringIO() # Save space for message length and request id buf.write(_ZERO_64) # responseTo, opCode buf.write(b"\x00\x00\x00\x00\xd4\x07\x00\x00") # No options buf.write(_ZERO_32) # Namespace as C string buf.write(b(namespace)) buf.write(_ZERO_8) # Skip: 0, Limit: -1 buf.write(_SKIPLIM) # Where to write command document length command_start = buf.tell() buf.write(bson.BSON.encode(command)) # Start of payload buf.seek(-1, 2) # Work around some Jython weirdness. buf.truncate() try: buf.write(_OP_MAP[operation]) except KeyError: raise InvalidOperation('Unknown command') if operation in (_UPDATE, _DELETE): check_keys = False # Where to write list document length list_start = buf.tell() - 4 to_send = [] idx = 0 for doc in docs: # Encode the current operation key = b(str(idx)) value = bson.BSON.encode(doc, check_keys, opts) # Is there enough room to add this document? max_cmd_size accounts for # the two trailing null bytes. enough_data = (buf.tell() + len(key) + len(value)) >= max_cmd_size enough_documents = (idx >= max_write_batch_size) if enough_data or enough_documents: if not idx: write_op = "insert" if operation == _INSERT else None _raise_document_too_large( write_op, len(value), max_bson_size) break buf.write(_BSONOBJ) buf.write(key) buf.write(_ZERO_8) buf.write(value) to_send.append(doc) idx += 1 # Finalize the current OP_QUERY message. # Close list and command documents buf.write(_ZERO_16) # Write document lengths and request id length = buf.tell() buf.seek(list_start) buf.write(struct.pack('