using System;
using System.Collections.Generic;
using Renci.SshNet.Common;
namespace Renci.SshNet
{
internal sealed class ClientAuthentication : IClientAuthentication
{
private readonly int _partialSuccessLimit;
///
/// Initializes a new instance of the class.
///
/// The number of times an authentication attempt with any given can result in before it is disregarded.
/// is less than one.
public ClientAuthentication(int partialSuccessLimit)
{
if (partialSuccessLimit < 1)
{
throw new ArgumentOutOfRangeException(nameof(partialSuccessLimit), "Cannot be less than one.");
}
_partialSuccessLimit = partialSuccessLimit;
}
///
/// Gets the number of times an authentication attempt with any given can
/// result in before it is disregarded.
///
///
/// The number of times an authentication attempt with any given can result
/// in before it is disregarded.
///
internal int PartialSuccessLimit
{
get { return _partialSuccessLimit; }
}
///
/// Attempts to authentication for a given using the
/// of the specified .
///
/// A to use for authenticating.
/// The for which to perform authentication.
public void Authenticate(IConnectionInfoInternal connectionInfo, ISession session)
{
if (connectionInfo is null)
{
throw new ArgumentNullException(nameof(connectionInfo));
}
if (session is null)
{
throw new ArgumentNullException(nameof(session));
}
session.RegisterMessage("SSH_MSG_USERAUTH_FAILURE");
session.RegisterMessage("SSH_MSG_USERAUTH_SUCCESS");
session.RegisterMessage("SSH_MSG_USERAUTH_BANNER");
session.UserAuthenticationBannerReceived += connectionInfo.UserAuthenticationBannerReceived;
try
{
// the exception to report an authentication failure with
SshAuthenticationException authenticationException = null;
// try to authenticate against none
var noneAuthenticationMethod = connectionInfo.CreateNoneAuthenticationMethod();
var authenticated = noneAuthenticationMethod.Authenticate(session);
if (authenticated != AuthenticationResult.Success)
{
if (!TryAuthenticate(session, new AuthenticationState(connectionInfo.AuthenticationMethods), noneAuthenticationMethod.AllowedAuthentications, ref authenticationException))
{
throw authenticationException;
}
}
}
finally
{
session.UserAuthenticationBannerReceived -= connectionInfo.UserAuthenticationBannerReceived;
session.UnRegisterMessage("SSH_MSG_USERAUTH_FAILURE");
session.UnRegisterMessage("SSH_MSG_USERAUTH_SUCCESS");
session.UnRegisterMessage("SSH_MSG_USERAUTH_BANNER");
}
}
private bool TryAuthenticate(ISession session,
AuthenticationState authenticationState,
string[] allowedAuthenticationMethods,
ref SshAuthenticationException authenticationException)
{
if (allowedAuthenticationMethods.Length == 0)
{
authenticationException = new SshAuthenticationException("No authentication methods defined on SSH server.");
return false;
}
// we want to try authentication methods in the order in which they were
// passed in the ctor, not the order in which the SSH server returns
// the allowed authentication methods
var matchingAuthenticationMethods = authenticationState.GetSupportedAuthenticationMethods(allowedAuthenticationMethods);
if (matchingAuthenticationMethods.Count == 0)
{
authenticationException = new SshAuthenticationException(string.Format("No suitable authentication method found to complete authentication ({0}).",
string.Join(",", allowedAuthenticationMethods)));
return false;
}
foreach (var authenticationMethod in authenticationState.GetActiveAuthenticationMethods(matchingAuthenticationMethods))
{
// guard against a stack overlow for servers that do not update the list of allowed authentication
// methods after a partial success
if (authenticationState.GetPartialSuccessCount(authenticationMethod) >= _partialSuccessLimit)
{
// TODO Get list of all authentication methods that have reached the partial success limit?
authenticationException = new SshAuthenticationException(string.Format("Reached authentication attempt limit for method ({0}).",
authenticationMethod.Name));
continue;
}
var authenticationResult = authenticationMethod.Authenticate(session);
switch (authenticationResult)
{
case AuthenticationResult.PartialSuccess:
authenticationState.RecordPartialSuccess(authenticationMethod);
if (TryAuthenticate(session, authenticationState, authenticationMethod.AllowedAuthentications, ref authenticationException))
{
authenticationResult = AuthenticationResult.Success;
}
break;
case AuthenticationResult.Failure:
authenticationState.RecordFailure(authenticationMethod);
authenticationException = new SshAuthenticationException(string.Format("Permission denied ({0}).", authenticationMethod.Name));
break;
case AuthenticationResult.Success:
authenticationException = null;
break;
default:
break;
}
if (authenticationResult == AuthenticationResult.Success)
{
return true;
}
}
return false;
}
private sealed class AuthenticationState
{
private readonly IList _supportedAuthenticationMethods;
///
/// Records if a given has been tried, and how many times this resulted
/// in .
///
///
/// When there's no entry for a given , then it was never tried.
///
private readonly Dictionary _authenticationMethodPartialSuccessRegister;
///
/// Holds the list of authentications methods that failed.
///
private readonly List _failedAuthenticationMethods;
public AuthenticationState(IList supportedAuthenticationMethods)
{
_supportedAuthenticationMethods = supportedAuthenticationMethods;
_failedAuthenticationMethods = new List();
_authenticationMethodPartialSuccessRegister = new Dictionary();
}
///
/// Records a authentication attempt for the specified
/// .
///
/// An for which to record the result of an authentication attempt.
public void RecordFailure(IAuthenticationMethod authenticationMethod)
{
_failedAuthenticationMethods.Add(authenticationMethod);
}
///
/// Records a authentication attempt for the specified
/// .
///
/// An for which to record the result of an authentication attempt.
public void RecordPartialSuccess(IAuthenticationMethod authenticationMethod)
{
if (_authenticationMethodPartialSuccessRegister.TryGetValue(authenticationMethod, out var partialSuccessCount))
{
_authenticationMethodPartialSuccessRegister[authenticationMethod] = ++partialSuccessCount;
}
else
{
_authenticationMethodPartialSuccessRegister.Add(authenticationMethod, 1);
}
}
///
/// Returns the number of times an authentication attempt with the specified
/// has resulted in .
///
/// An .
///
/// The number of times an authentication attempt with the specified
/// has resulted in .
///
public int GetPartialSuccessCount(IAuthenticationMethod authenticationMethod)
{
if (_authenticationMethodPartialSuccessRegister.TryGetValue(authenticationMethod, out var partialSuccessCount))
{
return partialSuccessCount;
}
return 0;
}
///
/// Returns a list of supported authentication methods that match one of the specified allowed authentication
/// methods.
///
/// A list of allowed authentication methods.
///
/// A list of supported authentication methods that match one of the specified allowed authentication methods.
///
///
/// The authentication methods are returned in the order in which they were specified in the list that was
/// used to initialize the current instance.
///
public List GetSupportedAuthenticationMethods(string[] allowedAuthenticationMethods)
{
var result = new List();
foreach (var supportedAuthenticationMethod in _supportedAuthenticationMethods)
{
var nameOfSupportedAuthenticationMethod = supportedAuthenticationMethod.Name;
for (var i = 0; i < allowedAuthenticationMethods.Length; i++)
{
if (allowedAuthenticationMethods[i] == nameOfSupportedAuthenticationMethod)
{
result.Add(supportedAuthenticationMethod);
break;
}
}
}
return result;
}
///
/// Returns the authentication methods from the specified list that have not yet failed.
///
/// A list of authentication methods.
///
/// The authentication methods from that have not yet failed.
///
///
///
/// This method first returns the authentication methods that have not yet been executed, and only then
/// returns those for which an authentication attempt resulted in a .
///
///
/// Any that has failed is skipped.
///
///
public IEnumerable GetActiveAuthenticationMethods(List matchingAuthenticationMethods)
{
var skippedAuthenticationMethods = new List();
for (var i = 0; i < matchingAuthenticationMethods.Count; i++)
{
var authenticationMethod = matchingAuthenticationMethods[i];
// skip authentication methods that have already failed
if (_failedAuthenticationMethods.Contains(authenticationMethod))
{
continue;
}
// delay use of authentication methods that had a PartialSuccess result
if (_authenticationMethodPartialSuccessRegister.ContainsKey(authenticationMethod))
{
skippedAuthenticationMethods.Add(authenticationMethod);
continue;
}
yield return authenticationMethod;
}
foreach (var authenticationMethod in skippedAuthenticationMethods)
{
yield return authenticationMethod;
}
}
}
}
}