Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions dotnet/src/VectorData/SqlServer/SqlServerCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand Down Expand Up @@ -33,7 +34,8 @@ public class SqlServerCollection<TKey, TRecord>

private static readonly VectorSearchOptions<TRecord> s_defaultVectorSearchOptions = new();

private readonly string _connectionString;
private readonly Func<Task<SqlConnection>>? _sqlConnectionProvider;
private readonly string? _connectionString;
private readonly CollectionModel _model;
private readonly SqlServerMapper<TRecord> _mapper;

Expand Down Expand Up @@ -62,6 +64,50 @@ public SqlServerCollection(
{
}

/// <summary>
/// Initializes a new instance of the <see cref="SqlServerCollection{TKey, TRecord}"/> class.
/// </summary>
/// <param name="sqlConnectionProvider">Provider for the database connection.</param>
/// <param name="name">The name of the collection.</param>
/// <param name="options">Optional configuration options.</param>
[RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")]
[RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")]
public SqlServerCollection(
Func<Task<SqlConnection>> sqlConnectionProvider,
string name,
SqlServerCollectionOptions? options = null)
: this(
sqlConnectionProvider,
name,
static options => typeof(TRecord) == typeof(Dictionary<string, object?>)
? throw new NotSupportedException(VectorDataStrings.NonDynamicCollectionWithDictionaryNotSupported(typeof(SqlServerDynamicCollection)))
: new SqlServerModelBuilder().Build(typeof(TRecord), options.Definition, options.EmbeddingGenerator),
options)
{
}

private SqlServerCollection(Func<Task<SqlConnection>> sqlConnectionProvider, string name, Func<SqlServerCollectionOptions, CollectionModel> modelFactory, SqlServerCollectionOptions? options)
{
Verify.NotNull(sqlConnectionProvider);
Verify.NotNull(name);

this._sqlConnectionProvider = sqlConnectionProvider;
options ??= SqlServerCollectionOptions.Default;
this._schema = options.Schema;

this.Name = name;
this._model = modelFactory(options);

this._mapper = new SqlServerMapper<TRecord>(this._model);

this._collectionMetadata = new()
{
VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName,
VectorStoreName = "(unknown)",
CollectionName = name
};
}

internal SqlServerCollection(string connectionString, string name, Func<SqlServerCollectionOptions, CollectionModel> modelFactory, SqlServerCollectionOptions? options)
{
Verify.NotNullOrWhiteSpace(connectionString);
Expand Down Expand Up @@ -92,7 +138,7 @@ internal SqlServerCollection(string connectionString, string name, Func<SqlServe
/// <inheritdoc/>
public override async Task<bool> CollectionExistsAsync(CancellationToken cancellationToken = default)
{
using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = SqlServerCommandBuilder.SelectTableName(
connection, this._schema, this.Name);

Expand All @@ -113,7 +159,7 @@ public override Task EnsureCollectionExistsAsync(CancellationToken cancellationT

private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken)
{
using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = SqlServerCommandBuilder.CreateTable(
connection,
this._schema,
Expand All @@ -131,7 +177,7 @@ await connection.ExecuteWithErrorHandlingAsync(
/// <inheritdoc/>
public override async Task EnsureCollectionDeletedAsync(CancellationToken cancellationToken = default)
{
using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists(
connection, this._schema, this.Name);

Expand All @@ -147,7 +193,7 @@ public override async Task DeleteAsync(TKey key, CancellationToken cancellationT
{
Verify.NotNull(key);

using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(
connection,
this._schema,
Expand All @@ -167,8 +213,11 @@ public override async Task DeleteAsync(IEnumerable<TKey> keys, CancellationToken
{
Verify.NotNull(keys);

using SqlConnection connection = new(this._connectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
if (connection.State != ConnectionState.Open)
{
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
}

using SqlTransaction transaction = connection.BeginTransaction();
int taken = 0;
Expand Down Expand Up @@ -251,7 +300,7 @@ public override async Task DeleteAsync(IEnumerable<TKey> keys, CancellationToken
throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration);
}

using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = SqlServerCommandBuilder.SelectSingle(
connection,
this._schema,
Expand Down Expand Up @@ -286,7 +335,7 @@ public override async IAsyncEnumerable<TRecord> GetAsync(IEnumerable<TKey> keys,
throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration);
}

using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = connection.CreateCommand();
int taken = 0;

Expand Down Expand Up @@ -373,7 +422,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
}
}

using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = SqlServerCommandBuilder.MergeIntoSingle(
connection,
this._schema,
Expand Down Expand Up @@ -446,8 +495,11 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
}
}

using SqlConnection connection = new(this._connectionString);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
if (connection.State != ConnectionState.Open)
{
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
}

using SqlTransaction transaction = connection.BeginTransaction();
int parametersPerRecord = this._model.Properties.Count;
Expand Down Expand Up @@ -565,7 +617,7 @@ _ when vectorProperty.EmbeddingGenerator is IEmbeddingGenerator<TInput, Embeddin
#pragma warning disable CA2000 // Dispose objects before losing scope
// Connection and command are going to be disposed by the ReadVectorSearchResultsAsync,
// when the user is done with the results.
SqlConnection connection = new(this._connectionString);
SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
SqlCommand command = SqlServerCommandBuilder.SelectVector(
connection,
this._schema,
Expand Down Expand Up @@ -645,7 +697,7 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord

options ??= new();

using SqlConnection connection = new(this._connectionString);
using SqlConnection connection = await this.GetSqlConnectionAsync().ConfigureAwait(false);
using SqlCommand command = SqlServerCommandBuilder.SelectWhere(
filter,
top,
Expand All @@ -670,4 +722,13 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
yield return this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors);
}
}

private async Task<SqlConnection> GetSqlConnectionAsync()
{
if (this._sqlConnectionProvider != null)
{
return await this._sqlConnectionProvider().ConfigureAwait(false);
}
return new SqlConnection(this._connectionString);
}
}