From 8515502da45d4e17d50e44070f1b85f555e4fc89 Mon Sep 17 00:00:00 2001 From: mariam-abdulla Date: Fri, 4 Oct 2024 14:46:01 +0200 Subject: [PATCH] Fix WriteSize() and WriteString() methods --- .../IPC/Serializers/BaseSerializer.cs | 66 ++++++++++++++++--- .../CommandLineOptionMessagesSerializer.cs | 6 +- .../DiscoveredTestMessagesSerializer.cs | 6 +- .../FileArtifactMessagesSerializer.cs | 14 ++-- .../TestResultMessagesSerializer.cs | 22 +++---- .../Serializers/TestSessionEventSerializer.cs | 6 +- 6 files changed, 85 insertions(+), 35 deletions(-) diff --git a/src/Platform/Microsoft.Testing.Platform/IPC/Serializers/BaseSerializer.cs b/src/Platform/Microsoft.Testing.Platform/IPC/Serializers/BaseSerializer.cs index 47190f69e4..013add7468 100644 --- a/src/Platform/Microsoft.Testing.Platform/IPC/Serializers/BaseSerializer.cs +++ b/src/Platform/Microsoft.Testing.Platform/IPC/Serializers/BaseSerializer.cs @@ -42,6 +42,24 @@ protected static string ReadString(Stream stream) } } + protected static string ReadStringValue(Stream stream, int size) + { + byte[] bytes = ArrayPool.Shared.Rent(size); + try + { +#if NET7_0_OR_GREATER + stream.ReadExactly(bytes, 0, size); +#else + _ = stream.Read(bytes, 0, size); +#endif + return Encoding.UTF8.GetString(bytes, 0, size); + } + finally + { + ArrayPool.Shared.Return(bytes); + } + } + protected static void WriteString(Stream stream, string str) { int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str); @@ -61,6 +79,21 @@ protected static void WriteString(Stream stream, string str) } } + protected static void WriteStringValue(Stream stream, string str) + { + int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str); + byte[] bytes = ArrayPool.Shared.Rent(stringutf8TotalBytes); + try + { + Encoding.UTF8.GetBytes(str, bytes); + stream.Write(bytes, 0, stringutf8TotalBytes); + } + finally + { + ArrayPool.Shared.Return(bytes); + } + } + protected static void WriteStringSize(Stream stream, string str) { int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str); @@ -71,9 +104,10 @@ protected static void WriteStringSize(Stream stream, string str) stream.Write(len); } - protected static void WriteSize(Stream stream) + protected static void WriteSize(Stream stream) + where T : struct { - int sizeInBytes = GetSize(); + int sizeInBytes = GetSize(); Span len = stackalloc byte[sizeof(int)]; ApplicationStateGuard.Ensure(BitConverter.TryWriteBytes(len, sizeInBytes), PlatformResources.UnexpectedExceptionDuringByteConversionErrorMessage); @@ -169,6 +203,14 @@ protected static string ReadString(Stream stream) return Encoding.UTF8.GetString(bytes); } + protected static string ReadStringValue(Stream stream, int size) + { + byte[] bytes = new byte[size]; + _ = stream.Read(bytes, 0, bytes.Length); + + return Encoding.UTF8.GetString(bytes); + } + protected static void WriteString(Stream stream, string str) { byte[] bytes = Encoding.UTF8.GetBytes(str); @@ -177,6 +219,12 @@ protected static void WriteString(Stream stream, string str) stream.Write(bytes, 0, bytes.Length); } + protected static void WriteStringValue(Stream stream, string str) + { + byte[] bytes = Encoding.UTF8.GetBytes(str); + stream.Write(bytes, 0, bytes.Length); + } + protected static void WriteStringSize(Stream stream, string str) { byte[] bytes = Encoding.UTF8.GetBytes(str); @@ -184,11 +232,12 @@ protected static void WriteStringSize(Stream stream, string str) stream.Write(len, 0, len.Length); } - protected static void WriteSize(Stream stream) + protected static void WriteSize(Stream stream) + where T : struct { - int sizeInBytes = GetSize(); + int sizeInBytes = GetSize(); byte[] len = BitConverter.GetBytes(sizeInBytes); - stream.Write(len, 0, sizeInBytes); + stream.Write(len, 0, len.Length); } protected static void WriteInt(Stream stream, int value) @@ -257,7 +306,7 @@ protected static void WriteField(Stream stream, ushort id, string? value) WriteShort(stream, id); WriteStringSize(stream, value); - WriteString(stream, value); + WriteStringValue(stream, value); } protected static void WriteField(Stream stream, string? value) @@ -288,7 +337,7 @@ protected static void WriteField(Stream stream, ushort id, bool? value) } WriteShort(stream, id); - WriteSize(stream); + WriteSize(stream); WriteBool(stream, value.Value); } @@ -300,7 +349,7 @@ protected static void WriteField(Stream stream, ushort id, byte? value) } WriteShort(stream, id); - WriteSize(stream); + WriteSize(stream); WriteByte(stream, value.Value); } @@ -320,6 +369,7 @@ protected static void WriteAtPosition(Stream stream, int value, long position) Type type when type == typeof(long) => sizeof(long), Type type when type == typeof(short) => sizeof(short), Type type when type == typeof(bool) => sizeof(bool), + Type type when type == typeof(byte) => sizeof(byte), _ => 0, }; diff --git a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/CommandLineOptionMessagesSerializer.cs b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/CommandLineOptionMessagesSerializer.cs index 3bbf3d6804..a96255e664 100644 --- a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/CommandLineOptionMessagesSerializer.cs +++ b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/CommandLineOptionMessagesSerializer.cs @@ -55,7 +55,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case CommandLineOptionMessagesFieldsId.ModulePath: - moduleName = ReadString(stream); + moduleName = ReadStringValue(stream, fieldSize); break; case CommandLineOptionMessagesFieldsId.CommandLineOptionMessageList: @@ -92,11 +92,11 @@ private static List ReadCommandLineOptionMessagesPaylo switch (fieldId) { case CommandLineOptionMessageFieldsId.Name: - name = ReadString(stream); + name = ReadStringValue(stream, fieldSize); break; case CommandLineOptionMessageFieldsId.Description: - description = ReadString(stream); + description = ReadStringValue(stream, fieldSize); break; case CommandLineOptionMessageFieldsId.IsHidden: diff --git a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/DiscoveredTestMessagesSerializer.cs b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/DiscoveredTestMessagesSerializer.cs index da2fa36ab4..487733b49a 100644 --- a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/DiscoveredTestMessagesSerializer.cs +++ b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/DiscoveredTestMessagesSerializer.cs @@ -47,7 +47,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case DiscoveredTestMessagesFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; case DiscoveredTestMessagesFieldsId.DiscoveredTestMessageList: @@ -83,11 +83,11 @@ private static List ReadDiscoveredTestMessagesPayload(Str switch (fieldId) { case DiscoveredTestMessageFieldsId.Uid: - uid = ReadString(stream); + uid = ReadStringValue(stream, fieldSize); break; case DiscoveredTestMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; default: diff --git a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/FileArtifactMessagesSerializer.cs b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/FileArtifactMessagesSerializer.cs index b05069f510..d9601be217 100644 --- a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/FileArtifactMessagesSerializer.cs +++ b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/FileArtifactMessagesSerializer.cs @@ -63,7 +63,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case FileArtifactMessagesFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; case FileArtifactMessagesFieldsId.FileArtifactMessageList: @@ -99,27 +99,27 @@ private static List ReadFileArtifactMessagesPayload(Stream switch (fieldId) { case FileArtifactMessageFieldsId.FullPath: - fullPath = ReadString(stream); + fullPath = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.Description: - description = ReadString(stream); + description = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.TestUid: - testUid = ReadString(stream); + testUid = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.TestDisplayName: - testDisplayName = ReadString(stream); + testDisplayName = ReadStringValue(stream, fieldSize); break; case FileArtifactMessageFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; default: diff --git a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestResultMessagesSerializer.cs b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestResultMessagesSerializer.cs index 796b934f25..ef94f25bab 100644 --- a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestResultMessagesSerializer.cs +++ b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestResultMessagesSerializer.cs @@ -95,7 +95,7 @@ public object Deserialize(Stream stream) switch (fieldId) { case TestResultMessagesFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; case TestResultMessagesFieldsId.SuccessfulTestMessageList: @@ -139,11 +139,11 @@ private static List ReadSuccessfulTestMessagesPaylo switch (fieldId) { case SuccessfulTestResultMessageFieldsId.Uid: - uid = ReadString(stream); + uid = ReadStringValue(stream, fieldSize); break; case SuccessfulTestResultMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; case SuccessfulTestResultMessageFieldsId.State: @@ -151,11 +151,11 @@ private static List ReadSuccessfulTestMessagesPaylo break; case SuccessfulTestResultMessageFieldsId.Reason: - reason = ReadString(stream); + reason = ReadStringValue(stream, fieldSize); break; case SuccessfulTestResultMessageFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; default: @@ -190,11 +190,11 @@ private static List ReadFailedTestMessagesPayload(Strea switch (fieldId) { case FailedTestResultMessageFieldsId.Uid: - uid = ReadString(stream); + uid = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.DisplayName: - displayName = ReadString(stream); + displayName = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.State: @@ -202,19 +202,19 @@ private static List ReadFailedTestMessagesPayload(Strea break; case FailedTestResultMessageFieldsId.Reason: - reason = ReadString(stream); + reason = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.ErrorMessage: - errorMessage = ReadString(stream); + errorMessage = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.ErrorStackTrace: - errorStackTrace = ReadString(stream); + errorStackTrace = ReadStringValue(stream, fieldSize); break; case FailedTestResultMessageFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; default: diff --git a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestSessionEventSerializer.cs b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestSessionEventSerializer.cs index 1c283f6370..2c5afdefa1 100644 --- a/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestSessionEventSerializer.cs +++ b/src/Platform/Microsoft.Testing.Platform/ServerMode/DotnetTest/IPC/Serializers/TestSessionEventSerializer.cs @@ -35,7 +35,7 @@ public object Deserialize(Stream stream) for (int i = 0; i < fieldCount; i++) { - int fieldId = ReadShort(stream); + ushort fieldId = ReadShort(stream); int fieldSize = ReadInt(stream); switch (fieldId) @@ -45,11 +45,11 @@ public object Deserialize(Stream stream) break; case TestSessionEventFieldsId.SessionUid: - sessionUid = ReadString(stream); + sessionUid = ReadStringValue(stream, fieldSize); break; case TestSessionEventFieldsId.ExecutionId: - executionId = ReadString(stream); + executionId = ReadStringValue(stream, fieldSize); break; default: