diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs index e3425b253..f6c8f3a92 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ResumabilityIntegrationTests.cs @@ -279,6 +279,7 @@ public async Task Client_CanResumePostResponseStream_AfterDisconnection() [Fact] public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() { + var timeout = TimeSpan.FromSeconds(10); using var faultingStreamHandler = new FaultingStreamHandler() { InnerHandler = SocketsHttpHandler, @@ -304,12 +305,12 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() await using var client = await ConnectClientAsync(); // Get the server instance - var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + var server = await serverTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Set up notification tracking with unique messages - var clientReceivedInitialNotificationTcs = new TaskCompletionSource(); - var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(); - var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(); + var clientReceivedInitialNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); const string CustomNotificationMethod = "test/custom_notification"; const string InitialMessage = "Initial notification"; @@ -343,11 +344,14 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() return default; }); + // Wait for the client's unsolicited message stream to be established before sending notifications + await faultingStreamHandler.WaitForUnsolicitedMessageStreamAsync(TestContext.Current.CancellationToken); + // Send a custom notification to the client on the unsolicited message stream await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken); // Wait for client to receive the first notification - await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedInitialNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Fault the unsolicited message stream (GET SSE) var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken); @@ -359,13 +363,13 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection() reconnectAttempt.Continue(); // Wait for client to receive the notification via replay - await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedReplayedNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Send a final notification while the client has reconnected - this should be handled by the transport await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReconnectMessage }, cancellationToken: TestContext.Current.CancellationToken); // Wait for the client to receive the final notification - await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken); + await clientReceivedReconnectNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken); // Assert each notification was received exactly once Assert.Equal(1, initialNotificationReceivedCount); @@ -531,7 +535,7 @@ public async Task PostResponse_EndsAndSseEventStreamWriterIsDisposed_WhenWriteEv timeoutCts.CancelAfter(TimeSpan.FromSeconds(10)); // The call task should throw an OCE due to cancellation - await Assert.ThrowsAsync(() => callTask).WaitAsync(timeoutCts.Token); + await Assert.ThrowsAnyAsync(() => callTask).WaitAsync(timeoutCts.Token); // Wait for the writer to be disposed await blockingStore.DisposedTask.WaitAsync(timeoutCts.Token); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs index cace4d8be..dc157735f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/FaultingStreamHandler.cs @@ -11,6 +11,12 @@ internal sealed class FaultingStreamHandler : DelegatingHandler { private FaultingStream? _lastStream; private TaskCompletionSource? _reconnectTcs; + private TaskCompletionSource _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public Task WaitForUnsolicitedMessageStreamAsync(CancellationToken cancellationToken = default) + => _unsolicitedMessageStreamReadyTcs.Task.WaitAsync(cancellationToken); + + internal void SignalUnsolicitedMessageStreamReady() => _unsolicitedMessageStreamReadyTcs.TrySetResult(); public async Task TriggerFaultAsync(CancellationToken cancellationToken) { @@ -24,6 +30,9 @@ public async Task TriggerFaultAsync(CancellationToken cancella throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection."); } + // Reset the TCS so we can wait for the reconnected unsolicited message stream + _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + _reconnectTcs = new(); await _lastStream.TriggerFaultAsync(cancellationToken); @@ -46,6 +55,7 @@ protected override async Task SendAsync( _reconnectTcs = null; } + var isGetRequest = request.Method == HttpMethod.Get; var response = await base.SendAsync(request, cancellationToken); // Only wrap SSE streams (text/event-stream) @@ -63,6 +73,13 @@ protected override async Task SendAsync( } response.Content = newContent; + + // For GET requests (unsolicited message stream), set up the stream to signal + // when first data is read. This ensures the server's transport handler is ready. + if (isGetRequest) + { + _lastStream.SetReadyCallback(SignalUnsolicitedMessageStreamReady); + } } return response; @@ -89,10 +106,14 @@ private sealed class FaultingStream(Stream innerStream) : Stream { private readonly CancellationTokenSource _cts = new(); private TaskCompletionSource? _faultTcs; + private Action? _readyCallback; + private bool _readySignaled; private bool _disposed; public bool IsDisposed => _disposed; + public void SetReadyCallback(Action callback) => _readyCallback = callback; + public async Task TriggerFaultAsync(CancellationToken cancellationToken) { if (_faultTcs is not null) @@ -131,6 +152,12 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation _cts.Token.ThrowIfCancellationRequested(); + if (bytesRead > 0 && !_readySignaled) + { + _readySignaled = true; + _readyCallback?.Invoke(); + } + return bytesRead; } catch (OperationCanceledException) when (_cts.IsCancellationRequested)