using System; using System.Collections.Generic; using System.Linq; using System.Security.Cryptography; using Renci.SshNet.Abstractions; using Renci.SshNet.Common; using Renci.SshNet.Compression; using Renci.SshNet.Messages; using Renci.SshNet.Messages.Transport; using Renci.SshNet.Security.Cryptography; namespace Renci.SshNet.Security { /// /// Represents base class for different key exchange algorithm implementations /// public abstract class KeyExchange : Algorithm, IKeyExchange { private CipherInfo _clientCipherInfo; private CipherInfo _serverCipherInfo; private HashInfo _clientHashInfo; private HashInfo _serverHashInfo; private Type _compressionType; private Type _decompressionType; /// /// Gets the session. /// /// /// The session. /// protected Session Session { get; private set; } /// /// Gets or sets key exchange shared key. /// /// /// The shared key. /// public byte[] SharedKey { get; protected set; } private byte[] _exchangeHash; /// /// Gets the exchange hash. /// /// The exchange hash. public byte[] ExchangeHash { get { _exchangeHash ??= CalculateHash(); return _exchangeHash; } } /// /// Occurs when host key received. /// public event EventHandler HostKeyReceived; /// /// Starts key exchange algorithm. /// /// The session. /// Key exchange init message. public virtual void Start(Session session, KeyExchangeInitMessage message) { Session = session; SendMessage(session.ClientInitMessage); // Determine encryption algorithm var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys from a in message.EncryptionAlgorithmsClientToServer where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(clientEncryptionAlgorithmName)) { throw new SshConnectionException("Client encryption algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientEncryption = clientEncryptionAlgorithmName; // Determine encryption algorithm var serverDecryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys from a in message.EncryptionAlgorithmsServerToClient where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(serverDecryptionAlgorithmName)) { throw new SshConnectionException("Server decryption algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerEncryption = serverDecryptionAlgorithmName; // Determine client hmac algorithm var clientHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys from a in message.MacAlgorithmsClientToServer where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(clientHmacAlgorithmName)) { throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientHmacAlgorithm = clientHmacAlgorithmName; // Determine server hmac algorithm var serverHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys from a in message.MacAlgorithmsServerToClient where a == b select a).FirstOrDefault(); if (string.IsNullOrEmpty(serverHmacAlgorithmName)) { throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerHmacAlgorithm = serverHmacAlgorithmName; // Determine compression algorithm var compressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys from a in message.CompressionAlgorithmsClientToServer where a == b select a).LastOrDefault(); if (string.IsNullOrEmpty(compressionAlgorithmName)) { throw new SshConnectionException("Compression algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentClientCompressionAlgorithm = compressionAlgorithmName; // Determine decompression algorithm var decompressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys from a in message.CompressionAlgorithmsServerToClient where a == b select a).LastOrDefault(); if (string.IsNullOrEmpty(decompressionAlgorithmName)) { throw new SshConnectionException("Decompression algorithm not found", DisconnectReason.KeyExchangeFailed); } session.ConnectionInfo.CurrentServerCompressionAlgorithm = decompressionAlgorithmName; _clientCipherInfo = session.ConnectionInfo.Encryptions[clientEncryptionAlgorithmName]; _serverCipherInfo = session.ConnectionInfo.Encryptions[serverDecryptionAlgorithmName]; _clientHashInfo = session.ConnectionInfo.HmacAlgorithms[clientHmacAlgorithmName]; _serverHashInfo = session.ConnectionInfo.HmacAlgorithms[serverHmacAlgorithmName]; _compressionType = session.ConnectionInfo.CompressionAlgorithms[compressionAlgorithmName]; _decompressionType = session.ConnectionInfo.CompressionAlgorithms[decompressionAlgorithmName]; } /// /// Finishes key exchange algorithm. /// public virtual void Finish() { if (!ValidateExchangeHash()) { throw new SshConnectionException("Key exchange negotiation failed.", DisconnectReason.KeyExchangeFailed); } SendMessage(new NewKeysMessage()); } /// /// Creates the server side cipher to use. /// /// Server cipher. public Cipher CreateServerCipher() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; // Calculate server to client initial IV var serverVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'B', sessionId)); // Calculate server to client encryption var serverKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'D', sessionId)); serverKey = GenerateSessionKey(SharedKey, ExchangeHash, serverKey, _serverCipherInfo.KeySize / 8); DiagnosticAbstraction.Log(string.Format("[{0}] Creating server cipher (Name:{1},Key:{2},IV:{3})", Session.ToHex(Session.SessionId), Session.ConnectionInfo.CurrentServerEncryption, Session.ToHex(serverKey), Session.ToHex(serverVector))); // Create server cipher return _serverCipherInfo.Cipher(serverKey, serverVector); } /// /// Creates the client side cipher to use. /// /// Client cipher. public Cipher CreateClientCipher() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; // Calculate client to server initial IV var clientVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'A', sessionId)); // Calculate client to server encryption var clientKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'C', sessionId)); clientKey = GenerateSessionKey(SharedKey, ExchangeHash, clientKey, _clientCipherInfo.KeySize / 8); // Create client cipher return _clientCipherInfo.Cipher(clientKey, clientVector); } /// /// Creates the server side hash algorithm to use. /// /// /// The server-side hash algorithm. /// public HashAlgorithm CreateServerHash() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; var serverKey = GenerateSessionKey(SharedKey, ExchangeHash, Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'F', sessionId)), _serverHashInfo.KeySize / 8); return _serverHashInfo.HashAlgorithm(serverKey); } /// /// Creates the client side hash algorithm to use. /// /// /// The client-side hash algorithm. /// public HashAlgorithm CreateClientHash() { // Resolve Session ID var sessionId = Session.SessionId ?? ExchangeHash; var clientKey = GenerateSessionKey(SharedKey, ExchangeHash, Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'E', sessionId)), _clientHashInfo.KeySize / 8); return _clientHashInfo.HashAlgorithm(clientKey); } /// /// Creates the compression algorithm to use to deflate data. /// /// /// The compression method. /// public Compressor CreateCompressor() { if (_compressionType is null) { return null; } var compressor = _compressionType.CreateInstance(); compressor.Init(Session); return compressor; } /// /// Creates the compression algorithm to use to inflate data. /// /// /// The decompression method. /// public Compressor CreateDecompressor() { if (_decompressionType is null) { return null; } var decompressor = _decompressionType.CreateInstance(); decompressor.Init(Session); return decompressor; } /// /// Determines whether the specified host key can be trusted. /// /// The host algorithm. /// /// true if the specified host can be trusted; otherwise, false. /// protected bool CanTrustHostKey(KeyHostAlgorithm host) { var handlers = HostKeyReceived; if (handlers != null) { var args = new HostKeyEventArgs(host); handlers(this, args); return args.CanTrust; } return true; } /// /// Validates the exchange hash. /// /// true if exchange hash is valid; otherwise false. protected abstract bool ValidateExchangeHash(); /// /// Calculates key exchange hash value. /// /// Key exchange hash. protected abstract byte[] CalculateHash(); /// /// Hashes the specified data bytes. /// /// The hash data. /// /// The hash of the data. /// protected abstract byte[] Hash(byte[] hashData); /// /// Sends SSH message to the server. /// /// The message. protected void SendMessage(Message message) { Session.SendMessage(message); } /// /// Generates the session key. /// /// The shared key. /// The exchange hash. /// The key. /// The size. /// /// The session key. /// private byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, byte[] key, int size) { var result = new List(key); while (size > result.Count) { var sessionKeyAdjustment = new SessionKeyAdjustment { SharedKey = sharedKey, ExchangeHash = exchangeHash, Key = key, }; result.AddRange(Hash(sessionKeyAdjustment.GetBytes())); } return result.ToArray(); } /// /// Generates the session key. /// /// The shared key. /// The exchange hash. /// The p. /// The session id. /// /// The session key. /// private static byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, char p, byte[] sessionId) { var sessionKeyGeneration = new SessionKeyGeneration { SharedKey = sharedKey, ExchangeHash = exchangeHash, Char = p, SessionId = sessionId }; return sessionKeyGeneration.GetBytes(); } private sealed class SessionKeyGeneration : SshData { public byte[] SharedKey { get; set; } public byte[] ExchangeHash { get; set; } public char Char { get; set; } public byte[] SessionId { get; set; } /// /// Gets the size of the message in bytes. /// /// /// The size of the messages in bytes. /// protected override int BufferCapacity { get { var capacity = base.BufferCapacity; capacity += 4; // SharedKey length capacity += SharedKey.Length; // SharedKey capacity += ExchangeHash.Length; // ExchangeHash capacity += 1; // Char capacity += SessionId.Length; // SessionId return capacity; } } protected override void LoadData() { throw new NotImplementedException(); } protected override void SaveData() { WriteBinaryString(SharedKey); Write(ExchangeHash); Write((byte) Char); Write(SessionId); } } private sealed class SessionKeyAdjustment : SshData { public byte[] SharedKey { get; set; } public byte[] ExchangeHash { get; set; } public byte[] Key { get; set; } /// /// Gets the size of the message in bytes. /// /// /// The size of the messages in bytes. /// protected override int BufferCapacity { get { var capacity = base.BufferCapacity; capacity += 4; // SharedKey length capacity += SharedKey.Length; // SharedKey capacity += ExchangeHash.Length; // ExchangeHash capacity += Key.Length; // Key return capacity; } } protected override void LoadData() { throw new NotImplementedException(); } protected override void SaveData() { WriteBinaryString(SharedKey); Write(ExchangeHash); Write(Key); } } #region IDisposable Members /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. /// public void Dispose() { Dispose(disposing: true); GC.SuppressFinalize(this); } /// /// Releases unmanaged and - optionally - managed resources /// /// true to release both managed and unmanaged resources; false to release only unmanaged resources. protected virtual void Dispose(bool disposing) { } /// /// Releases unmanaged resources and performs other cleanup operations before the /// is reclaimed by garbage collection. /// ~KeyExchange() { Dispose(disposing: false); } #endregion } }