using System; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using Renci.SshNet.Abstractions; using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.Connection { internal abstract class ConnectorBase : IConnector { protected ConnectorBase(ISocketFactory socketFactory) { if (socketFactory is null) { throw new ArgumentNullException(nameof(socketFactory)); } SocketFactory = socketFactory; } internal ISocketFactory SocketFactory { get; private set; } public abstract Socket Connect(IConnectionInfo connectionInfo); public abstract Task ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken); /// /// Establishes a socket connection to the specified host and port. /// /// The host name of the server to connect to. /// The port to connect to. /// The maximum time to wait for the connection to be established. /// The connection failed to establish within the configured . /// An error occurred trying to establish the connection. protected Socket SocketConnect(string host, int port, TimeSpan timeout) { var ipAddress = DnsAbstraction.GetHostAddresses(host)[0]; var ep = new IPEndPoint(ipAddress, port); DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port)); var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try { SocketAbstraction.Connect(socket, ep, timeout); const int socketBufferSize = 2 * Session.MaximumSshPacketSize; socket.SendBufferSize = socketBufferSize; socket.ReceiveBufferSize = socketBufferSize; return socket; } catch (Exception) { socket.Dispose(); throw; } } /// /// Establishes a socket connection to the specified host and port. /// /// The host name of the server to connect to. /// The port to connect to. /// The cancellation token to observe. /// The connection failed to establish within the configured . /// An error occurred trying to establish the connection. protected async Task SocketConnectAsync(string host, int port, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var ipAddress = (await DnsAbstraction.GetHostAddressesAsync(host).ConfigureAwait(false))[0]; var ep = new IPEndPoint(ipAddress, port); DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port)); var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp); try { await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false); const int socketBufferSize = 2 * Session.MaximumSshPacketSize; socket.SendBufferSize = socketBufferSize; socket.ReceiveBufferSize = socketBufferSize; return socket; } catch (Exception) { socket.Dispose(); throw; } } protected static byte SocketReadByte(Socket socket) { var buffer = new byte[1]; _ = SocketRead(socket, buffer, 0, 1, Session.InfiniteTimeSpan); return buffer[0]; } protected static byte SocketReadByte(Socket socket, TimeSpan readTimeout) { var buffer = new byte[1]; _ = SocketRead(socket, buffer, 0, 1, readTimeout); return buffer[0]; } /// /// Performs a blocking read on the socket until bytes are received. /// /// The to read from. /// An array of type that is the storage location for the received data. /// The position in parameter to store the received data. /// The number of bytes to read. /// /// The number of bytes read. /// /// The socket is closed. /// The read failed. protected static int SocketRead(Socket socket, byte[] buffer, int offset, int length) { return SocketRead(socket, buffer, offset, length, Session.InfiniteTimeSpan); } /// /// Performs a blocking read on the socket until bytes are received. /// /// The to read from. /// An array of type that is the storage location for the received data. /// The position in parameter to store the received data. /// The number of bytes to read. /// The maximum time to wait until bytes have been received. /// /// The number of bytes read. /// /// The socket is closed. /// The read has timed-out. /// The read failed. protected static int SocketRead(Socket socket, byte[] buffer, int offset, int length, TimeSpan readTimeout) { var bytesRead = SocketAbstraction.Read(socket, buffer, offset, length, readTimeout); if (bytesRead == 0) { throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); } return bytesRead; } } }