using System; using System.Globalization; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; using Renci.SshNet.Common; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.Abstractions { internal static class SocketAbstraction { public static bool CanRead(Socket socket) { if (socket.Connected) { return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0; } return false; } /// /// Returns a value indicating whether the specified can be used /// to send data. /// /// The to check. /// /// true if can be written to; otherwise, false. /// public static bool CanWrite(Socket socket) { if (socket != null && socket.Connected) { return socket.Poll(-1, SelectMode.SelectWrite); } return false; } public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: true); return socket; } public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout) { ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false); } public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken) { await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false); } private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket) { #if FEATURE_SOCKET_EAP var connectCompleted = new ManualResetEvent(initialState: false); var args = new SocketAsyncEventArgs { UserToken = connectCompleted, RemoteEndPoint = remoteEndpoint }; args.Completed += ConnectCompleted; if (socket.ConnectAsync(args)) { if (!connectCompleted.WaitOne(connectTimeout)) { // avoid ObjectDisposedException in ConnectCompleted args.Completed -= ConnectCompleted; if (ownsSocket) { // dispose Socket socket.Dispose(); } // dispose ManualResetEvent connectCompleted.Dispose(); // dispose SocketAsyncEventArgs args.Dispose(); throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); } } // dispose ManualResetEvent connectCompleted.Dispose(); if (args.SocketError != SocketError.Success) { var socketError = (int) args.SocketError; if (ownsSocket) { // dispose Socket socket.Dispose(); } // dispose SocketAsyncEventArgs args.Dispose(); throw new SocketException(socketError); } // dispose SocketAsyncEventArgs args.Dispose(); #elif FEATURE_SOCKET_APM var connectResult = socket.BeginConnect(remoteEndpoint, null, null); if (!connectResult.AsyncWaitHandle.WaitOne(connectTimeout, false)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); socket.EndConnect(connectResult); #elif FEATURE_SOCKET_TAP if (!socket.ConnectAsync(remoteEndpoint).Wait(connectTimeout)) throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds)); #else #error Connecting to a remote endpoint is not implemented. #endif } public static void ClearReadBuffer(Socket socket) { var timeout = TimeSpan.FromMilliseconds(500); var buffer = new byte[256]; int bytesReceived; do { bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout); } while (bytesReceived > 0); } public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout) { socket.ReceiveTimeout = (int) timeout.TotalMilliseconds; try { return socket.Receive(buffer, offset, size, SocketFlags.None); } catch (SocketException ex) { if (ex.SocketErrorCode == SocketError.TimedOut) { throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", timeout.TotalMilliseconds)); } throw; } } public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action processReceivedBytesAction) { // do not time-out receive socket.ReceiveTimeout = 0; while (socket.Connected) { try { var bytesRead = socket.Receive(buffer, offset, size, SocketFlags.None); if (bytesRead == 0) { break; } processReceivedBytesAction(buffer, offset, bytesRead); } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) { continue; } #pragma warning disable IDE0010 // Add missing cases switch (ex.SocketErrorCode) { case SocketError.ConnectionAborted: case SocketError.ConnectionReset: // connection was closed return; case SocketError.Interrupted: // connection was closed because FIN/ACK was not received in time after // shutting down the (send part of the) socket return; default: throw; // throw any other error } #pragma warning restore IDE0010 // Add missing cases } } } /// /// Reads a byte from the specified . /// /// The to read from. /// Specifies the amount of time after which the call will time out. /// /// The byte read, or -1 if the socket was closed. /// /// The read operation timed out. /// The read failed. public static int ReadByte(Socket socket, TimeSpan timeout) { var buffer = new byte[1]; if (Read(socket, buffer, 0, 1, timeout) == 0) { return -1; } return buffer[0]; } /// /// Sends a byte using the specified . /// /// The to write to. /// The value to send. /// The write failed. public static void SendByte(Socket socket, byte value) { var buffer = new[] { value }; Send(socket, buffer, 0, 1); } /// /// Receives data from a bound . /// /// The to read from. /// The number of bytes to receive. /// Specifies the amount of time after which the call will time out. /// /// The bytes received. /// /// /// If no data is available for reading, the method will /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the /// call will throw a . /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the /// method will complete immediately and throw a . /// public static byte[] Read(Socket socket, int size, TimeSpan timeout) { var buffer = new byte[size]; _ = Read(socket, buffer, 0, size, timeout); return buffer; } public static Task ReadAsync(Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken) { return socket.ReceiveAsync(buffer, offset, length, cancellationToken); } /// /// Receives data from a bound into a receive buffer. /// /// The to read from. /// An array of type that is the storage location for the received data. /// The position in parameter to store the received data. /// The number of bytes to receive. /// The maximum time to wait until bytes have been received. /// /// The number of bytes received. /// /// /// /// If no data is available for reading, the method will /// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the /// call will throw a . /// /// /// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the /// method will complete immediately and throw a . /// /// public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout) { var totalBytesRead = 0; var totalBytesToRead = size; socket.ReceiveTimeout = (int)readTimeout.TotalMilliseconds; do { try { var bytesRead = socket.Receive(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead, SocketFlags.None); if (bytesRead == 0) { return 0; } totalBytesRead += bytesRead; } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) { ThreadAbstraction.Sleep(30); continue; } if (ex.SocketErrorCode == SocketError.TimedOut) { throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture, "Socket read operation has timed out after {0:F0} milliseconds.", readTimeout.TotalMilliseconds)); } throw; } } while (totalBytesRead < totalBytesToRead); return totalBytesRead; } public static void Send(Socket socket, byte[] data) { Send(socket, data, 0, data.Length); } public static void Send(Socket socket, byte[] data, int offset, int size) { var totalBytesSent = 0; // how many bytes are already sent var totalBytesToSend = size; do { try { var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None); if (bytesSent == 0) { throw new SshConnectionException("An established connection was aborted by the server.", DisconnectReason.ConnectionLost); } totalBytesSent += bytesSent; } catch (SocketException ex) { if (IsErrorResumable(ex.SocketErrorCode)) { // socket buffer is probably full, wait and try again ThreadAbstraction.Sleep(30); } else { throw; // any serious error occurr } } } while (totalBytesSent < totalBytesToSend); } public static bool IsErrorResumable(SocketError socketError) { #pragma warning disable IDE0010 // Add missing cases switch (socketError) { case SocketError.WouldBlock: case SocketError.IOPending: case SocketError.NoBufferSpaceAvailable: return true; default: return false; } #pragma warning restore IDE0010 // Add missing cases } #if FEATURE_SOCKET_EAP private static void ConnectCompleted(object sender, SocketAsyncEventArgs e) { var eventWaitHandle = (ManualResetEvent) e.UserToken; _ = eventWaitHandle?.Set(); } #endif // FEATURE_SOCKET_EAP } }