diff --git a/Directory.Packages.props b/Directory.Packages.props index 2e377c4f..70eb82f3 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -60,7 +60,7 @@ - + @@ -75,8 +75,8 @@ - - + + diff --git a/src/Common/CancellableStreamReader/CancellableStreamReader.cs b/src/Common/CancellableStreamReader/CancellableStreamReader.cs new file mode 100644 index 00000000..f6df72d2 --- /dev/null +++ b/src/Common/CancellableStreamReader/CancellableStreamReader.cs @@ -0,0 +1,1395 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using ModelContextProtocol; + +namespace System.IO; + +/// +/// Netfx-compatible polyfill of System.IO.StreamReader that supports cancellation. +/// +internal class CancellableStreamReader : TextReader +{ + // CancellableStreamReader.Null is threadsafe. + public static new readonly CancellableStreamReader Null = new NullCancellableStreamReader(); + + // Using a 1K byte buffer and a 4K FileStream buffer works out pretty well + // perf-wise. On even a 40 MB text file, any perf loss by using a 4K + // buffer is negated by the win of allocating a smaller byte[], which + // saves construction time. This does break adaptive buffering, + // but this is slightly faster. + private const int DefaultBufferSize = 1024; // Byte buffer size + private const int DefaultFileStreamBufferSize = 4096; + private const int MinBufferSize = 128; + + private readonly Stream _stream; + private Encoding _encoding = null!; // only null in NullCancellableStreamReader where this is never used + private readonly byte[] _encodingPreamble = null!; // only null in NullCancellableStreamReader where this is never used + private Decoder _decoder = null!; // only null in NullCancellableStreamReader where this is never used + private readonly byte[] _byteBuffer = null!; // only null in NullCancellableStreamReader where this is never used + private char[] _charBuffer = null!; // only null in NullCancellableStreamReader where this is never used + private int _charPos; + private int _charLen; + // Record the number of valid bytes in the byteBuffer, for a few checks. + private int _byteLen; + // This is used only for preamble detection + private int _bytePos; + + // This is the maximum number of chars we can get from one call to + // ReadBuffer. Used so ReadBuffer can tell when to copy data into + // a user's char[] directly, instead of our internal char[]. + private int _maxCharsPerBuffer; + + /// True if the writer has been disposed; otherwise, false. + private bool _disposed; + + // We will support looking for byte order marks in the stream and trying + // to decide what the encoding might be from the byte order marks, IF they + // exist. But that's all we'll do. + private bool _detectEncoding; + + // Whether we must still check for the encoding's given preamble at the + // beginning of this file. + private bool _checkPreamble; + + // Whether the stream is most likely not going to give us back as much + // data as we want the next time we call it. We must do the computation + // before we do any byte order mark handling and save the result. Note + // that we need this to allow users to handle streams used for an + // interactive protocol, where they block waiting for the remote end + // to send a response, like logging in on a Unix machine. + private bool _isBlocked; + + // The intent of this field is to leave open the underlying stream when + // disposing of this CancellableStreamReader. A name like _leaveOpen is better, + // but this type is serializable, and this field's name was _closable. + private readonly bool _closable; // Whether to close the underlying stream. + + // We don't guarantee thread safety on CancellableStreamReader, but we should at + // least prevent users from trying to read anything while an Async + // read from the same thread is in progress. + private Task _asyncReadTask = Task.CompletedTask; + + private void CheckAsyncTaskInProgress() + { + // We are not locking the access to _asyncReadTask because this is not meant to guarantee thread safety. + // We are simply trying to deter calling any Read APIs while an async Read from the same thread is in progress. + if (!_asyncReadTask.IsCompleted) + { + ThrowAsyncIOInProgress(); + } + } + + [DoesNotReturn] + private static void ThrowAsyncIOInProgress() => + throw new InvalidOperationException("Async IO is in progress"); + + // CancellableStreamReader by default will ignore illegal UTF8 characters. We don't want to + // throw here because we want to be able to read ill-formed data without choking. + // The high level goal is to be tolerant of encoding errors when we read and very strict + // when we write. Hence, default StreamWriter encoding will throw on error. + + private CancellableStreamReader() + { + Debug.Assert(this is NullCancellableStreamReader); + _stream = Stream.Null; + _closable = true; + } + + public CancellableStreamReader(Stream stream) + : this(stream, true) + { + } + + public CancellableStreamReader(Stream stream, bool detectEncodingFromByteOrderMarks) + : this(stream, Encoding.UTF8, detectEncodingFromByteOrderMarks, DefaultBufferSize, false) + { + } + + public CancellableStreamReader(Stream stream, Encoding? encoding) + : this(stream, encoding, true, DefaultBufferSize, false) + { + } + + public CancellableStreamReader(Stream stream, Encoding? encoding, bool detectEncodingFromByteOrderMarks) + : this(stream, encoding, detectEncodingFromByteOrderMarks, DefaultBufferSize, false) + { + } + + // Creates a new CancellableStreamReader for the given stream. The + // character encoding is set by encoding and the buffer size, + // in number of 16-bit characters, is set by bufferSize. + // + // Note that detectEncodingFromByteOrderMarks is a very + // loose attempt at detecting the encoding by looking at the first + // 3 bytes of the stream. It will recognize UTF-8, little endian + // unicode, and big endian unicode text, but that's it. If neither + // of those three match, it will use the Encoding you provided. + // + public CancellableStreamReader(Stream stream, Encoding? encoding, bool detectEncodingFromByteOrderMarks, int bufferSize) + : this(stream, encoding, detectEncodingFromByteOrderMarks, bufferSize, false) + { + } + + public CancellableStreamReader(Stream stream, Encoding? encoding = null, bool detectEncodingFromByteOrderMarks = true, int bufferSize = -1, bool leaveOpen = false) + { + Throw.IfNull(stream); + + if (!stream.CanRead) + { + throw new ArgumentException("Stream not readable."); + } + + if (bufferSize == -1) + { + bufferSize = DefaultBufferSize; + } + + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize), bufferSize, "Buffer size must be greater than zero."); + } + + _stream = stream; + _encoding = encoding ??= Encoding.UTF8; + _decoder = encoding.GetDecoder(); + if (bufferSize < MinBufferSize) + { + bufferSize = MinBufferSize; + } + + _byteBuffer = new byte[bufferSize]; + _maxCharsPerBuffer = encoding.GetMaxCharCount(bufferSize); + _charBuffer = new char[_maxCharsPerBuffer]; + _detectEncoding = detectEncodingFromByteOrderMarks; + _encodingPreamble = encoding.GetPreamble(); + + // If the preamble length is larger than the byte buffer length, + // we'll never match it and will enter an infinite loop. This + // should never happen in practice, but just in case, we'll skip + // the preamble check for absurdly long preambles. + int preambleLength = _encodingPreamble.Length; + _checkPreamble = preambleLength > 0 && preambleLength <= bufferSize; + + _closable = !leaveOpen; + } + + public override void Close() + { + Dispose(true); + } + + protected override void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + _disposed = true; + + // Dispose of our resources if this CancellableStreamReader is closable. + if (_closable) + { + try + { + // Note that Stream.Close() can potentially throw here. So we need to + // ensure cleaning up internal resources, inside the finally block. + if (disposing) + { + _stream.Close(); + } + } + finally + { + _charPos = 0; + _charLen = 0; + base.Dispose(disposing); + } + } + } + + public virtual Encoding CurrentEncoding => _encoding; + + public virtual Stream BaseStream => _stream; + + // DiscardBufferedData tells CancellableStreamReader to throw away its internal + // buffer contents. This is useful if the user needs to seek on the + // underlying stream to a known location then wants the CancellableStreamReader + // to start reading from this new point. This method should be called + // very sparingly, if ever, since it can lead to very poor performance. + // However, it may be the only way of handling some scenarios where + // users need to re-read the contents of a CancellableStreamReader a second time. + public void DiscardBufferedData() + { + CheckAsyncTaskInProgress(); + + _byteLen = 0; + _charLen = 0; + _charPos = 0; + // in general we'd like to have an invariant that encoding isn't null. However, + // for startup improvements for NullCancellableStreamReader, we want to delay load encoding. + if (_encoding != null) + { + _decoder = _encoding.GetDecoder(); + } + _isBlocked = false; + } + + public bool EndOfStream + { + get + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos < _charLen) + { + return false; + } + + // This may block on pipes! + int numRead = ReadBuffer(); + return numRead == 0; + } + } + + public override int Peek() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos == _charLen) + { + if (ReadBuffer() == 0) + { + return -1; + } + } + return _charBuffer[_charPos]; + } + + public override int Read() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos == _charLen) + { + if (ReadBuffer() == 0) + { + return -1; + } + } + int result = _charBuffer[_charPos]; + _charPos++; + return result; + } + + public override int Read(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + Throw.IfNegative(index); + Throw.IfNegative(count); + + return ReadSpan(new Span(buffer, index, count)); + } + + private int ReadSpan(Span buffer) + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + int charsRead = 0; + // As a perf optimization, if we had exactly one buffer's worth of + // data read in, let's try writing directly to the user's buffer. + bool readToUserBuffer = false; + int count = buffer.Length; + while (count > 0) + { + int n = _charLen - _charPos; + if (n == 0) + { + n = ReadBuffer(buffer.Slice(charsRead), out readToUserBuffer); + } + if (n == 0) + { + break; // We're at EOF + } + if (n > count) + { + n = count; + } + if (!readToUserBuffer) + { + new Span(_charBuffer, _charPos, n).CopyTo(buffer.Slice(charsRead)); + _charPos += n; + } + + charsRead += n; + count -= n; + // This function shouldn't block for an indefinite amount of time, + // or reading from a network stream won't work right. If we got + // fewer bytes than we requested, then we want to break right here. + if (_isBlocked) + { + break; + } + } + + return charsRead; + } + + public override string ReadToEnd() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + // Call ReadBuffer, then pull data out of charBuffer. + StringBuilder sb = new StringBuilder(_charLen - _charPos); + do + { + sb.Append(_charBuffer, _charPos, _charLen - _charPos); + _charPos = _charLen; // Note we consumed these characters + ReadBuffer(); + } while (_charLen > 0); + return sb.ToString(); + } + + public override int ReadBlock(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + + Throw.IfNegative(index); + Throw.IfNegative(count); + if (buffer.Length - index < count) + { + throw new ArgumentException("invalid offset length."); + } + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + return base.ReadBlock(buffer, index, count); + } + + // Trims n bytes from the front of the buffer. + private void CompressBuffer(int n) + { + Debug.Assert(_byteLen >= n, "CompressBuffer was called with a number of bytes greater than the current buffer length. Are two threads using this CancellableStreamReader at the same time?"); + byte[] byteBuffer = _byteBuffer; + _ = byteBuffer.Length; // allow JIT to prove object is not null + new ReadOnlySpan(byteBuffer, n, _byteLen - n).CopyTo(byteBuffer); + _byteLen -= n; + } + + private void DetectEncoding() + { + Debug.Assert(_byteLen >= 2, "Caller should've validated that at least 2 bytes were available."); + + byte[] byteBuffer = _byteBuffer; + _detectEncoding = false; + bool changedEncoding = false; + + ushort firstTwoBytes = BinaryPrimitives.ReadUInt16LittleEndian(byteBuffer); + if (firstTwoBytes == 0xFFFE) + { + // Big Endian Unicode + _encoding = Encoding.BigEndianUnicode; + CompressBuffer(2); + changedEncoding = true; + } + else if (firstTwoBytes == 0xFEFF) + { + // Little Endian Unicode, or possibly little endian UTF32 + if (_byteLen < 4 || byteBuffer[2] != 0 || byteBuffer[3] != 0) + { + _encoding = Encoding.Unicode; + CompressBuffer(2); + changedEncoding = true; + } + else + { + _encoding = Encoding.UTF32; + CompressBuffer(4); + changedEncoding = true; + } + } + else if (_byteLen >= 3 && firstTwoBytes == 0xBBEF && byteBuffer[2] == 0xBF) + { + // UTF-8 + _encoding = Encoding.UTF8; + CompressBuffer(3); + changedEncoding = true; + } + else if (_byteLen >= 4 && firstTwoBytes == 0 && byteBuffer[2] == 0xFE && byteBuffer[3] == 0xFF) + { + // Big Endian UTF32 + _encoding = new UTF32Encoding(bigEndian: true, byteOrderMark: true); + CompressBuffer(4); + changedEncoding = true; + } + else if (_byteLen == 2) + { + _detectEncoding = true; + } + // Note: in the future, if we change this algorithm significantly, + // we can support checking for the preamble of the given encoding. + + if (changedEncoding) + { + _decoder = _encoding.GetDecoder(); + int newMaxCharsPerBuffer = _encoding.GetMaxCharCount(byteBuffer.Length); + if (newMaxCharsPerBuffer > _maxCharsPerBuffer) + { + _charBuffer = new char[newMaxCharsPerBuffer]; + } + _maxCharsPerBuffer = newMaxCharsPerBuffer; + } + } + + // Trims the preamble bytes from the byteBuffer. This routine can be called multiple times + // and we will buffer the bytes read until the preamble is matched or we determine that + // there is no match. If there is no match, every byte read previously will be available + // for further consumption. If there is a match, we will compress the buffer for the + // leading preamble bytes + private bool IsPreamble() + { + if (!_checkPreamble) + { + return false; + } + + return IsPreambleWorker(); // move this call out of the hot path + bool IsPreambleWorker() + { + Debug.Assert(_checkPreamble); + ReadOnlySpan preamble = _encodingPreamble; + + Debug.Assert(_bytePos < preamble.Length, "_compressPreamble was called with the current bytePos greater than the preamble buffer length. Are two threads using this CancellableStreamReader at the same time?"); + int len = Math.Min(_byteLen, preamble.Length); + + for (int i = _bytePos; i < len; i++) + { + if (_byteBuffer[i] != preamble[i]) + { + _bytePos = 0; // preamble match failed; back up to beginning of buffer + _checkPreamble = false; + return false; + } + } + _bytePos = len; // we've matched all bytes up to this point + + Debug.Assert(_bytePos <= preamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + + if (_bytePos == preamble.Length) + { + // We have a match + CompressBuffer(preamble.Length); + _bytePos = 0; + _checkPreamble = false; + _detectEncoding = false; + } + + return _checkPreamble; + } + } + + internal virtual int ReadBuffer() + { + _charLen = 0; + _charPos = 0; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + bool eofReached = false; + + do + { + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int len = _stream.Read(_byteBuffer, _bytePos, _byteBuffer.Length - _bytePos); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + eofReached = true; + break; + } + + _byteLen += len; + } + else + { + Debug.Assert(_bytePos == 0, "bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + _byteLen = _stream.Read(_byteBuffer, 0, _byteBuffer.Length); + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (_byteLen == 0) + { + eofReached = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change byteLen. + _isBlocked = (_byteLen < _byteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + if (IsPreamble()) + { + continue; + } + + // If we're supposed to detect the encoding and haven't done so yet, + // do it. Note this may need to be called more than once. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + } + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be trying to decode more data if we made progress in an earlier iteration."); + _charLen = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: false); + } while (_charLen == 0); + + if (eofReached) + { + // EOF has been reached - perform final flush. + // We need to reset _bytePos and _byteLen just in case we hadn't + // finished processing the preamble before we reached EOF. + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be looking for EOF unless we have an empty char buffer."); + _charLen = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: true); + _bytePos = 0; + _byteLen = 0; + } + + return _charLen; + } + + + // This version has a perf optimization to decode data DIRECTLY into the + // user's buffer, bypassing CancellableStreamReader's own buffer. + // This gives a > 20% perf improvement for our encodings across the board, + // but only when asking for at least the number of characters that one + // buffer's worth of bytes could produce. + // This optimization, if run, will break SwitchEncoding, so we must not do + // this on the first call to ReadBuffer. + private int ReadBuffer(Span userBuffer, out bool readToUserBuffer) + { + _charLen = 0; + _charPos = 0; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + bool eofReached = false; + int charsRead = 0; + + // As a perf optimization, we can decode characters DIRECTLY into a + // user's char[]. We absolutely must not write more characters + // into the user's buffer than they asked for. Calculating + // encoding.GetMaxCharCount(byteLen) each time is potentially very + // expensive - instead, cache the number of chars a full buffer's + // worth of data may produce. Yes, this makes the perf optimization + // less aggressive, in that all reads that asked for fewer than AND + // returned fewer than _maxCharsPerBuffer chars won't get the user + // buffer optimization. This affects reads where the end of the + // Stream comes in the middle somewhere, and when you ask for + // fewer chars than your buffer could produce. + readToUserBuffer = userBuffer.Length >= _maxCharsPerBuffer; + + do + { + Debug.Assert(charsRead == 0); + + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int len = _stream.Read(_byteBuffer, _bytePos, _byteBuffer.Length - _bytePos); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + eofReached = true; + break; + } + + _byteLen += len; + } + else + { + Debug.Assert(_bytePos == 0, "bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + _byteLen = _stream.Read(_byteBuffer, 0, _byteBuffer.Length); + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (_byteLen == 0) + { + eofReached = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change byteLen. + _isBlocked = (_byteLen < _byteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + // Note: we don't need to recompute readToUserBuffer optimization as IsPreamble + // doesn't change the encoding or affect _maxCharsPerBuffer + if (IsPreamble()) + { + continue; + } + + // On the first call to ReadBuffer, if we're supposed to detect the encoding, do it. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + // DetectEncoding changes some buffer state. Recompute this. + readToUserBuffer = userBuffer.Length >= _maxCharsPerBuffer; + } + + Debug.Assert(charsRead == 0 && _charPos == 0 && _charLen == 0, "We shouldn't be trying to decode more data if we made progress in an earlier iteration."); + if (readToUserBuffer) + { + charsRead = GetChars(_decoder, new ReadOnlySpan(_byteBuffer, 0, _byteLen), userBuffer, flush: false); + } + else + { + charsRead = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: false); + _charLen = charsRead; // Number of chars in CancellableStreamReader's buffer. + } + } while (charsRead == 0); + + if (eofReached) + { + // EOF has been reached - perform final flush. + // We need to reset _bytePos and _byteLen just in case we hadn't + // finished processing the preamble before we reached EOF. + + Debug.Assert(charsRead == 0 && _charPos == 0 && _charLen == 0, "We shouldn't be looking for EOF unless we have an empty char buffer."); + + if (readToUserBuffer) + { + charsRead = GetChars(_decoder, new ReadOnlySpan(_byteBuffer, 0, _byteLen), userBuffer, flush: true); + } + else + { + charsRead = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: true); + _charLen = charsRead; // Number of chars in CancellableStreamReader's buffer. + } + _bytePos = 0; + _byteLen = 0; + } + + _isBlocked &= charsRead < userBuffer.Length; + + return charsRead; + } + + + // Reads a line. A line is defined as a sequence of characters followed by + // a carriage return ('\r'), a line feed ('\n'), or a carriage return + // immediately followed by a line feed. The resulting string does not + // contain the terminating carriage return and/or line feed. The returned + // value is null if the end of the input stream has been reached. + // + public override string? ReadLine() + { + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (_charPos == _charLen) + { + if (ReadBuffer() == 0) + { + return null; + } + } + + var vsb = new ValueStringBuilder(stackalloc char[256]); + do + { + // Look for '\r' or \'n'. + ReadOnlySpan charBufferSpan = _charBuffer.AsSpan(_charPos, _charLen - _charPos); + Debug.Assert(!charBufferSpan.IsEmpty, "ReadBuffer returned > 0 but didn't bump _charLen?"); + + int idxOfNewline = charBufferSpan.IndexOfAny('\r', '\n'); + if (idxOfNewline >= 0) + { + string retVal; + if (vsb.Length == 0) + { + retVal = charBufferSpan.Slice(0, idxOfNewline).ToString(); + } + else + { + retVal = string.Concat(vsb.AsSpan().ToString(), charBufferSpan.Slice(0, idxOfNewline).ToString()); + vsb.Dispose(); + } + + char matchedChar = charBufferSpan[idxOfNewline]; + _charPos += idxOfNewline + 1; + + // If we found '\r', consume any immediately following '\n'. + if (matchedChar == '\r') + { + if (_charPos < _charLen || ReadBuffer() > 0) + { + if (_charBuffer[_charPos] == '\n') + { + _charPos++; + } + } + } + + return retVal; + } + + // We didn't find '\r' or '\n'. Add it to the StringBuilder + // and loop until we reach a newline or EOF. + + vsb.Append(charBufferSpan); + } while (ReadBuffer() > 0); + + return vsb.ToString(); + } + + public override Task ReadLineAsync() => + ReadLineAsync(default).AsTask(); + + /// + /// Reads a line of characters asynchronously from the current stream and returns the data as a string. + /// + /// The token to monitor for cancellation requests. + /// A value task that represents the asynchronous read operation. The value of the TResult + /// parameter contains the next line from the stream, or is if all of the characters have been read. + /// The number of characters in the next line is larger than . + /// The stream reader has been disposed. + /// The reader is currently in use by a previous read operation. + /// + /// The following example shows how to read and print all lines from the file until the end of the file is reached or the operation timed out. + /// + /// using CancellationTokenSource tokenSource = new (TimeSpan.FromSeconds(1)); + /// using CancellableStreamReader reader = File.OpenText("existingfile.txt"); + /// + /// string line; + /// while ((line = await reader.ReadLineAsync(tokenSource.Token)) is not null) + /// { + /// Console.WriteLine(line); + /// } + /// + /// + /// + /// If this method is canceled via , some data + /// that has been read from the current but not stored (by the + /// ) or returned (to the caller) may be lost. + /// + public virtual ValueTask ReadLineAsync(CancellationToken cancellationToken) + { + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return new ValueTask(base.ReadLineAsync()!); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = ReadLineAsyncInternal(cancellationToken); + _asyncReadTask = task; + + return new ValueTask(task); + } + + private async Task ReadLineAsyncInternal(CancellationToken cancellationToken) + { + if (_charPos == _charLen && (await ReadBufferAsync(cancellationToken).ConfigureAwait(false)) == 0) + { + return null; + } + + string retVal; + char[]? arrayPoolBuffer = null; + int arrayPoolBufferPos = 0; + + do + { + char[] charBuffer = _charBuffer; + int charLen = _charLen; + int charPos = _charPos; + + // Look for '\r' or \'n'. + Debug.Assert(charPos < charLen, "ReadBuffer returned > 0 but didn't bump _charLen?"); + + int idxOfNewline = charBuffer.AsSpan(charPos, charLen - charPos).IndexOfAny('\r', '\n'); + if (idxOfNewline >= 0) + { + if (arrayPoolBuffer is null) + { + retVal = new string(charBuffer, charPos, idxOfNewline); + } + else + { + retVal = string.Concat(arrayPoolBuffer.AsSpan(0, arrayPoolBufferPos).ToString(), charBuffer.AsSpan(charPos, idxOfNewline).ToString()); + ArrayPool.Shared.Return(arrayPoolBuffer); + } + + charPos += idxOfNewline; + char matchedChar = charBuffer[charPos++]; + _charPos = charPos; + + // If we found '\r', consume any immediately following '\n'. + if (matchedChar == '\r') + { + if (charPos < charLen || (await ReadBufferAsync(cancellationToken).ConfigureAwait(false)) > 0) + { + if (_charBuffer[_charPos] == '\n') + { + _charPos++; + } + } + } + + return retVal; + } + + // We didn't find '\r' or '\n'. Add the read data to the pooled buffer + // and loop until we reach a newline or EOF. + if (arrayPoolBuffer is null) + { + arrayPoolBuffer = ArrayPool.Shared.Rent(charLen - charPos + 80); + } + else if ((arrayPoolBuffer.Length - arrayPoolBufferPos) < (charLen - charPos)) + { + char[] newBuffer = ArrayPool.Shared.Rent(checked(arrayPoolBufferPos + charLen - charPos)); + arrayPoolBuffer.AsSpan(0, arrayPoolBufferPos).CopyTo(newBuffer); + ArrayPool.Shared.Return(arrayPoolBuffer); + arrayPoolBuffer = newBuffer; + } + charBuffer.AsSpan(charPos, charLen - charPos).CopyTo(arrayPoolBuffer.AsSpan(arrayPoolBufferPos)); + arrayPoolBufferPos += charLen - charPos; + } + while (await ReadBufferAsync(cancellationToken).ConfigureAwait(false) > 0); + + if (arrayPoolBuffer is not null) + { + retVal = new string(arrayPoolBuffer, 0, arrayPoolBufferPos); + ArrayPool.Shared.Return(arrayPoolBuffer); + } + else + { + retVal = string.Empty; + } + + return retVal; + } + + public override Task ReadToEndAsync() => ReadToEndAsync(default); + + /// + /// Reads all characters from the current position to the end of the stream asynchronously and returns them as one string. + /// + /// The token to monitor for cancellation requests. + /// A task that represents the asynchronous read operation. The value of the TResult parameter contains + /// a string with the characters from the current position to the end of the stream. + /// The number of characters is larger than . + /// The stream reader has been disposed. + /// The reader is currently in use by a previous read operation. + /// + /// The following example shows how to read the contents of a file by using the method. + /// + /// using CancellationTokenSource tokenSource = new (TimeSpan.FromSeconds(1)); + /// using CancellableStreamReader reader = File.OpenText("existingfile.txt"); + /// + /// Console.WriteLine(await reader.ReadToEndAsync(tokenSource.Token)); + /// + /// + /// + /// If this method is canceled via , some data + /// that has been read from the current but not stored (by the + /// ) or returned (to the caller) may be lost. + /// + public virtual Task ReadToEndAsync(CancellationToken cancellationToken) + { + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return base.ReadToEndAsync(); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = ReadToEndAsyncInternal(cancellationToken); + _asyncReadTask = task; + + return task; + } + + private async Task ReadToEndAsyncInternal(CancellationToken cancellationToken) + { + // Call ReadBuffer, then pull data out of charBuffer. + StringBuilder sb = new StringBuilder(_charLen - _charPos); + do + { + int tmpCharPos = _charPos; + sb.Append(_charBuffer, tmpCharPos, _charLen - tmpCharPos); + _charPos = _charLen; // We consumed these characters + await ReadBufferAsync(cancellationToken).ConfigureAwait(false); + } while (_charLen > 0); + + return sb.ToString(); + } + + public override Task ReadAsync(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + + Throw.IfNegative(index); + Throw.IfNegative(count); + if (buffer.Length - index < count) + { + throw new ArgumentException("invalid offset length."); + } + + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return base.ReadAsync(buffer, index, count); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = ReadAsyncInternal(new Memory(buffer, index, count), CancellationToken.None).AsTask(); + _asyncReadTask = task; + + return task; + } + + public virtual ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + Debug.Assert(GetType() == typeof(CancellableStreamReader)); + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + return ReadAsyncInternal(buffer, cancellationToken); + } + + private protected virtual async ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) + { + if (_charPos == _charLen && (await ReadBufferAsync(cancellationToken).ConfigureAwait(false)) == 0) + { + return 0; + } + + int charsRead = 0; + + // As a perf optimization, if we had exactly one buffer's worth of + // data read in, let's try writing directly to the user's buffer. + bool readToUserBuffer = false; + + byte[] tmpByteBuffer = _byteBuffer; + Stream tmpStream = _stream; + + int count = buffer.Length; + while (count > 0) + { + // n is the characters available in _charBuffer + int n = _charLen - _charPos; + + // charBuffer is empty, let's read from the stream + if (n == 0) + { + _charLen = 0; + _charPos = 0; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + readToUserBuffer = count >= _maxCharsPerBuffer; + + // We loop here so that we read in enough bytes to yield at least 1 char. + // We break out of the loop if the stream is blocked (EOF is reached). + do + { + Debug.Assert(n == 0); + + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int tmpBytePos = _bytePos; + int len = await tmpStream.ReadAsync(new Memory(tmpByteBuffer, tmpBytePos, tmpByteBuffer.Length - tmpBytePos), cancellationToken).ConfigureAwait(false); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + // EOF but we might have buffered bytes from previous + // attempts to detect preamble that needs to be decoded now + if (_byteLen > 0) + { + if (readToUserBuffer) + { + n = GetChars(_decoder, new ReadOnlySpan(tmpByteBuffer, 0, _byteLen), buffer.Span.Slice(charsRead), flush: false); + _charLen = 0; // CancellableStreamReader's buffer is empty. + } + else + { + n = _decoder.GetChars(tmpByteBuffer, 0, _byteLen, _charBuffer, 0); + _charLen += n; // Number of chars in CancellableStreamReader's buffer. + } + } + + // How can part of the preamble yield any chars? + Debug.Assert(n == 0); + + _isBlocked = true; + break; + } + else + { + _byteLen += len; + } + } + else + { + Debug.Assert(_bytePos == 0, "_bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + + _byteLen = await tmpStream.ReadAsync(new Memory(tmpByteBuffer), cancellationToken).ConfigureAwait(false); + + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (_byteLen == 0) // EOF + { + _isBlocked = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change _byteLen. + _isBlocked = (_byteLen < tmpByteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + // Note: we don't need to recompute readToUserBuffer optimization as IsPreamble + // doesn't change the encoding or affect _maxCharsPerBuffer + if (IsPreamble()) + { + continue; + } + + // On the first call to ReadBuffer, if we're supposed to detect the encoding, do it. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + // DetectEncoding changes some buffer state. Recompute this. + readToUserBuffer = count >= _maxCharsPerBuffer; + } + + Debug.Assert(n == 0); + + _charPos = 0; + if (readToUserBuffer) + { + n = GetChars(_decoder, new ReadOnlySpan(tmpByteBuffer, 0, _byteLen), buffer.Span.Slice(charsRead), flush: false); + _charLen = 0; // CancellableStreamReader's buffer is empty. + } + else + { + n = _decoder.GetChars(tmpByteBuffer, 0, _byteLen, _charBuffer, 0); + _charLen += n; // Number of chars in CancellableStreamReader's buffer. + } + } while (n == 0); + + if (n == 0) + { + break; // We're at EOF + } + } // if (n == 0) + + // Got more chars in charBuffer than the user requested + if (n > count) + { + n = count; + } + + if (!readToUserBuffer) + { + new Span(_charBuffer, _charPos, n).CopyTo(buffer.Span.Slice(charsRead)); + _charPos += n; + } + + charsRead += n; + count -= n; + + // This function shouldn't block for an indefinite amount of time, + // or reading from a network stream won't work right. If we got + // fewer bytes than we requested, then we want to break right here. + if (_isBlocked) + { + break; + } + } // while (count > 0) + + return charsRead; + } + + public override Task ReadBlockAsync(char[] buffer, int index, int count) + { + Throw.IfNull(buffer); + + Throw.IfNegative(index); + Throw.IfNegative(count); + if (buffer.Length - index < count) + { + throw new ArgumentException("invalid offset length."); + } + + // If we have been inherited into a subclass, the following implementation could be incorrect + // since it does not call through to Read() which a subclass might have overridden. + // To be safe we will only use this implementation in cases where we know it is safe to do so, + // and delegate to our base class (which will call into Read) when we are not sure. + if (GetType() != typeof(CancellableStreamReader)) + { + return base.ReadBlockAsync(buffer, index, count); + } + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + Task task = base.ReadBlockAsync(buffer, index, count); + _asyncReadTask = task; + + return task; + } + + public virtual ValueTask ReadBlockAsync(Memory buffer, CancellationToken cancellationToken = default) + { + Debug.Assert(GetType() == typeof(CancellableStreamReader)); + + ThrowIfDisposed(); + CheckAsyncTaskInProgress(); + + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + + ValueTask vt = ReadBlockAsyncInternal(buffer, cancellationToken); + if (vt.IsCompletedSuccessfully) + { + return vt; + } + + Task t = vt.AsTask(); + _asyncReadTask = t; + return new ValueTask(t); + } + + private async ValueTask ReadBufferAsync(CancellationToken cancellationToken) + { + _charLen = 0; + _charPos = 0; + byte[] tmpByteBuffer = _byteBuffer; + Stream tmpStream = _stream; + + if (!_checkPreamble) + { + _byteLen = 0; + } + + bool eofReached = false; + + do + { + if (_checkPreamble) + { + Debug.Assert(_bytePos <= _encodingPreamble.Length, "possible bug in _compressPreamble. Are two threads using this CancellableStreamReader at the same time?"); + int tmpBytePos = _bytePos; + int len = await tmpStream.ReadAsync(tmpByteBuffer.AsMemory(tmpBytePos), cancellationToken).ConfigureAwait(false); + Debug.Assert(len >= 0, "Stream.Read returned a negative number! This is a bug in your stream class."); + + if (len == 0) + { + eofReached = true; + break; + } + + _byteLen += len; + } + else + { + Debug.Assert(_bytePos == 0, "_bytePos can be non zero only when we are trying to _checkPreamble. Are two threads using this CancellableStreamReader at the same time?"); + _byteLen = await tmpStream.ReadAsync(new Memory(tmpByteBuffer), cancellationToken).ConfigureAwait(false); + Debug.Assert(_byteLen >= 0, "Stream.Read returned a negative number! Bug in stream class."); + + if (_byteLen == 0) + { + eofReached = true; + break; + } + } + + // _isBlocked == whether we read fewer bytes than we asked for. + // Note we must check it here because CompressBuffer or + // DetectEncoding will change _byteLen. + _isBlocked = (_byteLen < tmpByteBuffer.Length); + + // Check for preamble before detect encoding. This is not to override the + // user supplied Encoding for the one we implicitly detect. The user could + // customize the encoding which we will loose, such as ThrowOnError on UTF8 + if (IsPreamble()) + { + continue; + } + + // If we're supposed to detect the encoding and haven't done so yet, + // do it. Note this may need to be called more than once. + if (_detectEncoding && _byteLen >= 2) + { + DetectEncoding(); + } + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be trying to decode more data if we made progress in an earlier iteration."); + _charLen = _decoder.GetChars(tmpByteBuffer, 0, _byteLen, _charBuffer, 0, flush: false); + } while (_charLen == 0); + + if (eofReached) + { + // EOF has been reached - perform final flush. + // We need to reset _bytePos and _byteLen just in case we hadn't + // finished processing the preamble before we reached EOF. + + Debug.Assert(_charPos == 0 && _charLen == 0, "We shouldn't be looking for EOF unless we have an empty char buffer."); + _charLen = _decoder.GetChars(_byteBuffer, 0, _byteLen, _charBuffer, 0, flush: true); + _bytePos = 0; + _byteLen = 0; + } + + return _charLen; + } + + private async ValueTask ReadBlockAsyncInternal(Memory buffer, CancellationToken cancellationToken) + { + int n = 0, i; + do + { + i = await ReadAsyncInternal(buffer.Slice(n), cancellationToken).ConfigureAwait(false); + n += i; + } while (i > 0 && n < buffer.Length); + + return n; + } + + private static unsafe int GetChars(Decoder decoder, ReadOnlySpan bytes, Span chars, bool flush = false) + { + Throw.IfNull(decoder); + if (decoder is null || bytes.IsEmpty || chars.IsEmpty) + { + return 0; + } + + fixed (byte* pBytes = bytes) + fixed (char* pChars = chars) + { + return decoder.GetChars(pBytes, bytes.Length, pChars, chars.Length, flush); + } + } + + private void ThrowIfDisposed() + { + if (_disposed) + { + ThrowObjectDisposedException(); + } + + void ThrowObjectDisposedException() => throw new ObjectDisposedException(GetType().Name, "reader has been closed."); + } + + // No data, class doesn't need to be serializable. + // Note this class is threadsafe. + internal sealed class NullCancellableStreamReader : CancellableStreamReader + { + public override Encoding CurrentEncoding => Encoding.Unicode; + + protected override void Dispose(bool disposing) + { + // Do nothing - this is essentially unclosable. + } + + public override int Peek() => -1; + + public override int Read() => -1; + + public override int Read(char[] buffer, int index, int count) => 0; + + public override Task ReadAsync(char[] buffer, int index, int count) => Task.FromResult(0); + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + public override int ReadBlock(char[] buffer, int index, int count) => 0; + + public override Task ReadBlockAsync(char[] buffer, int index, int count) => Task.FromResult(0); + + public override ValueTask ReadBlockAsync(Memory buffer, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + public override string? ReadLine() => null; + + public override Task ReadLineAsync() => Task.FromResult(null); + + public override ValueTask ReadLineAsync(CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + public override string ReadToEnd() => ""; + + public override Task ReadToEndAsync() => Task.FromResult(""); + + public override Task ReadToEndAsync(CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? Task.FromCanceled(cancellationToken) : Task.FromResult(""); + + private protected override ValueTask ReadAsyncInternal(Memory buffer, CancellationToken cancellationToken) => + cancellationToken.IsCancellationRequested ? new ValueTask(Task.FromCanceled(cancellationToken)) : default; + + internal override int ReadBuffer() => 0; + } +} \ No newline at end of file diff --git a/src/Common/CancellableStreamReader/TextReaderExtensions.cs b/src/Common/CancellableStreamReader/TextReaderExtensions.cs new file mode 100644 index 00000000..11ba0565 --- /dev/null +++ b/src/Common/CancellableStreamReader/TextReaderExtensions.cs @@ -0,0 +1,15 @@ +namespace System.IO; + +internal static class TextReaderExtensions +{ + public static ValueTask ReadLineAsync(this TextReader reader, CancellationToken cancellationToken) + { + if (reader is CancellableStreamReader cancellableReader) + { + return cancellableReader.ReadLineAsync(cancellationToken)!; + } + + cancellationToken.ThrowIfCancellationRequested(); + return new ValueTask(reader.ReadLineAsync()); + } +} \ No newline at end of file diff --git a/src/Common/CancellableStreamReader/ValueStringBuilder.cs b/src/Common/CancellableStreamReader/ValueStringBuilder.cs new file mode 100644 index 00000000..27bea693 --- /dev/null +++ b/src/Common/CancellableStreamReader/ValueStringBuilder.cs @@ -0,0 +1,317 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +#nullable enable + +namespace System.Text +{ + internal ref partial struct ValueStringBuilder + { + private char[]? _arrayToReturnToPool; + private Span _chars; + private int _pos; + + public ValueStringBuilder(Span initialBuffer) + { + _arrayToReturnToPool = null; + _chars = initialBuffer; + _pos = 0; + } + + public ValueStringBuilder(int initialCapacity) + { + _arrayToReturnToPool = ArrayPool.Shared.Rent(initialCapacity); + _chars = _arrayToReturnToPool; + _pos = 0; + } + + public int Length + { + get => _pos; + set + { + Debug.Assert(value >= 0); + Debug.Assert(value <= _chars.Length); + _pos = value; + } + } + + public int Capacity => _chars.Length; + + public void EnsureCapacity(int capacity) + { + // This is not expected to be called this with negative capacity + Debug.Assert(capacity >= 0); + + // If the caller has a bug and calls this with negative capacity, make sure to call Grow to throw an exception. + if ((uint)capacity > (uint)_chars.Length) + Grow(capacity - _pos); + } + + /// + /// Get a pinnable reference to the builder. + /// Does not ensure there is a null char after + /// This overload is pattern matched in the C# 7.3+ compiler so you can omit + /// the explicit method call, and write eg "fixed (char* c = builder)" + /// + public ref char GetPinnableReference() + { + return ref MemoryMarshal.GetReference(_chars); + } + + /// + /// Get a pinnable reference to the builder. + /// + /// Ensures that the builder has a null char after + public ref char GetPinnableReference(bool terminate) + { + if (terminate) + { + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; + } + return ref MemoryMarshal.GetReference(_chars); + } + + public ref char this[int index] + { + get + { + Debug.Assert(index < _pos); + return ref _chars[index]; + } + } + + public override string ToString() + { + string s = _chars.Slice(0, _pos).ToString(); + Dispose(); + return s; + } + + /// Returns the underlying storage of the builder. + public Span RawChars => _chars; + + /// + /// Returns a span around the contents of the builder. + /// + /// Ensures that the builder has a null char after + public ReadOnlySpan AsSpan(bool terminate) + { + if (terminate) + { + EnsureCapacity(Length + 1); + _chars[Length] = '\0'; + } + return _chars.Slice(0, _pos); + } + + public ReadOnlySpan AsSpan() => _chars.Slice(0, _pos); + public ReadOnlySpan AsSpan(int start) => _chars.Slice(start, _pos - start); + public ReadOnlySpan AsSpan(int start, int length) => _chars.Slice(start, length); + + public bool TryCopyTo(Span destination, out int charsWritten) + { + if (_chars.Slice(0, _pos).TryCopyTo(destination)) + { + charsWritten = _pos; + Dispose(); + return true; + } + else + { + charsWritten = 0; + Dispose(); + return false; + } + } + + public void Insert(int index, char value, int count) + { + if (_pos > _chars.Length - count) + { + Grow(count); + } + + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + _chars.Slice(index, count).Fill(value); + _pos += count; + } + + public void Insert(int index, string? s) + { + if (s == null) + { + return; + } + + int count = s.Length; + + if (_pos > (_chars.Length - count)) + { + Grow(count); + } + + int remaining = _pos - index; + _chars.Slice(index, remaining).CopyTo(_chars.Slice(index + count)); + s +#if !NET + .AsSpan() +#endif + .CopyTo(_chars.Slice(index)); + _pos += count; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(char c) + { + int pos = _pos; + Span chars = _chars; + if ((uint)pos < (uint)chars.Length) + { + chars[pos] = c; + _pos = pos + 1; + } + else + { + GrowAndAppend(c); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Append(string? s) + { + if (s == null) + { + return; + } + + int pos = _pos; + if (s.Length == 1 && (uint)pos < (uint)_chars.Length) // very common case, e.g. appending strings from NumberFormatInfo like separators, percent symbols, etc. + { + _chars[pos] = s[0]; + _pos = pos + 1; + } + else + { + AppendSlow(s); + } + } + + private void AppendSlow(string s) + { + int pos = _pos; + if (pos > _chars.Length - s.Length) + { + Grow(s.Length); + } + + s +#if !NET + .AsSpan() +#endif + .CopyTo(_chars.Slice(pos)); + _pos += s.Length; + } + + public void Append(char c, int count) + { + if (_pos > _chars.Length - count) + { + Grow(count); + } + + Span dst = _chars.Slice(_pos, count); + for (int i = 0; i < dst.Length; i++) + { + dst[i] = c; + } + _pos += count; + } + + public void Append(scoped ReadOnlySpan value) + { + int pos = _pos; + if (pos > _chars.Length - value.Length) + { + Grow(value.Length); + } + + value.CopyTo(_chars.Slice(_pos)); + _pos += value.Length; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span AppendSpan(int length) + { + int origPos = _pos; + if (origPos > _chars.Length - length) + { + Grow(length); + } + + _pos = origPos + length; + return _chars.Slice(origPos, length); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void GrowAndAppend(char c) + { + Grow(1); + Append(c); + } + + /// + /// Resize the internal buffer either by doubling current buffer size or + /// by adding to + /// whichever is greater. + /// + /// + /// Number of chars requested beyond current position. + /// + [MethodImpl(MethodImplOptions.NoInlining)] + private void Grow(int additionalCapacityBeyondPos) + { + Debug.Assert(additionalCapacityBeyondPos > 0); + Debug.Assert(_pos > _chars.Length - additionalCapacityBeyondPos, "Grow called incorrectly, no resize is needed."); + + const uint ArrayMaxLength = 0x7FFFFFC7; // same as Array.MaxLength + + // Increase to at least the required size (_pos + additionalCapacityBeyondPos), but try + // to double the size if possible, bounding the doubling to not go beyond the max array length. + int newCapacity = (int)Math.Max( + (uint)(_pos + additionalCapacityBeyondPos), + Math.Min((uint)_chars.Length * 2, ArrayMaxLength)); + + // Make sure to let Rent throw an exception if the caller has a bug and the desired capacity is negative. + // This could also go negative if the actual required length wraps around. + char[] poolArray = ArrayPool.Shared.Rent(newCapacity); + + _chars.Slice(0, _pos).CopyTo(poolArray); + + char[]? toReturn = _arrayToReturnToPool; + _chars = _arrayToReturnToPool = poolArray; + if (toReturn != null) + { + ArrayPool.Shared.Return(toReturn); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Dispose() + { + char[]? toReturn = _arrayToReturnToPool; + this = default; // for safety, to avoid using pooled array if this instance is erroneously appended to again + if (toReturn != null) + { + ArrayPool.Shared.Return(toReturn); + } + } + } +} \ No newline at end of file diff --git a/src/Common/Polyfills/System/IO/StreamExtensions.cs b/src/Common/Polyfills/System/IO/StreamExtensions.cs index d58ffaf3..4dc8e2a5 100644 --- a/src/Common/Polyfills/System/IO/StreamExtensions.cs +++ b/src/Common/Polyfills/System/IO/StreamExtensions.cs @@ -1,6 +1,7 @@ using ModelContextProtocol; using System.Buffers; using System.Runtime.InteropServices; +using System.Text; namespace System.IO; @@ -33,4 +34,31 @@ static async ValueTask WriteAsyncCore(Stream stream, ReadOnlyMemory buffer } } } + + public static ValueTask ReadAsync(this Stream stream, Memory buffer, CancellationToken cancellationToken) + { + Throw.IfNull(stream); + if (MemoryMarshal.TryGetArray(buffer, out ArraySegment segment)) + { + return new ValueTask(stream.ReadAsync(segment.Array, segment.Offset, segment.Count, cancellationToken)); + } + else + { + return ReadAsyncCore(stream, buffer, cancellationToken); + static async ValueTask ReadAsyncCore(Stream stream, Memory buffer, CancellationToken cancellationToken) + { + byte[] array = ArrayPool.Shared.Rent(buffer.Length); + try + { + int bytesRead = await stream.ReadAsync(array, 0, buffer.Length, cancellationToken).ConfigureAwait(false); + array.AsSpan(0, bytesRead).CopyTo(buffer.Span); + return bytesRead; + } + finally + { + ArrayPool.Shared.Return(array); + } + } + } + } } \ No newline at end of file diff --git a/src/Common/Polyfills/System/IO/TextReaderExtensions.cs b/src/Common/Polyfills/System/IO/TextReaderExtensions.cs deleted file mode 100644 index 63b3db25..00000000 --- a/src/Common/Polyfills/System/IO/TextReaderExtensions.cs +++ /dev/null @@ -1,10 +0,0 @@ -namespace System.IO; - -internal static class TextReaderExtensions -{ - public static Task ReadLineAsync(this TextReader reader, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - return reader.ReadLineAsync(); - } -} \ No newline at end of file diff --git a/src/Common/Throw.cs b/src/Common/Throw.cs index 0c6927e5..ed80036f 100644 --- a/src/Common/Throw.cs +++ b/src/Common/Throw.cs @@ -25,6 +25,15 @@ public static void IfNullOrWhiteSpace([NotNull] string? arg, [CallerArgumentExpr } } + public static void IfNegative(int arg, [CallerArgumentExpression(nameof(arg))] string? parameterName = null) + { + if (arg < 0) + { + Throw(parameterName); + static void Throw(string? parameterName) => throw new ArgumentOutOfRangeException(parameterName, "must not be negative."); + } + } + [DoesNotReturn] private static void ThrowArgumentNullOrWhiteSpaceException(string? parameterName) { diff --git a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs index 3c7210ec..2ce32cb7 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientSessionTransport.cs @@ -13,7 +13,7 @@ internal sealed class StdioClientSessionTransport : StreamClientSessionTransport private int _cleanedUp = 0; public StdioClientSessionTransport(StdioClientTransportOptions options, Process process, string endpointName, Queue stderrRollingLog, ILoggerFactory? loggerFactory) - : base(process.StandardInput, process.StandardOutput, endpointName, loggerFactory) + : base(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory) { _process = process; _options = options; diff --git a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs index e00ddbab..c026acb9 100644 --- a/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StdioClientTransport.cs @@ -69,8 +69,6 @@ public async Task ConnectAsync(CancellationToken cancellationToken = { LogTransportConnecting(logger, endpointName); - UTF8Encoding noBomUTF8 = new(encoderShouldEmitUTF8Identifier: false); - ProcessStartInfo startInfo = new() { FileName = command, @@ -80,10 +78,10 @@ public async Task ConnectAsync(CancellationToken cancellationToken = UseShellExecute = false, CreateNoWindow = true, WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, - StandardOutputEncoding = noBomUTF8, - StandardErrorEncoding = noBomUTF8, + StandardOutputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding, + StandardErrorEncoding = StreamClientSessionTransport.NoBomUtf8Encoding, #if NET - StandardInputEncoding = noBomUTF8, + StandardInputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding, #endif }; @@ -164,7 +162,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = Encoding originalInputEncoding = Console.InputEncoding; try { - Console.InputEncoding = noBomUTF8; + Console.InputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding; processStarted = process.Start(); } finally diff --git a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs index e35e2b18..dfcccf61 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientSessionTransport.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; +using System.Text; using System.Text.Json; namespace ModelContextProtocol.Client; @@ -7,6 +8,8 @@ namespace ModelContextProtocol.Client; /// Provides the client side of a stream-based session transport. internal class StreamClientSessionTransport : TransportBase { + internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false); + private readonly TextReader _serverOutput; private readonly TextWriter _serverInput; private readonly SemaphoreSlim _sendLock = new(1, 1); @@ -54,6 +57,43 @@ public StreamClientSessionTransport( readTask.Start(); } + /// + /// Initializes a new instance of the class. + /// + /// + /// The server's input stream. Messages written to this stream will be sent to the server. + /// + /// + /// The server's output stream. Messages read from this stream will be received from the server. + /// + /// + /// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null. + /// + /// + /// A name that identifies this transport endpoint in logs. + /// + /// + /// Optional factory for creating loggers. If null, a NullLogger will be used. + /// + /// + /// This constructor starts a background task to read messages from the server output stream. + /// The transport will be marked as connected once initialized. + /// + public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory) + : this( + new StreamWriter(serverInput, encoding ?? NoBomUtf8Encoding), +#if NET + new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding), +#else + new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding), +#endif + endpointName, + loggerFactory) + { + Throw.IfNull(serverInput); + Throw.IfNull(serverOutput); + } + /// public override async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { diff --git a/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs b/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs index 30607e57..a0e335be 100644 --- a/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamClientTransport.cs @@ -47,8 +47,9 @@ public StreamClientTransport( public Task ConnectAsync(CancellationToken cancellationToken = default) { return Task.FromResult(new StreamClientSessionTransport( - new StreamWriter(_serverInput), - new StreamReader(_serverOutput), + _serverInput, + _serverOutput, + encoding: null, "Client (stream)", _loggerFactory)); } diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj index 928c76d9..f3ab7181 100644 --- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj +++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj @@ -7,6 +7,7 @@ ModelContextProtocol.Core Core .NET SDK for the Model Context Protocol (MCP) README.md + True @@ -20,6 +21,7 @@ + diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs b/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs index f891858e..a6188773 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrimitiveCollection.cs @@ -17,7 +17,7 @@ public class McpServerPrimitiveCollection : ICollection, IReadOnlyCollecti /// public McpServerPrimitiveCollection(IEqualityComparer? keyComparer = null) { - _primitives = new(keyComparer); + _primitives = new(keyComparer ?? EqualityComparer.Default); } /// Occurs when the collection is changed. diff --git a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs index 9528e4f4..915d7813 100644 --- a/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamServerTransport.cs @@ -46,7 +46,11 @@ public StreamServerTransport(Stream inputStream, Stream outputStream, string? se _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; +#if NET _inputReader = new StreamReader(inputStream, Encoding.UTF8); +#else + _inputReader = new CancellableStreamReader(inputStream, Encoding.UTF8); +#endif _outputStream = outputStream; SetConnected(); diff --git a/tests/Common/Utils/MockHttpHandler.cs b/tests/Common/Utils/MockHttpHandler.cs index 5e58a6cd..d15ec3dc 100644 --- a/tests/Common/Utils/MockHttpHandler.cs +++ b/tests/Common/Utils/MockHttpHandler.cs @@ -1,4 +1,6 @@ -namespace ModelContextProtocol.Tests.Utils; +using System.Net.Http; + +namespace ModelContextProtocol.Tests.Utils; public class MockHttpHandler : HttpMessageHandler { diff --git a/tests/Common/Utils/ProcessExtensions.cs b/tests/Common/Utils/ProcessExtensions.cs new file mode 100644 index 00000000..186ecc9c --- /dev/null +++ b/tests/Common/Utils/ProcessExtensions.cs @@ -0,0 +1,15 @@ +namespace System.Diagnostics; + +public static class ProcessExtensions +{ + public static async Task WaitForExitAsync(this Process process, TimeSpan timeout) + { +#if NET + using var shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + shutdownCts.CancelAfter(timeout); + await process.WaitForExitAsync(shutdownCts.Token); +#else + process.WaitForExit(milliseconds: (int)timeout.TotalMilliseconds); +#endif + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj index 75de837a..f38a3585 100644 --- a/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj +++ b/tests/ModelContextProtocol.TestServer/ModelContextProtocol.TestServer.csproj @@ -15,6 +15,7 @@ + diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index dec5ad05..ebc7171e 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -25,14 +25,14 @@ public ClientIntegrationTestFixture() TestServerTransportOptions = new() { - Command = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "TestServer.exe" : "dotnet", + Command = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "TestServer.exe" : PlatformDetection.IsMonoRuntime ? "mono" : "dotnet", Name = "TestServer", }; if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { // Change to Arguments to "mcp-server-everything" if you want to run the server locally after creating a symlink - TestServerTransportOptions.Arguments = ["TestServer.dll"]; + TestServerTransportOptions.Arguments = [PlatformDetection.IsMonoRuntime ? "TestServer.exe" : "TestServer.dll"]; } } diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 38c688cc..d2080e1f 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -606,7 +606,7 @@ public async Task HandlesIProgressParameter() McpClientTool progressTool = tools.First(t => t.Name == "sends_progress_notifications"); - TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + TaskCompletionSource tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); int remainingNotifications = 10; ConcurrentQueue notifications = new(); @@ -618,7 +618,7 @@ public async Task HandlesIProgressParameter() notifications.Enqueue(pn); if (Interlocked.Decrement(ref remainingNotifications) == 0) { - tcs.SetResult(); + tcs.SetResult(true); } } diff --git a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs index ffd8859a..7a019c89 100644 --- a/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs +++ b/tests/ModelContextProtocol.Tests/EverythingSseServerFixture.cs @@ -49,7 +49,7 @@ public async ValueTask DisposeAsync() using var stopProcess = Process.Start(stopInfo) ?? throw new InvalidOperationException($"Could not stop process for {stopInfo.FileName} with '{stopInfo.Arguments}'."); - await stopProcess.WaitForExitAsync(); + await stopProcess.WaitForExitAsync(TimeSpan.FromSeconds(10)); } catch (Exception ex) { @@ -60,6 +60,7 @@ public async ValueTask DisposeAsync() private static bool CheckIsDockerAvailable() { +#if NET try { ProcessStartInfo processStartInfo = new() @@ -78,5 +79,9 @@ private static bool CheckIsDockerAvailable() { return false; } +#else + // Do not run docker tests using .NET framework. + return false; +#endif } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/GlobalUsings.cs b/tests/ModelContextProtocol.Tests/GlobalUsings.cs index c802f448..6d129626 100644 --- a/tests/ModelContextProtocol.Tests/GlobalUsings.cs +++ b/tests/ModelContextProtocol.Tests/GlobalUsings.cs @@ -1 +1,2 @@ global using Xunit; +global using System.Net.Http; \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj index 6ddad70f..993564bf 100644 --- a/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj +++ b/tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj @@ -2,7 +2,7 @@ Exe - net9.0;net8.0 + net9.0;net8.0;net472 enable enable @@ -27,6 +27,11 @@ + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -44,7 +49,9 @@ + + @@ -60,16 +67,10 @@ - - PreserveNewest - - - PreserveNewest - - + PreserveNewest - + PreserveNewest diff --git a/tests/ModelContextProtocol.Tests/PlatformDetection.cs b/tests/ModelContextProtocol.Tests/PlatformDetection.cs new file mode 100644 index 00000000..1eef9942 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/PlatformDetection.cs @@ -0,0 +1,6 @@ +namespace ModelContextProtocol.Tests; + +internal static class PlatformDetection +{ + public static bool IsMonoRuntime { get; } = Type.GetType("Mono.Runtime") is not null; +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs index 97b63157..30675f7b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerDelegatesTests.cs @@ -1,10 +1,18 @@ using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Runtime.InteropServices; namespace ModelContextProtocol.Tests.Server; public class McpServerHandlerTests { + public McpServerHandlerTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void AllPropertiesAreSettable() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs index 7cdbdb5b..b2e74873 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs @@ -1,11 +1,19 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Runtime.InteropServices; namespace ModelContextProtocol.Tests.Server; public class McpServerLoggingLevelTests { + public McpServerLoggingLevelTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void CanCreateServerWithLoggingLevelHandler() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 90998e24..39e9b72f 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -7,6 +7,7 @@ using System.ComponentModel; using System.Diagnostics; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -14,6 +15,13 @@ namespace ModelContextProtocol.Tests.Server; public class McpServerPromptTests { + public McpServerPromptTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void Create_InvalidArgs_Throws() { diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index fb0772d0..011c4f2b 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -4,12 +4,20 @@ using ModelContextProtocol.Server; using Moq; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json.Serialization; namespace ModelContextProtocol.Tests.Server; public partial class McpServerResourceTests { + public McpServerResourceTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void CanCreateServerWithResource() { @@ -191,6 +199,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e411.2.3.4", ((TextResourceContents)result.Contents[0]).Text); +#if NET t = McpServerResource.Create((Half a2, Int128 a3, UInt128 a4, IntPtr a5) => (a3 + (Int128)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( @@ -206,6 +215,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); +#endif t = McpServerResource.Create((bool? a2, char? a3, byte? a4, sbyte? a5) => a2?.ToString() + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); @@ -239,6 +249,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e41", ((TextResourceContents)result.Contents[0]).Text); +#if NET t = McpServerResource.Create((Half? a2, Int128? a3, UInt128? a4, IntPtr? a5) => (a3 + (Int128?)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( @@ -254,6 +265,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); +#endif } [Theory] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 260b9bdd..6750b2ca 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -3,6 +3,7 @@ using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Utils; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; @@ -15,6 +16,9 @@ public class McpServerTests : LoggedTest public McpServerTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif _options = CreateOptions(); } @@ -212,6 +216,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Initialize_Requests() { + AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(IMcpServer).Assembly).GetName(); await Can_Handle_Requests( serverCapabilities: null, method: RequestMethods.Initialize, @@ -220,8 +225,8 @@ await Can_Handle_Requests( { var result = JsonSerializer.Deserialize(response, McpJsonUtilities.DefaultOptions); Assert.NotNull(result); - Assert.Equal("ModelContextProtocol.Tests", result.ServerInfo.Name); - Assert.Equal("1.0.0.0", result.ServerInfo.Version); + Assert.Equal(expectedAssemblyName.Name, result.ServerInfo.Name); + Assert.Equal(expectedAssemblyName.Version?.ToString() ?? "1.0.0", result.ServerInfo.Version); Assert.Equal("2024", result.ProtocolVersion); }); } @@ -518,10 +523,10 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s }; await transport.SendMessageAsync( - new JsonRpcRequest - { - Method = method, - Id = new RequestId(55) + new JsonRpcRequest + { + Method = method, + Id = new RequestId(55) } ); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 0f67f2a5..5cc6fa78 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -7,6 +7,7 @@ using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; +using System.Runtime.InteropServices; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; @@ -17,6 +18,13 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerToolTests { + public McpServerToolTests() + { +#if !NET + Assert.SkipWhen(RuntimeInformation.IsOSPlatform(OSPlatform.Windows), "https://github.com/modelcontextprotocol/csharp-sdk/issues/587"); +#endif + } + [Fact] public void Create_InvalidArgs_Throws() { @@ -525,7 +533,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(tool.ProtocolTool.OutputSchema); Assert.Null(result.StructuredContent); - tool = McpServerTool.Create(() => ValueTask.CompletedTask); + tool = McpServerTool.Create(() => default(ValueTask)); request = new RequestContext(mockServer.Object) { Params = new CallToolRequestParams { Name = "tool" }, diff --git a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs index db22ec24..f3927be6 100644 --- a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Tests; public class StdioServerIntegrationTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { - public static bool CanSendSigInt { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX); + public static bool CanSendSigInt { get; } = (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) && !PlatformDetection.IsMonoRuntime; private const int SIGINT = 2; [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(CanSendSigInt))] @@ -46,9 +46,7 @@ public async Task SigInt_DisposesTestServerWithHosting_Gracefully() // https://github.com/dotnet/runtime/issues/109432, https://github.com/dotnet/runtime/issues/44944 Assert.Equal(0, kill(process.Id, SIGINT)); - using var shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - shutdownCts.CancelAfter(TimeSpan.FromSeconds(10)); - await process.WaitForExitAsync(shutdownCts.Token); + await process.WaitForExitAsync(TimeSpan.FromSeconds(10)); Assert.True(process.HasExited); Assert.Equal(0, process.ExitCode); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs index 416d1719..b4954278 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs @@ -15,10 +15,18 @@ public async Task Can_Customize_MessageEndpoint() var transportRunTask = transport.RunAsync(TestContext.Current.CancellationToken); using var responseStreamReader = new StreamReader(responsePipe.Reader.AsStream()); - var firstLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + var firstLine = await responseStreamReader.ReadLineAsync( +#if NET + TestContext.Current.CancellationToken +#endif + ); Assert.Equal("event: endpoint", firstLine); - var secondLine = await responseStreamReader.ReadLineAsync(TestContext.Current.CancellationToken); + var secondLine = await responseStreamReader.ReadLineAsync( +#if NET + TestContext.Current.CancellationToken +#endif + ); Assert.Equal("data: /my-message-endpoint", secondLine); responsePipe.Reader.Complete(); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index 40602a9e..93cbcec8 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -7,6 +7,8 @@ namespace ModelContextProtocol.Tests.Transport; public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { + public static bool IsStdErrCallbackSupported => !PlatformDetection.IsMonoRuntime; + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() { @@ -19,8 +21,8 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.Contains(id, e.ToString()); } - - [Fact] + + [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() { string id = Guid.NewGuid().ToString("N");