Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WriteSize() and WriteString() methods #3913

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ protected static string ReadString(Stream stream)
}
}

protected static string ReadStringValue(Stream stream, int size)
{
byte[] bytes = ArrayPool<byte>.Shared.Rent(size);
try
{
#if NET7_0_OR_GREATER
stream.ReadExactly(bytes, 0, size);
#else
_ = stream.Read(bytes, 0, size);
#endif
return Encoding.UTF8.GetString(bytes, 0, size);
}
finally
{
ArrayPool<byte>.Shared.Return(bytes);
}
}

protected static void WriteString(Stream stream, string str)
{
int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str);
Expand All @@ -61,6 +79,21 @@ protected static void WriteString(Stream stream, string str)
}
}

protected static void WriteStringValue(Stream stream, string str)
{
int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str);
byte[] bytes = ArrayPool<byte>.Shared.Rent(stringutf8TotalBytes);
try
{
Encoding.UTF8.GetBytes(str, bytes);
stream.Write(bytes, 0, stringutf8TotalBytes);
}
finally
{
ArrayPool<byte>.Shared.Return(bytes);
}
}

protected static void WriteStringSize(Stream stream, string str)
{
int stringutf8TotalBytes = Encoding.UTF8.GetByteCount(str);
Expand All @@ -71,9 +104,10 @@ protected static void WriteStringSize(Stream stream, string str)
stream.Write(len);
}

protected static void WriteSize(Stream stream)
protected static void WriteSize<T>(Stream stream)
where T : struct
{
int sizeInBytes = GetSize<int>();
int sizeInBytes = GetSize<T>();
Span<byte> len = stackalloc byte[sizeof(int)];

ApplicationStateGuard.Ensure(BitConverter.TryWriteBytes(len, sizeInBytes), PlatformResources.UnexpectedExceptionDuringByteConversionErrorMessage);
Expand Down Expand Up @@ -169,6 +203,14 @@ protected static string ReadString(Stream stream)
return Encoding.UTF8.GetString(bytes);
}

protected static string ReadStringValue(Stream stream, int size)
{
byte[] bytes = new byte[size];
_ = stream.Read(bytes, 0, bytes.Length);

return Encoding.UTF8.GetString(bytes);
}

protected static void WriteString(Stream stream, string str)
{
byte[] bytes = Encoding.UTF8.GetBytes(str);
Expand All @@ -177,18 +219,25 @@ protected static void WriteString(Stream stream, string str)
stream.Write(bytes, 0, bytes.Length);
}

protected static void WriteStringValue(Stream stream, string str)
{
byte[] bytes = Encoding.UTF8.GetBytes(str);
stream.Write(bytes, 0, bytes.Length);
}

protected static void WriteStringSize(Stream stream, string str)
{
byte[] bytes = Encoding.UTF8.GetBytes(str);
byte[] len = BitConverter.GetBytes(bytes.Length);
stream.Write(len, 0, len.Length);
}

protected static void WriteSize(Stream stream)
protected static void WriteSize<T>(Stream stream)
where T : struct
{
int sizeInBytes = GetSize<int>();
int sizeInBytes = GetSize<T>();
byte[] len = BitConverter.GetBytes(sizeInBytes);
stream.Write(len, 0, sizeInBytes);
stream.Write(len, 0, len.Length);
}

protected static void WriteInt(Stream stream, int value)
Expand Down Expand Up @@ -257,7 +306,7 @@ protected static void WriteField(Stream stream, ushort id, string? value)

WriteShort(stream, id);
WriteStringSize(stream, value);
WriteString(stream, value);
WriteStringValue(stream, value);
}

protected static void WriteField(Stream stream, string? value)
Expand Down Expand Up @@ -288,7 +337,7 @@ protected static void WriteField(Stream stream, ushort id, bool? value)
}

WriteShort(stream, id);
WriteSize(stream);
WriteSize<bool>(stream);
WriteBool(stream, value.Value);
}

Expand All @@ -300,7 +349,7 @@ protected static void WriteField(Stream stream, ushort id, byte? value)
}

WriteShort(stream, id);
WriteSize(stream);
WriteSize<byte>(stream);
WriteByte(stream, value.Value);
}

