using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography;
using Renci.SshNet.Abstractions;
using Renci.SshNet.Common;
using Renci.SshNet.Compression;
using Renci.SshNet.Messages;
using Renci.SshNet.Messages.Transport;
using Renci.SshNet.Security.Cryptography;
namespace Renci.SshNet.Security
{
///
/// Represents base class for different key exchange algorithm implementations
///
public abstract class KeyExchange : Algorithm, IKeyExchange
{
private CipherInfo _clientCipherInfo;
private CipherInfo _serverCipherInfo;
private HashInfo _clientHashInfo;
private HashInfo _serverHashInfo;
private Type _compressionType;
private Type _decompressionType;
///
/// Gets the session.
///
///
/// The session.
///
protected Session Session { get; private set; }
///
/// Gets or sets key exchange shared key.
///
///
/// The shared key.
///
public byte[] SharedKey { get; protected set; }
private byte[] _exchangeHash;
///
/// Gets the exchange hash.
///
/// The exchange hash.
public byte[] ExchangeHash
{
get
{
_exchangeHash ??= CalculateHash();
return _exchangeHash;
}
}
///
/// Occurs when host key received.
///
public event EventHandler HostKeyReceived;
///
/// Starts key exchange algorithm.
///
/// The session.
/// Key exchange init message.
public virtual void Start(Session session, KeyExchangeInitMessage message)
{
Session = session;
SendMessage(session.ClientInitMessage);
// Determine encryption algorithm
var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
from a in message.EncryptionAlgorithmsClientToServer
where a == b
select a).FirstOrDefault();
if (string.IsNullOrEmpty(clientEncryptionAlgorithmName))
{
throw new SshConnectionException("Client encryption algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentClientEncryption = clientEncryptionAlgorithmName;
// Determine encryption algorithm
var serverDecryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
from a in message.EncryptionAlgorithmsServerToClient
where a == b
select a).FirstOrDefault();
if (string.IsNullOrEmpty(serverDecryptionAlgorithmName))
{
throw new SshConnectionException("Server decryption algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentServerEncryption = serverDecryptionAlgorithmName;
// Determine client hmac algorithm
var clientHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys
from a in message.MacAlgorithmsClientToServer
where a == b
select a).FirstOrDefault();
if (string.IsNullOrEmpty(clientHmacAlgorithmName))
{
throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentClientHmacAlgorithm = clientHmacAlgorithmName;
// Determine server hmac algorithm
var serverHmacAlgorithmName = (from b in session.ConnectionInfo.HmacAlgorithms.Keys
from a in message.MacAlgorithmsServerToClient
where a == b
select a).FirstOrDefault();
if (string.IsNullOrEmpty(serverHmacAlgorithmName))
{
throw new SshConnectionException("Server HMAC algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentServerHmacAlgorithm = serverHmacAlgorithmName;
// Determine compression algorithm
var compressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
from a in message.CompressionAlgorithmsClientToServer
where a == b
select a).LastOrDefault();
if (string.IsNullOrEmpty(compressionAlgorithmName))
{
throw new SshConnectionException("Compression algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentClientCompressionAlgorithm = compressionAlgorithmName;
// Determine decompression algorithm
var decompressionAlgorithmName = (from b in session.ConnectionInfo.CompressionAlgorithms.Keys
from a in message.CompressionAlgorithmsServerToClient
where a == b
select a).LastOrDefault();
if (string.IsNullOrEmpty(decompressionAlgorithmName))
{
throw new SshConnectionException("Decompression algorithm not found", DisconnectReason.KeyExchangeFailed);
}
session.ConnectionInfo.CurrentServerCompressionAlgorithm = decompressionAlgorithmName;
_clientCipherInfo = session.ConnectionInfo.Encryptions[clientEncryptionAlgorithmName];
_serverCipherInfo = session.ConnectionInfo.Encryptions[serverDecryptionAlgorithmName];
_clientHashInfo = session.ConnectionInfo.HmacAlgorithms[clientHmacAlgorithmName];
_serverHashInfo = session.ConnectionInfo.HmacAlgorithms[serverHmacAlgorithmName];
_compressionType = session.ConnectionInfo.CompressionAlgorithms[compressionAlgorithmName];
_decompressionType = session.ConnectionInfo.CompressionAlgorithms[decompressionAlgorithmName];
}
///
/// Finishes key exchange algorithm.
///
public virtual void Finish()
{
if (!ValidateExchangeHash())
{
throw new SshConnectionException("Key exchange negotiation failed.", DisconnectReason.KeyExchangeFailed);
}
SendMessage(new NewKeysMessage());
}
///
/// Creates the server side cipher to use.
///
/// Server cipher.
public Cipher CreateServerCipher()
{
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
// Calculate server to client initial IV
var serverVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'B', sessionId));
// Calculate server to client encryption
var serverKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'D', sessionId));
serverKey = GenerateSessionKey(SharedKey, ExchangeHash, serverKey, _serverCipherInfo.KeySize / 8);
DiagnosticAbstraction.Log(string.Format("[{0}] Creating server cipher (Name:{1},Key:{2},IV:{3})",
Session.ToHex(Session.SessionId),
Session.ConnectionInfo.CurrentServerEncryption,
Session.ToHex(serverKey),
Session.ToHex(serverVector)));
// Create server cipher
return _serverCipherInfo.Cipher(serverKey, serverVector);
}
///
/// Creates the client side cipher to use.
///
/// Client cipher.
public Cipher CreateClientCipher()
{
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
// Calculate client to server initial IV
var clientVector = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'A', sessionId));
// Calculate client to server encryption
var clientKey = Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'C', sessionId));
clientKey = GenerateSessionKey(SharedKey, ExchangeHash, clientKey, _clientCipherInfo.KeySize / 8);
// Create client cipher
return _clientCipherInfo.Cipher(clientKey, clientVector);
}
///
/// Creates the server side hash algorithm to use.
///
///
/// The server-side hash algorithm.
///
public HashAlgorithm CreateServerHash()
{
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
var serverKey = GenerateSessionKey(SharedKey,
ExchangeHash,
Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'F', sessionId)),
_serverHashInfo.KeySize / 8);
return _serverHashInfo.HashAlgorithm(serverKey);
}
///
/// Creates the client side hash algorithm to use.
///
///
/// The client-side hash algorithm.
///
public HashAlgorithm CreateClientHash()
{
// Resolve Session ID
var sessionId = Session.SessionId ?? ExchangeHash;
var clientKey = GenerateSessionKey(SharedKey,
ExchangeHash,
Hash(GenerateSessionKey(SharedKey, ExchangeHash, 'E', sessionId)),
_clientHashInfo.KeySize / 8);
return _clientHashInfo.HashAlgorithm(clientKey);
}
///
/// Creates the compression algorithm to use to deflate data.
///
///
/// The compression method.
///
public Compressor CreateCompressor()
{
if (_compressionType is null)
{
return null;
}
var compressor = _compressionType.CreateInstance();
compressor.Init(Session);
return compressor;
}
///
/// Creates the compression algorithm to use to inflate data.
///
///
/// The decompression method.
///
public Compressor CreateDecompressor()
{
if (_decompressionType is null)
{
return null;
}
var decompressor = _decompressionType.CreateInstance();
decompressor.Init(Session);
return decompressor;
}
///
/// Determines whether the specified host key can be trusted.
///
/// The host algorithm.
///
/// true if the specified host can be trusted; otherwise, false.
///
protected bool CanTrustHostKey(KeyHostAlgorithm host)
{
var handlers = HostKeyReceived;
if (handlers != null)
{
var args = new HostKeyEventArgs(host);
handlers(this, args);
return args.CanTrust;
}
return true;
}
///
/// Validates the exchange hash.
///
/// true if exchange hash is valid; otherwise false.
protected abstract bool ValidateExchangeHash();
///
/// Calculates key exchange hash value.
///
/// Key exchange hash.
protected abstract byte[] CalculateHash();
///
/// Hashes the specified data bytes.
///
/// The hash data.
///
/// The hash of the data.
///
protected abstract byte[] Hash(byte[] hashData);
///
/// Sends SSH message to the server.
///
/// The message.
protected void SendMessage(Message message)
{
Session.SendMessage(message);
}
///
/// Generates the session key.
///
/// The shared key.
/// The exchange hash.
/// The key.
/// The size.
///
/// The session key.
///
private byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, byte[] key, int size)
{
var result = new List(key);
while (size > result.Count)
{
var sessionKeyAdjustment = new SessionKeyAdjustment
{
SharedKey = sharedKey,
ExchangeHash = exchangeHash,
Key = key,
};
result.AddRange(Hash(sessionKeyAdjustment.GetBytes()));
}
return result.ToArray();
}
///
/// Generates the session key.
///
/// The shared key.
/// The exchange hash.
/// The p.
/// The session id.
///
/// The session key.
///
private static byte[] GenerateSessionKey(byte[] sharedKey, byte[] exchangeHash, char p, byte[] sessionId)
{
var sessionKeyGeneration = new SessionKeyGeneration
{
SharedKey = sharedKey,
ExchangeHash = exchangeHash,
Char = p,
SessionId = sessionId
};
return sessionKeyGeneration.GetBytes();
}
private sealed class SessionKeyGeneration : SshData
{
public byte[] SharedKey { get; set; }
public byte[] ExchangeHash { get; set; }
public char Char { get; set; }
public byte[] SessionId { get; set; }
///
/// Gets the size of the message in bytes.
///
///
/// The size of the messages in bytes.
///
protected override int BufferCapacity
{
get
{
var capacity = base.BufferCapacity;
capacity += 4; // SharedKey length
capacity += SharedKey.Length; // SharedKey
capacity += ExchangeHash.Length; // ExchangeHash
capacity += 1; // Char
capacity += SessionId.Length; // SessionId
return capacity;
}
}
protected override void LoadData()
{
throw new NotImplementedException();
}
protected override void SaveData()
{
WriteBinaryString(SharedKey);
Write(ExchangeHash);
Write((byte) Char);
Write(SessionId);
}
}
private sealed class SessionKeyAdjustment : SshData
{
public byte[] SharedKey { get; set; }
public byte[] ExchangeHash { get; set; }
public byte[] Key { get; set; }
///
/// Gets the size of the message in bytes.
///
///
/// The size of the messages in bytes.
///
protected override int BufferCapacity
{
get
{
var capacity = base.BufferCapacity;
capacity += 4; // SharedKey length
capacity += SharedKey.Length; // SharedKey
capacity += ExchangeHash.Length; // ExchangeHash
capacity += Key.Length; // Key
return capacity;
}
}
protected override void LoadData()
{
throw new NotImplementedException();
}
protected override void SaveData()
{
WriteBinaryString(SharedKey);
Write(ExchangeHash);
Write(Key);
}
}
#region IDisposable Members
///
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
///
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
///
/// Releases unmanaged and - optionally - managed resources
///
/// true to release both managed and unmanaged resources; false to release only unmanaged resources.
protected virtual void Dispose(bool disposing)
{
}
///
/// Releases unmanaged resources and performs other cleanup operations before the
/// is reclaimed by garbage collection.
///
~KeyExchange()
{
Dispose(disposing: false);
}
#endregion
}
}