diff --git a/NuGet.config b/NuGet.config
index c6d3a0ed1898b0..f04033fb6afa26 100644
--- a/NuGet.config
+++ b/NuGet.config
@@ -10,6 +10,7 @@
+
diff --git a/eng/Versions.props b/eng/Versions.props
index b3e6634bec545b..d6638df6bc6e53 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -1,13 +1,13 @@
- 9.0.1
+ 9.0.2
9
0
- 1
+ 2
9.0.100
- 8.0.11
+ 8.0.$([MSBuild]::Add($(PatchVersion),11))
7.0.20
6.0.36
servicing
@@ -164,7 +164,7 @@
1.0.0-prerelease.24462.2
2.0.0
- 17.10.0-beta1.24272.1
+ 17.12.0-beta1.24603.5
2.0.0-beta4.24324.3
3.1.7
2.1.0
diff --git a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs
index f82bbb96732c95..292a5eb1038d53 100644
--- a/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs
+++ b/src/libraries/System.Formats.Nrbf/ref/System.Formats.Nrbf.cs
@@ -9,7 +9,6 @@ namespace System.Formats.Nrbf
public abstract partial class ArrayRecord : System.Formats.Nrbf.SerializationRecord
{
internal ArrayRecord() { }
- public virtual long FlattenedLength { get { throw null; } }
public override System.Formats.Nrbf.SerializationRecordId Id { get { throw null; } }
public abstract System.ReadOnlySpan Lengths { get; }
public int Rank { get { throw null; } }
diff --git a/src/libraries/System.Formats.Nrbf/src/PACKAGE.md b/src/libraries/System.Formats.Nrbf/src/PACKAGE.md
index c301459358838b..23e5ac389d3d14 100644
--- a/src/libraries/System.Formats.Nrbf/src/PACKAGE.md
+++ b/src/libraries/System.Formats.Nrbf/src/PACKAGE.md
@@ -54,7 +54,7 @@ There are more than a dozen different serialization [record types](https://learn
- `PrimitiveTypeRecord` derives from the non-generic [PrimitiveTypeRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord), which also exposes a [Value](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.primitivetyperecord.value) property. But on the base class, the value is returned as `object` (which introduces boxing for value types).
- [ClassRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.classrecord): describes all `class` and `struct` besides the aforementioned primitive types.
- [ArrayRecord](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.arrayrecord): describes all array records, including jagged and multi-dimensional arrays.
-- [`SZArrayRecord`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `ClassRecord`.
+- [`SZArrayRecord`](https://learn.microsoft.com/dotnet/api/system.formats.nrbf.szarrayrecord-1): describes single-dimensional, zero-indexed array records, where `T` can be either a primitive type or a `SerializationRecord`.
```csharp
SerializationRecord rootObject = NrbfDecoder.Decode(payload); // payload is a Stream
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs
index 063a2430782064..60623ac0dbde3a 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs
@@ -28,12 +28,13 @@ internal enum AllowedRecordTypes : uint
ArraySingleString = 1 << SerializationRecordType.ArraySingleString,
Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple,
+ Arrays = ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray,
///
/// Any .NET object (a primitive, a reference type, a reference or single null).
///
AnyObject = MemberPrimitiveTyped
- | ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray
+ | Arrays
| ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes
| BinaryObjectString
| MemberReference
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs
index 237b7b72a27198..c18208668225f8 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRecord.cs
@@ -4,6 +4,9 @@
using System.Diagnostics.CodeAnalysis;
using System.Reflection.Metadata;
using System.Formats.Nrbf.Utils;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Runtime.Serialization;
namespace System.Formats.Nrbf;
@@ -27,12 +30,6 @@ private protected ArrayRecord(ArrayInfo arrayInfo)
/// A buffer of integers that represent the number of elements in every dimension.
public abstract ReadOnlySpan Lengths { get; }
- ///
- /// When overridden in a derived class, gets the total number of all elements in every dimension.
- ///
- /// A number that represent the total number of all elements in every dimension.
- public virtual long FlattenedLength => ArrayInfo.FlattenedLength;
-
///
/// Gets the rank of the array.
///
@@ -118,4 +115,86 @@ private void HandleNext(object value, NextInfo info, int size)
}
internal abstract (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType();
+
+ internal static void Populate(List source, Array destination, int[] lengths, AllowedRecordTypes allowedRecordTypes, bool allowNulls)
+ {
+ int[] indices = new int[lengths.Length];
+ nuint numElementsWritten = 0; // only for debugging; not used in release builds
+
+ foreach (SerializationRecord record in source)
+ {
+ object? value = GetActualValue(record, allowedRecordTypes, out int incrementCount);
+ if (value is not null)
+ {
+ // null is a default element for all array of reference types, so we don't call SetValue for nulls.
+ destination.SetValue(value, indices);
+ Debug.Assert(incrementCount == 1, "IncrementCount other than 1 is allowed only for null records.");
+ }
+ else if (!allowNulls)
+ {
+ ThrowHelper.ThrowArrayContainedNulls();
+ }
+
+ while (incrementCount > 0)
+ {
+ incrementCount--;
+ numElementsWritten++;
+ int dimension = indices.Length - 1;
+ while (dimension >= 0)
+ {
+ indices[dimension]++;
+ if (indices[dimension] < lengths[dimension])
+ {
+ break;
+ }
+ indices[dimension] = 0;
+ dimension--;
+ }
+
+ if (dimension < 0)
+ {
+ break;
+ }
+ }
+ }
+
+ Debug.Assert(numElementsWritten == (uint)source.Count, "We should have traversed the entirety of the source records collection.");
+ Debug.Assert(numElementsWritten == (ulong)destination.LongLength, "We should have traversed the entirety of the destination array.");
+ }
+
+ private static object? GetActualValue(SerializationRecord record, AllowedRecordTypes allowedRecordTypes, out int repeatCount)
+ {
+ repeatCount = 1;
+
+ if (record is NullsRecord nullsRecord)
+ {
+ repeatCount = nullsRecord.NullCount;
+ return null;
+ }
+ else if (record.RecordType == SerializationRecordType.MemberReference)
+ {
+ record = ((MemberReferenceRecord)record).GetReferencedRecord();
+ }
+
+ if (allowedRecordTypes == AllowedRecordTypes.BinaryObjectString)
+ {
+ if (record is not BinaryObjectStringRecord stringRecord)
+ {
+ throw new SerializationException(SR.Serialization_InvalidReference);
+ }
+
+ return stringRecord.Value;
+ }
+ else if (allowedRecordTypes == AllowedRecordTypes.Arrays)
+ {
+ if (record is not ArrayRecord arrayRecord)
+ {
+ throw new SerializationException(SR.Serialization_InvalidReference);
+ }
+
+ return arrayRecord;
+ }
+
+ return record;
+ }
}
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs
new file mode 100644
index 00000000000000..39c66c5f2af0d9
--- /dev/null
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayRectangularPrimitiveRecord.cs
@@ -0,0 +1,83 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Formats.Nrbf.Utils;
+using System.Linq;
+using System.Reflection.Metadata;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace System.Formats.Nrbf
+{
+ internal sealed class ArrayRectangularPrimitiveRecord : ArrayRecord where T : unmanaged
+ {
+ private readonly int[] _lengths;
+ private readonly IReadOnlyList _values;
+ private TypeName? _typeName;
+
+ internal ArrayRectangularPrimitiveRecord(ArrayInfo arrayInfo, int[] lengths, IReadOnlyList values) : base(arrayInfo)
+ {
+ _lengths = lengths;
+ _values = values;
+ ValuesToRead = 0; // there is nothing to read anymore
+ }
+
+ public override ReadOnlySpan Lengths => _lengths;
+
+ public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
+
+ public override TypeName TypeName
+ => _typeName ??= TypeNameHelpers.GetPrimitiveTypeName(TypeNameHelpers.GetPrimitiveType()).MakeArrayTypeName(Rank);
+
+ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() => throw new InvalidOperationException();
+
+ private protected override void AddValue(object value) => throw new InvalidOperationException();
+
+ [RequiresDynamicCode("May call Array.CreateInstance().")]
+ private protected override Array Deserialize(Type arrayType, bool allowNulls)
+ {
+ Array result =
+#if NET9_0_OR_GREATER
+ Array.CreateInstanceFromArrayType(arrayType, _lengths);
+#else
+ Array.CreateInstance(typeof(T), _lengths);
+#endif
+ int[] indices = new int[_lengths.Length];
+ nuint numElementsWritten = 0; // only for debugging; not used in release builds
+
+ for (int i = 0; i < _values.Count; i++)
+ {
+ result.SetValue(_values[i], indices);
+ numElementsWritten++;
+
+ int dimension = indices.Length - 1;
+ while (dimension >= 0)
+ {
+ indices[dimension]++;
+ if (indices[dimension] < Lengths[dimension])
+ {
+ break;
+ }
+ indices[dimension] = 0;
+ dimension--;
+ }
+
+ if (dimension < 0)
+ {
+ break;
+ }
+ }
+
+ Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection.");
+ Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array.");
+
+ return result;
+ }
+ }
+}
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs
index d0276ff3782e3a..2c402af7c35ab0 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs
@@ -15,9 +15,9 @@ namespace System.Formats.Nrbf;
///
/// ArraySingleObject records are described in [MS-NRBF] 2.4.3.2 .
///
-internal sealed class ArraySingleObjectRecord : SZArrayRecord
+internal sealed class ArraySingleObjectRecord : SZArrayRecord
{
- private ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = [];
+ internal ArraySingleObjectRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = [];
public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleObject;
@@ -27,25 +27,26 @@ public override TypeName TypeName
private List Records { get; }
///
- public override object?[] GetArray(bool allowNulls = true)
- => (object?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));
+ public override SerializationRecord?[] GetArray(bool allowNulls = true)
+ => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));
- private object?[] ToArray(bool allowNulls)
+ private SerializationRecord?[] ToArray(bool allowNulls)
{
- object?[] values = new object?[Length];
+ SerializationRecord?[] values = new SerializationRecord?[Length];
int valueIndex = 0;
for (int recordIndex = 0; recordIndex < Records.Count; recordIndex++)
{
SerializationRecord record = Records[recordIndex];
- int nullCount = record is NullsRecord nullsRecord ? nullsRecord.NullCount : 0;
- if (nullCount == 0)
+ if (record is MemberReferenceRecord referenceRecord)
{
- // "new object[] { }" is special cased because it allows for storing reference to itself.
- values[valueIndex++] = record is MemberReferenceRecord referenceRecord && referenceRecord.Reference.Equals(Id)
- ? values // a reference to self, and a way to get StackOverflow exception ;)
- : record.GetValue();
+ record = referenceRecord.GetReferencedRecord();
+ }
+
+ if (record is not NullsRecord nullsRecord)
+ {
+ values[valueIndex++] = record;
continue;
}
@@ -54,6 +55,7 @@ public override TypeName TypeName
ThrowHelper.ThrowArrayContainedNulls();
}
+ int nullCount = nullsRecord.NullCount;
do
{
values[valueIndex++] = null;
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs
index a13507b97015a0..a28359d9bb13dc 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySinglePrimitiveRecord.cs
@@ -47,6 +47,11 @@ public override T[] GetArray(bool allowNulls = true)
internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int count)
{
+ if (count == 0)
+ {
+ return Array.Empty(); // Empty arrays are allowed.
+ }
+
// For decimals, the input is provided as strings, so we can't compute the required size up-front.
if (typeof(T) == typeof(decimal))
{
@@ -71,18 +76,15 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c
// allocations to be proportional to the amount of data present in the input stream,
// which is a sufficient defense against DoS.
- long requiredBytes = count;
- if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan))
- {
- // We can't assume DateTime as represented by the runtime is 8 bytes.
- // The only assumption we can make is that it's 8 bytes on the wire.
- requiredBytes *= 8;
- }
- else if (typeof(T) != typeof(char))
- {
- requiredBytes *= Unsafe.SizeOf();
- }
+ // We can't assume DateTime as represented by the runtime is 8 bytes.
+ // The only assumption we can make is that it's 8 bytes on the wire.
+ int sizeOfT = typeof(T) == typeof(DateTime) || typeof(T) == typeof(TimeSpan)
+ ? 8
+ : typeof(T) != typeof(char)
+ ? Unsafe.SizeOf()
+ : 1;
+ long requiredBytes = (long)count * sizeOfT;
bool? isDataAvailable = reader.IsDataAvailable(requiredBytes);
if (!isDataAvailable.HasValue)
{
@@ -110,26 +112,49 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c
// It's safe to pre-allocate, as we have ensured there is enough bytes in the stream.
T[] result = new T[count];
- Span resultAsBytes = MemoryMarshal.AsBytes(result);
-#if NET
- reader.BaseStream.ReadExactly(resultAsBytes);
+
+ // MemoryMarshal.AsBytes can fail for inputs that need more than int.MaxValue bytes.
+ // To avoid OverflowException, we read the data in chunks.
+ int MaxChunkLength =
+#if !DEBUG
+ int.MaxValue / sizeOfT;
#else
- byte[] bytes = ArrayPool.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
+ // Let's use a different value for non-release builds to ensure this code path
+ // is covered with tests without the need of decoding enormous payloads.
+ 8_000;
+#endif
- while (!resultAsBytes.IsEmpty)
+#if !NET
+ byte[] rented = ArrayPool.Shared.Rent((int)Math.Min(requiredBytes, 256_000));
+#endif
+
+ Span valuesToRead = result.AsSpan();
+ while (!valuesToRead.IsEmpty)
{
- int bytesRead = reader.Read(bytes, 0, Math.Min(resultAsBytes.Length, bytes.Length));
- if (bytesRead <= 0)
+ int sliceSize = Math.Min(valuesToRead.Length, MaxChunkLength);
+
+ Span resultAsBytes = MemoryMarshal.AsBytes(valuesToRead.Slice(0, sliceSize));
+#if NET
+ reader.BaseStream.ReadExactly(resultAsBytes);
+#else
+ while (!resultAsBytes.IsEmpty)
{
- ArrayPool.Shared.Return(bytes);
- ThrowHelper.ThrowEndOfStreamException();
- }
+ int bytesRead = reader.Read(rented, 0, Math.Min(resultAsBytes.Length, rented.Length));
+ if (bytesRead <= 0)
+ {
+ ArrayPool.Shared.Return(rented);
+ ThrowHelper.ThrowEndOfStreamException();
+ }
- bytes.AsSpan(0, bytesRead).CopyTo(resultAsBytes);
- resultAsBytes = resultAsBytes.Slice(bytesRead);
+ rented.AsSpan(0, bytesRead).CopyTo(resultAsBytes);
+ resultAsBytes = resultAsBytes.Slice(bytesRead);
+ }
+#endif
+ valuesToRead = valuesToRead.Slice(sliceSize);
}
- ArrayPool.Shared.Return(bytes);
+#if !NET
+ ArrayPool.Shared.Return(rented);
#endif
if (!BitConverter.IsLittleEndian)
@@ -176,7 +201,7 @@ internal static IReadOnlyList DecodePrimitiveTypes(BinaryReader reader, int c
{
// See DontCastBytesToBooleans test to see what could go wrong.
bool[] booleans = (bool[])(object)result;
- resultAsBytes = MemoryMarshal.AsBytes(result);
+ Span resultAsBytes = MemoryMarshal.AsBytes(result);
for (int i = 0; i < booleans.Length; i++)
{
// We don't use the bool array to get the value, as an optimizing compiler or JIT could elide this.
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs
index 42b9eadd97bd55..38884aadc54693 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs
@@ -17,7 +17,7 @@ namespace System.Formats.Nrbf;
///
internal sealed class ArraySingleStringRecord : SZArrayRecord
{
- private ArraySingleStringRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = [];
+ internal ArraySingleStringRecord(ArrayInfo arrayInfo) : base(arrayInfo) => Records = [];
public override SerializationRecordType RecordType => SerializationRecordType.ArraySingleString;
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs
deleted file mode 100644
index 41b1f73f03550e..00000000000000
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/BinaryArrayRecord.cs
+++ /dev/null
@@ -1,309 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-using System.Collections.Generic;
-using System.Diagnostics.CodeAnalysis;
-using System.IO;
-using System.Reflection.Metadata;
-using System.Formats.Nrbf.Utils;
-using System.Diagnostics;
-
-namespace System.Formats.Nrbf;
-
-///
-/// Represents an array other than single dimensional array of primitive types or .
-///
-///
-/// BinaryArray records are described in [MS-NRBF] 2.4.3.1 .
-///
-internal sealed class BinaryArrayRecord : ArrayRecord
-{
- private static HashSet PrimitiveTypes { get; } =
- [
- typeof(bool), typeof(char), typeof(byte), typeof(sbyte),
- typeof(short), typeof(ushort), typeof(int), typeof(uint),
- typeof(long), typeof(ulong), typeof(IntPtr), typeof(UIntPtr),
- typeof(float), typeof(double), typeof(decimal), typeof(DateTime),
- typeof(TimeSpan), typeof(string), typeof(object)
- ];
-
- private TypeName? _typeName;
- private long _totalElementsCount;
-
- private BinaryArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
- : base(arrayInfo)
- {
- MemberTypeInfo = memberTypeInfo;
- Values = [];
- // We need to parse all elements of the jagged array to obtain total elements count.
- _totalElementsCount = -1;
- }
-
- public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
-
- ///
- public override ReadOnlySpan Lengths => new int[1] { Length };
-
- ///
- public override long FlattenedLength
- {
- get
- {
- if (_totalElementsCount < 0)
- {
- _totalElementsCount = IsJagged
- ? GetJaggedArrayFlattenedLength(this)
- : ArrayInfo.FlattenedLength;
- }
-
- return _totalElementsCount;
- }
- }
-
- public override TypeName TypeName
- => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);
-
- private int Length => ArrayInfo.GetSZArrayLength();
-
- private MemberTypeInfo MemberTypeInfo { get; }
-
- private List Values { get; }
-
- [RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")]
- private protected override Array Deserialize(Type arrayType, bool allowNulls)
- {
- // We can not deserialize non-primitive types.
- // This method returns arrays of ClassRecord for arrays of complex types.
- Type elementType = MapElementType(arrayType, out bool isClassRecord);
- Type actualElementType = arrayType.GetElementType()!;
- Array array =
-#if NET9_0_OR_GREATER
- isClassRecord
- ? Array.CreateInstance(elementType, Length)
- : Array.CreateInstanceFromArrayType(arrayType, Length);
-#else
- Array.CreateInstance(elementType, Length);
-#endif
-
- int resultIndex = 0;
- foreach (object value in Values)
- {
- object item = value is MemberReferenceRecord referenceRecord
- ? referenceRecord.GetReferencedRecord()
- : value;
-
- if (item is not SerializationRecord record)
- {
- array.SetValue(item, resultIndex++);
- continue;
- }
-
- switch (record.RecordType)
- {
- case SerializationRecordType.BinaryArray:
- case SerializationRecordType.ArraySinglePrimitive:
- case SerializationRecordType.ArraySingleObject:
- case SerializationRecordType.ArraySingleString:
-
- // Recursion depth is bounded by the depth of arrayType, which is
- // a trustworthy Type instance. Don't need to worry about stack overflow.
-
- ArrayRecord nestedArrayRecord = (ArrayRecord)record;
- Array nestedArray = nestedArrayRecord.GetArray(actualElementType, allowNulls);
- array.SetValue(nestedArray, resultIndex++);
- break;
- case SerializationRecordType.ObjectNull:
- case SerializationRecordType.ObjectNullMultiple256:
- case SerializationRecordType.ObjectNullMultiple:
- if (!allowNulls)
- {
- ThrowHelper.ThrowArrayContainedNulls();
- }
-
- int nullCount = ((NullsRecord)item).NullCount;
- Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
- do
- {
- array.SetValue(null, resultIndex++);
- nullCount--;
- }
- while (nullCount > 0);
- break;
- default:
- array.SetValue(record.GetValue(), resultIndex++);
- break;
- }
- }
-
- Debug.Assert(resultIndex == array.Length, "We should have traversed the entirety of the newly created array.");
-
- return array;
- }
-
- internal static ArrayRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options)
- {
- SerializationRecordId objectId = SerializationRecordId.Decode(reader);
- BinaryArrayType arrayType = reader.ReadArrayType();
- int rank = reader.ReadInt32();
-
- bool isRectangular = arrayType is BinaryArrayType.Rectangular;
-
- // It is an arbitrary limit in the current CoreCLR type loader.
- // Don't change this value without reviewing the loop a few lines below.
- const int MaxSupportedArrayRank = 32;
-
- if (rank < 1 || rank > MaxSupportedArrayRank
- || (rank != 1 && !isRectangular)
- || (rank == 1 && isRectangular))
- {
- ThrowHelper.ThrowInvalidValue(rank);
- }
-
- int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32
- long totalElementCount = 1; // to avoid integer overflow during the multiplication below
- for (int i = 0; i < lengths.Length; i++)
- {
- lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
- totalElementCount *= lengths[i];
-
- // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]"
- // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But
- // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least
- // we're consistent.
-
- if (totalElementCount > ArrayInfo.MaxArrayLength)
- {
- ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
- }
- }
-
- // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so
- // we don't need to read the NRBF stream 'LowerBounds' field here.
-
- MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap);
- ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank);
-
- if (isRectangular)
- {
- return RectangularArrayRecord.Create(reader, arrayInfo, memberTypeInfo, lengths);
- }
-
- return memberTypeInfo.ShouldBeRepresentedAsArrayOfClassRecords()
- ? new ArrayOfClassesRecord(arrayInfo, memberTypeInfo)
- : new BinaryArrayRecord(arrayInfo, memberTypeInfo);
- }
-
- private static long GetJaggedArrayFlattenedLength(BinaryArrayRecord jaggedArrayRecord)
- {
- long result = 0;
- Queue? jaggedArrayRecords = null;
-
- do
- {
- if (jaggedArrayRecords is not null)
- {
- jaggedArrayRecord = jaggedArrayRecords.Dequeue();
- }
-
- Debug.Assert(jaggedArrayRecord.IsJagged);
-
- // In theory somebody could create a payload that would represent
- // a very nested array with total elements count > long.MaxValue.
- // That is why this method is using checked arithmetic.
- result = checked(result + jaggedArrayRecord.Length); // count the arrays themselves
-
- foreach (object value in jaggedArrayRecord.Values)
- {
- if (value is not SerializationRecord record)
- {
- continue;
- }
-
- if (record.RecordType == SerializationRecordType.MemberReference)
- {
- record = ((MemberReferenceRecord)record).GetReferencedRecord();
- }
-
- switch (record.RecordType)
- {
- case SerializationRecordType.ArraySinglePrimitive:
- case SerializationRecordType.ArraySingleObject:
- case SerializationRecordType.ArraySingleString:
- case SerializationRecordType.BinaryArray:
- ArrayRecord nestedArrayRecord = (ArrayRecord)record;
- if (nestedArrayRecord.IsJagged)
- {
- (jaggedArrayRecords ??= new()).Enqueue((BinaryArrayRecord)nestedArrayRecord);
- }
- else
- {
- // Don't call nestedArrayRecord.FlattenedLength to avoid any potential recursion,
- // just call nestedArrayRecord.ArrayInfo.FlattenedLength that returns pre-computed value.
- result = checked(result + nestedArrayRecord.ArrayInfo.FlattenedLength);
- }
- break;
- default:
- break;
- }
- }
- }
- while (jaggedArrayRecords is not null && jaggedArrayRecords.Count > 0);
-
- return result;
- }
-
- private protected override void AddValue(object value) => Values.Add(value);
-
- internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
- {
- (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0);
-
- if (allowed != AllowedRecordTypes.None)
- {
- // It's an array, it can also contain multiple nulls
- return (allowed | AllowedRecordTypes.Nulls, primitiveType);
- }
-
- return (allowed, primitiveType);
- }
-
- ///
- /// Complex types must not be instantiated, but represented as ClassRecord.
- /// For arrays of primitive types like int, string and object this method returns the element type.
- /// For array of complex types, it returns ClassRecord.
- /// It takes arrays of arrays into account:
- /// - int[][] => int[]
- /// - MyClass[][][] => ClassRecord[][]
- ///
- [RequiresDynamicCode("May call Type.MakeArrayType().")]
- private static Type MapElementType(Type arrayType, out bool isClassRecord)
- {
- Type elementType = arrayType;
- int arrayNestingDepth = 0;
-
- // Loop iteration counts are bound by the nesting depth of arrayType,
- // which is a trustworthy input. No DoS concerns.
-
- while (elementType.IsArray)
- {
- elementType = elementType.GetElementType()!;
- arrayNestingDepth++;
- }
-
- if (PrimitiveTypes.Contains(elementType) || (Nullable.GetUnderlyingType(elementType) is Type nullable && PrimitiveTypes.Contains(nullable)))
- {
- isClassRecord = false;
- return arrayNestingDepth == 1 ? elementType : arrayType.GetElementType()!;
- }
-
- // Complex types are never instantiated, but represented as ClassRecord
- isClassRecord = true;
- Type complexType = typeof(ClassRecord);
- for (int i = 1; i < arrayNestingDepth; i++)
- {
- complexType = complexType.MakeArrayType();
- }
-
- return complexType;
- }
-}
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs
index c643d3ce8c8465..2762be167b1112 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ClassWithIdRecord.cs
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Formats.Nrbf.Utils;
using System.IO;
using System.Runtime.Serialization;
@@ -27,16 +28,57 @@ private ClassWithIdRecord(SerializationRecordId id, ClassRecord metadataClass) :
internal ClassRecord MetadataClass { get; }
- internal static ClassWithIdRecord Decode(
+ internal static SerializationRecord Decode(
BinaryReader reader,
RecordMap recordMap)
{
SerializationRecordId id = SerializationRecordId.Decode(reader);
SerializationRecordId metadataId = SerializationRecordId.Decode(reader);
- ClassRecord referencedRecord = recordMap.GetRecord(metadataId);
+ SerializationRecord metadataRecord = recordMap.GetRecord(metadataId);
+ if (metadataRecord is ClassRecord referencedClassRecord)
+ {
+ return new ClassWithIdRecord(id, referencedClassRecord);
+ }
+ else if (metadataRecord is PrimitiveTypeRecord primitiveTypeRecord
+ && !primitiveTypeRecord.Id.Equals(default) // such records always have Id provided
+ && metadataRecord is not BinaryObjectStringRecord) // it does not apply to BinaryObjectStringRecord
+ {
+ // BinaryFormatter represents primitive types as MemberPrimitiveTypedRecord
+ // only for arrays of objects. For other arrays, like arrays of some abstraction
+ // (example: new IComparable[] { int.MaxValue }), it uses SystemClassWithMembersAndTypes.
+ // SystemClassWithMembersAndTypes.Decode handles that by returning MemberPrimitiveTypedRecord.
+ // But arrays of such types typically have only one SystemClassWithMembersAndTypes record with
+ // all the member information and multiple ClassWithIdRecord records that just reuse that information.
+ return primitiveTypeRecord switch
+ {
+ MemberPrimitiveTypedRecord => Create(reader.ReadBoolean()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadByte()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadSByte()),
+ MemberPrimitiveTypedRecord => Create(reader.ParseChar()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadInt16()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadUInt16()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadInt32()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadUInt32()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadInt64()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadUInt64()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadSingle()),
+ MemberPrimitiveTypedRecord => Create(reader.ReadDouble()),
+ MemberPrimitiveTypedRecord => Create(new IntPtr(reader.ReadInt64())),
+ MemberPrimitiveTypedRecord => Create(new UIntPtr(reader.ReadUInt64())),
+ MemberPrimitiveTypedRecord => Create(new TimeSpan(reader.ReadInt64())),
+ MemberPrimitiveTypedRecord => SystemClassWithMembersAndTypesRecord.DecodeDateTime(reader, id),
+ MemberPrimitiveTypedRecord => SystemClassWithMembersAndTypesRecord.DecodeDecimal(reader, id),
+ _ => throw new InvalidOperationException()
+ };
+ }
+ else
+ {
+ throw new SerializationException(SR.Serialization_InvalidReference);
+ }
- return new ClassWithIdRecord(id, referencedRecord);
+ SerializationRecord Create(T value) where T : unmanaged
+ => new MemberPrimitiveTypedRecord(value, id);
}
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType()
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs
new file mode 100644
index 00000000000000..6ac97ef40675d6
--- /dev/null
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/JaggedArrayRecord.cs
@@ -0,0 +1,65 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.IO;
+using System.Reflection.Metadata;
+using System.Formats.Nrbf.Utils;
+using System.Diagnostics;
+using System.Runtime.Serialization;
+
+namespace System.Formats.Nrbf;
+
+///
+/// Represents an array of arrays.
+///
+///
+/// BinaryArray records are described in [MS-NRBF] 2.4.3.1 .
+///
+internal sealed class JaggedArrayRecord : ArrayRecord
+{
+ private readonly MemberTypeInfo _memberTypeInfo;
+ private readonly int[] _lengths;
+ private readonly List _records;
+ private readonly AllowedRecordTypes _allowedRecordTypes;
+ private TypeName? _typeName;
+
+ internal JaggedArrayRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo, int[] lengths)
+ : base(arrayInfo)
+ {
+ _memberTypeInfo = memberTypeInfo;
+ _lengths = lengths;
+ _records = [];
+ _allowedRecordTypes = memberTypeInfo.GetNextAllowedRecordType(0).allowed;
+
+ Debug.Assert(TypeName.GetElementType().IsArray, "Jagged arrays are required.");
+ }
+
+ public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
+
+ public override ReadOnlySpan Lengths => _lengths;
+
+ public override TypeName TypeName => _typeName ??= _memberTypeInfo.GetArrayTypeName(ArrayInfo);
+
+ [RequiresDynamicCode("May call Array.CreateInstance().")]
+ private protected override Array Deserialize(Type arrayType, bool allowNulls)
+ {
+ // This method returns arrays of ArrayRecords.
+ Array array = _lengths.Length switch
+ {
+ 1 => new ArrayRecord[_lengths[0]],
+ 2 => new ArrayRecord[_lengths[0], _lengths[1]],
+ _ => Array.CreateInstance(typeof(ArrayRecord), _lengths)
+ };
+
+ Populate(_records, array, _lengths, AllowedRecordTypes.Arrays, allowNulls);
+
+ return array;
+ }
+
+ private protected override void AddValue(object value) => _records.Add((SerializationRecord)value);
+
+ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
+ => (_allowedRecordTypes, default);
+}
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs
index 57e47a02eec688..84c1073b0ef67a 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs
@@ -86,10 +86,12 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
// Every class can be a null or a reference and a ClassWithId
const AllowedRecordTypes Classes = AllowedRecordTypes.ClassWithId
| AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference
- | AllowedRecordTypes.MemberPrimitiveTyped
| AllowedRecordTypes.BinaryLibrary; // Classes may be preceded with a library record (System too!)
// but System Classes can be expressed only by System records
- const AllowedRecordTypes SystemClass = Classes | AllowedRecordTypes.SystemClassWithMembersAndTypes;
+ const AllowedRecordTypes SystemClass = Classes | AllowedRecordTypes.SystemClassWithMembersAndTypes
+ // All primitive types can be stored by using one of the interfaces they implement.
+ // Example: `new IEnumerable[1] { "hello" }` or `new IComparable[1] { int.MaxValue }`.
+ | AllowedRecordTypes.BinaryObjectString | AllowedRecordTypes.MemberPrimitiveTyped;
const AllowedRecordTypes NonSystemClass = Classes | AllowedRecordTypes.ClassWithMembersAndTypes;
return binaryType switch
@@ -106,43 +108,6 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt
};
}
- internal bool ShouldBeRepresentedAsArrayOfClassRecords()
- {
- // This library tries to minimize the number of concepts the users need to learn to use it.
- // Since SZArrays are most common, it provides an SZArrayRecord abstraction.
- // Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
- // The goal of this method is to determine whether given array can be represented as SZArrayRecord.
-
- (BinaryType binaryType, object? additionalInfo) = Infos[0];
-
- if (binaryType == BinaryType.Class)
- {
- // An array of arrays can not be represented as SZArrayRecord.
- return !((ClassTypeInfo)additionalInfo!).TypeName.IsArray;
- }
- else if (binaryType == BinaryType.SystemClass)
- {
- TypeName typeName = (TypeName)additionalInfo!;
-
- // An array of arrays can not be represented as SZArrayRecord.
- if (typeName.IsArray)
- {
- return false;
- }
-
- if (!typeName.IsConstructedGenericType)
- {
- return true;
- }
-
- // Can't use SZArrayRecord for Nullable[]
- // as it consists of MemberPrimitiveTypedRecord and NullsRecord
- return typeName.GetGenericTypeDefinition().FullName != typeof(Nullable<>).FullName;
- }
-
- return false;
- }
-
internal TypeName GetArrayTypeName(ArrayInfo arrayInfo)
{
(BinaryType binaryType, object? additionalInfo) = Infos[0];
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs
index a315b37cff0234..65bb6e8beb9c5c 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs
@@ -9,6 +9,7 @@
using System.Text;
using System.Runtime.Serialization;
using System.Runtime.InteropServices;
+using System.Reflection.Metadata;
namespace System.Formats.Nrbf;
@@ -223,7 +224,7 @@ private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap rec
SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader),
SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader),
SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader),
- SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options),
+ SerializationRecordType.BinaryArray => DecodeBinaryArrayRecord(reader, recordMap, options),
SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options),
SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader),
SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap),
@@ -269,11 +270,16 @@ private static SerializationRecord DecodeMemberPrimitiveTypedRecord(BinaryReader
};
}
- private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader)
+ private static ArrayRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader)
{
ArrayInfo info = ArrayInfo.Decode(reader);
PrimitiveType primitiveType = reader.ReadPrimitiveType();
+ return DecodeArraySinglePrimitiveRecord(reader, info, primitiveType);
+ }
+
+ private static ArrayRecord DecodeArraySinglePrimitiveRecord(BinaryReader reader, ArrayInfo info, PrimitiveType primitiveType)
+ {
return primitiveType switch
{
PrimitiveType.Boolean => Decode(info, reader),
@@ -294,10 +300,171 @@ private static SerializationRecord DecodeArraySinglePrimitiveRecord(BinaryReader
_ => throw new InvalidOperationException()
};
- static SerializationRecord Decode(ArrayInfo info, BinaryReader reader) where T : unmanaged
+ static ArrayRecord Decode(ArrayInfo info, BinaryReader reader) where T : unmanaged
=> new ArraySinglePrimitiveRecord(info, ArraySinglePrimitiveRecord.DecodePrimitiveTypes(reader, info.GetSZArrayLength()));
}
+ private static ArrayRecord DecodeArrayRectangularPrimitiveRecord(PrimitiveType primitiveType, ArrayInfo info, int[] lengths, BinaryReader reader)
+ {
+ return primitiveType switch
+ {
+ PrimitiveType.Boolean => Decode(info, lengths, reader),
+ PrimitiveType.Byte => Decode(info, lengths, reader),
+ PrimitiveType.SByte => Decode(info, lengths, reader),
+ PrimitiveType.Char => Decode(info, lengths, reader),
+ PrimitiveType.Int16 => Decode(info, lengths, reader),
+ PrimitiveType.UInt16 => Decode(info, lengths, reader),
+ PrimitiveType.Int32 => Decode(info, lengths, reader),
+ PrimitiveType.UInt32 => Decode(info, lengths, reader),
+ PrimitiveType.Int64 => Decode(info, lengths, reader),
+ PrimitiveType.UInt64 => Decode(info, lengths, reader),
+ PrimitiveType.Single => Decode(info, lengths, reader),
+ PrimitiveType.Double => Decode(info, lengths, reader),
+ PrimitiveType.Decimal => Decode(info, lengths, reader),
+ PrimitiveType.DateTime => Decode(info, lengths, reader),
+ PrimitiveType.TimeSpan => Decode(info, lengths, reader),
+ _ => throw new InvalidOperationException()
+ };
+
+ static ArrayRecord Decode(ArrayInfo info, int[] lengths, BinaryReader reader) where T : unmanaged
+ {
+ // We limit the length of multi-dimensional array to max length of SZArray.
+ // Because of that, it's possible to re-use the same decoding logic for both MD and SZ arrays.
+ IReadOnlyList values = ArraySinglePrimitiveRecord.DecodePrimitiveTypes(reader, info.GetSZArrayLength());
+ return new ArrayRectangularPrimitiveRecord(info, lengths, values);
+ }
+ }
+
+ private static ArrayRecord DecodeBinaryArrayRecord(BinaryReader reader, RecordMap recordMap, PayloadOptions options)
+ {
+ SerializationRecordId objectId = SerializationRecordId.Decode(reader);
+ BinaryArrayType arrayType = reader.ReadArrayType();
+ int rank = reader.ReadInt32();
+
+ bool isRectangular = arrayType is BinaryArrayType.Rectangular;
+
+ // It is an arbitrary limit in the current CoreCLR type loader.
+ // Don't change this value without reviewing the loop a few lines below.
+ const int MaxSupportedArrayRank = 32;
+
+ if (rank < 1 || rank > MaxSupportedArrayRank
+ || (rank != 1 && !isRectangular)
+ || (rank == 1 && isRectangular))
+ {
+ ThrowHelper.ThrowInvalidValue(rank);
+ }
+
+ int[] lengths = new int[rank]; // adversary-controlled, but acceptable since upper limit of 32
+ long totalElementCount = 1; // to avoid integer overflow during the multiplication below
+ for (int i = 0; i < lengths.Length; i++)
+ {
+ lengths[i] = ArrayInfo.ParseValidArrayLength(reader);
+ totalElementCount *= lengths[i];
+
+ // n.b. This forbids "new T[Array.MaxLength, Array.MaxLength, Array.MaxLength, ..., 0]"
+ // but allows "new T[0, Array.MaxLength, Array.MaxLength, Array.MaxLength, ...]". But
+ // that's the same behavior that newarr and Array.CreateInstance exhibit, so at least
+ // we're consistent.
+
+ if (totalElementCount > ArrayInfo.MaxArrayLength)
+ {
+ ThrowHelper.ThrowInvalidValue(lengths[i]); // max array size exceeded
+ }
+ }
+
+ // Per BinaryReaderExtensions.ReadArrayType, we do not support nonzero offsets, so
+ // we don't need to read the NRBF stream 'LowerBounds' field here.
+
+ MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, 1, options, recordMap);
+ ArrayInfo arrayInfo = new(objectId, totalElementCount, arrayType, rank);
+
+ (BinaryType binaryType, object? additionalInfo) = memberTypeInfo.Infos[0];
+ if (arrayType == BinaryArrayType.Rectangular)
+ {
+ if (binaryType == BinaryType.Primitive)
+ {
+ return DecodeArrayRectangularPrimitiveRecord((PrimitiveType)additionalInfo!, arrayInfo, lengths, reader);
+ }
+ else if (binaryType == BinaryType.String)
+ {
+ return new RectangularArrayRecord(typeof(string), arrayInfo, memberTypeInfo, lengths);
+ }
+ else if (binaryType == BinaryType.Object)
+ {
+ return new RectangularArrayRecord(typeof(SerializationRecord), arrayInfo, memberTypeInfo, lengths);
+ }
+ else if (binaryType is BinaryType.SystemClass or BinaryType.Class)
+ {
+ TypeName typeName = binaryType == BinaryType.SystemClass ? (TypeName)additionalInfo! : ((ClassTypeInfo)additionalInfo!).TypeName;
+ // BinaryArrayType.Rectangular can be also a jagged array.
+ return typeName.IsArray
+ ? new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths)
+ : new RectangularArrayRecord(typeof(SerializationRecord), arrayInfo, memberTypeInfo, lengths);
+ }
+ else if (binaryType is BinaryType.PrimitiveArray or BinaryType.StringArray or BinaryType.ObjectArray)
+ {
+ // A multi-dimensional array of single dimensional arrays. Example: int[][,]
+ return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths);
+ }
+ }
+ else if (arrayType == BinaryArrayType.Single)
+ {
+ if (binaryType is BinaryType.SystemClass or BinaryType.Class)
+ {
+ TypeName typeName = binaryType == BinaryType.SystemClass ? (TypeName)additionalInfo! : ((ClassTypeInfo)additionalInfo!).TypeName;
+ // BinaryArrayType.Single that describes an array is just a jagged array.
+ return typeName.IsArray
+ ? new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths)
+ : new SZArrayOfRecords(arrayInfo, memberTypeInfo);
+ }
+ else if (binaryType == BinaryType.String)
+ {
+ // BinaryArrayRecord can represent string[] (but BF always uses ArraySingleStringRecord for that).
+ return new ArraySingleStringRecord(arrayInfo);
+ }
+ else if (binaryType == BinaryType.Primitive)
+ {
+ // BinaryArrayRecord can represent Primitive[] (but BF always uses ArraySinglePrimitiveRecord for that).
+ return DecodeArraySinglePrimitiveRecord(reader, arrayInfo, (PrimitiveType)additionalInfo!);
+ }
+ else if (binaryType == BinaryType.Object)
+ {
+ // BinaryArrayRecord can represent object[] (but BF always uses ArraySingleObjectRecord for that).
+ return new ArraySingleObjectRecord(arrayInfo);
+ }
+ else if (binaryType is BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.PrimitiveArray)
+ {
+ // It's a Jagged array that does not use BinaryArrayType.Jagged.
+ return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths);
+ }
+ }
+ else if (arrayType == BinaryArrayType.Jagged)
+ {
+ if (binaryType is BinaryType.ObjectArray or BinaryType.StringArray or BinaryType.PrimitiveArray)
+ {
+ // It's a Jagged array that does not use BinaryArrayType.Jagged.
+ return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths);
+ }
+ else if (binaryType == BinaryType.SystemClass && ((TypeName)additionalInfo!).IsArray)
+ {
+ // BinaryType.SystemClass can be used to describe arrays of system class records.
+ // Example: new Exception[] { new Exception("test") };
+ return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths);
+ }
+ else if (binaryType == BinaryType.Class && ((ClassTypeInfo)additionalInfo!).TypeName.IsArray)
+ {
+ // BinaryType.Class can be used to describe arrays of class records.
+ // Example: new MyCustomType[] { new MyCustomType(0) };
+ return new JaggedArrayRecord(arrayInfo, memberTypeInfo, lengths);
+ }
+
+ // It's invalid, the element type must be an array.
+ throw new SerializationException(SR.Format(SR.Serialization_InvalidValue, binaryType));
+ }
+
+ throw new InvalidOperationException();
+ }
+
///
/// This method is responsible for pushing only the FIRST read info
/// of the NESTED record into the .
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs
index eafcbf93249c57..dd5862c7b2b862 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs
@@ -61,18 +61,7 @@ internal void Add(SerializationRecord record)
}
}
- internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header)
- {
- SerializationRecord rootRecord = GetRecord(header.RootId);
-
- if (rootRecord is SystemClassWithMembersAndTypesRecord systemClass)
- {
- // update the record map, so it's visible also to those who access it via Id
- _map[header.RootId] = rootRecord = systemClass.TryToMapToUserFriendly();
- }
-
- return rootRecord;
- }
+ internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) => GetRecord(header.RootId);
internal SerializationRecord GetRecord(SerializationRecordId recordId)
=> _map.TryGetValue(recordId, out SerializationRecord? record)
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs
index f64dde36163d69..f10bc3f51efdae 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RectangularArrayRecord.cs
@@ -9,24 +9,29 @@
using System.Runtime.InteropServices;
using System.Formats.Nrbf.Utils;
using System.Diagnostics;
+using System.Runtime.Serialization;
namespace System.Formats.Nrbf;
internal sealed class RectangularArrayRecord : ArrayRecord
{
+ private readonly Type _elementType;
private readonly int[] _lengths;
- private readonly List _values;
+ private readonly List _records;
+ private readonly AllowedRecordTypes _allowedRecordTypes;
+ private readonly MemberTypeInfo _memberTypeInfo;
private TypeName? _typeName;
- private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
- MemberTypeInfo memberTypeInfo, int[] lengths, bool canPreAllocate) : base(arrayInfo)
+ internal RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo, int[] lengths) : base(arrayInfo)
{
- ElementType = elementType;
- MemberTypeInfo = memberTypeInfo;
+ _elementType = elementType;
_lengths = lengths;
+ _memberTypeInfo = memberTypeInfo;
+ _records = new List(Math.Min(4, arrayInfo.GetSZArrayLength()));
+ _allowedRecordTypes = memberTypeInfo.GetNextAllowedRecordType(0).allowed;
- // ArrayInfo.GetSZArrayLength ensures to return a value <= Array.MaxLength
- _values = new List(canPreAllocate ? arrayInfo.GetSZArrayLength() : Math.Min(4, arrayInfo.GetSZArrayLength()));
+ Debug.Assert(elementType == typeof(string) || elementType == typeof(SerializationRecord));
+ Debug.Assert(!TypeName.GetElementType().IsArray, "Use JaggedArrayRecord instead.");
}
public override SerializationRecordType RecordType => SerializationRecordType.BinaryArray;
@@ -34,230 +39,32 @@ private RectangularArrayRecord(Type elementType, ArrayInfo arrayInfo,
public override ReadOnlySpan Lengths => _lengths.AsSpan();
public override TypeName TypeName
- => _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);
-
- private Type ElementType { get; }
-
- private MemberTypeInfo MemberTypeInfo { get; }
+ => _typeName ??= _memberTypeInfo.GetArrayTypeName(ArrayInfo);
[RequiresDynamicCode("May call Array.CreateInstance() and Type.MakeArrayType().")]
private protected override Array Deserialize(Type arrayType, bool allowNulls)
{
- // We can not deserialize non-primitive types.
- // This method returns arrays of ClassRecord for arrays of complex types.
+ bool storeStrings = _elementType == typeof(string);
+
+ // We can not deserialize non-string types.
+ // This method returns arrays of SerializationRecord for arrays of complex types.
Array result =
#if NET9_0_OR_GREATER
- ElementType == typeof(ClassRecord)
- ? Array.CreateInstance(ElementType, _lengths)
- : Array.CreateInstanceFromArrayType(arrayType, _lengths);
+ storeStrings
+ ? Array.CreateInstanceFromArrayType(arrayType, _lengths)
+ : Array.CreateInstance(_elementType, _lengths);
#else
- Array.CreateInstance(ElementType, _lengths);
+ Array.CreateInstance(_elementType, _lengths);
#endif
-#if !NET8_0_OR_GREATER
- int[] indices = new int[_lengths.Length];
- nuint numElementsWritten = 0; // only for debugging; not used in release builds
-
- foreach (object value in _values)
- {
- result.SetValue(GetActualValue(value), indices);
- numElementsWritten++;
-
- int dimension = indices.Length - 1;
- while (dimension >= 0)
- {
- indices[dimension]++;
- if (indices[dimension] < Lengths[dimension])
- {
- break;
- }
- indices[dimension] = 0;
- dimension--;
- }
-
- if (dimension < 0)
- {
- break;
- }
- }
-
- Debug.Assert(numElementsWritten == (uint)_values.Count, "We should have traversed the entirety of the source values collection.");
- Debug.Assert(numElementsWritten == (ulong)result.LongLength, "We should have traversed the entirety of the destination array.");
+ AllowedRecordTypes allowedRecordTypes = storeStrings ? AllowedRecordTypes.BinaryObjectString : AllowedRecordTypes.AnyObject;
+ Populate(_records, result, _lengths, allowedRecordTypes, allowNulls);
return result;
-#else
- // Idea from Array.CoreCLR that maps an array of int indices into
- // an internal flat index.
- if (ElementType.IsValueType)
- {
- if (ElementType == typeof(bool)) CopyTo(_values, result);
- else if (ElementType == typeof(byte)) CopyTo(_values, result);
- else if (ElementType == typeof(sbyte)) CopyTo(_values, result);
- else if (ElementType == typeof(short)) CopyTo(_values, result);
- else if (ElementType == typeof(ushort)) CopyTo(_values, result);
- else if (ElementType == typeof(char)) CopyTo(_values, result);
- else if (ElementType == typeof(int)) CopyTo(_values, result);
- else if (ElementType == typeof(float)) CopyTo(_values, result);
- else if (ElementType == typeof(long)) CopyTo(_values, result);
- else if (ElementType == typeof(ulong)) CopyTo(_values, result);
- else if (ElementType == typeof(double)) CopyTo(_values, result);
- else if (ElementType == typeof(TimeSpan)) CopyTo(_values, result);
- else if (ElementType == typeof(DateTime)) CopyTo(_values, result);
- else if (ElementType == typeof(decimal)) CopyTo(_values, result);
- else throw new InvalidOperationException();
- }
- else
- {
- CopyTo(_values, result);
- }
-
- return result;
-
- static void CopyTo(List list, Array array)
- {
- ref byte arrayDataRef = ref MemoryMarshal.GetArrayDataReference(array);
- ref T firstElementRef = ref Unsafe.As(ref arrayDataRef);
- nuint flattenedIndex = 0;
- foreach (object value in list)
- {
- ref T targetElement = ref Unsafe.Add(ref firstElementRef, flattenedIndex);
- targetElement = (T)GetActualValue(value)!;
- flattenedIndex++;
- }
-
- Debug.Assert(flattenedIndex == (ulong)array.LongLength, "We should have traversed the entirety of the array.");
- }
-#endif
}
- private protected override void AddValue(object value) => _values.Add(value);
+ private protected override void AddValue(object value) => _records.Add((SerializationRecord)value);
internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType()
- {
- (AllowedRecordTypes allowed, PrimitiveType primitiveType) = MemberTypeInfo.GetNextAllowedRecordType(0);
-
- if (allowed != AllowedRecordTypes.None)
- {
- // It's an array, it can also contain multiple nulls
- return (allowed | AllowedRecordTypes.Nulls, primitiveType);
- }
-
- return (allowed, primitiveType);
- }
-
- internal static RectangularArrayRecord Create(BinaryReader reader, ArrayInfo arrayInfo,
- MemberTypeInfo memberTypeInfo, int[] lengths)
- {
- BinaryType binaryType = memberTypeInfo.Infos[0].BinaryType;
- Type elementType = binaryType switch
- {
- BinaryType.Primitive => MapPrimitive((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!),
- BinaryType.PrimitiveArray => MapPrimitiveArray((PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!),
- BinaryType.String => typeof(string),
- BinaryType.Object => typeof(object),
- _ => typeof(ClassRecord)
- };
-
- bool canPreAllocate = false;
- if (binaryType == BinaryType.Primitive)
- {
- int sizeOfSingleValue = (PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo! switch
- {
- PrimitiveType.Boolean => sizeof(bool),
- PrimitiveType.Byte => sizeof(byte),
- PrimitiveType.SByte => sizeof(sbyte),
- PrimitiveType.Char => sizeof(byte), // it's UTF8 (see comment below)
- PrimitiveType.Int16 => sizeof(short),
- PrimitiveType.UInt16 => sizeof(ushort),
- PrimitiveType.Int32 => sizeof(int),
- PrimitiveType.UInt32 => sizeof(uint),
- PrimitiveType.Single => sizeof(float),
- PrimitiveType.Int64 => sizeof(long),
- PrimitiveType.UInt64 => sizeof(ulong),
- PrimitiveType.Double => sizeof(double),
- PrimitiveType.TimeSpan => sizeof(ulong),
- PrimitiveType.DateTime => sizeof(ulong),
- PrimitiveType.Decimal => -1, // represented as variable-length string
- _ => throw new InvalidOperationException()
- };
-
- if (sizeOfSingleValue > 0)
- {
- // NRBF encodes rectangular char[,,,...] by converting each standalone UTF-16 code point into
- // its UTF-8 encoding. This means that surrogate code points (including adjacent surrogate
- // pairs) occurring within a char[,,,...] cannot be encoded by NRBF. BinaryReader will detect
- // that they're ill-formed and reject them on read.
- //
- // Per the comment in ArraySinglePrimitiveRecord.DecodePrimitiveTypes, we'll assume best-case
- // encoding where 1 UTF-16 char encodes as a single UTF-8 byte, even though this might lead
- // to encountering an EOF if we realize later that we actually need to read more bytes in
- // order to fully populate the char[,,,...] array. Any such allocation is still linearly
- // proportional to the length of the incoming payload, so it's not a DoS vector.
- // The multiplication below is guaranteed not to overflow because FlattenedLength is bounded
- // to <= Array.MaxLength (see BinaryArrayRecord.Decode) and sizeOfSingleValue is at most 8.
- Debug.Assert(arrayInfo.FlattenedLength >= 0 && arrayInfo.FlattenedLength <= long.MaxValue / sizeOfSingleValue);
-
- long size = arrayInfo.FlattenedLength * sizeOfSingleValue;
- bool? isDataAvailable = reader.IsDataAvailable(size);
- if (isDataAvailable.HasValue)
- {
- if (!isDataAvailable.Value)
- {
- ThrowHelper.ThrowEndOfStreamException();
- }
-
- canPreAllocate = true;
- }
- }
- }
-
- return new RectangularArrayRecord(elementType, arrayInfo, memberTypeInfo, lengths, canPreAllocate);
- }
-
- private static Type MapPrimitive(PrimitiveType primitiveType)
- => primitiveType switch
- {
- PrimitiveType.Boolean => typeof(bool),
- PrimitiveType.Byte => typeof(byte),
- PrimitiveType.Char => typeof(char),
- PrimitiveType.Decimal => typeof(decimal),
- PrimitiveType.Double => typeof(double),
- PrimitiveType.Int16 => typeof(short),
- PrimitiveType.Int32 => typeof(int),
- PrimitiveType.Int64 => typeof(long),
- PrimitiveType.SByte => typeof(sbyte),
- PrimitiveType.Single => typeof(float),
- PrimitiveType.TimeSpan => typeof(TimeSpan),
- PrimitiveType.DateTime => typeof(DateTime),
- PrimitiveType.UInt16 => typeof(ushort),
- PrimitiveType.UInt32 => typeof(uint),
- PrimitiveType.UInt64 => typeof(ulong),
- _ => throw new InvalidOperationException()
- };
-
- private static Type MapPrimitiveArray(PrimitiveType primitiveType)
- => primitiveType switch
- {
- PrimitiveType.Boolean => typeof(bool[]),
- PrimitiveType.Byte => typeof(byte[]),
- PrimitiveType.Char => typeof(char[]),
- PrimitiveType.Decimal => typeof(decimal[]),
- PrimitiveType.Double => typeof(double[]),
- PrimitiveType.Int16 => typeof(short[]),
- PrimitiveType.Int32 => typeof(int[]),
- PrimitiveType.Int64 => typeof(long[]),
- PrimitiveType.SByte => typeof(sbyte[]),
- PrimitiveType.Single => typeof(float[]),
- PrimitiveType.TimeSpan => typeof(TimeSpan[]),
- PrimitiveType.DateTime => typeof(DateTime[]),
- PrimitiveType.UInt16 => typeof(ushort[]),
- PrimitiveType.UInt32 => typeof(uint[]),
- PrimitiveType.UInt64 => typeof(ulong[]),
- _ => throw new InvalidOperationException()
- };
-
- private static object? GetActualValue(object value)
- => value is SerializationRecord serializationRecord
- ? serializationRecord.GetValue()
- : value; // it must be a primitive type
+ => (_allowedRecordTypes, default);
}
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs
similarity index 69%
rename from src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs
rename to src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs
index f345292c693a61..b77a4a57a2a348 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArrayOfClassesRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SZArrayOfRecords.cs
@@ -8,11 +8,15 @@
namespace System.Formats.Nrbf;
-internal sealed class ArrayOfClassesRecord : SZArrayRecord
+// This library tries to minimize the number of concepts the users need to learn to use it.
+// Since SZArrays are most common, it provides an SZArrayRecord abstraction.
+// Every other array (jagged, multi-dimensional etc) is represented using ArrayRecord.
+// The goal of this class is to let the users use SZArrayRecord abstraction.
+internal sealed class SZArrayOfRecords : SZArrayRecord
{
private TypeName? _typeName;
- internal ArrayOfClassesRecord(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
+ internal SZArrayOfRecords(ArrayInfo arrayInfo, MemberTypeInfo memberTypeInfo)
: base(arrayInfo)
{
MemberTypeInfo = memberTypeInfo;
@@ -29,12 +33,12 @@ public override TypeName TypeName
=> _typeName ??= MemberTypeInfo.GetArrayTypeName(ArrayInfo);
///
- public override ClassRecord?[] GetArray(bool allowNulls = true)
- => (ClassRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));
+ public override SerializationRecord?[] GetArray(bool allowNulls = true)
+ => (SerializationRecord?[])(allowNulls ? _arrayNullsAllowed ??= ToArray(true) : _arrayNullsNotAllowed ??= ToArray(false));
- private ClassRecord?[] ToArray(bool allowNulls)
+ private SerializationRecord?[] ToArray(bool allowNulls)
{
- ClassRecord?[] result = new ClassRecord?[Length];
+ SerializationRecord?[] result = new SerializationRecord?[Length];
int resultIndex = 0;
foreach (SerializationRecord record in Records)
@@ -43,9 +47,9 @@ public override TypeName TypeName
? referenceRecord.GetReferencedRecord()
: record;
- if (actual is ClassRecord classRecord)
+ if (actual is not NullsRecord nullsRecord)
{
- result[resultIndex++] = classRecord;
+ result[resultIndex++] = actual;
}
else
{
@@ -54,7 +58,7 @@ public override TypeName TypeName
ThrowHelper.ThrowArrayContainedNulls();
}
- int nullCount = ((NullsRecord)actual).NullCount;
+ int nullCount = nullsRecord.NullCount;
Debug.Assert(nullCount > 0, "All implementations of NullsRecord are expected to return a positive value for NullCount.");
do
{
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs
index ccecc2246e8c22..0c5193cd92272a 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SystemClassWithMembersAndTypesRecord.cs
@@ -3,6 +3,7 @@
using System.IO;
using System.Formats.Nrbf.Utils;
+using System.Reflection.Metadata;
namespace System.Formats.Nrbf;
@@ -21,92 +22,100 @@ private SystemClassWithMembersAndTypesRecord(ClassInfo classInfo, MemberTypeInfo
public override SerializationRecordType RecordType => SerializationRecordType.SystemClassWithMembersAndTypes;
- internal static SystemClassWithMembersAndTypesRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options)
+ internal static SerializationRecord Decode(BinaryReader reader, RecordMap recordMap, PayloadOptions options)
{
ClassInfo classInfo = ClassInfo.Decode(reader);
MemberTypeInfo memberTypeInfo = MemberTypeInfo.Decode(reader, classInfo.MemberNames.Count, options, recordMap);
// the only difference with ClassWithMembersAndTypesRecord is that we don't read library id here
classInfo.LoadTypeName(options);
- return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo);
- }
+ TypeName typeName = classInfo.TypeName;
- internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType()
- => MemberTypeInfo.GetNextAllowedRecordType(MemberValues.Count);
+ // BinaryFormatter represents primitive types as MemberPrimitiveTypedRecord
+ // only for arrays of objects. For other arrays, like arrays of some abstraction
+ // (example: new IComparable[] { int.MaxValue }), it uses SystemClassWithMembersAndTypes.
+ // The same goes for root records that turn out to be primitive types.
+ // We want to have the behavior unified, so we map such records to
+ // PrimitiveTypeRecord so the users don't need to learn the BF internals
+ // to get a single primitive value.
+ // We need to be as strict as possible, as we don't want to map anything else by accident.
+ // That is why the code below is VERY defensive.
- // For the root records that turn out to be primitive types, we map them to
- // PrimitiveTypeRecord so the users don't need to learn the BF internals
- // to get a single primitive value!
- internal SerializationRecord TryToMapToUserFriendly()
- {
- if (!TypeName.IsSimple)
+ if (!classInfo.TypeName.IsSimple || classInfo.MemberNames.Count == 0 || memberTypeInfo.Infos[0].BinaryType != BinaryType.Primitive)
{
- return this;
+ return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo);
}
-
- if (MemberValues.Count == 1)
+ else if (classInfo.MemberNames.Count == 1)
{
- if (HasMember("m_value"))
+ PrimitiveType primitiveType = (PrimitiveType)memberTypeInfo.Infos[0].AdditionalInfo!;
+ // Get the member name without allocating on the heap.
+ Collections.Generic.Dictionary.Enumerator structEnumerator = classInfo.MemberNames.GetEnumerator();
+ _ = structEnumerator.MoveNext();
+ string memberName = structEnumerator.Current.Key;
+ // Everything needs to match: primitive type, type name name and member name.
+ return (primitiveType, typeName.FullName, memberName) switch
{
- return MemberValues[0] switch
- {
- // there can be a value match, but no TypeName match
- bool value when TypeNameMatches(typeof(bool)) => Create(value),
- byte value when TypeNameMatches(typeof(byte)) => Create(value),
- sbyte value when TypeNameMatches(typeof(sbyte)) => Create(value),
- char value when TypeNameMatches(typeof(char)) => Create(value),
- short value when TypeNameMatches(typeof(short)) => Create(value),
- ushort value when TypeNameMatches(typeof(ushort)) => Create(value),
- int value when TypeNameMatches(typeof(int)) => Create(value),
- uint value when TypeNameMatches(typeof(uint)) => Create(value),
- long value when TypeNameMatches(typeof(long)) => Create(value),
- ulong value when TypeNameMatches(typeof(ulong)) => Create(value),
- float value when TypeNameMatches(typeof(float)) => Create(value),
- double value when TypeNameMatches(typeof(double)) => Create(value),
- _ => this
- };
- }
- else if (HasMember("value"))
- {
- return MemberValues[0] switch
- {
- // there can be a value match, but no TypeName match
- long value when TypeNameMatches(typeof(IntPtr)) => Create(new IntPtr(value)),
- ulong value when TypeNameMatches(typeof(UIntPtr)) => Create(new UIntPtr(value)),
- _ => this
- };
- }
- else if (HasMember("_ticks") && GetRawValue("_ticks") is long ticks && TypeNameMatches(typeof(TimeSpan)))
- {
- return Create(new TimeSpan(ticks));
- }
+ (PrimitiveType.Boolean, "System.Boolean", "m_value") => Create(reader.ReadBoolean()),
+ (PrimitiveType.Byte, "System.Byte", "m_value") => Create(reader.ReadByte()),
+ (PrimitiveType.SByte, "System.SByte", "m_value") => Create(reader.ReadSByte()),
+ (PrimitiveType.Char, "System.Char", "m_value") => Create(reader.ParseChar()),
+ (PrimitiveType.Int16, "System.Int16", "m_value") => Create(reader.ReadInt16()),
+ (PrimitiveType.UInt16, "System.UInt16", "m_value") => Create(reader.ReadUInt16()),
+ (PrimitiveType.Int32, "System.Int32", "m_value") => Create(reader.ReadInt32()),
+ (PrimitiveType.UInt32, "System.UInt32", "m_value") => Create(reader.ReadUInt32()),
+ (PrimitiveType.Int64, "System.Int64", "m_value") => Create(reader.ReadInt64()),
+ (PrimitiveType.Int64, "System.IntPtr", "value") => Create(new IntPtr(reader.ReadInt64())),
+ (PrimitiveType.Int64, "System.TimeSpan", "_ticks") => Create(new TimeSpan(reader.ReadInt64())),
+ (PrimitiveType.UInt64, "System.UInt64", "m_value") => Create(reader.ReadUInt64()),
+ (PrimitiveType.UInt64, "System.UIntPtr", "value") => Create(new UIntPtr(reader.ReadUInt64())),
+ (PrimitiveType.Single, "System.Single", "m_value") => Create(reader.ReadSingle()),
+ (PrimitiveType.Double, "System.Double", "m_value") => Create(reader.ReadDouble()),
+ _ => new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo)
+ };
}
- else if (MemberValues.Count == 2
- && HasMember("ticks") && HasMember("dateData")
- && GetRawValue("ticks") is long && GetRawValue("dateData") is ulong dateData
- && TypeNameMatches(typeof(DateTime)))
+ else if (classInfo.MemberNames.Count == 2 && typeName.FullName == "System.DateTime"
+ && HasMember("ticks", 0, PrimitiveType.Int64)
+ && HasMember("dateData", 1, PrimitiveType.UInt64))
{
- return Create(Utils.BinaryReaderExtensions.CreateDateTimeFromData(dateData));
+ return DecodeDateTime(reader, classInfo.Id);
}
- else if (MemberValues.Count == 4
- && HasMember("lo") && HasMember("mid") && HasMember("hi") && HasMember("flags")
- && GetRawValue("lo") is int lo && GetRawValue("mid") is int mid
- && GetRawValue("hi") is int hi && GetRawValue("flags") is int flags
- && TypeNameMatches(typeof(decimal)))
+ else if (classInfo.MemberNames.Count == 4 && typeName.FullName == "System.Decimal"
+ && HasMember("flags", 0, PrimitiveType.Int32)
+ && HasMember("hi", 1, PrimitiveType.Int32)
+ && HasMember("lo", 2, PrimitiveType.Int32)
+ && HasMember("mid", 3, PrimitiveType.Int32))
{
- int[] bits =
- [
- lo,
- mid,
- hi,
- flags
- ];
-
- return Create(new decimal(bits));
+ return DecodeDecimal(reader, classInfo.Id);
}
- return this;
+ return new SystemClassWithMembersAndTypesRecord(classInfo, memberTypeInfo);
SerializationRecord Create(T value) where T : unmanaged
- => new MemberPrimitiveTypedRecord(value, Id);
+ => new MemberPrimitiveTypedRecord(value, classInfo.Id);
+
+ bool HasMember(string name, int order, PrimitiveType primitiveType)
+ => classInfo.MemberNames.TryGetValue(name, out int memberOrder)
+ && memberOrder == order
+ && ((PrimitiveType)memberTypeInfo.Infos[order].AdditionalInfo!) == primitiveType;
+ }
+
+ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetNextAllowedRecordType()
+ => MemberTypeInfo.GetNextAllowedRecordType(MemberValues.Count);
+
+ internal static MemberPrimitiveTypedRecord DecodeDateTime(BinaryReader reader, SerializationRecordId id)
+ {
+ _ = reader.ReadInt64(); // ticks are not used, but they need to be read as they go first in the payload
+ ulong dateData = reader.ReadUInt64();
+
+ return new MemberPrimitiveTypedRecord(BinaryReaderExtensions.CreateDateTimeFromData(dateData), id);
+ }
+
+ internal static MemberPrimitiveTypedRecord DecodeDecimal(BinaryReader reader, SerializationRecordId id)
+ {
+ int flags = reader.ReadInt32();
+ int hi = reader.ReadInt32();
+ int lo = reader.ReadInt32();
+ int mid = reader.ReadInt32();
+
+ return new MemberPrimitiveTypedRecord(new decimal([lo, mid, hi, flags]), id);
}
}
diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs
index d5baa09dbd8fc4..8bb3ac3a1107bd 100644
--- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs
+++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/BinaryReaderExtensions.cs
@@ -33,7 +33,7 @@ internal static BinaryArrayType ReadArrayType(this BinaryReader reader)
{
// To simplify the behavior and security review of the BinaryArrayRecord type, we
// do not support reading non-zero-offset arrays. If this should change in the
- // future, the BinaryArrayRecord.Decode method and supporting infrastructure
+ // future, the NrbfDecoder.DecodeBinaryArrayRecord method and supporting infrastructure
// will need re-review.
byte arrayType = reader.ReadByte();
diff --git a/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs
new file mode 100644
index 00000000000000..18e39a5fd68e1f
--- /dev/null
+++ b/src/libraries/System.Formats.Nrbf/tests/ArrayOfSerializationRecordsTests.cs
@@ -0,0 +1,516 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Runtime.Serialization;
+using Microsoft.DotNet.XUnitExtensions;
+using Xunit;
+
+namespace System.Formats.Nrbf.Tests
+{
+ public class ArrayOfSerializationRecordsTests : ReadTests
+ {
+ public enum ElementType
+ {
+ Object,
+ NonGeneric,
+ Generic
+ }
+
+ [Serializable]
+ public class CustomClassThatImplementsIEnumerable : IEnumerable
+ {
+ public int Field;
+
+ public IEnumerator GetEnumerator() => Array.Empty().GetEnumerator();
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsStringRecord_SZ(ElementType elementType)
+ {
+ const string Text = "hello";
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[] { Text },
+ ElementType.NonGeneric => new IEnumerable[] { Text },
+ ElementType.Generic => new IEnumerable[] { Text },
+ _ => throw new InvalidOperationException()
+ };
+
+ SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[] output = arrayRecord.GetArray();
+
+ Verify(input, arrayRecord, output, recordMap);
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output.Single();
+ Assert.Equal(Text, stringRecord.Value);
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsStringRecord_MD(ElementType elementType)
+ {
+ const string Text = "hello";
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1, 1],
+ ElementType.NonGeneric => new IEnumerable[1, 1],
+ ElementType.Generic => new IEnumerable[1, 1],
+ _ => throw new InvalidOperationException()
+ };
+ input.SetValue(Text, 0, 0);
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, 0];
+ Assert.Equal(Text, stringRecord.Value);
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsStringRecord_Jagged(ElementType elementType)
+ {
+ const string Text = "hello";
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1][] { [Text] },
+ ElementType.NonGeneric => new IEnumerable[1][] { [Text] },
+ ElementType.Generic => new IEnumerable[1][] { [Text] },
+ _ => throw new InvalidOperationException()
+ };
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+
+ SZArrayRecord contained = (SZArrayRecord)output.Single();
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)contained.GetArray().Single();
+ Assert.Equal(Text, stringRecord.Value);
+ }
+
+ [ConditionalTheory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_SZ(ElementType elementType)
+ {
+ if (elementType != ElementType.Object && !IsPatched)
+ {
+ throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix.");
+ }
+
+ const int Integer = 123;
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[] { Integer },
+ ElementType.NonGeneric => new IComparable[] { Integer },
+ ElementType.Generic => new IComparable[] { Integer },
+ _ => throw new InvalidOperationException()
+ };
+
+ SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[] output = arrayRecord.GetArray();
+
+ Verify(input, arrayRecord, output, recordMap);
+ PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)output.Single();
+ Assert.Equal(Integer, intRecord.Value);
+ }
+
+ [ConditionalTheory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_MD(ElementType elementType)
+ {
+ if (elementType != ElementType.Object && !IsPatched)
+ {
+ throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix.");
+ }
+
+ const int Integer = 123;
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1, 1],
+ ElementType.NonGeneric => new IComparable[1, 1],
+ ElementType.Generic => new IComparable[1, 1],
+ _ => throw new InvalidOperationException()
+ };
+ input.SetValue(Integer, 0, 0);
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)output[0, 0];
+ Assert.Equal(Integer, intRecord.Value);
+ }
+
+ [ConditionalTheory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsMemberPrimitiveTypedRecord_Jagged(ElementType elementType)
+ {
+ if (elementType != ElementType.Object && !IsPatched)
+ {
+ throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix.");
+ }
+
+ const int Integer = 123;
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1][] { [Integer] },
+ ElementType.NonGeneric => new IComparable[1][] { [Integer] },
+ ElementType.Generic => new IComparable[1][] { [Integer] },
+ _ => throw new InvalidOperationException()
+ };
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ SZArrayRecord contained = (SZArrayRecord)output.Single();
+ PrimitiveTypeRecord intRecord = (PrimitiveTypeRecord)contained.GetArray().Single();
+ Assert.Equal(Integer, intRecord.Value);
+ }
+
+ public static IEnumerable NullAndArrayPermutations()
+ {
+ foreach (ElementType elementType in Enum.GetValues(typeof(ElementType)))
+ {
+ yield return new object[] { elementType, 1 }; // ObjectNullRecord
+ yield return new object[] { elementType, 200 }; // ObjectNullMultiple256Record
+ yield return new object[] { elementType, 1_000 }; // ObjectNullMultipleRecord
+ }
+ }
+
+ [Theory]
+ [MemberData(nameof(NullAndArrayPermutations))]
+ public void CanReadArrayThatContainsNullRecords_SZ(ElementType elementType, int nullCount)
+ {
+ const string Text = "notNull";
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[nullCount + 1],
+ ElementType.NonGeneric => new IEnumerable[nullCount + 1],
+ ElementType.Generic => new IEnumerable[nullCount + 1],
+ _ => throw new InvalidOperationException()
+ };
+ input.SetValue(Text, nullCount);
+
+ SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord?[] output = arrayRecord.GetArray();
+
+ Verify(input, arrayRecord, output, recordMap);
+ Assert.All(output.Take(nullCount), Assert.Null);
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[nullCount];
+ Assert.Equal(Text, stringRecord.Value);
+ }
+
+ [Theory]
+ [MemberData(nameof(NullAndArrayPermutations))]
+ public void CanReadArrayThatContainsNullRecords_MD(ElementType elementType, int nullCount)
+ {
+ const string Text = "notNull";
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1, nullCount + 1],
+ ElementType.NonGeneric => new IEnumerable[1, nullCount + 1],
+ ElementType.Generic => new IEnumerable[1, nullCount + 1],
+ _ => throw new InvalidOperationException()
+ };
+ input.SetValue(Text, 0, nullCount);
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ for (int i = 0; i < nullCount; i++)
+ {
+ Assert.Null(output[0, i]);
+ }
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, nullCount];
+ Assert.Equal(Text, stringRecord.Value);
+ }
+
+ [Theory]
+ [MemberData(nameof(NullAndArrayPermutations))]
+ public void CanReadArrayThatContainsNullRecords_Jagged(ElementType elementType, int nullCount)
+ {
+ const string Text = "notNull";
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1][] { new object[nullCount + 1] },
+ ElementType.NonGeneric => new IEnumerable[1][] { new IEnumerable[nullCount + 1] },
+ ElementType.Generic => new IEnumerable[1][] { new IEnumerable[nullCount + 1] },
+ _ => throw new InvalidOperationException()
+ };
+ ((Array)input.GetValue(0)).SetValue(Text, nullCount);
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ SZArrayRecord contained = (SZArrayRecord)output.Single();
+ Assert.All(contained.GetArray().Take(nullCount), Assert.Null);
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)contained.GetArray()[nullCount];
+ Assert.Equal(Text, stringRecord.Value);
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsArrayRecord_SZ(ElementType elementType)
+ {
+ int[] intArray = [1, 2, 3];
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[] { intArray },
+ ElementType.NonGeneric => new IEnumerable[] { intArray },
+ ElementType.Generic => new IEnumerable[] { intArray },
+ _ => throw new InvalidOperationException()
+ };
+
+ SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[] output = arrayRecord.GetArray();
+
+ Verify(input, arrayRecord, output, recordMap);
+ SZArrayRecord intArrayRecord = (SZArrayRecord)output.Single();
+ Assert.Equal(intArray, intArrayRecord.GetArray());
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsArrayRecord_MD(ElementType elementType)
+ {
+ int[] intArray = [1, 2, 3];
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1, 1],
+ ElementType.NonGeneric => new IEnumerable[1, 1],
+ ElementType.Generic => new IEnumerable[1, 1],
+ _ => throw new InvalidOperationException()
+ };
+ input.SetValue(intArray, 0, 0);
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ SZArrayRecord intArrayRecord = (SZArrayRecord)output[0, 0];
+ Assert.Equal(intArray, intArrayRecord.GetArray());
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ [InlineData(ElementType.Generic)]
+ public void CanReadArrayThatContainsArrayRecord_Jagged(ElementType elementType)
+ {
+ int[] intArray = [1, 2, 3];
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1][] { [intArray] },
+ ElementType.NonGeneric => new IEnumerable[1][] { [intArray] },
+ ElementType.Generic => new IEnumerable[1][] { [intArray] },
+ _ => throw new InvalidOperationException()
+ };
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ SZArrayRecord contained = (SZArrayRecord)output.Single();
+ SZArrayRecord intArrayRecord = (SZArrayRecord)contained.GetArray().Single();
+ Assert.Equal(intArray, intArrayRecord.GetArray());
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ public void CanReadArrayThatContainsAllRecordTypes_SZ(ElementType elementType)
+ {
+ const string Text = "hello";
+ int[] intArray = [1, 2, 3];
+ CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 };
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[]
+ {
+ Text, // BinaryObjectStringRecord
+ intArray, // ArraySinglePrimitiveRecord
+ classThatImplementsIEnumerable, // ClassWithMembersAndTypesRecord,
+ null // ObjectNullRecord
+ },
+ ElementType.NonGeneric => new IEnumerable[] { Text, intArray, classThatImplementsIEnumerable, null },
+ _ => throw new InvalidOperationException()
+ };
+
+ SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[] output = arrayRecord.GetArray();
+
+ Verify(input, arrayRecord, output, recordMap);
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0];
+ Assert.Equal(Text, stringRecord.Value);
+ SZArrayRecord intArrayRecord = (SZArrayRecord)output[1];
+ Assert.Equal(intArray, intArrayRecord.GetArray());
+ ClassRecord classRecord = (ClassRecord)output[2];
+ Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field)));
+ Assert.Null(output[3]);
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ public void CanReadArrayThatContainsAllRecordTypes_MD(ElementType elementType)
+ {
+ const string Text = "hello";
+ int[] intArray = [1, 2, 3];
+ CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 };
+
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1, 4],
+ ElementType.NonGeneric => new IEnumerable[1, 4],
+ _ => throw new InvalidOperationException()
+ };
+ input.SetValue(Text, 0, 0);
+ input.SetValue(intArray, 0, 1);
+ input.SetValue(classThatImplementsIEnumerable, 0, 2);
+ input.SetValue(null, 0, 3);
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ SerializationRecord[,] output = (SerializationRecord[,])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)output[0, 0];
+ Assert.Equal(Text, stringRecord.Value);
+ SZArrayRecord intArrayRecord = (SZArrayRecord)output[0, 1];
+ Assert.Equal(intArray, intArrayRecord.GetArray());
+ ClassRecord classRecord = (ClassRecord)output[0, 2];
+ Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field)));
+ Assert.Null(output[0, 3]);
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ public void CanReadArrayThatContainsAllRecordTypes_Jagged(ElementType elementType)
+ {
+ const string Text = "hello";
+ int[] intArray = [1, 2, 3];
+ CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 };
+
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1][] { [Text, intArray, classThatImplementsIEnumerable, null] },
+ ElementType.NonGeneric => new IEnumerable[1][] { [Text, intArray, classThatImplementsIEnumerable, null] },
+ _ => throw new InvalidOperationException()
+ };
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ SZArrayRecord contained = (SZArrayRecord)output.Single();
+ SerializationRecord[] records = contained.GetArray();
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)records[0];
+ Assert.Equal(Text, stringRecord.Value);
+ SZArrayRecord intArrayRecord = (SZArrayRecord)records[1];
+ Assert.Equal(intArray, intArrayRecord.GetArray());
+ ClassRecord classRecord = (ClassRecord)records[2];
+ Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field)));
+ Assert.Null(records[3]);
+ }
+
+ [Theory]
+ [InlineData(ElementType.Object)]
+ [InlineData(ElementType.NonGeneric)]
+ public void CanReadArrayThatContainsAllRecordTypes_Jagged_MD(ElementType elementType)
+ {
+ const string Text = "hello";
+ int[] intArray = [1, 2, 3];
+ CustomClassThatImplementsIEnumerable classThatImplementsIEnumerable = new() { Field = 456 };
+
+ Array input = elementType switch
+ {
+ ElementType.Object => new object[1, 1][,],
+ ElementType.NonGeneric => new IEnumerable[1, 1][,],
+ _ => throw new InvalidOperationException()
+ };
+ Array contained = elementType switch
+ {
+ ElementType.Object => new object[2, 2],
+ ElementType.NonGeneric => new IEnumerable[2, 2],
+ _ => throw new InvalidOperationException()
+ };
+ contained.SetValue(Text, 0, 0);
+ contained.SetValue(intArray, 0, 1);
+ contained.SetValue(classThatImplementsIEnumerable, 1, 0);
+ input.SetValue(contained, 0, 0);
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input), out var recordMap);
+ ArrayRecord[,] output = (ArrayRecord[,])arrayRecord.GetArray(input.GetType());
+
+ Verify(input, arrayRecord, output, recordMap);
+ SerializationRecord[,] records = (SerializationRecord[,])output[0, 0].GetArray(contained.GetType());
+ PrimitiveTypeRecord stringRecord = (PrimitiveTypeRecord)records[0, 0];
+ Assert.Equal(Text, stringRecord.Value);
+ SZArrayRecord intArrayRecord = (SZArrayRecord)records[0, 1];
+ Assert.Equal(intArray, intArrayRecord.GetArray());
+ ClassRecord classRecord = (ClassRecord)records[1, 0];
+ Assert.Equal(classThatImplementsIEnumerable.Field, classRecord.GetInt32(nameof(CustomClassThatImplementsIEnumerable.Field)));
+ Assert.Null(records[1, 1]);
+ }
+
+ [Fact]
+ public void TypeMismatch()
+ {
+ // An array of strings that contains non-string.
+ byte[] bytes = Convert.FromBase64String("AAEAAAD/////AQAAAAAAAAAHAQAAAAICAAAAAQAAAAEAAAABCQEAAAAL");
+
+ ArrayRecord arrRecord = (ArrayRecord)NrbfDecoder.Decode(new MemoryStream(bytes));
+
+ Assert.Throws(() => arrRecord.GetArray(typeof(string[,])));
+ }
+
+ private static void Verify(Array input, ArrayRecord arrayRecord, Array output,
+ IReadOnlyDictionary recordMap)
+ {
+ Assert.Equal(input.Rank, arrayRecord.Rank);
+ Assert.Equal(input.Rank, output.Rank);
+
+ for (int i = 0; i < input.Rank; i++)
+ {
+ Assert.Equal(input.GetLength(i), arrayRecord.Lengths[i]);
+ Assert.Equal(input.GetLength(i), output.GetLength(i));
+ }
+
+ foreach (object? recordOrNull in output)
+ {
+ if (recordOrNull is SerializationRecord record && !record.Id.Equals(default))
+ {
+ // An array of abstractions always uses SystemClassWithMembersAndTypesRecord to represent primitive values.
+ // This requires some non-trivial mapping and we need to ensure that it's reflected not only in what
+ // has been stored in the array, but also in the record map.
+ Assert.Same(record, recordMap[record.Id]);
+ }
+ }
+ }
+ }
+}
diff --git a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs
index 49d523088a89fe..7ef801808e4e95 100644
--- a/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs
+++ b/src/libraries/System.Formats.Nrbf/tests/ArraySinglePrimitiveRecordTests.cs
@@ -5,6 +5,7 @@
using System.IO;
using System.Runtime.Serialization;
using System.Text;
+using Microsoft.DotNet.XUnitExtensions;
using Xunit;
namespace System.Formats.Nrbf.Tests;
@@ -71,63 +72,63 @@ public void DontCastBytesToDateTimes()
Assert.Throws(() => NrbfDecoder.Decode(stream));
}
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Bool(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Byte(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_SByte(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Char(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Int16(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_UInt16(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Int32(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_UInt32(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Int64(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_UInt64(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Single(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_Double(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_TimeSpan(int size, bool canSeek) => Test(size, canSeek);
- [Theory]
+ [ConditionalTheory]
[MemberData(nameof(GetCanReadArrayOfAnySizeArgs))]
public void CanReadArrayOfAnySize_DateTime(int size, bool canSeek) => Test(size, canSeek);
- private void Test(int size, bool canSeek)
+ private void Test(int size, bool canSeek) where T : IComparable
{
Random constSeed = new Random(27644437);
T[] input = new T[size];
@@ -136,17 +137,69 @@ private void Test(int size, bool canSeek)
input[i] = GenerateValue(constSeed);
}
+ TestSZArrayOfT(input, size, canSeek);
+ TestSZArrayOfIComparable(input, size, canSeek);
+ }
+
+ private void TestSZArrayOfT(T[] input, int size, bool canSeek)
+ {
MemoryStream stream = Serialize(input);
stream = canSeek ? stream : new NonSeekableStream(stream.ToArray());
SZArrayRecord arrayRecord = (SZArrayRecord)NrbfDecoder.Decode(stream);
Assert.Equal(size, arrayRecord.Length);
- Assert.Equal(size, arrayRecord.FlattenedLength);
T?[] output = arrayRecord.GetArray();
Assert.Equal(input, output);
Assert.Same(output, arrayRecord.GetArray());
}
+ private void TestSZArrayOfIComparable(T[] input, int size, bool canSeek) where T : IComparable
+ {
+ if (!IsPatched)
+ {
+ throw new SkipTestException("Current machine has not been patched with the most recent BinaryFormatter fix.");
+ }
+
+ // Arrays of abstractions that store primitive values (example: new IComparable[1] { int.MaxValue })
+ // are represented by BinaryFormatter with a single SystemClassWithMembersAndTypesRecord
+ // and multiple ClassWithIdRecord that re-use the information from the system record.
+ // This requires some non-trivial mapping and this test is very important as it covers that code path.
+ IComparable[] comparables = new IComparable[size];
+ for (int i = 0; i < input.Length; i++)
+ {
+ comparables[i] = input[i];
+ }
+
+ TestArrayOfSerializationRecords(input, comparables, canSeek);
+ }
+
+ private void TestSZArrayOfObjects(T[] input, int size, bool canSeek)
+ {
+ // Arrays of objects that store primitive values (example: new object[1] { int.MaxValue })
+ // are represented by BinaryFormatter with MemberPrimitiveTypedRecord instances.
+ object[] objects = new object[size];
+ for (int i = 0; i < input.Length; i++)
+ {
+ objects[i] = input[i];
+ }
+
+ TestArrayOfSerializationRecords(input, objects, canSeek);
+ }
+
+ private void TestArrayOfSerializationRecords(T[] values, object input, bool canSeek)
+ {
+ MemoryStream stream = Serialize(input);
+
+ stream = canSeek ? stream : new NonSeekableStream(stream.ToArray());
+ SZArrayRecord arrayRecordOfPrimitiveRecords = (SZArrayRecord)NrbfDecoder.Decode(stream);
+ SerializationRecord[] arrayOfPrimitiveRecords = arrayRecordOfPrimitiveRecords.GetArray();
+ for (int i = 0; i < values.Length; i++)
+ {
+ Assert.Equal(values[i], ((PrimitiveTypeRecord)arrayOfPrimitiveRecords[i]).Value);
+ Assert.Equal(values[i], ((PrimitiveTypeRecord)arrayOfPrimitiveRecords[i]).Value);
+ }
+ }
+
private static T GenerateValue(Random random)
{
if (typeof(T) == typeof(byte))
diff --git a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs
index fe780d94698df0..3a81e3f131c823 100644
--- a/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs
+++ b/src/libraries/System.Formats.Nrbf/tests/AttackTests.cs
@@ -50,18 +50,51 @@ public void CyclicReferencesInSystemClassesDoNotCauseStackOverflow()
}
[Fact]
- public void CyclicReferencesInArraysOfObjectsDoNotCauseStackOverflow()
+ public void CyclicReferencesInSZArraysOfObjectsDoNotCauseStackOverflow()
{
object[] input = new object[2];
input[0] = "not an array";
input[1] = input;
ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
- object?[] output = ((SZArrayRecord)arrayRecord).GetArray();
+ SerializationRecord?[] output = ((SZArrayRecord)arrayRecord).GetArray();
- Assert.Equal(input[0], output[0]);
+ Assert.Equal(input[0], ((PrimitiveTypeRecord)output[0]).Value);
Assert.Same(input, input[1]);
- Assert.Same(output, output[1]);
+ Assert.Same(arrayRecord, output[1]);
+ }
+
+ [Fact]
+ public void CyclicReferencesInMDArraysOfObjectsDoNotCauseStackOverflow()
+ {
+ object[,] input = new object[2, 2];
+ input[0, 0] = "not an array";
+ input[1, 1] = input;
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
+ SerializationRecord?[,] output = (SerializationRecord?[,])arrayRecord.GetArray(typeof(object[,]));
+
+ Assert.Equal(input[0, 0], ((PrimitiveTypeRecord)output[0, 0]).Value);
+ Assert.Same(input, input[1, 1]);
+ Assert.Same(arrayRecord, output[1, 1]);
+ }
+
+ [Fact]
+ public void CyclicReferencesInJaggedArraysOfObjectsDoNotCauseStackOverflow()
+ {
+ object[][] input = new object[1][];
+ input[0] = new object[2];
+ input[0][0] = "not an array";
+ input[0][1] = input;
+
+ ArrayRecord arrayRecord = (ArrayRecord)NrbfDecoder.Decode(Serialize(input));
+ ArrayRecord[] output = (ArrayRecord[])arrayRecord.GetArray(typeof(object[][]));
+ SZArrayRecord row = (SZArrayRecord)output.Single();
+ SerializationRecord[] contained = row.GetArray();
+
+ Assert.Equal(input[0][0], ((PrimitiveTypeRecord)contained[0]).Value);
+ Assert.Same(input, input[0][1]);
+ Assert.Same(arrayRecord, contained[1]);
}
[Serializable]
@@ -81,8 +114,8 @@ public void CyclicClassReferencesInArraysOfObjectsDoNotCauseStackOverflow()
ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input));
Assert.Equal(input.Name, classRecord.GetString(nameof(WithCyclicReferenceInArrayOfObjects.Name)));
- SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfObjects.ArrayWithReferenceToSelf))!;
- object?[] array = arrayRecord.GetArray();
+ SZArrayRecord arrayRecord = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfObjects.ArrayWithReferenceToSelf))!;
+ SerializationRecord?[] array = arrayRecord.GetArray();
Assert.Same(classRecord, array.Single());
}
@@ -103,7 +136,7 @@ public void CyclicClassReferencesInArraysOfTDoNotCauseStackOverflow()
ClassRecord classRecord = NrbfDecoder.DecodeClassRecord(Serialize(input));
Assert.Equal(input.Name, classRecord.GetString(nameof(WithCyclicReferenceInArrayOfT.Name)));
- SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfT.ArrayWithReferenceToSelf))!;
+ SZArrayRecord classRecords = (SZArrayRecord)classRecord.GetSerializationRecord(nameof(WithCyclicReferenceInArrayOfT.ArrayWithReferenceToSelf))!;
Assert.Same(classRecord, classRecords.GetArray().Single());
}
diff --git a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs
index 6acb44d03697d2..2d78954d649094 100644
--- a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs
+++ b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs
@@ -355,6 +355,36 @@ public void ThrowsForInvalidPositiveArrayRank(int rank, byte arrayType)
Assert.Throws