using System; using System.Collections.Generic; using System.Globalization; using Renci.SshNet.Common; using Renci.SshNet.Messages; using Renci.SshNet.Messages.Authentication; using Renci.SshNet.Messages.Connection; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet { internal sealed class SshMessageFactory { private readonly MessageMetadata[] _enabledMessagesByNumber; private readonly bool[] _activatedMessagesById; internal static readonly MessageMetadata[] AllMessages; private static readonly Dictionary MessagesByName; /// /// Defines the highest message number that is currently supported. /// internal const byte HighestMessageNumber = 100; /// /// Defines the total number of supported messages. /// internal const int TotalMessageCount = 32; static SshMessageFactory() { AllMessages = new MessageMetadata[] { new MessageMetadata(0, "SSH_MSG_KEXINIT", 20), new MessageMetadata(1, "SSH_MSG_NEWKEYS", 21), new MessageMetadata(2, "SSH_MSG_REQUEST_FAILURE", 82), new MessageMetadata(3, "SSH_MSG_CHANNEL_OPEN_FAILURE", 92), new MessageMetadata(4, "SSH_MSG_CHANNEL_FAILURE", 100), new MessageMetadata(5, "SSH_MSG_CHANNEL_EXTENDED_DATA", 95), new MessageMetadata(6, "SSH_MSG_CHANNEL_DATA", 94), new MessageMetadata(7, "SSH_MSG_CHANNEL_REQUEST", 98), new MessageMetadata(8, "SSH_MSG_USERAUTH_BANNER", 53), new MessageMetadata(9, "SSH_MSG_USERAUTH_INFO_RESPONSE", 61), new MessageMetadata(10, "SSH_MSG_USERAUTH_FAILURE", 51), new MessageMetadata(11, "SSH_MSG_DEBUG", 4), new MessageMetadata(12, "SSH_MSG_GLOBAL_REQUEST", 80), new MessageMetadata(13, "SSH_MSG_CHANNEL_OPEN", 90), new MessageMetadata(14, "SSH_MSG_CHANNEL_OPEN_CONFIRMATION", 91), new MessageMetadata(15, "SSH_MSG_USERAUTH_INFO_REQUEST", 60), new MessageMetadata(16, "SSH_MSG_UNIMPLEMENTED", 3), new MessageMetadata(17, "SSH_MSG_REQUEST_SUCCESS", 81), new MessageMetadata(18, "SSH_MSG_CHANNEL_SUCCESS", 99), new MessageMetadata(19, "SSH_MSG_USERAUTH_PASSWD_CHANGEREQ", 60), new MessageMetadata(20, "SSH_MSG_DISCONNECT", 1), new MessageMetadata(21, "SSH_MSG_USERAUTH_SUCCESS", 52), new MessageMetadata(22, "SSH_MSG_USERAUTH_PK_OK", 60), new MessageMetadata(23, "SSH_MSG_IGNORE", 2), new MessageMetadata(24, "SSH_MSG_CHANNEL_WINDOW_ADJUST", 93), new MessageMetadata(25, "SSH_MSG_CHANNEL_EOF", 96), new MessageMetadata(26, "SSH_MSG_CHANNEL_CLOSE", 97), new MessageMetadata(27, "SSH_MSG_SERVICE_ACCEPT", 6), new MessageMetadata(28, "SSH_MSG_KEX_DH_GEX_GROUP", 31), new MessageMetadata(29, "SSH_MSG_KEXDH_REPLY", 31), new MessageMetadata(30, "SSH_MSG_KEX_DH_GEX_REPLY", 33), new MessageMetadata(31, "SSH_MSG_KEX_ECDH_REPLY", 31) }; MessagesByName = new Dictionary(AllMessages.Length); for (var i = 0; i < AllMessages.Length; i++) { var messageMetadata = AllMessages[i]; MessagesByName.Add(messageMetadata.Name, messageMetadata); } } /// /// Initializes a new instance of the class. /// public SshMessageFactory() { _activatedMessagesById = new bool[TotalMessageCount]; _enabledMessagesByNumber = new MessageMetadata[HighestMessageNumber + 1]; } /// /// Disables and deactivate all messages. /// public void Reset() { Array.Clear(_activatedMessagesById, 0, _activatedMessagesById.Length); Array.Clear(_enabledMessagesByNumber, 0, _enabledMessagesByNumber.Length); } public Message Create(byte messageNumber) { if (messageNumber > HighestMessageNumber) { throw CreateMessageTypeNotSupportedException(messageNumber); } var enabledMessageMetadata = _enabledMessagesByNumber[messageNumber]; if (enabledMessageMetadata is null) { MessageMetadata definedMessageMetadata = null; // find first message with specified number for (var i = 0; i < AllMessages.Length; i++) { var messageMetadata = AllMessages[i]; if (messageMetadata.Number == messageNumber) { definedMessageMetadata = messageMetadata; break; } } if (definedMessageMetadata is null) { throw CreateMessageTypeNotSupportedException(messageNumber); } throw new SshException(string.Format(CultureInfo.InvariantCulture, "Message type {0} is not valid in the current context.", messageNumber)); } return enabledMessageMetadata.Create(); } public void DisableNonKeyExchangeMessages() { for (var i = 0; i < AllMessages.Length; i++) { var messageMetadata = AllMessages[i]; var messageNumber = messageMetadata.Number; if (messageNumber is (> 2 and < 20) or > 30) { _enabledMessagesByNumber[messageNumber] = null; } } } public void EnableActivatedMessages() { for (var i = 0; i < AllMessages.Length; i++) { var messageMetadata = AllMessages[i]; if (!_activatedMessagesById[messageMetadata.Id]) { continue; } var enabledMessageMetadata = _enabledMessagesByNumber[messageMetadata.Number]; if (enabledMessageMetadata != null && enabledMessageMetadata != messageMetadata) { throw CreateMessageTypeAlreadyEnabledForOtherMessageException(messageMetadata.Number, messageMetadata.Name, enabledMessageMetadata.Name); } _enabledMessagesByNumber[messageMetadata.Number] = messageMetadata; } } public void EnableAndActivateMessage(string messageName) { if (messageName is null) { throw new ArgumentNullException(nameof(messageName)); } lock (this) { if (!MessagesByName.TryGetValue(messageName, out var messageMetadata)) { throw CreateMessageNotSupportedException(messageName); } var enabledMessageMetadata = _enabledMessagesByNumber[messageMetadata.Number]; if (enabledMessageMetadata != null && enabledMessageMetadata != messageMetadata) { throw CreateMessageTypeAlreadyEnabledForOtherMessageException(messageMetadata.Number, messageMetadata.Name, enabledMessageMetadata.Name); } _enabledMessagesByNumber[messageMetadata.Number] = messageMetadata; _activatedMessagesById[messageMetadata.Id] = true; } } public void DisableAndDeactivateMessage(string messageName) { if (messageName is null) { throw new ArgumentNullException(nameof(messageName)); } lock (this) { if (!MessagesByName.TryGetValue(messageName, out var messageMetadata)) { throw CreateMessageNotSupportedException(messageName); } var enabledMessageMetadata = _enabledMessagesByNumber[messageMetadata.Number]; if (enabledMessageMetadata != null && enabledMessageMetadata != messageMetadata) { throw CreateMessageTypeAlreadyEnabledForOtherMessageException(messageMetadata.Number, messageMetadata.Name, enabledMessageMetadata.Name); } _activatedMessagesById[messageMetadata.Id] = false; _enabledMessagesByNumber[messageMetadata.Number] = null; } } private static SshException CreateMessageTypeNotSupportedException(byte messageNumber) { throw new SshException(string.Format(CultureInfo.InvariantCulture, "Message type {0} is not supported.", messageNumber)); } private static SshException CreateMessageNotSupportedException(string messageName) { throw new SshException(string.Format(CultureInfo.InvariantCulture, "Message '{0}' is not supported.", messageName)); } private static SshException CreateMessageTypeAlreadyEnabledForOtherMessageException(byte messageNumber, string messageName, string currentEnabledForMessageName) { throw new SshException(string.Format(CultureInfo.InvariantCulture, "Cannot enable message '{0}'. Message type {1} is already enabled for '{2}'.", messageName, messageNumber, currentEnabledForMessageName)); } internal abstract class MessageMetadata { protected MessageMetadata(byte id, string name, byte number) { Id = id; Name = name; Number = number; } public readonly byte Id; public readonly string Name; public readonly byte Number; public abstract Message Create(); } internal sealed class MessageMetadata : MessageMetadata where T : Message, new() { public MessageMetadata(byte id, string name, byte number) : base(id, name, number) { } public override Message Create() { return new T(); } } } }