using System;
using System.Globalization;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Transport;
namespace Renci.SshNet.Abstractions
{
internal static class SocketAbstraction
{
public static bool CanRead(Socket socket)
{
if (socket.Connected)
{
return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
}
return false;
}
///
/// Returns a value indicating whether the specified can be used
/// to send data.
///
/// The to check.
///
/// true if can be written to; otherwise, false.
///
public static bool CanWrite(Socket socket)
{
if (socket != null && socket.Connected)
{
return socket.Poll(-1, SelectMode.SelectWrite);
}
return false;
}
public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: true);
return socket;
}
public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
{
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false);
}
public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
{
await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false);
}
private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
{
#if FEATURE_SOCKET_EAP
var connectCompleted = new ManualResetEvent(initialState: false);
var args = new SocketAsyncEventArgs
{
UserToken = connectCompleted,
RemoteEndPoint = remoteEndpoint
};
args.Completed += ConnectCompleted;
if (socket.ConnectAsync(args))
{
if (!connectCompleted.WaitOne(connectTimeout))
{
// avoid ObjectDisposedException in ConnectCompleted
args.Completed -= ConnectCompleted;
if (ownsSocket)
{
// dispose Socket
socket.Dispose();
}
// dispose ManualResetEvent
connectCompleted.Dispose();
// dispose SocketAsyncEventArgs
args.Dispose();
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Connection failed to establish within {0:F0} milliseconds.",
connectTimeout.TotalMilliseconds));
}
}
// dispose ManualResetEvent
connectCompleted.Dispose();
if (args.SocketError != SocketError.Success)
{
var socketError = (int) args.SocketError;
if (ownsSocket)
{
// dispose Socket
socket.Dispose();
}
// dispose SocketAsyncEventArgs
args.Dispose();
throw new SocketException(socketError);
}
// dispose SocketAsyncEventArgs
args.Dispose();
#elif FEATURE_SOCKET_APM
var connectResult = socket.BeginConnect(remoteEndpoint, null, null);
if (!connectResult.AsyncWaitHandle.WaitOne(connectTimeout, false))
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds));
socket.EndConnect(connectResult);
#elif FEATURE_SOCKET_TAP
if (!socket.ConnectAsync(remoteEndpoint).Wait(connectTimeout))
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Connection failed to establish within {0:F0} milliseconds.", connectTimeout.TotalMilliseconds));
#else
#error Connecting to a remote endpoint is not implemented.
#endif
}
public static void ClearReadBuffer(Socket socket)
{
var timeout = TimeSpan.FromMilliseconds(500);
var buffer = new byte[256];
int bytesReceived;
do
{
bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout);
}
while (bytesReceived > 0);
}
public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
{
socket.ReceiveTimeout = (int) timeout.TotalMilliseconds;
try
{
return socket.Receive(buffer, offset, size, SocketFlags.None);
}
catch (SocketException ex)
{
if (ex.SocketErrorCode == SocketError.TimedOut)
{
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Socket read operation has timed out after {0:F0} milliseconds.",
timeout.TotalMilliseconds));
}
throw;
}
}
public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action processReceivedBytesAction)
{
// do not time-out receive
socket.ReceiveTimeout = 0;
while (socket.Connected)
{
try
{
var bytesRead = socket.Receive(buffer, offset, size, SocketFlags.None);
if (bytesRead == 0)
{
break;
}
processReceivedBytesAction(buffer, offset, bytesRead);
}
catch (SocketException ex)
{
if (IsErrorResumable(ex.SocketErrorCode))
{
continue;
}
#pragma warning disable IDE0010 // Add missing cases
switch (ex.SocketErrorCode)
{
case SocketError.ConnectionAborted:
case SocketError.ConnectionReset:
// connection was closed
return;
case SocketError.Interrupted:
// connection was closed because FIN/ACK was not received in time after
// shutting down the (send part of the) socket
return;
default:
throw; // throw any other error
}
#pragma warning restore IDE0010 // Add missing cases
}
}
}
///
/// Reads a byte from the specified .
///
/// The to read from.
/// Specifies the amount of time after which the call will time out.
///
/// The byte read, or -1 if the socket was closed.
///
/// The read operation timed out.
/// The read failed.
public static int ReadByte(Socket socket, TimeSpan timeout)
{
var buffer = new byte[1];
if (Read(socket, buffer, 0, 1, timeout) == 0)
{
return -1;
}
return buffer[0];
}
///
/// Sends a byte using the specified .
///
/// The to write to.
/// The value to send.
/// The write failed.
public static void SendByte(Socket socket, byte value)
{
var buffer = new[] { value };
Send(socket, buffer, 0, 1);
}
///
/// Receives data from a bound .
///
/// The to read from.
/// The number of bytes to receive.
/// Specifies the amount of time after which the call will time out.
///
/// The bytes received.
///
///
/// If no data is available for reading, the method will
/// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
/// call will throw a .
/// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
/// method will complete immediately and throw a .
///
public static byte[] Read(Socket socket, int size, TimeSpan timeout)
{
var buffer = new byte[size];
_ = Read(socket, buffer, 0, size, timeout);
return buffer;
}
public static Task ReadAsync(Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, offset, length, cancellationToken);
}
///
/// Receives data from a bound into a receive buffer.
///
/// The to read from.
/// An array of type that is the storage location for the received data.
/// The position in parameter to store the received data.
/// The number of bytes to receive.
/// The maximum time to wait until bytes have been received.
///
/// The number of bytes received.
///
///
///
/// If no data is available for reading, the method will
/// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
/// call will throw a .
///
///
/// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
/// method will complete immediately and throw a .
///
///
public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout)
{
var totalBytesRead = 0;
var totalBytesToRead = size;
socket.ReceiveTimeout = (int)readTimeout.TotalMilliseconds;
do
{
try
{
var bytesRead = socket.Receive(buffer, offset + totalBytesRead, totalBytesToRead - totalBytesRead, SocketFlags.None);
if (bytesRead == 0)
{
return 0;
}
totalBytesRead += bytesRead;
}
catch (SocketException ex)
{
if (IsErrorResumable(ex.SocketErrorCode))
{
ThreadAbstraction.Sleep(30);
continue;
}
if (ex.SocketErrorCode == SocketError.TimedOut)
{
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
"Socket read operation has timed out after {0:F0} milliseconds.",
readTimeout.TotalMilliseconds));
}
throw;
}
}
while (totalBytesRead < totalBytesToRead);
return totalBytesRead;
}
public static void Send(Socket socket, byte[] data)
{
Send(socket, data, 0, data.Length);
}
public static void Send(Socket socket, byte[] data, int offset, int size)
{
var totalBytesSent = 0; // how many bytes are already sent
var totalBytesToSend = size;
do
{
try
{
var bytesSent = socket.Send(data, offset + totalBytesSent, totalBytesToSend - totalBytesSent, SocketFlags.None);
if (bytesSent == 0)
{
throw new SshConnectionException("An established connection was aborted by the server.",
DisconnectReason.ConnectionLost);
}
totalBytesSent += bytesSent;
}
catch (SocketException ex)
{
if (IsErrorResumable(ex.SocketErrorCode))
{
// socket buffer is probably full, wait and try again
ThreadAbstraction.Sleep(30);
}
else
{
throw; // any serious error occurr
}
}
}
while (totalBytesSent < totalBytesToSend);
}
public static bool IsErrorResumable(SocketError socketError)
{
#pragma warning disable IDE0010 // Add missing cases
switch (socketError)
{
case SocketError.WouldBlock:
case SocketError.IOPending:
case SocketError.NoBufferSpaceAvailable:
return true;
default:
return false;
}
#pragma warning restore IDE0010 // Add missing cases
}
#if FEATURE_SOCKET_EAP
private static void ConnectCompleted(object sender, SocketAsyncEventArgs e)
{
var eventWaitHandle = (ManualResetEvent) e.UserToken;
_ = eventWaitHandle?.Set();
}
#endif // FEATURE_SOCKET_EAP
}
}