Expand All @@ -320,6 +369,7 @@ protected static void WriteAtPosition(Stream stream, int value, long position)
Type type when type == typeof(long) => sizeof(long),
Type type when type == typeof(short) => sizeof(short),
Type type when type == typeof(bool) => sizeof(bool),
Type type when type == typeof(byte) => sizeof(byte),
_ => 0,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public object Deserialize(Stream stream)
switch (fieldId)
{
case CommandLineOptionMessagesFieldsId.ModulePath:
moduleName = ReadString(stream);
moduleName = ReadStringValue(stream, fieldSize);
break;

case CommandLineOptionMessagesFieldsId.CommandLineOptionMessageList:
Expand Down Expand Up @@ -92,11 +92,11 @@ private static List<CommandLineOptionMessage> ReadCommandLineOptionMessagesPaylo
switch (fieldId)
{
case CommandLineOptionMessageFieldsId.Name:
name = ReadString(stream);
name = ReadStringValue(stream, fieldSize);
break;

case CommandLineOptionMessageFieldsId.Description:
description = ReadString(stream);
description = ReadStringValue(stream, fieldSize);
break;

case CommandLineOptionMessageFieldsId.IsHidden:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public object Deserialize(Stream stream)
switch (fieldId)
{
case DiscoveredTestMessagesFieldsId.ExecutionId:
executionId = ReadString(stream);
executionId = ReadStringValue(stream, fieldSize);
break;

case DiscoveredTestMessagesFieldsId.DiscoveredTestMessageList:
Expand Down Expand Up @@ -83,11 +83,11 @@ private static List<DiscoveredTestMessage> ReadDiscoveredTestMessagesPayload(Str
switch (fieldId)
{
case DiscoveredTestMessageFieldsId.Uid:
uid = ReadString(stream);
uid = ReadStringValue(stream, fieldSize);
break;

case DiscoveredTestMessageFieldsId.DisplayName:
displayName = ReadString(stream);
displayName = ReadStringValue(stream, fieldSize);
break;

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public object Deserialize(Stream stream)
switch (fieldId)
{
case FileArtifactMessagesFieldsId.ExecutionId:
executionId = ReadString(stream);
executionId = ReadStringValue(stream, fieldSize);
break;

case FileArtifactMessagesFieldsId.FileArtifactMessageList:
Expand Down Expand Up @@ -99,27 +99,27 @@ private static List<FileArtifactMessage> ReadFileArtifactMessagesPayload(Stream
switch (fieldId)
{
case FileArtifactMessageFieldsId.FullPath:
fullPath = ReadString(stream);
fullPath = ReadStringValue(stream, fieldSize);
break;

case FileArtifactMessageFieldsId.DisplayName:
displayName = ReadString(stream);
displayName = ReadStringValue(stream, fieldSize);
break;

case FileArtifactMessageFieldsId.Description:
description = ReadString(stream);
description = ReadStringValue(stream, fieldSize);
break;

case FileArtifactMessageFieldsId.TestUid:
testUid = ReadString(stream);
testUid = ReadStringValue(stream, fieldSize);
break;

case FileArtifactMessageFieldsId.TestDisplayName:
testDisplayName = ReadString(stream);
testDisplayName = ReadStringValue(stream, fieldSize);
break;

case FileArtifactMessageFieldsId.SessionUid:
sessionUid = ReadString(stream);
sessionUid = ReadStringValue(stream, fieldSize);
break;

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public object Deserialize(Stream stream)
switch (fieldId)
{
case TestResultMessagesFieldsId.ExecutionId:
executionId = ReadString(stream);
executionId = ReadStringValue(stream, fieldSize);
break;

case TestResultMessagesFieldsId.SuccessfulTestMessageList:
Expand Down Expand Up @@ -139,23 +139,23 @@ private static List<SuccessfulTestResultMessage> ReadSuccessfulTestMessagesPaylo
switch (fieldId)
{
case SuccessfulTestResultMessageFieldsId.Uid:
uid = ReadString(stream);
uid = ReadStringValue(stream, fieldSize);
break;

case SuccessfulTestResultMessageFieldsId.DisplayName:
displayName = ReadString(stream);
displayName = ReadStringValue(stream, fieldSize);
break;

case SuccessfulTestResultMessageFieldsId.State:
state = ReadByte(stream);
break;

case SuccessfulTestResultMessageFieldsId.Reason:
reason = ReadString(stream);
reason = ReadStringValue(stream, fieldSize);
break;

case SuccessfulTestResultMessageFieldsId.SessionUid:
sessionUid = ReadString(stream);
sessionUid = ReadStringValue(stream, fieldSize);
break;

default:
Expand Down Expand Up @@ -190,31 +190,31 @@ private static List<FailedTestResultMessage> ReadFailedTestMessagesPayload(Strea
switch (fieldId)
{
case FailedTestResultMessageFieldsId.Uid:
uid = ReadString(stream);
uid = ReadStringValue(stream, fieldSize);
break;

case FailedTestResultMessageFieldsId.DisplayName:
displayName = ReadString(stream);
displayName = ReadStringValue(stream, fieldSize);
break;

case FailedTestResultMessageFieldsId.State:
state = ReadByte(stream);
break;

case FailedTestResultMessageFieldsId.Reason:
reason = ReadString(stream);
reason = ReadStringValue(stream, fieldSize);
break;

case FailedTestResultMessageFieldsId.ErrorMessage:
errorMessage = ReadString(stream);
errorMessage = ReadStringValue(stream, fieldSize);
break;

case FailedTestResultMessageFieldsId.ErrorStackTrace:
errorStackTrace = ReadString(stream);
errorStackTrace = ReadStringValue(stream, fieldSize);
break;

case FailedTestResultMessageFieldsId.SessionUid:
sessionUid = ReadString(stream);
sessionUid = ReadStringValue(stream, fieldSize);
break;

default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public object Deserialize(Stream stream)

for (int i = 0; i < fieldCount; i++)
{
int fieldId = ReadShort(stream);
ushort fieldId = ReadShort(stream);
int fieldSize = ReadInt(stream);

switch (fieldId)
Expand All @@ -45,11 +45,11 @@ public object Deserialize(Stream stream)
break;

case TestSessionEventFieldsId.SessionUid:
sessionUid = ReadString(stream);
sessionUid = ReadStringValue(stream, fieldSize);
break;

case TestSessionEventFieldsId.ExecutionId:
executionId = ReadString(stream);
executionId = ReadStringValue(stream, fieldSize);
break;

default:
Expand Down
Loading