diff --git a/src/EFCore.Cosmos/Extensions/CosmosDatabaseFacadeExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosDatabaseFacadeExtensions.cs index f5c5d179b0b..fcdbf80fca5 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosDatabaseFacadeExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosDatabaseFacadeExtensions.cs @@ -25,6 +25,79 @@ public static class CosmosDatabaseFacadeExtensions public static CosmosClient GetCosmosClient(this DatabaseFacade databaseFacade) => GetService(databaseFacade).Client; + /// + /// Gets the composite session token for the default container for this . + /// + /// Use this when using only 1 container in the same . + /// The for the context. + /// The session token for the default container in the context, or if none present. + public static string? GetSessionToken(this DatabaseFacade databaseFacade) + => GetSessionTokenStorage(databaseFacade).GetDefaultContainerTrackedToken(); + + /// + /// Gets a dictionary that contains the composite session token per container for this . + /// + /// Use this when using multiple containers in the same . + /// The for the context. + /// The session token dictionary. + public static IReadOnlyDictionary GetSessionTokens(this DatabaseFacade databaseFacade) + => GetSessionTokenStorage(databaseFacade).GetTrackedTokens(); + + /// + /// Sets the composite session token for the default container for this . + /// + /// Use this when using only 1 container in the same . + /// The for the context. + /// The session token to set. + public static void UseSessionToken(this DatabaseFacade databaseFacade, string sessionToken) + => GetSessionTokenStorage(databaseFacade).SetDefaultContainerSessionToken(sessionToken); + + /// + /// Appends the composite session token for the default container for this . + /// + /// Use this when using only 1 container in the same . + /// The for the context. + /// The session token to append. + public static void AppendSessionToken(this DatabaseFacade databaseFacade, string sessionToken) + => GetSessionTokenStorage(databaseFacade).AppendDefaultContainerSessionToken(sessionToken); + + /// + /// Sets the composite sessions token per container for this with the tokens specified in . + /// + /// Use this when using multiple containers in the same . + /// The for the context. + /// The session tokens to set per container. + public static void UseSessionTokens(this DatabaseFacade databaseFacade, IReadOnlyDictionary sessionTokens) + { + var sessionTokenStorage = GetSessionTokenStorage(databaseFacade, sessionTokens); + + sessionTokenStorage.SetSessionTokens(sessionTokens); + } + + /// + /// Appends the composite sessions token per container for this with the tokens specified in . + /// + /// Use this when using multiple containers in the same . + /// The for the context. + /// The session tokens to append per container. + public static void AppendSessionTokens(this DatabaseFacade databaseFacade, IReadOnlyDictionary sessionTokens) + { + var sessionTokenStorage = GetSessionTokenStorage(databaseFacade, (IReadOnlyDictionary)sessionTokens); + + sessionTokenStorage.AppendSessionTokens(sessionTokens); + } + + private static ISessionTokenStorage GetSessionTokenStorage(DatabaseFacade databaseFacade, IReadOnlyDictionary? sessionTokens = null) + { + var db = GetService(databaseFacade); + if (db is not CosmosDatabaseWrapper dbWrapper) + { + throw new InvalidOperationException(CosmosStrings.CosmosNotInUse); + } + + return dbWrapper.SessionTokenStorage; + } + private static TService GetService(IInfrastructure databaseFacade) where TService : class { diff --git a/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs index cfff8a16aa9..09018b0f2a5 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosServiceCollectionExtensions.cs @@ -97,6 +97,7 @@ public static IServiceCollection AddEntityFrameworkCosmos(this IServiceCollectio .TryAdd() .TryAdd>() .TryAdd() + .TryAdd(sp => (CosmosDatabaseWrapper)sp.GetRequiredService()) .TryAdd() .TryAdd() .TryAdd() @@ -121,7 +122,8 @@ public static IServiceCollection AddEntityFrameworkCosmos(this IServiceCollectio .TryAddScoped() .TryAddScoped() .TryAddScoped() - .TryAddScoped()); + .TryAddScoped() + .TryAddSingleton()); builder.TryAddCoreServices(); diff --git a/src/EFCore.Cosmos/Infrastructure/CosmosDbContextOptionsBuilder.cs b/src/EFCore.Cosmos/Infrastructure/CosmosDbContextOptionsBuilder.cs index e88fabf78da..605c4551034 100644 --- a/src/EFCore.Cosmos/Infrastructure/CosmosDbContextOptionsBuilder.cs +++ b/src/EFCore.Cosmos/Infrastructure/CosmosDbContextOptionsBuilder.cs @@ -3,6 +3,7 @@ using System.ComponentModel; using System.Net; +using Microsoft.EntityFrameworkCore.Cosmos.Infrastructure; using Microsoft.EntityFrameworkCore.Cosmos.Infrastructure.Internal; namespace Microsoft.EntityFrameworkCore.Infrastructure; @@ -211,6 +212,23 @@ public virtual CosmosDbContextOptionsBuilder MaxRequestsPerTcpConnection(int req public virtual CosmosDbContextOptionsBuilder ContentResponseOnWriteEnabled(bool enabled = true) => WithOption(e => e.ContentResponseOnWriteEnabled(Check.NotNull(enabled))); + + /// + /// Sets the to use. + /// By default, will be used. + /// Any other mode is only relevant when your application needs to manage session tokens manually. + /// For example: If you're using a round-robin load balancer that doesn't maintain session affinity between requests. + /// Manual session token management can break session consistency when not handled properly. + /// See Utilize session tokens for more details. + /// + /// + /// See Using DbContextOptions, and + /// Accessing Azure Cosmos DB with EF Core for more information and examples. + /// + /// The to use. + public virtual CosmosDbContextOptionsBuilder SessionTokenManagementMode(SessionTokenManagementMode mode) + => WithOption(e => e.WithSessionTokenManagementMode(mode)); + /// /// Sets an option by cloning the extension used to store the settings. This ensures the builder /// does not modify options that are already in use elsewhere. diff --git a/src/EFCore.Cosmos/Infrastructure/Internal/CosmosDbOptionExtension.cs b/src/EFCore.Cosmos/Infrastructure/Internal/CosmosDbOptionExtension.cs index f2545174ac3..6758b48ce0a 100644 --- a/src/EFCore.Cosmos/Infrastructure/Internal/CosmosDbOptionExtension.cs +++ b/src/EFCore.Cosmos/Infrastructure/Internal/CosmosDbOptionExtension.cs @@ -36,6 +36,7 @@ public class CosmosOptionsExtension : IDbContextOptionsExtension private bool? _enableContentResponseOnWrite; private DbContextOptionsExtensionInfo? _info; private Func? _httpClientFactory; + private SessionTokenManagementMode _sessionTokenManagementMode = SessionTokenManagementMode.FullyAutomatic; /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -73,6 +74,7 @@ protected CosmosOptionsExtension(CosmosOptionsExtension copyFrom) _maxTcpConnectionsPerEndpoint = copyFrom._maxTcpConnectionsPerEndpoint; _maxRequestsPerTcpConnection = copyFrom._maxRequestsPerTcpConnection; _httpClientFactory = copyFrom._httpClientFactory; + _sessionTokenManagementMode = copyFrom._sessionTokenManagementMode; } /// @@ -564,6 +566,30 @@ public virtual CosmosOptionsExtension WithHttpClientFactory(Func? ht return clone; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SessionTokenManagementMode SessionTokenManagementMode + => _sessionTokenManagementMode; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual CosmosOptionsExtension WithSessionTokenManagementMode(SessionTokenManagementMode mode) + { + var clone = Clone(); + + clone._sessionTokenManagementMode = mode; + + return clone; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -632,6 +658,7 @@ public override int GetServiceProviderHashCode() hashCode.Add(Extension._maxTcpConnectionsPerEndpoint); hashCode.Add(Extension._maxRequestsPerTcpConnection); hashCode.Add(Extension._httpClientFactory); + hashCode.Add(Extension._sessionTokenManagementMode); _serviceProviderHash = hashCode.ToHashCode(); } @@ -656,7 +683,8 @@ public override bool ShouldUseSameServiceProvider(DbContextOptionsExtensionInfo && Extension._gatewayModeMaxConnectionLimit == otherInfo.Extension._gatewayModeMaxConnectionLimit && Extension._maxTcpConnectionsPerEndpoint == otherInfo.Extension._maxTcpConnectionsPerEndpoint && Extension._maxRequestsPerTcpConnection == otherInfo.Extension._maxRequestsPerTcpConnection - && Extension._httpClientFactory == otherInfo.Extension._httpClientFactory; + && Extension._httpClientFactory == otherInfo.Extension._httpClientFactory + && Extension._sessionTokenManagementMode == otherInfo.Extension._sessionTokenManagementMode; public override void PopulateDebugInfo(IDictionary debugInfo) { diff --git a/src/EFCore.Cosmos/Infrastructure/Internal/CosmosSingletonOptions.cs b/src/EFCore.Cosmos/Infrastructure/Internal/CosmosSingletonOptions.cs index af29229cfa5..42e5ea687a9 100644 --- a/src/EFCore.Cosmos/Infrastructure/Internal/CosmosSingletonOptions.cs +++ b/src/EFCore.Cosmos/Infrastructure/Internal/CosmosSingletonOptions.cs @@ -151,6 +151,14 @@ public class CosmosSingletonOptions : ICosmosSingletonOptions /// public virtual Func? HttpClientFactory { get; private set; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SessionTokenManagementMode SessionTokenManagementMode { get; private set; } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -178,6 +186,7 @@ public virtual void Initialize(IDbContextOptions options) MaxTcpConnectionsPerEndpoint = cosmosOptions.MaxTcpConnectionsPerEndpoint; MaxRequestsPerTcpConnection = cosmosOptions.MaxRequestsPerTcpConnection; HttpClientFactory = cosmosOptions.HttpClientFactory; + SessionTokenManagementMode = cosmosOptions.SessionTokenManagementMode; } } @@ -208,6 +217,7 @@ public virtual void Validate(IDbContextOptions options) || MaxTcpConnectionsPerEndpoint != cosmosOptions.MaxTcpConnectionsPerEndpoint || MaxRequestsPerTcpConnection != cosmosOptions.MaxRequestsPerTcpConnection || HttpClientFactory != cosmosOptions.HttpClientFactory + || SessionTokenManagementMode != cosmosOptions.SessionTokenManagementMode )) { throw new InvalidOperationException( diff --git a/src/EFCore.Cosmos/Infrastructure/Internal/ICosmosSingletonOptions.cs b/src/EFCore.Cosmos/Infrastructure/Internal/ICosmosSingletonOptions.cs index a26b79a82b3..cbdbf8c1d79 100644 --- a/src/EFCore.Cosmos/Infrastructure/Internal/ICosmosSingletonOptions.cs +++ b/src/EFCore.Cosmos/Infrastructure/Internal/ICosmosSingletonOptions.cs @@ -155,4 +155,12 @@ public interface ICosmosSingletonOptions : ISingletonOptions /// doing so can result in application failures when updating to a new Entity Framework Core release. /// Func? HttpClientFactory { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + SessionTokenManagementMode SessionTokenManagementMode { get; } } diff --git a/src/EFCore.Cosmos/Infrastructure/SessionTokenManagementMode.cs b/src/EFCore.Cosmos/Infrastructure/SessionTokenManagementMode.cs new file mode 100644 index 00000000000..f023127b05f --- /dev/null +++ b/src/EFCore.Cosmos/Infrastructure/SessionTokenManagementMode.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Cosmos.Infrastructure; + +/// +/// Defines the behaviour of EF regarding the management of Cosmos DB session tokens. +/// +/// +/// See Consistency level choices for more info. +/// +public enum SessionTokenManagementMode +{ + /// + /// The default mode. + /// Uses the underlying Cosmos DB SDK automatic session token management. + /// EF will not track or parse session tokens returned from Cosmos DB. and methods will throw when invoked. + /// Use this mode when every request for the same user will land on the same instance of your app. + /// This means you either have 1 application instance, or maintain session affinity between requests. + /// Otherwhise, use of one of the other modes is required to guarantee session consistency between requests. + /// + FullyAutomatic, + + /// + /// Allows the usage of to overwrite the default Cosmos DB SDK automatic session token management by use of the method on a instance. + /// If has not been invoked for an container, the default Cosmos DB SDK automatic session token management will be used. + /// EF will track and parse session tokens returned from Cosmos DB, which can be retrieved via . + /// + SemiAutomatic, + + /// + /// Fully overwrites the Cosmos DB SDK automatic session token management, and only uses session tokens specified via . + /// If has not been invoked for an container, no session token will be used. + /// EF will track and parse session tokens returned from Cosmos DB, which can be retrieved via . + /// + Manual, + + /// + /// Same as , but will throw an exception if was not invoked before executong a read. + /// + EnforcedManual +} diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs index 6b42a75fb11..eead7f1c210 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs @@ -97,6 +97,14 @@ public static string ContainerContainingPropertyConflict(object? entityType, obj GetString("ContainerContainingPropertyConflict", nameof(entityType), nameof(container), nameof(property)), entityType, container, property); + /// + /// The container with the name '{containerName}' does not exist. + /// + public static string ContainerNameDoesNotExist(object? containerName) + => string.Format( + GetString("ContainerNameDoesNotExist", nameof(containerName)), + containerName); + /// /// An Azure Cosmos DB container name is defined on entity type '{entityType}', which inherits from '{baseEntityType}'. Container names must be defined on the root entity type of a hierarchy. /// @@ -157,6 +165,12 @@ public static string ElementWithValueConverter(object? propertyType, object? str GetString("ElementWithValueConverter", nameof(propertyType), nameof(structuralType), nameof(property), nameof(elementType)), propertyType, structuralType, property, elementType); + /// + /// Enable manual session token management using 'options.SessionTokenManagementMode' to use this method. + /// + public static string EnableManualSessionTokenManagement + => GetString("EnableManualSessionTokenManagement"); + /// /// The type of the etag property '{property}' on '{entityType}' is '{propertyType}'. All etag properties must be strings or have a string value converter. /// @@ -261,6 +275,14 @@ public static string LimitOffsetNotSupportedInSubqueries public static string MissingOrderingInSelectExpression => GetString("MissingOrderingInSelectExpression"); + /// + /// No session token has been set for container: {container}. While using EnforceManual you must always set a session token for any container used. + /// + public static string MissingSessionTokenEnforceManual(object? container) + => string.Format( + GetString("MissingSessionTokenEnforceManual", nameof(container)), + container); + /// /// Root entity type '{entityType1}' is referenced by the query, but '{entityType2}' is already being referenced. A query can only reference a single root entity type. /// diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.resx b/src/EFCore.Cosmos/Properties/CosmosStrings.resx index bb443f30aa5..38b46cfe53d 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.resx +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.resx @@ -147,6 +147,10 @@ The entity type '{entityType}' is mapped to the container '{container}' but it is also configured as being contained in property '{property}'. + + The container with the name '{containerName}' does not exist. + string + An Azure Cosmos DB container name is defined on entity type '{entityType}', which inherits from '{baseEntityType}'. Container names must be defined on the root entity type of a hierarchy. @@ -171,6 +175,9 @@ The property '{propertyType} {structuralType}.{property}' has element type '{elementType}', which requires a value converter. Elements types requiring value converters are not currently supported with the Azure Cosmos DB database provider. + + Enable manual session token management using 'options.SessionTokenManagementMode' to use this method. + The type of the etag property '{property}' on '{entityType}' is '{propertyType}'. All etag properties must be strings or have a string value converter. @@ -253,6 +260,10 @@ 'Reverse' could not be translated to the server because there is no ordering on the server side. + + No session token has been set for container: {container}. While using EnforceManual you must always set a session token for any container used. + string + Root entity type '{entityType1}' is referenced by the query, but '{entityType2}' is already being referenced. A query can only reference a single root entity type. diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs index 09f0723b959..0249b3ebdc2 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryContext.cs @@ -13,7 +13,8 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; /// public class CosmosQueryContext( QueryContextDependencies dependencies, - ICosmosClientWrapper cosmosClient) + ICosmosClientWrapper cosmosClient, + ISessionTokenStorage sessionTokenStorage) : QueryContext(dependencies) { /// @@ -23,4 +24,12 @@ public class CosmosQueryContext( /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual ICosmosClientWrapper CosmosClient { get; } = cosmosClient; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual ISessionTokenStorage SessionTokenStorage { get; } = sessionTokenStorage; } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryContextFactory.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryContextFactory.cs index 12c83848a98..245a89a1700 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryContextFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryContextFactory.cs @@ -20,6 +20,14 @@ public class CosmosQueryContextFactory( /// protected virtual QueryContextDependencies Dependencies { get; } = dependencies; + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected virtual ISessionTokenStorage SessionTokenStorage { get; } = ((CosmosDatabaseWrapper)dependencies.CurrentContext.Context.GetService()).SessionTokenStorage; + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -27,5 +35,5 @@ public class CosmosQueryContextFactory( /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual QueryContext Create() - => new CosmosQueryContext(Dependencies, cosmosClient); + => new CosmosQueryContext(Dependencies, cosmosClient, SessionTokenStorage); } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.PagingQueryingEnumerable.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.PagingQueryingEnumerable.cs index 2766b35a3c9..5687867d455 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.PagingQueryingEnumerable.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.PagingQueryingEnumerable.cs @@ -62,8 +62,8 @@ public PagingQueryingEnumerable( _threadSafetyChecksEnabled = threadSafetyChecksEnabled; _maxItemCountParameterName = maxItemCountParameterName; _continuationTokenParameterName = continuationTokenParameterName; - _responseContinuationTokenLimitInKbParameterName = responseContinuationTokenLimitInKbParameterName; + _responseContinuationTokenLimitInKbParameterName = responseContinuationTokenLimitInKbParameterName; _cosmosContainer = rootEntityType.GetContainer() ?? throw new UnreachableException("Root entity type without a Cosmos container."); _cosmosPartitionKey = GeneratePartitionKey( @@ -155,6 +155,8 @@ public async ValueTask MoveNextAsync() queryRequestOptions.PartitionKey = _cosmosPartitionKey; } + queryRequestOptions.SessionToken = _cosmosQueryContext.SessionTokenStorage.GetSessionToken(_cosmosContainer); + var cosmosClient = _cosmosQueryContext.CosmosClient; _commandLogger.ExecutingSqlQuery(_cosmosContainer, _cosmosPartitionKey, sqlQuery); _cosmosQueryContext.InitializeStateManager(_standAloneStateManager); @@ -165,7 +167,7 @@ public async ValueTask MoveNextAsync() { queryRequestOptions.MaxItemCount = maxItemCount; using var feedIterator = cosmosClient.CreateQuery( - _cosmosContainer, sqlQuery, continuationToken, queryRequestOptions); + _cosmosContainer, sqlQuery, _cosmosQueryContext.SessionTokenStorage, continuationToken, queryRequestOptions); using var responseMessage = await feedIterator.ReadNextAsync(_cancellationToken).ConfigureAwait(false); diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs index 58c976b5ace..d2906182de9 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.QueryingEnumerable.cs @@ -149,7 +149,7 @@ public async ValueTask MoveNextAsync() EntityFrameworkMetricsData.ReportQueryExecuting(); _enumerator = _cosmosQueryContext.CosmosClient - .ExecuteSqlQueryAsync(_cosmosContainer, _cosmosPartitionKey, sqlQuery) + .ExecuteSqlQueryAsync(_cosmosContainer, _cosmosPartitionKey, sqlQuery, _cosmosQueryContext.SessionTokenStorage) .GetAsyncEnumerator(_cancellationToken); _cosmosQueryContext.InitializeStateManager(_standAloneStateManager); } diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs index eae16897ff1..c1461cfa6fe 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.ReadItemQueryingEnumerable.cs @@ -159,6 +159,7 @@ public async ValueTask MoveNextAsync() _cosmosContainer, _cosmosPartitionKey, resourceId, + _cosmosQueryContext.SessionTokenStorage, _cancellationToken) .ConfigureAwait(false); diff --git a/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs b/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs index ed2ab2804bb..62d48e027d7 100644 --- a/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs +++ b/src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs @@ -335,26 +335,28 @@ public virtual Task CreateItemAsync( string containerId, JToken document, IUpdateEntry updateEntry, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default) - => _executionStrategy.ExecuteAsync((containerId, document, updateEntry, this), CreateItemOnceAsync, null, cancellationToken); + => _executionStrategy.ExecuteAsync((containerId, document, updateEntry, sessionTokenStorage, this), CreateItemOnceAsync, null, cancellationToken); private static async Task CreateItemOnceAsync( DbContext _, - (string ContainerId, JToken Document, IUpdateEntry Entry, CosmosClientWrapper Wrapper) parameters, + (string ContainerId, JToken Document, IUpdateEntry Entry, ISessionTokenStorage SessionTokenStorage, CosmosClientWrapper Wrapper) parameters, CancellationToken cancellationToken = default) { using var stream = Serialize(parameters.Document); + var containerId = parameters.ContainerId; var entry = parameters.Entry; var wrapper = parameters.Wrapper; + var sessionTokenStorage = parameters.SessionTokenStorage; var container = wrapper.Client.GetDatabase(wrapper._databaseId).GetContainer(parameters.ContainerId); - var itemRequestOptions = CreateItemRequestOptions(entry, wrapper._enableContentResponseOnWrite); + var itemRequestOptions = CreateItemRequestOptions(entry, wrapper._enableContentResponseOnWrite, sessionTokenStorage.GetSessionToken(containerId)); var partitionKeyValue = ExtractPartitionKeyValue(entry); var preTriggers = GetTriggers(entry, TriggerType.Pre, TriggerOperation.Create); var postTriggers = GetTriggers(entry, TriggerType.Post, TriggerOperation.Create); if (preTriggers != null || postTriggers != null) { - itemRequestOptions ??= new ItemRequestOptions(); if (preTriggers != null) { itemRequestOptions.PreTriggers = preTriggers; @@ -378,10 +380,10 @@ private static async Task CreateItemOnceAsync( response.Headers.RequestCharge, response.Headers.ActivityId, parameters.Document["id"]!.ToString(), - parameters.ContainerId, + containerId, partitionKeyValue); - ProcessResponse(response, entry); + ProcessResponse(containerId, response, entry, sessionTokenStorage); return response.StatusCode == HttpStatusCode.Created; } @@ -397,27 +399,29 @@ public virtual Task ReplaceItemAsync( string documentId, JObject document, IUpdateEntry updateEntry, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default) => _executionStrategy.ExecuteAsync( - (collectionId, documentId, document, updateEntry, this), ReplaceItemOnceAsync, null, cancellationToken); + (collectionId, documentId, document, updateEntry, sessionTokenStorage, this), ReplaceItemOnceAsync, null, cancellationToken); private static async Task ReplaceItemOnceAsync( DbContext _, - (string ContainerId, string ResourceId, JObject Document, IUpdateEntry Entry, CosmosClientWrapper Wrapper) parameters, + (string ContainerId, string ResourceId, JObject Document, IUpdateEntry Entry, ISessionTokenStorage SessionTokenStorage, CosmosClientWrapper Wrapper) parameters, CancellationToken cancellationToken = default) { using var stream = Serialize(parameters.Document); + var containerId = parameters.ContainerId; var entry = parameters.Entry; var wrapper = parameters.Wrapper; + var sessionTokenStorage = parameters.SessionTokenStorage; var container = wrapper.Client.GetDatabase(wrapper._databaseId).GetContainer(parameters.ContainerId); - var itemRequestOptions = CreateItemRequestOptions(entry, wrapper._enableContentResponseOnWrite); + var itemRequestOptions = CreateItemRequestOptions(entry, wrapper._enableContentResponseOnWrite, sessionTokenStorage.GetSessionToken(containerId)); var partitionKeyValue = ExtractPartitionKeyValue(entry); var preTriggers = GetTriggers(entry, TriggerType.Pre, TriggerOperation.Replace); var postTriggers = GetTriggers(entry, TriggerType.Post, TriggerOperation.Replace); if (preTriggers != null || postTriggers != null) { - itemRequestOptions ??= new ItemRequestOptions(); if (preTriggers != null) { itemRequestOptions.PreTriggers = preTriggers; @@ -442,10 +446,10 @@ private static async Task ReplaceItemOnceAsync( response.Headers.RequestCharge, response.Headers.ActivityId, parameters.ResourceId, - parameters.ContainerId, + containerId, partitionKeyValue); - ProcessResponse(response, entry); + ProcessResponse(containerId, response, entry, sessionTokenStorage); return response.StatusCode == HttpStatusCode.OK; } @@ -460,8 +464,57 @@ public virtual Task DeleteItemAsync( string containerId, string documentId, IUpdateEntry entry, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default) - => _executionStrategy.ExecuteAsync((containerId, documentId, entry, this), DeleteItemOnceAsync, null, cancellationToken); + => _executionStrategy.ExecuteAsync((containerId, documentId, entry, sessionTokenStorage, this), DeleteItemOnceAsync, null, cancellationToken); + + private static async Task DeleteItemOnceAsync( + DbContext? _, + (string ContainerId, string ResourceId, IUpdateEntry Entry, ISessionTokenStorage SessionTokenStorage, CosmosClientWrapper Wrapper) parameters, + CancellationToken cancellationToken = default) + { + var containerId = parameters.ContainerId; + var entry = parameters.Entry; + var wrapper = parameters.Wrapper; + var sessionTokenStorage = parameters.SessionTokenStorage; + var items = wrapper.Client.GetDatabase(wrapper._databaseId).GetContainer(parameters.ContainerId); + + var itemRequestOptions = CreateItemRequestOptions(entry, wrapper._enableContentResponseOnWrite, sessionTokenStorage.GetSessionToken(containerId)); + var partitionKeyValue = ExtractPartitionKeyValue(entry); + var preTriggers = GetTriggers(entry, TriggerType.Pre, TriggerOperation.Delete); + var postTriggers = GetTriggers(entry, TriggerType.Post, TriggerOperation.Delete); + if (preTriggers != null || postTriggers != null) + { + if (preTriggers != null) + { + itemRequestOptions.PreTriggers = preTriggers; + } + + if (postTriggers != null) + { + itemRequestOptions.PostTriggers = postTriggers; + } + } + + using var response = await items.DeleteItemStreamAsync( + parameters.ResourceId, + partitionKeyValue, + itemRequestOptions, + cancellationToken: cancellationToken) + .ConfigureAwait(false); + + wrapper._commandLogger.ExecutedDeleteItem( + response.Diagnostics.GetClientElapsedTime(), + response.Headers.RequestCharge, + response.Headers.ActivityId, + parameters.ResourceId, + containerId, + partitionKeyValue); + + ProcessResponse(containerId, response, entry, sessionTokenStorage); + + return response.StatusCode == HttpStatusCode.NoContent; + } /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -481,6 +534,7 @@ public virtual PartitionKey GetPartitionKeyValue(IUpdateEntry updateEntry) public virtual ICosmosTransactionalBatchWrapper CreateTransactionalBatch(string containerId, PartitionKey partitionKeyValue, bool checkSize) { var container = Client.GetDatabase(_databaseId).GetContainer(containerId); + var batch = container.CreateTransactionalBatch(partitionKeyValue); return new CosmosTransactionalBatchWrapper(batch, containerId, partitionKeyValue, checkSize, _enableContentResponseOnWrite); @@ -492,18 +546,32 @@ public virtual ICosmosTransactionalBatchWrapper CreateTransactionalBatch(string /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - public virtual Task ExecuteTransactionalBatchAsync(ICosmosTransactionalBatchWrapper batch, CancellationToken cancellationToken = default) - => _executionStrategy.ExecuteAsync((batch, this), ExecuteTransactionalBatchOnceAsync, null, cancellationToken); + public virtual Task ExecuteTransactionalBatchAsync(ICosmosTransactionalBatchWrapper batch, ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default) + => _executionStrategy.ExecuteAsync((batch, sessionTokenStorage, this), ExecuteTransactionalBatchOnceAsync, null, cancellationToken); private static async Task ExecuteTransactionalBatchOnceAsync(DbContext _, - (ICosmosTransactionalBatchWrapper Batch, CosmosClientWrapper Wrapper) parameters, + (ICosmosTransactionalBatchWrapper Batch, ISessionTokenStorage SessionTokenStorage, CosmosClientWrapper Wrapper) parameters, CancellationToken cancellationToken = default) { var batch = parameters.Batch; var transactionalBatch = batch.GetTransactionalBatch(); var wrapper = parameters.Wrapper; + var sessionTokenStorage = parameters.SessionTokenStorage; + + var options = new TransactionalBatchRequestOptions + { + SessionToken = sessionTokenStorage.GetSessionToken(batch.CollectionId) + }; + + using var response = await transactionalBatch.ExecuteAsync(options, cancellationToken).ConfigureAwait(false); - using var response = await transactionalBatch.ExecuteAsync(cancellationToken).ConfigureAwait(false); + wrapper._commandLogger.ExecutedTransactionalBatch( + response.Diagnostics.GetClientElapsedTime(), + response.Headers.RequestCharge, + response.Headers.ActivityId, + batch.CollectionId, + batch.PartitionKeyValue, + "[ \"" + string.Join("\", \"", batch.Entries.Select(x => x.Id)) + "\" ]"); if (!response.IsSuccessStatusCode) { @@ -518,73 +586,27 @@ private static async Task ExecuteTransactionalBa return new CosmosTransactionalBatchResult(errorEntries, exception); } - wrapper._commandLogger.ExecutedTransactionalBatch( - response.Diagnostics.GetClientElapsedTime(), - response.Headers.RequestCharge, - response.Headers.ActivityId, - batch.CollectionId, - batch.PartitionKeyValue, - "[ \"" + string.Join("\", \"", batch.Entries.Select(x => x.Id)) + "\" ]"); - - ProcessResponse(response, batch.Entries); + ProcessResponse(batch.CollectionId, response, batch.Entries, sessionTokenStorage); return CosmosTransactionalBatchResult.Success; } - private static async Task DeleteItemOnceAsync( - DbContext? _, - (string ContainerId, string ResourceId, IUpdateEntry Entry, CosmosClientWrapper Wrapper) parameters, - CancellationToken cancellationToken = default) + private static ItemRequestOptions CreateItemRequestOptions(IUpdateEntry entry, bool? enableContentResponseOnWrite, string? sessionToken) { - var entry = parameters.Entry; - var wrapper = parameters.Wrapper; - var items = wrapper.Client.GetDatabase(wrapper._databaseId).GetContainer(parameters.ContainerId); + var helper = RequestOptionsHelper.Create(entry, enableContentResponseOnWrite); - var itemRequestOptions = CreateItemRequestOptions(entry, wrapper._enableContentResponseOnWrite); - var partitionKeyValue = ExtractPartitionKeyValue(entry); - var preTriggers = GetTriggers(entry, TriggerType.Pre, TriggerOperation.Delete); - var postTriggers = GetTriggers(entry, TriggerType.Post, TriggerOperation.Delete); - if (preTriggers != null || postTriggers != null) + var itemRequestOptions = new ItemRequestOptions { - itemRequestOptions ??= new ItemRequestOptions(); - if (preTriggers != null) - { - itemRequestOptions.PreTriggers = preTriggers; - } + SessionToken = sessionToken + }; - if (postTriggers != null) - { - itemRequestOptions.PostTriggers = postTriggers; - } + if (helper != null) + { + itemRequestOptions.IfMatchEtag = helper.IfMatchEtag; + itemRequestOptions.EnableContentResponseOnWrite = helper.EnableContentResponseOnWrite; } - using var response = await items.DeleteItemStreamAsync( - parameters.ResourceId, - partitionKeyValue, - itemRequestOptions, - cancellationToken: cancellationToken) - .ConfigureAwait(false); - - wrapper._commandLogger.ExecutedDeleteItem( - response.Diagnostics.GetClientElapsedTime(), - response.Headers.RequestCharge, - response.Headers.ActivityId, - parameters.ResourceId, - parameters.ContainerId, - partitionKeyValue); - - ProcessResponse(response, entry); - - return response.StatusCode == HttpStatusCode.NoContent; - } - - private static ItemRequestOptions? CreateItemRequestOptions(IUpdateEntry entry, bool? enableContentResponseOnWrite) - { - var helper = RequestOptionsHelper.Create(entry, enableContentResponseOnWrite); - - return helper == null - ? null - : new ItemRequestOptions { IfMatchEtag = helper.IfMatchEtag, EnableContentResponseOnWrite = helper.EnableContentResponseOnWrite }; + return itemRequestOptions; } private static IReadOnlyList? GetTriggers(IUpdateEntry entry, TriggerType type, TriggerOperation operation) @@ -620,14 +642,25 @@ private static PartitionKey ExtractPartitionKeyValue(IUpdateEntry entry) return builder.Build(); } - private static void ProcessResponse(ResponseMessage response, IUpdateEntry entry) + private static void ProcessResponse(string containerId, ResponseMessage response, IUpdateEntry entry, ISessionTokenStorage sessionTokenStorage) { response.EnsureSuccessStatusCode(); + + if (!string.IsNullOrWhiteSpace(response.Headers.Session)) + { + sessionTokenStorage.TrackSessionToken(containerId, response.Headers.Session); + } + ProcessResponse(entry, response.Headers.ETag, response.Content); } - private static void ProcessResponse(TransactionalBatchResponse batchResponse, IReadOnlyList entries) + private static void ProcessResponse(string containerId, TransactionalBatchResponse batchResponse, IReadOnlyList entries, ISessionTokenStorage sessionTokenStorage) { + if (!string.IsNullOrWhiteSpace(batchResponse.Headers.Session)) + { + sessionTokenStorage.TrackSessionToken(containerId, batchResponse.Headers.Session); + } + for (var i = 0; i < batchResponse.Count; i++) { var entry = entries[i]; @@ -668,11 +701,12 @@ private static void ProcessResponse(IUpdateEntry entry, string eTag, Stream? con public virtual IAsyncEnumerable ExecuteSqlQueryAsync( string containerId, PartitionKey partitionKeyValue, - CosmosSqlQuery query) + CosmosSqlQuery query, + ISessionTokenStorage sessionTokenStorage) { _commandLogger.ExecutingSqlQuery(containerId, partitionKeyValue, query); - return new DocumentAsyncEnumerable(this, containerId, partitionKeyValue, query); + return new DocumentAsyncEnumerable(this, containerId, partitionKeyValue, query, sessionTokenStorage); } /// @@ -685,12 +719,13 @@ public virtual IAsyncEnumerable ExecuteSqlQueryAsync( string containerId, PartitionKey partitionKeyValue, string resourceId, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default) { _commandLogger.ExecutingReadItem(containerId, partitionKeyValue, resourceId); var response = await _executionStrategy.ExecuteAsync( - (containerId, partitionKeyValue, resourceId, this), + (containerId, partitionKeyValue, resourceId, sessionTokenStorage, this), CreateSingleItemQueryAsync, null, cancellationToken) @@ -707,29 +742,43 @@ public virtual IAsyncEnumerable ExecuteSqlQueryAsync( return JObjectFromReadItemResponseMessage(response); } - private static Task CreateSingleItemQueryAsync( + private static async Task CreateSingleItemQueryAsync( DbContext? _, - (string ContainerId, PartitionKey PartitionKeyValue, string ResourceId, CosmosClientWrapper Wrapper) parameters, + (string ContainerId, PartitionKey PartitionKeyValue, string ResourceId, ISessionTokenStorage SessionTokenStorage, CosmosClientWrapper Wrapper) parameters, CancellationToken cancellationToken = default) { - var (containerId, partitionKeyValue, resourceId, wrapper) = parameters; + var (containerId, partitionKeyValue, resourceId, sessionTokenStorage, wrapper) = parameters; var container = wrapper.Client.GetDatabase(wrapper._databaseId).GetContainer(containerId); - return container.ReadItemStreamAsync( + var itemRequestOptions = new ItemRequestOptions { SessionToken = sessionTokenStorage.GetSessionToken(containerId) }; + + var response = await container.ReadItemStreamAsync( resourceId, partitionKeyValue, - cancellationToken: cancellationToken); + itemRequestOptions, + cancellationToken: cancellationToken).ConfigureAwait(false); + + if (!string.IsNullOrWhiteSpace(response.Headers.Session)) + { + sessionTokenStorage.TrackSessionToken(containerId, response.Headers.Session); + } + + return response; } private static JObject? JObjectFromReadItemResponseMessage(ResponseMessage responseMessage) { - if (responseMessage.StatusCode == HttpStatusCode.NotFound) + const int resourceNotFoundSubStatusCode = 0; + + try + { + responseMessage.EnsureSuccessStatusCode(); + } + catch (CosmosException ex) when (ex.StatusCode == HttpStatusCode.NotFound && ex.SubStatusCode == resourceNotFoundSubStatusCode) { return null; } - responseMessage.EnsureSuccessStatusCode(); - var responseStream = responseMessage.Content; using var reader = new StreamReader(responseStream); using var jsonReader = new JsonTextReader(reader); @@ -746,6 +795,7 @@ private static Task CreateSingleItemQueryAsync( public virtual FeedIterator CreateQuery( string containerId, CosmosSqlQuery query, + ISessionTokenStorage sessionTokenStorage, string? continuationToken = null, QueryRequestOptions? queryRequestOptions = null) { @@ -757,7 +807,7 @@ public virtual FeedIterator CreateQuery( queryDefinition, (current, parameter) => current.WithParameter(parameter.Name, parameter.Value)); - return container.GetItemQueryStreamIterator(queryDefinition, continuationToken, queryRequestOptions); + return new CosmosFeedIteratorWrapper(container.GetItemQueryStreamIterator(queryDefinition, continuationToken, queryRequestOptions), containerId, sessionTokenStorage); } [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -787,13 +837,15 @@ private sealed class DocumentAsyncEnumerable( CosmosClientWrapper cosmosClient, string containerId, PartitionKey partitionKeyValue, - CosmosSqlQuery cosmosSqlQuery) + CosmosSqlQuery cosmosSqlQuery, + ISessionTokenStorage sessionTokenStorage) : IAsyncEnumerable { private readonly CosmosClientWrapper _cosmosClient = cosmosClient; private readonly string _containerId = containerId; private readonly PartitionKey _partitionKeyValue = partitionKeyValue; private readonly CosmosSqlQuery _cosmosSqlQuery = cosmosSqlQuery; + private readonly ISessionTokenStorage _sessionTokenStorage = sessionTokenStorage; public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) => new AsyncEnumerator(this, cancellationToken); @@ -805,6 +857,7 @@ private sealed class AsyncEnumerator(DocumentAsyncEnumerable documentEnumerable, private readonly string _containerId = documentEnumerable._containerId; private readonly PartitionKey _partitionKeyValue = documentEnumerable._partitionKeyValue; private readonly CosmosSqlQuery _cosmosSqlQuery = documentEnumerable._cosmosSqlQuery; + private readonly ISessionTokenStorage _sessionTokenStorage = documentEnumerable._sessionTokenStorage; private JToken? _current; private ResponseMessage? _responseMessage; @@ -830,8 +883,10 @@ public async ValueTask MoveNextAsync() queryRequestOptions.PartitionKey = _partitionKeyValue; } + queryRequestOptions.SessionToken = _sessionTokenStorage.GetSessionToken(_containerId); + _query = _cosmosClientWrapper.CreateQuery( - _containerId, _cosmosSqlQuery, continuationToken: null, queryRequestOptions); + _containerId, _cosmosSqlQuery, _sessionTokenStorage, continuationToken: null, queryRequestOptions); } if (!_query.HasMoreResults) @@ -991,4 +1046,31 @@ public async ValueTask DisposeAsync() } #endregion ResponseMessageEnumerable + + private sealed class CosmosFeedIteratorWrapper : FeedIterator + { + private readonly FeedIterator _inner; + private readonly string _containerName; + private readonly ISessionTokenStorage _sessionTokenStorage; + + public CosmosFeedIteratorWrapper(FeedIterator inner, string containerName, ISessionTokenStorage sessionTokenStorage) + { + _inner = inner; + _containerName = containerName; + _sessionTokenStorage = sessionTokenStorage; + } + + public override bool HasMoreResults => _inner.HasMoreResults; + + public override async Task ReadNextAsync(CancellationToken cancellationToken = default) + { + var response = await _inner.ReadNextAsync(cancellationToken).ConfigureAwait(false); + if (!string.IsNullOrWhiteSpace(response.Headers.Session)) + { + _sessionTokenStorage.TrackSessionToken(_containerName, response.Headers.Session); + } + return response; + } + } + } diff --git a/src/EFCore.Cosmos/Storage/Internal/CosmosDatabaseWrapper.cs b/src/EFCore.Cosmos/Storage/Internal/CosmosDatabaseWrapper.cs index 6fd56faea75..11a74988305 100644 --- a/src/EFCore.Cosmos/Storage/Internal/CosmosDatabaseWrapper.cs +++ b/src/EFCore.Cosmos/Storage/Internal/CosmosDatabaseWrapper.cs @@ -19,7 +19,7 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// -public class CosmosDatabaseWrapper : Database +public class CosmosDatabaseWrapper : Database, IResettableService { private readonly Dictionary _documentCollections = new(); @@ -38,11 +38,13 @@ public CosmosDatabaseWrapper( DatabaseDependencies dependencies, ICurrentDbContext currentDbContext, ICosmosClientWrapper cosmosClient, + ISessionTokenStorageFactory sessionTokenStorageFactory, ILoggingOptions loggingOptions) : base(dependencies) { _currentDbContext = currentDbContext; _cosmosClient = cosmosClient; + SessionTokenStorage = sessionTokenStorageFactory.Create(currentDbContext.Context); if (loggingOptions.IsSensitiveDataLoggingEnabled) { @@ -50,6 +52,14 @@ public CosmosDatabaseWrapper( } } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual ISessionTokenStorage SessionTokenStorage { get; } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -92,7 +102,7 @@ public override async Task SaveChangesAsync( { try { - var response = await _cosmosClient.ExecuteTransactionalBatchAsync(transaction, cancellationToken).ConfigureAwait(false); + var response = await _cosmosClient.ExecuteTransactionalBatchAsync(transaction, SessionTokenStorage, cancellationToken).ConfigureAwait(false); if (!response.IsSuccess) { var exception = WrapUpdateException(response.Exception, response.ErroredEntries); @@ -456,17 +466,20 @@ private async Task SaveAsync(CosmosUpdateEntry updateEntry, CancellationTo updateEntry.CollectionId, updateEntry.Document!, updateEntry.Entry, + SessionTokenStorage, cancellationToken).ConfigureAwait(false), CosmosCudOperation.Update => await _cosmosClient.ReplaceItemAsync( updateEntry.CollectionId, updateEntry.DocumentSource.GetId(updateEntry.Entry.SharedIdentityEntry ?? updateEntry.Entry), updateEntry.Document!, updateEntry.Entry, + SessionTokenStorage, cancellationToken).ConfigureAwait(false), CosmosCudOperation.Delete => await _cosmosClient.DeleteItemAsync( updateEntry.CollectionId, updateEntry.DocumentSource.GetId(updateEntry.Entry), updateEntry.Entry, + SessionTokenStorage, cancellationToken).ConfigureAwait(false), _ => throw new UnreachableException(), }; @@ -549,6 +562,17 @@ private DbUpdateException WrapUpdateException(Exception exception, IReadOnlyList }; } + void IResettableService.ResetState() + { + SessionTokenStorage.Clear(); + } + + Task IResettableService.ResetStateAsync(CancellationToken cancellationToken) + { + ((IResettableService)this).ResetState(); + return Task.CompletedTask; + } + private sealed class SaveGroups { public required IEnumerable SingleUpdateEntries { get; init; } diff --git a/src/EFCore.Cosmos/Storage/Internal/ICosmosClientWrapper.cs b/src/EFCore.Cosmos/Storage/Internal/ICosmosClientWrapper.cs index 2775250ef76..7c401a3c977 100644 --- a/src/EFCore.Cosmos/Storage/Internal/ICosmosClientWrapper.cs +++ b/src/EFCore.Cosmos/Storage/Internal/ICosmosClientWrapper.cs @@ -47,6 +47,7 @@ Task CreateItemAsync( string containerId, JToken document, IUpdateEntry updateEntry, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default); /// @@ -60,6 +61,7 @@ Task ReplaceItemAsync( string documentId, JObject document, IUpdateEntry updateEntry, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default); /// @@ -72,6 +74,7 @@ Task DeleteItemAsync( string containerId, string documentId, IUpdateEntry entry, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default); /// @@ -83,6 +86,7 @@ Task DeleteItemAsync( FeedIterator CreateQuery( string containerId, CosmosSqlQuery query, + ISessionTokenStorage sessionTokenStorage, string? continuationToken = null, QueryRequestOptions? queryRequestOptions = null); @@ -96,6 +100,7 @@ FeedIterator CreateQuery( string containerId, PartitionKey partitionKeyValue, string resourceId, + ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default); /// @@ -107,7 +112,8 @@ FeedIterator CreateQuery( IAsyncEnumerable ExecuteSqlQueryAsync( string containerId, PartitionKey partitionKeyValue, - CosmosSqlQuery query); + CosmosSqlQuery query, + ISessionTokenStorage sessionTokenStorage); /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to @@ -139,5 +145,5 @@ IAsyncEnumerable ExecuteSqlQueryAsync( /// any release. You should only use it directly in your code with extreme caution and knowing that /// doing so can result in application failures when updating to a new Entity Framework Core release. /// - Task ExecuteTransactionalBatchAsync(ICosmosTransactionalBatchWrapper batch, CancellationToken cancellationToken = default); + Task ExecuteTransactionalBatchAsync(ICosmosTransactionalBatchWrapper batch, ISessionTokenStorage sessionTokenStorage, CancellationToken cancellationToken = default); } diff --git a/src/EFCore.Cosmos/Storage/Internal/ISessionTokenStorage.cs b/src/EFCore.Cosmos/Storage/Internal/ISessionTokenStorage.cs new file mode 100644 index 00000000000..d0c160b6e61 --- /dev/null +++ b/src/EFCore.Cosmos/Storage/Internal/ISessionTokenStorage.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + + +namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public interface ISessionTokenStorage +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public void AppendDefaultContainerSessionToken(string sessionToken); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public void AppendSessionTokens(IReadOnlyDictionary sessionTokens); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public void Clear(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public string? GetDefaultContainerTrackedToken(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public string? GetSessionToken(string containerName); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public IReadOnlyDictionary GetTrackedTokens(); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public void TrackSessionToken(string containerName, string sessionToken); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public void SetDefaultContainerSessionToken(string sessionToken); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public void SetSessionTokens(IReadOnlyDictionary sessionTokens); +} diff --git a/src/EFCore.Cosmos/Storage/Internal/ISessionTokenStorageFactory.cs b/src/EFCore.Cosmos/Storage/Internal/ISessionTokenStorageFactory.cs new file mode 100644 index 00000000000..91abc884a01 --- /dev/null +++ b/src/EFCore.Cosmos/Storage/Internal/ISessionTokenStorageFactory.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public interface ISessionTokenStorageFactory +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public ISessionTokenStorage Create(DbContext dbContext); +} diff --git a/src/EFCore.Cosmos/Storage/Internal/SessionTokenStorage.cs b/src/EFCore.Cosmos/Storage/Internal/SessionTokenStorage.cs new file mode 100644 index 00000000000..358f62525e3 --- /dev/null +++ b/src/EFCore.Cosmos/Storage/Internal/SessionTokenStorage.cs @@ -0,0 +1,254 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Cosmos.Infrastructure; +using Microsoft.EntityFrameworkCore.Cosmos.Internal; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SessionTokenStorage : ISessionTokenStorage +{ + private readonly Dictionary _containerSessionTokens; + private readonly string _defaultContainerName; + private readonly HashSet _containerNames; + private readonly SessionTokenManagementMode _mode; + private readonly string? _defaultToken; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SessionTokenStorage(string defaultContainerName, HashSet containerNames, SessionTokenManagementMode mode) + { + Debug.Assert(containerNames.Contains(defaultContainerName)); + _defaultContainerName = defaultContainerName; + _containerNames = containerNames; + _mode = mode; + _defaultToken = _mode == SessionTokenManagementMode.Manual || _mode == SessionTokenManagementMode.EnforcedManual ? "" : null; + + _containerSessionTokens = containerNames.ToDictionary(x => x, x => new CompositeSessionToken(_defaultToken)); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void SetSessionTokens(IReadOnlyDictionary sessionTokens) + { + CheckMode(); + foreach (var sessionToken in sessionTokens) + { + if (!_containerNames.Contains(sessionToken.Key)) + { + throw new InvalidOperationException(CosmosStrings.ContainerNameDoesNotExist(sessionToken.Key)); + } + + _containerSessionTokens[sessionToken.Key] = new CompositeSessionToken(sessionToken.Value, true); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void AppendSessionTokens(IReadOnlyDictionary sessionTokens) + { + CheckMode(); + foreach (var sessionToken in sessionTokens) + { + if (!_containerNames.Contains(sessionToken.Key)) + { + throw new InvalidOperationException(CosmosStrings.ContainerNameDoesNotExist("bad")); + } + + _containerSessionTokens[sessionToken.Key].Add(sessionToken.Value, true); + } + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void AppendDefaultContainerSessionToken(string sessionToken) + { + ArgumentException.ThrowIfNullOrWhiteSpace(sessionToken, nameof(sessionToken)); + CheckMode(); + _containerSessionTokens[_defaultContainerName].Add(sessionToken, true); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void SetDefaultContainerSessionToken(string? sessionToken) + { + CheckMode(); + _containerSessionTokens[_defaultContainerName] = new CompositeSessionToken(sessionToken, true); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual IReadOnlyDictionary GetTrackedTokens() + { + CheckMode(); + return _containerSessionTokens.ToDictionary(x => x.Key, x => x.Value.ConvertToString()); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string? GetDefaultContainerTrackedToken() + { + CheckMode(); + return _containerSessionTokens[_defaultContainerName].ConvertToString(); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string? GetSessionToken(string containerName) + { + ArgumentNullException.ThrowIfNullOrWhiteSpace(containerName, nameof(containerName)); + + if (_mode == SessionTokenManagementMode.FullyAutomatic) + { + return null; + } + + var sessionToken = _containerSessionTokens[containerName]; + + if (!sessionToken.IsSet) + { + if (_mode == SessionTokenManagementMode.EnforcedManual) + { + throw new InvalidOperationException(CosmosStrings.MissingSessionTokenEnforceManual(containerName)); + } + + if (_mode == SessionTokenManagementMode.SemiAutomatic) + { + return null; + } + } + + return sessionToken.ConvertToString(); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void TrackSessionToken(string containerName, string sessionToken) + { + ArgumentNullException.ThrowIfNullOrWhiteSpace(containerName, nameof(containerName)); + ArgumentNullException.ThrowIfNullOrWhiteSpace(sessionToken, nameof(sessionToken)); + + if (_mode == SessionTokenManagementMode.FullyAutomatic) + { + return; + } + + _containerSessionTokens[containerName].Add(sessionToken); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual void Clear() + { + foreach (var key in _containerSessionTokens.Keys) + { + _containerSessionTokens[key] = new CompositeSessionToken(_defaultToken); + } + } + + private void CheckMode() + { + if (_mode == SessionTokenManagementMode.FullyAutomatic) + { + throw new InvalidOperationException(CosmosStrings.EnableManualSessionTokenManagement); + } + } + + private sealed class CompositeSessionToken + { + private string? _string; + private bool _isChanged; + private readonly HashSet _tokens = new(); + + public CompositeSessionToken(string? token, bool isSet = false) + { + if (token != null) + { + Add(token); + } + IsSet = isSet; + } + + public bool IsSet { get; private set; } + + public void Add(string token, bool isSet = false) + { + IsSet = IsSet || isSet; + + if (token == string.Empty && _tokens.Count == 0) + { + _string = ""; + } + + foreach (var tokenPart in token.Split(',')) + { + if (string.IsNullOrEmpty(tokenPart)) + { + continue; + } + + if (_tokens.Add(tokenPart)) + { + _isChanged = true; + } + } + } + + public string? ConvertToString() + { + if (_isChanged) + { + _isChanged = false; + _string = string.Join(",", _tokens); + } + + return _string; + } + } +} diff --git a/src/EFCore.Cosmos/Storage/Internal/SessionTokenStorageFactory.cs b/src/EFCore.Cosmos/Storage/Internal/SessionTokenStorageFactory.cs new file mode 100644 index 00000000000..fd0aa7e5c24 --- /dev/null +++ b/src/EFCore.Cosmos/Storage/Internal/SessionTokenStorageFactory.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.EntityFrameworkCore.Cosmos.Infrastructure; +using Microsoft.EntityFrameworkCore.Cosmos.Infrastructure.Internal; +using Microsoft.EntityFrameworkCore.Cosmos.Metadata.Internal; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; + +/// +/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to +/// the same compatibility standards as public APIs. It may be changed or removed without notice in +/// any release. You should only use it directly in your code with extreme caution and knowing that +/// doing so can result in application failures when updating to a new Entity Framework Core release. +/// +public class SessionTokenStorageFactory : ISessionTokenStorageFactory +{ + private string? _defaultContainerName; + private HashSet? _containerNames; + private readonly SessionTokenManagementMode _mode; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SessionTokenStorageFactory(ICosmosSingletonOptions options) + { + _mode = options.SessionTokenManagementMode; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public ISessionTokenStorage Create(DbContext dbContext) + => new SessionTokenStorage( + _defaultContainerName ??= (string)dbContext.Model.GetAnnotation(CosmosAnnotationNames.ContainerName).Value!, + _containerNames ??= new HashSet([_defaultContainerName, ..GetContainerNames(dbContext.Model)]), + _mode); + + + private static IEnumerable GetContainerNames(IModel model) + => model.GetEntityTypes() + .Where(et => et.FindPrimaryKey() != null) + .Select(et => et.GetContainer()) + .Where(container => container != null)!; +} diff --git a/test/EFCore.Cosmos.FunctionalTests/CosmosSessionTokensTest.cs b/test/EFCore.Cosmos.FunctionalTests/CosmosSessionTokensTest.cs new file mode 100644 index 00000000000..6bf0f9a7dea --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/CosmosSessionTokensTest.cs @@ -0,0 +1,767 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Azure.Cosmos; +using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.EntityFrameworkCore; + +public class CosmosSessionTokensTest(NonSharedFixture fixture) : NonSharedModelTestBase(fixture), IClassFixture +{ + protected const string OtherContainerName = "Other"; + + protected override ITestStoreFactory TestStoreFactory + => CosmosTestStoreFactory.Instance; + + protected override string StoreName => nameof(CosmosSessionTokensTest); + + private bool _mock = true; + protected override IServiceCollection AddServices(IServiceCollection serviceCollection) + { + var services = base.AddServices(serviceCollection); + + return _mock + ? services.Replace(ServiceDescriptor.Singleton()) + : services; + } + + protected override TestStore CreateTestStore() => CosmosTestStore.Create(StoreName, (cfg) => cfg.SessionTokenManagementMode(Cosmos.Infrastructure.SessionTokenManagementMode.SemiAutomatic)); + + private static TestSessionTokenStorage _sessionTokenStorage = null!; + + [ConditionalFact] + public virtual async Task Can_use_session_tokens_no_mock() + { + _mock = false; + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + context.Customers.Add(new Customer { Id = "1", PartitionKey = "1" }); + context.OtherContainerCustomers.Add(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + + await context.SaveChangesAsync(); + + var sessionTokens = context.Database.GetSessionTokens(); + + Assert.NotNull(sessionTokens[nameof(CosmosSessionTokenContext)]); + Assert.NotNull(sessionTokens[OtherContainerName]); + + // Only way we can test this is by setting a session token that will fail the request if used.. + // This will take a couple of seconds to fail + var newTokens = sessionTokens.ToDictionary(x => x.Key, x => x.Value!.Substring(0, x.Value.IndexOf('#') + 1) + int.MaxValue); + context.Database.UseSessionTokens(newTokens!); + + var exes = new List() + { + await Assert.ThrowsAsync(() => context.Customers.ToListAsync()), + await Assert.ThrowsAsync(() => context.OtherContainerCustomers.ToListAsync()) + }; + + foreach (var ex in exes) + { + Assert.Contains("The read session is not available for the input session token.", ex.ResponseBody); + } + } + + [ConditionalFact] + public virtual async Task AppendSessionToken_uses_AppendDefaultContainerSessionToken() + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + var arg = "0:-1#231"; + context.Database.AppendSessionToken(arg); + Assert.Equal(arg, _sessionTokenStorage.AppendDefaultContainerSessionTokenCalls.Single()); + } + + [ConditionalFact] + public virtual async Task AppendSessionTokens_uses_AppendSessionTokens() + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + + var arg = new Dictionary { { OtherContainerName, "0:-1#123" }, { nameof(CosmosSessionTokenContext), "0:-1#231" } }; + context.Database.AppendSessionTokens(arg); + Assert.Equal(arg, _sessionTokenStorage.AppendSessionTokensCalls.Single()); + } + + [ConditionalFact] + public virtual async Task UseSessionToken_uses_SetDefaultContainerSessionToken() + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + var arg = "0:-1#231"; + context.Database.UseSessionToken(arg); + Assert.Equal(arg, _sessionTokenStorage.SetDefaultContainerSessionTokenCalls.Single()); + } + + [ConditionalFact] + public virtual async Task UseSessionTokens_uses_SetSessionTokens() + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + + var arg = new Dictionary { { OtherContainerName, "0:-1#123" }, { nameof(CosmosSessionTokenContext), "0:-1#231" } }; + context.Database.UseSessionTokens(arg); + Assert.Equal(arg, _sessionTokenStorage.SetSessionTokensCalls.Single()); + } + + [ConditionalFact] + public virtual async Task GetSessionTokens_uses_GetTrackedSessionTokens() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + _sessionTokenStorage.SessionTokens = new Dictionary { { OtherContainerName, "0:-1#123" }, { nameof(CosmosSessionTokenContext), "0:-1#231" } }; + var sessionTokens = context.Database.GetSessionTokens(); + Assert.Equal(_sessionTokenStorage.SessionTokens, sessionTokens); + } + + [ConditionalFact] + public virtual async Task Query_uses_session_token() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + _sessionTokenStorage.SessionTokens = new Dictionary { { OtherContainerName, "invalidtoken" }, { nameof(CosmosSessionTokenContext), "invalidtoken" } }; + + var exes = new List + { + await Assert.ThrowsAsync(() => context.Customers.ToListAsync()), + await Assert.ThrowsAsync(() => context.OtherContainerCustomers.ToListAsync()) + }; + + foreach (var ex in exes) + { + Assert.Contains("The session token provided 'invalidtoken' is invalid", ex.ResponseBody); + } + } + + [ConditionalFact] + public virtual async Task New_context_does_not_use_same_SessionTokenStorage() + { + _mock = false; + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + context.Database.UseSessionToken("A"); + + using var newContext = contextFactory.CreateContext(); + Assert.NotSame(context, newContext); + Assert.Null(newContext.Database.GetSessionToken()); + Assert.Equal("A", context.Database.GetSessionToken()); + Assert.NotSame(((CosmosDatabaseWrapper)context.GetService()).SessionTokenStorage, ((CosmosDatabaseWrapper)newContext.GetService()).SessionTokenStorage); + } + + [ConditionalFact] + public virtual async Task Pooled_context_uses_same_SessionTokenStorage() + { + _mock = false; + + var contextFactory = await InitializeAsync(); + DbContext contextCopy; + ISessionTokenStorage sessionTokenStorageCopy; + using (var context = contextFactory.CreateContext()) + { + contextCopy = context; + context.Database.UseSessionToken("A"); + sessionTokenStorageCopy = ((CosmosDatabaseWrapper)context.GetService()).SessionTokenStorage; + } + + using var newContext = contextFactory.CreateContext(); + + Assert.Same(newContext, contextCopy); + Assert.Same(sessionTokenStorageCopy, ((CosmosDatabaseWrapper)newContext.GetService()).SessionTokenStorage); + Assert.Null(newContext.Database.GetSessionToken()); + } + + [ConditionalFact] + public virtual async Task Pooled_context_clears_SessionTokenStorage() + { + var contextFactory = await InitializeAsync(); + DbContext contextCopy; + ISessionTokenStorage sessionTokenStorageCopy; + using (var context = contextFactory.CreateContext()) + { + contextCopy = context; + sessionTokenStorageCopy = ((CosmosDatabaseWrapper)context.GetService()).SessionTokenStorage; + _sessionTokenStorage.ClearCalled = false; + } + + using var newContext = contextFactory.CreateContext(); + + Assert.Same(newContext, contextCopy); + Assert.Same(sessionTokenStorageCopy, ((CosmosDatabaseWrapper)newContext.GetService()).SessionTokenStorage); + Assert.True(_sessionTokenStorage.ClearCalled); + } + + [ConditionalFact] + public virtual async Task PagingQuery_uses_session_token() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + _sessionTokenStorage.SessionTokens = new Dictionary { { OtherContainerName, "invalidtoken" }, { nameof(CosmosSessionTokenContext), "invalidtoken" } }; + + var exes = new List() + { + await Assert.ThrowsAsync(() => context.Customers.ToPageAsync(1, null)), + await Assert.ThrowsAsync(() => context.OtherContainerCustomers.ToPageAsync(1, null)), + }; + + foreach (var ex in exes) + { + Assert.Contains("The session token provided 'invalidtoken' is invalid", ex.ResponseBody); + } + } + + [ConditionalFact] + public virtual async Task Shaped_query_uses_session_token() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + _sessionTokenStorage.SessionTokens = new Dictionary { { OtherContainerName, "invalidtoken" }, { nameof(CosmosSessionTokenContext), "invalidtoken" } }; + + var exes = new List() + { + await Assert.ThrowsAsync(() => context.Customers.Select(x => new { x.Id, x.PartitionKey }).ToListAsync()), + await Assert.ThrowsAsync(() => context.OtherContainerCustomers.Select(x => new { x.Id, x.PartitionKey }).ToListAsync()) + }; + + foreach (var ex in exes) + { + Assert.Contains("The session token provided 'invalidtoken' is invalid", ex.ResponseBody); + } + } + + [ConditionalFact] + public virtual async Task Read_item_uses_session_token() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + _sessionTokenStorage.SessionTokens = new Dictionary { { OtherContainerName, "invalidtoken" }, { nameof(CosmosSessionTokenContext), "invalidtoken" } }; + + var exes = new List() + { + await Assert.ThrowsAsync(() => context.Customers.FirstOrDefaultAsync(x => x.Id == "1" && x.PartitionKey == "1")), + await Assert.ThrowsAsync(() => context.OtherContainerCustomers.FirstOrDefaultAsync(x => x.Id == "1" && x.PartitionKey == "1")) + }; + + foreach (var ex in exes) + { + Assert.Contains("The session token provided 'invalidtoken' is invalid", ex.ResponseBody); + } + } + + [ConditionalFact] + public virtual async Task Query_uses_TrackSessionToken() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + await context.Customers.ToListAsync(); + await context.OtherContainerCustomers.ToListAsync(); + + Assert.Equal(2, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var defaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.First(); + var otherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.Last(); + + Assert.Equal(nameof(CosmosSessionTokenContext), defaultContainerCall.containerName); + Assert.NotEmpty(defaultContainerCall.sessionToken); + + Assert.Equal(OtherContainerName, otherContainerCall.containerName); + Assert.NotEmpty(otherContainerCall.sessionToken); + } + + [ConditionalFact] + public virtual async Task PagingQuery_uses_TrackSessionToken() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + await context.Customers.ToPageAsync(1, null); + await context.OtherContainerCustomers.ToPageAsync(1, null); + + Assert.Equal(2, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var defaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.First(); + var otherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.Last(); + + Assert.Equal(nameof(CosmosSessionTokenContext), defaultContainerCall.containerName); + Assert.NotEmpty(defaultContainerCall.sessionToken); + + Assert.Equal(OtherContainerName, otherContainerCall.containerName); + Assert.NotEmpty(otherContainerCall.sessionToken); + } + + [ConditionalFact] + public virtual async Task Read_item_uses_TrackSessionToken() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + await context.Customers.FirstOrDefaultAsync(x => x.Id == "1" && x.PartitionKey == "1"); + await context.OtherContainerCustomers.FirstOrDefaultAsync(x => x.Id == "1" && x.PartitionKey == "1"); + + Assert.Equal(2, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var defaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.First(); + var otherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.Last(); + + Assert.Equal(nameof(CosmosSessionTokenContext), defaultContainerCall.containerName); + Assert.NotEmpty(defaultContainerCall.sessionToken); + + Assert.Equal(OtherContainerName, otherContainerCall.containerName); + Assert.NotEmpty(otherContainerCall.sessionToken); + } + + [ConditionalFact] + public virtual async Task Read_item_enumerable_uses_TrackSessionToken() + { + var contextFactory = await InitializeAsync(); + using var context = contextFactory.CreateContext(); + + await context.Customers.Where(x => x.Id == "1" && x.PartitionKey == "1").ToListAsync(); + await context.OtherContainerCustomers.Where(x => x.Id == "1" && x.PartitionKey == "1").ToListAsync(); + + Assert.Equal(2, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var defaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.First(); + var otherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.Last(); + + Assert.Equal(nameof(CosmosSessionTokenContext), defaultContainerCall.containerName); + Assert.NotEmpty(defaultContainerCall.sessionToken); + + Assert.Equal(OtherContainerName, otherContainerCall.containerName); + Assert.NotEmpty(otherContainerCall.sessionToken); + } + + [ConditionalFact] + public virtual async Task Add_AutoTransactionBehavior_Never_uses_TrackSessionToken() + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = AutoTransactionBehavior.Never; + context.Customers.Add(new Customer { Id = "1", PartitionKey = "1" }); + context.OtherContainerCustomers.Add(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + + await context.SaveChangesAsync(); + + Assert.Equal(2, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var defaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.First(); + var otherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls.Last(); + + Assert.Equal(nameof(CosmosSessionTokenContext), defaultContainerCall.containerName); + Assert.NotEmpty(defaultContainerCall.sessionToken); + + Assert.Equal(OtherContainerName, otherContainerCall.containerName); + Assert.NotEmpty(otherContainerCall.sessionToken); + } + + [ConditionalTheory] + [InlineData(true)] + [InlineData(false)] + public virtual async Task Add_AutoTransactionBehavior_Always_uses_TrackSessionToken(bool defaultContainer) + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = AutoTransactionBehavior.Always; + if (defaultContainer) + { + context.Customers.Add(new Customer { Id = "1", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Add(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + } + + await context.SaveChangesAsync(); + + Assert.Equal(1, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var call = _sessionTokenStorage.TrackSessionTokenCalls.First(); + + if (defaultContainer) + { + Assert.Equal(nameof(CosmosSessionTokenContext), call.containerName); + } + else + { + Assert.Equal(OtherContainerName, call.containerName); + } + + Assert.NotEmpty(call.sessionToken); + } + + [ConditionalFact] + public virtual async Task Delete_never_uses_TrackSessionToken() + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = AutoTransactionBehavior.Never; + + var customer = new Customer { Id = "1", PartitionKey = "1" }; + var otherContainerCustomer = new OtherContainerCustomer { Id = "1", PartitionKey = "1" }; + context.Customers.Add(customer); + context.OtherContainerCustomers.Add(otherContainerCustomer); + + await context.SaveChangesAsync(); + + var initialDefaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[0]; + var initialOtherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[1]; + + context.Customers.Remove(customer); + context.OtherContainerCustomers.Remove(otherContainerCustomer); + + await context.SaveChangesAsync(); + + Assert.Equal(4, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var defaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[2]; + var otherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[3]; + + Assert.Equal(nameof(CosmosSessionTokenContext), defaultContainerCall.containerName); + Assert.NotEmpty(defaultContainerCall.sessionToken); + + Assert.Equal(OtherContainerName, otherContainerCall.containerName); + Assert.NotEmpty(otherContainerCall.sessionToken); + + Assert.Equal(initialDefaultContainerCall.containerName, defaultContainerCall.containerName); + Assert.Equal(initialOtherContainerCall.containerName, otherContainerCall.containerName); + + Assert.NotEqual(initialDefaultContainerCall.sessionToken, defaultContainerCall.sessionToken); + Assert.NotEqual(initialOtherContainerCall.sessionToken, otherContainerCall.sessionToken); + } + + [ConditionalTheory] + [InlineData(true)] + [InlineData(false)] + public virtual async Task Delete_always_uses_TrackSessionToken(bool defaultContainer) + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = AutoTransactionBehavior.Always; + + if (defaultContainer) + { + context.Customers.Add(new Customer { Id = "1", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Add(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + } + + await context.SaveChangesAsync(); + + context.ChangeTracker.Clear(); + var initialCall = _sessionTokenStorage.TrackSessionTokenCalls[0]; + + if (defaultContainer) + { + context.Customers.Remove(new Customer { Id = "1", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Remove(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + } + + await context.SaveChangesAsync(); + + Assert.Equal(2, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var call = _sessionTokenStorage.TrackSessionTokenCalls[1]; + + if (defaultContainer) + { + Assert.Equal(nameof(CosmosSessionTokenContext), call.containerName); + } + else + { + Assert.Equal(OtherContainerName, call.containerName); + + } + Assert.NotEmpty(call.sessionToken); + + Assert.Equal(initialCall.containerName, call.containerName); + Assert.NotEqual(initialCall.sessionToken, call.sessionToken); + } + + [ConditionalFact] + public virtual async Task Update_never_uses_TrackSessionToken() + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = AutoTransactionBehavior.Never; + + var customer = new Customer { Id = "1", PartitionKey = "1" }; + var otherContainerCustomer = new OtherContainerCustomer { Id = "1", PartitionKey = "1" }; + context.Customers.Add(customer); + context.OtherContainerCustomers.Add(otherContainerCustomer); + + await context.SaveChangesAsync(); + + var initialDefaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[0]; + var initialOtherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[1]; + + customer.Name = "updated"; + otherContainerCustomer.Name = "updated"; + + await context.SaveChangesAsync(); + + Assert.Equal(4, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var defaultContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[2]; + var otherContainerCall = _sessionTokenStorage.TrackSessionTokenCalls[3]; + + Assert.Equal(nameof(CosmosSessionTokenContext), defaultContainerCall.containerName); + Assert.NotEmpty(defaultContainerCall.sessionToken); + + Assert.Equal(OtherContainerName, otherContainerCall.containerName); + Assert.NotEmpty(otherContainerCall.sessionToken); + + Assert.Equal(initialDefaultContainerCall.containerName, defaultContainerCall.containerName); + Assert.Equal(initialOtherContainerCall.containerName, otherContainerCall.containerName); + + Assert.NotEqual(initialDefaultContainerCall.sessionToken, defaultContainerCall.sessionToken); + Assert.NotEqual(initialOtherContainerCall.sessionToken, otherContainerCall.sessionToken); + } + + [ConditionalTheory] + [InlineData(true)] + [InlineData(false)] + public virtual async Task Update_always_uses_TrackSessionToken(bool defaultContainer) + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = AutoTransactionBehavior.Always; + + if (defaultContainer) + { + context.Customers.Add(new Customer { Id = "1", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Add(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + } + + await context.SaveChangesAsync(); + + context.ChangeTracker.Clear(); + var initialCall = _sessionTokenStorage.TrackSessionTokenCalls[0]; + + if (defaultContainer) + { + context.Customers.Update(new Customer { Id = "1", Name = "updated", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Update(new OtherContainerCustomer { Id = "1", Name = "updated", PartitionKey = "1" }); + } + + await context.SaveChangesAsync(); + + Assert.Equal(2, _sessionTokenStorage.TrackSessionTokenCalls.Count); + var call = _sessionTokenStorage.TrackSessionTokenCalls[1]; + + if (defaultContainer) + { + Assert.Equal(nameof(CosmosSessionTokenContext), call.containerName); + } + else + { + Assert.Equal(OtherContainerName, call.containerName); + + } + Assert.NotEmpty(call.sessionToken); + + Assert.Equal(initialCall.containerName, call.containerName); + Assert.NotEqual(initialCall.sessionToken, call.sessionToken); + } + + [ConditionalTheory] + [InlineData(AutoTransactionBehavior.WhenNeeded, true)] + [InlineData(AutoTransactionBehavior.WhenNeeded, false)] + [InlineData(AutoTransactionBehavior.Never, false)] + [InlineData(AutoTransactionBehavior.Never, true)] + [InlineData(AutoTransactionBehavior.Always, false)] + [InlineData(AutoTransactionBehavior.Always, true)] + public virtual async Task Add_uses_GetSessionToken(AutoTransactionBehavior autoTransactionBehavior, bool defaultContainer) + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = autoTransactionBehavior; + + // Only way we can test this is by setting a session token that will fail the request if used.. + // Only way to do this for a write is to set an invalid session token.. + _sessionTokenStorage.SessionTokens = new Dictionary { { defaultContainer ? nameof(CosmosSessionTokenContext) : OtherContainerName, "invalidtoken" } }; + + if (defaultContainer) + { + context.Customers.Add(new Customer { Id = "1", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Add(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + } + + var ex = await Assert.ThrowsAsync(() => context.SaveChangesAsync()); + + Assert.Contains("The session token provided 'invalidtoken' is invalid.", ((CosmosException)ex.InnerException!).ResponseBody); + } + + [ConditionalTheory] + [InlineData(AutoTransactionBehavior.WhenNeeded, true)] + [InlineData(AutoTransactionBehavior.WhenNeeded, false)] + [InlineData(AutoTransactionBehavior.Never, false)] + [InlineData(AutoTransactionBehavior.Never, true)] + [InlineData(AutoTransactionBehavior.Always, false)] + [InlineData(AutoTransactionBehavior.Always, true)] + public virtual async Task Update_uses_session_token(AutoTransactionBehavior autoTransactionBehavior, bool defaultContainer) + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = autoTransactionBehavior; + + var sessionTokens = context.Database.GetSessionTokens(); + // Only way we can test this is by setting a session token that will fail the request if used.. + // Only way to do this for a write is to set an invalid session token.. + _sessionTokenStorage.SessionTokens = new Dictionary { { defaultContainer ? nameof(CosmosSessionTokenContext) : OtherContainerName, "invalidtoken" } }; + + if (defaultContainer) + { + context.Customers.Update(new Customer { Id = "1", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Update(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + } + + var ex = await Assert.ThrowsAsync(() => context.SaveChangesAsync()); + + Assert.Contains("The session token provided 'invalidtoken' is invalid.", ((CosmosException)ex.InnerException!).ResponseBody); + } + + [ConditionalTheory] + [InlineData(AutoTransactionBehavior.WhenNeeded, true)] + [InlineData(AutoTransactionBehavior.WhenNeeded, false)] + [InlineData(AutoTransactionBehavior.Never, false)] + [InlineData(AutoTransactionBehavior.Never, true)] + [InlineData(AutoTransactionBehavior.Always, false)] + [InlineData(AutoTransactionBehavior.Always, true)] + public virtual async Task Delete_uses_session_token(AutoTransactionBehavior autoTransactionBehavior, bool defaultContainer) + { + var contextFactory = await InitializeAsync(); + + using var context = contextFactory.CreateContext(); + context.Database.AutoTransactionBehavior = autoTransactionBehavior; + + var sessionTokens = context.Database.GetSessionTokens(); + // Only way we can test this is by setting a session token that will fail the request if used.. + // Only way to do this for a write is to set an invalid session token.. + _sessionTokenStorage.SessionTokens = new Dictionary { { defaultContainer ? nameof(CosmosSessionTokenContext) : OtherContainerName, "invalidtoken" } }; + + if (defaultContainer) + { + context.Customers.Remove(new Customer { Id = "1", PartitionKey = "1" }); + } + else + { + context.OtherContainerCustomers.Remove(new OtherContainerCustomer { Id = "1", PartitionKey = "1" }); + } + + var ex = await Assert.ThrowsAsync(() => context.SaveChangesAsync()); + + Assert.Contains("The session token provided 'invalidtoken' is invalid.", ((CosmosException)ex.InnerException!).ResponseBody); + } + + private class TestSessionTokenStorageFactory : ISessionTokenStorageFactory + { + public ISessionTokenStorage Create(DbContext _) + => _sessionTokenStorage = new(); + } + + private class TestSessionTokenStorage : ISessionTokenStorage + { + public Dictionary SessionTokens { get; set; } = new() { { nameof(CosmosSessionTokenContext), null }, { OtherContainerName, null } }; + + public List AppendDefaultContainerSessionTokenCalls { get; set; } = new(); + public List> AppendSessionTokensCalls { get; set; } = new(); + public List SetDefaultContainerSessionTokenCalls { get; set; } = new(); + + public List> SetSessionTokensCalls { get; set; } = new(); + public List<(string containerName, string sessionToken)> TrackSessionTokenCalls { get; set; } = new(); + public bool ClearCalled { get; set; } + + + + public void AppendDefaultContainerSessionToken(string sessionToken) => AppendDefaultContainerSessionTokenCalls.Add(sessionToken); + + public void AppendSessionTokens(IReadOnlyDictionary sessionTokens) => AppendSessionTokensCalls.Add(sessionTokens); + public void Clear() => ClearCalled = true; + public string? GetDefaultContainerTrackedToken() => SessionTokens.FirstOrDefault().Value; + public string? GetSessionToken(string containerName) => SessionTokens[containerName]; + public IReadOnlyDictionary GetTrackedTokens() => SessionTokens; + public void SetDefaultContainerSessionToken(string sessionToken) => SetDefaultContainerSessionTokenCalls.Add(sessionToken); + public void SetSessionTokens(IReadOnlyDictionary sessionTokens) => SetSessionTokensCalls.Add(sessionTokens); + public void TrackSessionToken(string containerName, string sessionToken) => TrackSessionTokenCalls.Add((containerName, sessionToken)); + } + + + public class CosmosSessionTokenContext(DbContextOptions options) : PoolableDbContext(options) + { + public DbSet Customers { get; set; } = null!; + public DbSet OtherContainerCustomers { get; set; } = null!; + + protected override void OnModelCreating(ModelBuilder builder) + { + builder.Entity( + b => + { + b.HasKey(c => c.Id); + b.Property(c => c.ETag).IsETagConcurrency(); + b.OwnsMany(x => x.Children); + b.HasPartitionKey(c => c.PartitionKey); + }); + + builder.Entity( + b => + { + b.HasKey(c => c.Id); + b.HasPartitionKey(c => c.PartitionKey); + b.ToContainer(OtherContainerName); + }); + } + } + + public class Customer + { + public string? Id { get; set; } + + public string? Name { get; set; } + + public string? ETag { get; set; } + + public string? PartitionKey { get; set; } + + public ICollection Children { get; } = new HashSet(); + } + + public class DummyChild + { + public string? Id { get; init; } + } + + public class OtherContainerCustomer + { + public string? Id { get; set; } + + public string? Name { get; set; } + + public string? PartitionKey { get; set; } + } +} diff --git a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs index 6107763f180..988350786a2 100644 --- a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs +++ b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/CosmosTestStore.cs @@ -241,7 +241,7 @@ private async Task CreateFromFile(DbContext context) document["$type"] = entityName; await cosmosClient.CreateItemAsync( - containerName!, document, new FakeUpdateEntry()).ConfigureAwait(false); + containerName!, document, new FakeUpdateEntry(), new NullSessionTokenStorage()).ConfigureAwait(false); } else if (reader.TokenType == JsonToken.EndObject) { diff --git a/test/EFCore.Cosmos.FunctionalTests/TestUtilities/NullSessionTokenStorage.cs b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/NullSessionTokenStorage.cs new file mode 100644 index 00000000000..9aec9cdbe9f --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/TestUtilities/NullSessionTokenStorage.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + + +using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; + +namespace Microsoft.EntityFrameworkCore.TestUtilities; + +/// +public class NullSessionTokenStorage : ISessionTokenStorage +{ + /// + public void AppendDefaultContainerSessionToken(string sessionToken) { } + + /// + public void AppendSessionTokens(IReadOnlyDictionary sessionTokens) {} + + /// + public void Clear() {} + + /// + public string? GetDefaultContainerTrackedToken() => null; + + /// + public string? GetSessionToken(string containerName) => null; + + /// + public IReadOnlyDictionary GetTrackedTokens() => null!; + + /// + public void TrackSessionToken(string containerName, string sessionToken) {} + + /// + public void SetDefaultContainerSessionToken(string sessionToken) {} + + /// + public void SetSessionTokens(IReadOnlyDictionary sessionTokens) {} +} diff --git a/test/EFCore.Cosmos.Tests/Extensions/CosmosDbContextOptionsExtensionsTests.cs b/test/EFCore.Cosmos.Tests/Extensions/CosmosDbContextOptionsExtensionsTests.cs index 895f45889bf..a7c8db0ea68 100644 --- a/test/EFCore.Cosmos.Tests/Extensions/CosmosDbContextOptionsExtensionsTests.cs +++ b/test/EFCore.Cosmos.Tests/Extensions/CosmosDbContextOptionsExtensionsTests.cs @@ -65,6 +65,7 @@ public void Can_create_options_with_valid_values() Test(o => o.MaxTcpConnectionsPerEndpoint(3), o => Assert.Equal(3, o.MaxTcpConnectionsPerEndpoint)); Test(o => o.LimitToEndpoint(), o => Assert.True(o.LimitToEndpoint)); Test(o => o.ContentResponseOnWriteEnabled(), o => Assert.True(o.EnableContentResponseOnWrite)); + Test(o => o.SessionTokenManagementMode(Cosmos.Infrastructure.SessionTokenManagementMode.EnforcedManual), o => Assert.Equal(Cosmos.Infrastructure.SessionTokenManagementMode.EnforcedManual, o.SessionTokenManagementMode)); var webProxy = new WebProxy(); Test(o => o.WebProxy(webProxy), o => Assert.Same(webProxy, o.WebProxy)); diff --git a/test/EFCore.Cosmos.Tests/Storage/Internal/SessionTokenStorageTest.cs b/test/EFCore.Cosmos.Tests/Storage/Internal/SessionTokenStorageTest.cs new file mode 100644 index 00000000000..1783efeec50 --- /dev/null +++ b/test/EFCore.Cosmos.Tests/Storage/Internal/SessionTokenStorageTest.cs @@ -0,0 +1,1243 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using Microsoft.EntityFrameworkCore.Cosmos.Infrastructure; +using Microsoft.EntityFrameworkCore.Cosmos.Internal; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; + +public class SessionTokenStorageTest +{ + private readonly string _defaultContainerName = "default"; + private readonly string _otherContainerName = "other"; + private readonly HashSet _containerNames = new(["default", "other"]); + + // ================================================================ + // FUNCTIONAL TESTS - SET AND RETRIEVE + // ================================================================ + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_SetSingle_Default(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "A" } }); + + AssertDefault(storage, "A"); + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_SetSingle_Other(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.SetSessionTokens(new Dictionary { { _otherContainerName, "A" } }); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertDefault(storage, ""); + } + else + { + AssertDefault(storage, null); + } + } + AssertOther(storage, "A"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_Multiple(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.SetSessionTokens(new Dictionary + { + { _defaultContainerName, "Token1" }, + { _otherContainerName, "Token2" } + }); + + AssertDefault(storage, "Token1"); + AssertOther(storage, "Token2"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_OverwritesSet(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.SetSessionTokens(new Dictionary + { + { _defaultContainerName, "A" }, + { _otherContainerName, "B" } + }); + storage.SetSessionTokens(new Dictionary + { + { _defaultContainerName, "" }, + { _otherContainerName, "" } + }); + + AssertDefault(storage, ""); + AssertOther(storage, ""); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_OverwritesTracked(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.TrackSessionToken(_defaultContainerName, "Token1"); + storage.TrackSessionToken(_otherContainerName, "Token2"); + storage.SetSessionTokens(new Dictionary + { + { _defaultContainerName, "A" }, + { _otherContainerName, "B" } + }); + + AssertDefault(storage, "A"); + AssertOther(storage, "B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_SingleContainer_OverwritesOnlySingleContainer(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary + { + { _defaultContainerName, "A" }, + { _otherContainerName, "B" } + }); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "C" } }); + + AssertDefault(storage, "C"); + AssertOther(storage, "B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_Null_SetsNull(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "A" }, { _otherContainerName, "B" } }); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, null } }); + + AssertDefault(storage, null); + AssertOther(storage, "B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetDefaultContainerSessionToken_SetsToken(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetDefaultContainerSessionToken("A"); + + AssertDefault(storage, "A"); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetDefaultContainerSessionToken_OverwritesSet(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetDefaultContainerSessionToken("A"); + storage.SetDefaultContainerSessionToken("B"); + + AssertDefault(storage, "B"); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetDefaultContainerSessionToken_OverwritesTracked(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.SetDefaultContainerSessionToken("B"); + + AssertDefault(storage, "B"); + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_NoPreviousToken_SetsToken(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.AppendDefaultContainerSessionToken("A"); + + AssertDefault(storage, "A"); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_PreviousSetToken_AppendsToken(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetDefaultContainerSessionToken("A"); + storage.AppendDefaultContainerSessionToken("B"); + + AssertDefault(storage, "A,B"); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_PreviousSetToken_Duplicate_DoesNotAppendToken(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetDefaultContainerSessionToken("A"); + storage.AppendDefaultContainerSessionToken("A"); + + AssertDefault(storage, "A"); + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_PreviousTrackedToken_AppendsToken(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.AppendDefaultContainerSessionToken("B"); + + AssertDefault(storage, "A,B"); + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_PreviousTrackedToken_Duplicate_DoesNotAppendToken(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.AppendDefaultContainerSessionToken("A"); + + AssertDefault(storage, "A"); + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertOther(storage, ""); + } + else + { + AssertOther(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_MultipleContainers_NoPreviousTokens_SetsTokens(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.AppendSessionTokens(new Dictionary + { + { _defaultContainerName, "A" }, + { _otherContainerName, "B" } + }); + + AssertDefault(storage, "A"); + AssertOther(storage, "B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_SingleContainer_NoPreviousTokens_SetsTokens(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.AppendSessionTokens(new Dictionary + { + { _otherContainerName, "B" } + }); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertDefault(storage, ""); + } + else + { + AssertDefault(storage, null); + } + } + AssertOther(storage, "B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_PreviousSetTokens_AppendsTokens(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary + { + { _defaultContainerName, "A" }, + { _otherContainerName, "B" } + }); + storage.AppendSessionTokens(new Dictionary + { + { _defaultContainerName, "C" }, + { _otherContainerName, "D" } + }); + + AssertDefault(storage, "A,C"); + AssertOther(storage, "B,D"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_PreviousSetToken_AppendsAndSetsTokens(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary + { + { _otherContainerName, "B" } + }); + storage.AppendSessionTokens(new Dictionary + { + { _defaultContainerName, "C" }, + { _otherContainerName, "D" } + }); + + AssertDefault(storage, "C"); + AssertOther(storage, "B,D"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_PreviousTrackedToken_AppendsAndSetsTokens(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.TrackSessionToken(_otherContainerName, "B"); + storage.AppendSessionTokens(new Dictionary + { + { _defaultContainerName, "C" }, + { _otherContainerName, "D" } + }); + + AssertDefault(storage, "C"); + AssertOther(storage, "B,D"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetDefaultContainerSessionToken_RemovesDuplicates(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetDefaultContainerSessionToken("A,A,B"); + AssertDefault(storage, "A,B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_RemovesDuplicates(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.AppendDefaultContainerSessionToken("A,A,B"); + AssertDefault(storage, "A,B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_RemovesDuplicates(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "A,B,A" }, { _otherContainerName, "B,C,B" } }); + AssertDefault(storage, "A,B"); + AssertOther(storage, "B,C"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_RemovesDuplicates(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "A,B,A" }, { _otherContainerName, "B,C,B" } }); + AssertDefault(storage, "A,B"); + AssertOther(storage, "B,C"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_PreviouslySetTokens_RemovesDuplicates(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "A,C,E" }, { _otherContainerName, "J,K,L" } }); + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "A,B,B" }, { _otherContainerName, "K,A,A" } }); + AssertDefault(storage, "A,C,E,B"); + AssertOther(storage, "J,K,L,A"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_EmptyStrings_DoesNotAppend(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "A,C" }, { _otherContainerName, "J,K" } }); + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "" }, { _otherContainerName, "" } }); + AssertDefault(storage, "A,C"); + AssertOther(storage, "J,K"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_NoPreviousTokens_EmptyStrings_Sets(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "" }, { _otherContainerName, "" } }); + AssertDefault(storage, ""); + AssertOther(storage, ""); + } + + // ================================================================ + // TRACK + // ================================================================ + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void TrackSessionToken_SetsToken(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.TrackSessionToken(_otherContainerName, "A"); + + AssertDefaultTracked(storage, "A"); + AssertOtherTracked(storage, "A"); + + if (mode == SessionTokenManagementMode.Manual) + { + AssertDefaultUsed(storage, "A"); + AssertOtherUsed(storage, "A"); + } + else if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertDefaultUsed(storage, ""); + AssertOtherUsed(storage, ""); + } + else + { + AssertDefaultUsed(storage, null); + AssertOtherUsed(storage, null); + } + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void TrackSessionToken_Appends(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.TrackSessionToken(_defaultContainerName, "B"); + storage.TrackSessionToken(_defaultContainerName, "A"); + + storage.TrackSessionToken(_otherContainerName, "A"); + storage.TrackSessionToken(_otherContainerName, "C"); + storage.TrackSessionToken(_otherContainerName, "A"); + + AssertDefaultTracked(storage, "A,B"); + AssertOtherTracked(storage, "A,C"); + + if (mode == SessionTokenManagementMode.Manual) + { + AssertDefaultUsed(storage, "A,B"); + AssertOtherUsed(storage, "A,C"); + } + else if (mode != SessionTokenManagementMode.EnforcedManual) + { + if (mode == SessionTokenManagementMode.Manual) + { + AssertDefaultUsed(storage, ""); + AssertOtherUsed(storage, ""); + } + else + { + AssertDefaultUsed(storage, null); + AssertOtherUsed(storage, null); + } + } + } + + // ================================================================ + // ENFORCED MANUAL MODE TESTS + // ================================================================ + + [ConditionalFact] + public virtual void EnforcedManual_WhenGettingTokenBeforeSet_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.EnforcedManual); + var ex = Assert.Throws(() => + storage.GetSessionToken(_defaultContainerName)); + Assert.Contains(CosmosStrings.MissingSessionTokenEnforceManual(_defaultContainerName), ex.Message); + } + + [ConditionalFact] + public virtual void EnforcedManual_WhenGettingTokenAfterClear_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.EnforcedManual); + storage.SetDefaultContainerSessionToken("A"); + storage.Clear(); + var ex = Assert.Throws(() => + storage.GetSessionToken(_defaultContainerName)); + Assert.Contains(CosmosStrings.MissingSessionTokenEnforceManual(_defaultContainerName), ex.Message); + } + + [ConditionalFact] + public virtual void EnforcedManual_SetDefaultContainerSessionToken_SetsAndUses() + { + var storage = CreateStorage(SessionTokenManagementMode.EnforcedManual); + storage.SetDefaultContainerSessionToken("A"); + AssertDefault(storage, "A"); + var ex = Assert.Throws(() => + storage.GetSessionToken(_otherContainerName)); + Assert.Contains(CosmosStrings.MissingSessionTokenEnforceManual(_otherContainerName), ex.Message); + } + + [ConditionalFact] + public virtual void EnforcedManual_SetSessionTokens_SetsAndUses() + { + var storage = CreateStorage(SessionTokenManagementMode.EnforcedManual); + storage.SetSessionTokens(new Dictionary + { + { _defaultContainerName, "A" }, + { _otherContainerName, "B" } + }); + + AssertDefault(storage, "A"); + AssertOther(storage, "B"); + } + + [ConditionalFact] + public virtual void EnforcedManual_WhenOneContainerNotSet_ThrowsForThatContainerOnly() + { + var storage = CreateStorage(SessionTokenManagementMode.EnforcedManual); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "A" } }); + + Assert.Equal("A", storage.GetSessionToken(_defaultContainerName)); + var ex = Assert.Throws(() => storage.GetSessionToken(_otherContainerName)); + Assert.Contains(CosmosStrings.MissingSessionTokenEnforceManual(_otherContainerName), ex.Message); + } + + // ================================================================ + // SEMI-AUTOMATIC SPECIFIC TESTS + // ================================================================ + + [ConditionalFact] + public virtual void SemiAutomatic_WhenTrackingToken_SetsButDoesnotUseToken() + { + var storage = CreateStorage(SessionTokenManagementMode.SemiAutomatic); + storage.TrackSessionToken(_defaultContainerName, "A"); + AssertDefaultTracked(storage, "A"); + AssertDefaultUsed(storage, null); + } + + [ConditionalFact] + public virtual void SemiAutomatic_WhenSetToken_SetsAndUses() + { + var storage = CreateStorage(SessionTokenManagementMode.SemiAutomatic); + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.AppendDefaultContainerSessionToken("A"); + AssertDefault(storage, "A"); + } + + // ================================================================ + // MANUAL SPECIFIC TESTS + // ================================================================ + + [ConditionalFact] + public virtual void Manual_TrackedToken_UsesToken() + { + var storage = CreateStorage(SessionTokenManagementMode.Manual); + + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.TrackSessionToken(_otherContainerName, "B"); + + Assert.True(storage.GetSessionToken(_defaultContainerName) == "A"); + Assert.True(storage.GetSessionToken(_otherContainerName) == "B"); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Manual_Constructor_AllContainersHaveEmptyString(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + var tokens = storage.GetTrackedTokens(); + Assert.True(tokens[_defaultContainerName] == ""); + Assert.True(tokens[_otherContainerName] == ""); + Assert.True(storage.GetDefaultContainerTrackedToken() == ""); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + Assert.True(storage.GetSessionToken(_defaultContainerName) == ""); + Assert.True(storage.GetSessionToken(_otherContainerName) == ""); + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Manual_Clear_ResetsAllContainersToEmptyString(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "A" }, { _otherContainerName, "B" } }); + storage.Clear(); + + var tokens = storage.GetTrackedTokens(); + Assert.True(tokens[_defaultContainerName] == ""); + Assert.True(tokens[_otherContainerName] == ""); + Assert.True(storage.GetDefaultContainerTrackedToken() == ""); + if (mode != SessionTokenManagementMode.EnforcedManual) + { + Assert.True(storage.GetSessionToken(_defaultContainerName) == ""); + Assert.True(storage.GetSessionToken(_otherContainerName) == ""); + } + } + + // ================================================================ + // CLEAR + // ================================================================ + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Clear_WhenClearing_ContainersStillExistInTrackedTokens(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.Clear(); + + var tokens = storage.GetTrackedTokens(); + Assert.Contains(_defaultContainerName, tokens.Keys); + Assert.Contains(_otherContainerName, tokens.Keys); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Clear_WhenClearingSetTokens_ResetsAllContainers(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "A" }, { _otherContainerName, "B" } }); + storage.Clear(); + + var tokens = storage.GetTrackedTokens(); + + if (mode == SessionTokenManagementMode.Manual || mode == SessionTokenManagementMode.EnforcedManual) + { + Assert.True(tokens[_defaultContainerName] == ""); + Assert.True(tokens[_otherContainerName] == ""); + Assert.True(storage.GetDefaultContainerTrackedToken() == ""); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + Assert.True(storage.GetSessionToken(_defaultContainerName) == ""); + Assert.True(storage.GetSessionToken(_otherContainerName) == ""); + } + } + else + { + Assert.Null(tokens[_defaultContainerName]); + Assert.Null(tokens[_otherContainerName]); + Assert.Null(storage.GetDefaultContainerTrackedToken()); + + Assert.Null(storage.GetSessionToken(_defaultContainerName)); + Assert.Null(storage.GetSessionToken(_otherContainerName)); + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Clear_WhenClearingTrackedTokens_ResetsAllContainers(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.TrackSessionToken(_otherContainerName, "B"); + storage.Clear(); + + var tokens = storage.GetTrackedTokens(); + + if (mode == SessionTokenManagementMode.Manual || mode == SessionTokenManagementMode.EnforcedManual) + { + Assert.True(tokens[_defaultContainerName] == ""); + Assert.True(tokens[_otherContainerName] == ""); + Assert.True(storage.GetDefaultContainerTrackedToken() == ""); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + Assert.True(storage.GetSessionToken(_defaultContainerName) == ""); + Assert.True(storage.GetSessionToken(_otherContainerName) == ""); + } + } + else + { + Assert.Null(tokens[_defaultContainerName]); + Assert.Null(tokens[_otherContainerName]); + Assert.Null(storage.GetDefaultContainerTrackedToken()); + + Assert.Null(storage.GetSessionToken(_defaultContainerName)); + Assert.Null(storage.GetSessionToken(_otherContainerName)); + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Clear_WhenClearing_CanSetNewTokensAfterClear(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "A" }, { _otherContainerName, "B" } }); + + storage.Clear(); + storage.SetSessionTokens(new Dictionary { { _defaultContainerName, "C" }, { _otherContainerName, "D" } }); + + var expectedDefault = "C"; + var expectedOther = "D"; + var defaultContainerTrackedToken = storage.GetDefaultContainerTrackedToken(); + var tokens = storage.GetTrackedTokens(); + + Assert.Equal(expectedDefault, defaultContainerTrackedToken); + + Assert.Equal(expectedDefault, tokens[_defaultContainerName]); + Assert.Equal(expectedOther, tokens[_otherContainerName]); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + Assert.Equal(expectedDefault, storage.GetSessionToken(_defaultContainerName)); + Assert.Equal(expectedOther, storage.GetSessionToken(_otherContainerName)); + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Clear_WhenClearing_CanAppendNewTokensAfterClear(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "A" }, { _otherContainerName, "B" } }); + + storage.Clear(); + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "C" }, { _otherContainerName, "D" } }); + + var expectedDefault = "C"; + var expectedOther = "D"; + var defaultContainerTrackedToken = storage.GetDefaultContainerTrackedToken(); + var tokens = storage.GetTrackedTokens(); + + Assert.Equal(expectedDefault, defaultContainerTrackedToken); + + Assert.Equal(expectedDefault, tokens[_defaultContainerName]); + Assert.Equal(expectedOther, tokens[_otherContainerName]); + + if (mode != SessionTokenManagementMode.EnforcedManual) + { + Assert.Equal(expectedDefault, storage.GetSessionToken(_defaultContainerName)); + Assert.Equal(expectedOther, storage.GetSessionToken(_otherContainerName)); + } + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Clear_WhenClearing_CanTrackNewTokensAfterClear(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + storage.AppendSessionTokens(new Dictionary { { _defaultContainerName, "A" }, { _otherContainerName, "B" } }); + + storage.Clear(); + + storage.TrackSessionToken(_defaultContainerName, "C"); + storage.TrackSessionToken(_otherContainerName, "D"); + + AssertDefaultTracked(storage, "C"); + AssertOtherTracked(storage, "D"); + } + + // ================================================================ + // INITIALIZATION AND CONTAINER MANAGEMENT TESTS + // ================================================================ + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void Constructor_AllContainersAreInitialized(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + + var tokens = storage.GetTrackedTokens(); + Assert.Equal(2, tokens.Count); + Assert.True(tokens.ContainsKey(_defaultContainerName)); + Assert.True(tokens.ContainsKey(_otherContainerName)); + } + + [ConditionalFact] + public virtual void Constructor_WhenInitializing_AllContainersStartWithNullTokens() + { + var storage = CreateStorage(SessionTokenManagementMode.SemiAutomatic); + + var tokens = storage.GetTrackedTokens(); + Assert.Null(tokens[_defaultContainerName]); + Assert.Null(tokens[_otherContainerName]); + Assert.Null(storage.GetDefaultContainerTrackedToken()); + Assert.Null(storage.GetSessionToken(_defaultContainerName)); + Assert.Null(storage.GetSessionToken(_otherContainerName)); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void GetTrackedTokens_WhenCalled_ReturnsSnapshotNotLiveReference(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + var snapshot = storage.GetTrackedTokens(); + + storage.AppendDefaultContainerSessionToken("A"); + var snapshot2 = storage.GetTrackedTokens(); + + Assert.NotSame(snapshot, snapshot2); + if (mode == SessionTokenManagementMode.Manual || mode == SessionTokenManagementMode.EnforcedManual) + { + Assert.True(snapshot[_defaultContainerName] == ""); + } + else + { + Assert.Null(snapshot[_defaultContainerName]); + } + Assert.Equal("A", snapshot2[_defaultContainerName]); + } + + // ================================================================ + // FULLY AUTOMATIC MODE SPECIFIC TESTS + // ================================================================ + + [ConditionalFact] + public virtual void FullyAutomatic_WhenCallingSetSessionTokens_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + var ex = Assert.Throws(() => + storage.SetSessionTokens(new Dictionary())); + Assert.Equal(CosmosStrings.EnableManualSessionTokenManagement, ex.Message); + } + + [ConditionalFact] + public virtual void FullyAutomatic_WhenCallingGetTrackedTokens_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + var ex = Assert.Throws(() => storage.GetTrackedTokens()); + Assert.Equal(CosmosStrings.EnableManualSessionTokenManagement, ex.Message); + } + + [ConditionalFact] + public virtual void FullyAutomatic_WhenCallingAppendSessionTokens_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + var ex = Assert.Throws(() => + storage.AppendSessionTokens(new Dictionary())); + Assert.Equal(CosmosStrings.EnableManualSessionTokenManagement, ex.Message); + } + + [ConditionalFact] + public virtual void FullyAutomatic_WhenCallingSetDefaultContainerSessionToken_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + var ex = Assert.Throws(() => + storage.SetDefaultContainerSessionToken(null)); + Assert.Equal(CosmosStrings.EnableManualSessionTokenManagement, ex.Message); + } + + [ConditionalFact] + public virtual void FullyAutomatic_WhenCallingAppendDefaultContainerSessionToken_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + var ex = Assert.Throws(() => + storage.AppendDefaultContainerSessionToken("A")); + Assert.Equal(CosmosStrings.EnableManualSessionTokenManagement, ex.Message); + } + + [ConditionalFact] + public virtual void FullyAutomatic_WhenCallingGetDefaultContainerTrackedToken_ThrowsInvalidOperationException() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + var ex = Assert.Throws(() => + storage.GetDefaultContainerTrackedToken()); + Assert.Equal(CosmosStrings.EnableManualSessionTokenManagement, ex.Message); + } + + [ConditionalFact] + public virtual void FullyAutomatic_WhenTrackingToken_AlwaysReturnsNull() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + storage.TrackSessionToken(_defaultContainerName, "A"); + Assert.Null(storage.GetSessionToken(_defaultContainerName)); + } + + [ConditionalFact] + public virtual void FullyAutomatic_WhenTrackingMultipleTokens_AlwaysReturnsNull() + { + var storage = CreateStorage(SessionTokenManagementMode.FullyAutomatic); + storage.TrackSessionToken(_defaultContainerName, "A"); + storage.TrackSessionToken(_defaultContainerName, "B"); + storage.TrackSessionToken(_otherContainerName, "C"); + Assert.Null(storage.GetSessionToken(_defaultContainerName)); + Assert.Null(storage.GetSessionToken(_otherContainerName)); + } + + // ================================================================ + // Argument exceptions + // ================================================================ + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.FullyAutomatic)] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void TrackSessionToken_WhenContainerNameIsNull_ThrowsArgumentNullException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.TrackSessionToken(null!, "A")); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.FullyAutomatic)] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void TrackSessionToken_WhenContainerNameIsWhitespace_ThrowsArgumentNullException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.TrackSessionToken(" ", "A")); + Assert.Throws(() => storage.TrackSessionToken("", "A")); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.FullyAutomatic)] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void TrackSessionToken_WhenTokenIsNull_ThrowsArgumentNullException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.TrackSessionToken(_defaultContainerName, null!)); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.FullyAutomatic)] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void TrackSessionToken_WhenTokenIsWhitespace_ThrowsArgumentNullException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.TrackSessionToken(_defaultContainerName, " ")); + Assert.Throws(() => storage.TrackSessionToken(_defaultContainerName, "")); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.FullyAutomatic)] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_WhenTokenIsNull_ThrowsArgumentNullException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.AppendDefaultContainerSessionToken(null!)); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.FullyAutomatic)] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendDefaultContainerSessionToken_WhenTokenIsWhitespace_ThrowsArgumentException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.AppendDefaultContainerSessionToken(" ")); + Assert.Throws(() => storage.AppendDefaultContainerSessionToken("")); + } + + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void SetSessionTokens_WhenContainerNameIsUnknown_ThrowsInvalidOperationException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + var ex = Assert.Throws(() => + storage.SetSessionTokens(new Dictionary { { "bad", "A" } })); + Assert.Equal(CosmosStrings.ContainerNameDoesNotExist("bad"), ex.Message); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void AppendSessionTokens_WhenContainerNameIsUnknown_ThrowsInvalidOperationException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + var ex = Assert.Throws(() => + storage.AppendSessionTokens(new Dictionary { { "bad", "A" } })); + Assert.Equal(CosmosStrings.ContainerNameDoesNotExist("bad"), ex.Message); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void GetSessionToken_WhenContainerNameIsNull_ThrowsArgumentNullException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.GetSessionToken(null!)); + } + + [ConditionalTheory] + [InlineData(SessionTokenManagementMode.SemiAutomatic)] + [InlineData(SessionTokenManagementMode.Manual)] + [InlineData(SessionTokenManagementMode.EnforcedManual)] + public virtual void GetSessionToken_WhenContainerNameIsWhitespace_ThrowsArgumentNullException(SessionTokenManagementMode mode) + { + var storage = CreateStorage(mode); + Assert.Throws(() => storage.GetSessionToken(" ")); + Assert.Throws(() => storage.GetSessionToken("")); + } + + + private SessionTokenStorage CreateStorage(SessionTokenManagementMode mode) + => new(_defaultContainerName, _containerNames, mode); + + private void AssertDefault(SessionTokenStorage storage, string? value) + { + AssertDefaultTracked(storage, value); + AssertDefaultUsed(storage, value); + } + + private void AssertDefaultUsed(SessionTokenStorage storage, string? value) + { + if (value == null) + { + Assert.Null(storage.GetSessionToken(_defaultContainerName)); + } + else + { + Assert.Equal(value, storage.GetSessionToken(_defaultContainerName)); + } + } + + private void AssertDefaultTracked(SessionTokenStorage storage, string? value) + { + if (value == null) + { + Assert.Null(storage.GetDefaultContainerTrackedToken()); + Assert.Null(storage.GetTrackedTokens()[_defaultContainerName]); + } + else + { + Assert.Equal(value, storage.GetDefaultContainerTrackedToken()); + Assert.Equal(value, storage.GetTrackedTokens()[_defaultContainerName]); + } + } + + private void AssertOther(SessionTokenStorage storage, string? value) + { + AssertOtherTracked(storage, value); + AssertOtherUsed(storage, value); + } + + private void AssertOtherUsed(SessionTokenStorage storage, string? value) + { + if (value == null) + { + Assert.Null(storage.GetSessionToken(_otherContainerName)); + } + else + { + Assert.Equal(value, storage.GetSessionToken(_otherContainerName)); + } + } + + private void AssertOtherTracked(SessionTokenStorage storage, string? value) + { + if (value == null) + { + Assert.Null(storage.GetTrackedTokens()[_otherContainerName]); + } + else + { + Assert.Equal(value, storage.GetTrackedTokens()[_otherContainerName]); + } + } + +}