Skip to content
This repository was archived by the owner on Jul 12, 2020. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions src/Cosmonaut/CosmosStore.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
using System;
using Cosmonaut.Extensions;
using Cosmonaut.Response;
using Cosmonaut.Storage;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Client;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using Cosmonaut.Extensions;
using Cosmonaut.Response;
using Cosmonaut.Storage;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Client;

namespace Cosmonaut
{
Expand All @@ -17,11 +17,11 @@ public sealed class CosmosStore<TEntity> : ICosmosStore<TEntity> where TEntity :
public bool IsShared { get; internal set; }

public string CollectionName { get; private set; }

public string DatabaseName { get; }

public CosmosStoreSettings Settings { get; }

public ICosmonautClient CosmonautClient { get; }

private readonly IDatabaseCreator _databaseCreator;
Expand Down Expand Up @@ -80,7 +80,8 @@ internal CosmosStore(ICosmonautClient cosmonautClient,
public IQueryable<TEntity> Query(FeedOptions feedOptions = null)
{
var queryable =
CosmonautClient.Query<TEntity>(DatabaseName, CollectionName, GetFeedOptionsForQuery(feedOptions));
CosmonautClient.Query<TEntity>(DatabaseName, CollectionName, GetFeedOptionsForQuery(feedOptions))
.ApplyInterception(Settings.Interceptors);

return IsShared ? queryable.Where(ExpressionExtensions.SharedCollectionExpression<TEntity>()) : queryable;
}
Expand All @@ -105,7 +106,7 @@ public async Task<T> QuerySingleAsync<T>(string sql, object parameters = null, F
var queryable = CosmonautClient.Query<T>(DatabaseName, CollectionName, collectionSharingFriendlySql, parameters, GetFeedOptionsForQuery(feedOptions));
return await queryable.SingleOrDefaultAsync(cancellationToken);
}

public async Task<IEnumerable<TEntity>> QueryMultipleAsync(string sql, object parameters = null, FeedOptions feedOptions = null, CancellationToken cancellationToken = default)
{
var collectionSharingFriendlySql = sql.EnsureQueryIsCollectionSharingFriendly<TEntity>();
Expand All @@ -125,14 +126,14 @@ public async Task<CosmosResponse<TEntity>> AddAsync(TEntity entity, RequestOptio
return await CosmonautClient.CreateDocumentAsync(DatabaseName, CollectionName, entity,
GetRequestOptions(requestOptions, entity), cancellationToken);
}

public async Task<CosmosMultipleResponse<TEntity>> AddRangeAsync(IEnumerable<TEntity> entities, Func<TEntity, RequestOptions> requestOptions = null, CancellationToken cancellationToken = default)
{
return await ExecuteMultiOperationAsync(entities, x => AddAsync(x, requestOptions?.Invoke(x), cancellationToken));
}

public async Task<CosmosMultipleResponse<TEntity>> RemoveAsync(
Expression<Func<TEntity, bool>> predicate,
Expression<Func<TEntity, bool>> predicate,
FeedOptions feedOptions = null,
Func<TEntity, RequestOptions> requestOptions = null,
CancellationToken cancellationToken = default)
Expand All @@ -148,7 +149,7 @@ public async Task<CosmosResponse<TEntity>> RemoveAsync(TEntity entity, RequestOp
return await CosmonautClient.DeleteDocumentAsync(DatabaseName, CollectionName, documentId,
GetRequestOptions(requestOptions, entity), cancellationToken).ExecuteCosmosCommand(entity);
}

public async Task<CosmosMultipleResponse<TEntity>> RemoveRangeAsync(IEnumerable<TEntity> entities, Func<TEntity, RequestOptions> requestOptions = null, CancellationToken cancellationToken = default)
{
return await ExecuteMultiOperationAsync(entities, x => RemoveAsync(x, requestOptions?.Invoke(x), cancellationToken));
Expand All @@ -161,7 +162,7 @@ public async Task<CosmosResponse<TEntity>> UpdateAsync(TEntity entity, RequestOp
return await CosmonautClient.UpdateDocumentAsync(DatabaseName, CollectionName, document,
GetRequestOptions(requestOptions, entity), cancellationToken).ExecuteCosmosCommand(entity);
}

public async Task<CosmosMultipleResponse<TEntity>> UpdateRangeAsync(IEnumerable<TEntity> entities, Func<TEntity, RequestOptions> requestOptions = null, CancellationToken cancellationToken = default)
{
return await ExecuteMultiOperationAsync(entities, x => UpdateAsync(x, requestOptions?.Invoke(x), cancellationToken));
Expand All @@ -178,7 +179,7 @@ public async Task<CosmosMultipleResponse<TEntity>> UpsertRangeAsync(IEnumerable<
{
return await ExecuteMultiOperationAsync(entities, x => UpsertAsync(x, requestOptions?.Invoke(x), cancellationToken));
}

public async Task<CosmosResponse<TEntity>> RemoveByIdAsync(string id, RequestOptions requestOptions = null, CancellationToken cancellationToken = default)
{
var response = await CosmonautClient.DeleteDocumentAsync(DatabaseName, CollectionName, id,
Expand Down Expand Up @@ -208,12 +209,12 @@ public async Task<TEntity> FindAsync(string id, object partitionKeyValue, Cancel
: null;
return await FindAsync(id, requestOptions, cancellationToken);
}

private void InitialiseCosmosStore(string overridenCollectionName)
{
IsShared = typeof(TEntity).UsesSharedCollection();
CollectionName = GetCosmosStoreCollectionName(overridenCollectionName);

_databaseCreator.EnsureCreatedAsync(DatabaseName).ConfigureAwait(false).GetAwaiter().GetResult();
_collectionCreator.EnsureCreatedAsync<TEntity>(DatabaseName, CollectionName, Settings.DefaultCollectionThroughput, Settings.IndexingPolicy)
.ConfigureAwait(false).GetAwaiter().GetResult();
Expand All @@ -235,7 +236,7 @@ private async Task<CosmosMultipleResponse<TEntity>> ExecuteMultiOperationAsync(I
var entitiesList = entities.ToList();
if (!entitiesList.Any())
return multipleResponse;

var results = (await entitiesList.Select(operationFunc).WhenAllTasksAsync()).ToList();
multipleResponse.SuccessfulEntities.AddRange(results.Where(x => x.IsSuccess));
multipleResponse.FailedEntities.AddRange(results.Where(x => !x.IsSuccess));
Expand Down Expand Up @@ -277,7 +278,7 @@ private RequestOptions GetRequestOptions(string id, RequestOptions requestOption

private FeedOptions GetFeedOptionsForQuery(FeedOptions feedOptions)
{
var shouldEnablePartitionQuery = (typeof(TEntity).HasPartitionKey() && feedOptions?.PartitionKey == null)
var shouldEnablePartitionQuery = (typeof(TEntity).HasPartitionKey() && feedOptions?.PartitionKey == null)
|| (feedOptions != null && feedOptions.EnableCrossPartitionQuery);

if (feedOptions == null)
Expand Down
31 changes: 24 additions & 7 deletions src/Cosmonaut/CosmosStoreSettings.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using System;
using Cosmonaut.Interception;
using Microsoft.Azure.Documents;
using Microsoft.Azure.Documents.Client;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq.Expressions;

namespace Cosmonaut
{
Expand All @@ -19,14 +22,16 @@ public class CosmosStoreSettings

public IndexingPolicy IndexingPolicy { get; set; } = CosmosConstants.DefaultIndexingPolicy;

public int DefaultCollectionThroughput { get; set; } = CosmosConstants.MinimumCosmosThroughput;
public int DefaultCollectionThroughput { get; set; } = CosmosConstants.MinimumCosmosThroughput;

public JsonSerializerSettings JsonSerializerSettings { get; set; }

public bool InfiniteRetries { get; set; } = true;

public string CollectionPrefix { get; set; } = string.Empty;

public List<IQueryInterceptor> Interceptors { get; }

public CosmosStoreSettings(string databaseName,
string endpointUrl,
string authKey,
Expand All @@ -42,6 +47,7 @@ public CosmosStoreSettings(string databaseName,
DatabaseName = databaseName ?? throw new ArgumentNullException(nameof(databaseName));
EndpointUrl = endpointUrl ?? throw new ArgumentNullException(nameof(endpointUrl));
AuthKey = authKey ?? throw new ArgumentNullException(nameof(authKey));
Interceptors = new List<IQueryInterceptor>();
settings?.Invoke(this);
}

Expand All @@ -52,18 +58,18 @@ public CosmosStoreSettings(
ConnectionPolicy connectionPolicy = null,
IndexingPolicy indexingPolicy = null,
int defaultCollectionThroughput = CosmosConstants.MinimumCosmosThroughput)
: this(databaseName,
new Uri(endpointUrl),
: this(databaseName,
new Uri(endpointUrl),
authKey,
connectionPolicy,
indexingPolicy,
defaultCollectionThroughput)
{
}

public CosmosStoreSettings(
string databaseName,
Uri endpointUrl,
string databaseName,
Uri endpointUrl,
string authKey,
ConnectionPolicy connectionPolicy = null,
IndexingPolicy indexingPolicy = null,
Expand All @@ -75,6 +81,17 @@ public CosmosStoreSettings(
ConnectionPolicy = connectionPolicy;
DefaultCollectionThroughput = defaultCollectionThroughput;
IndexingPolicy = indexingPolicy ?? CosmosConstants.DefaultIndexingPolicy;
Interceptors = new List<IQueryInterceptor>();
}

public void AddInterceptor<T>(Expression<Func<T, bool>> filter) where T : class
{
Interceptors.Add(new QueryInterceptor<T>(filter));
}
}

public interface IQueryInterceptor
{
Type Type { get; }
}
}
1 change: 1 addition & 0 deletions src/Cosmonaut/Extensions/ExpressionExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;

namespace Cosmonaut.Extensions
Expand Down
30 changes: 30 additions & 0 deletions src/Cosmonaut/Extensions/IQueryableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using Cosmonaut.Interception.QueryTranslation;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;

namespace Cosmonaut.Extensions
{
public static class IQueryableExtensions
{
public static IQueryable<TEntity> ApplyInterception<TEntity>(this IQueryable<TEntity> queryable, IEnumerable<IQueryInterceptor> interceptors)
{
interceptors = interceptors?.Where(x => x.Type == typeof(TEntity));

if (!interceptors?.Any() ?? true)
{
return queryable;
}

var visitors = interceptors.Cast<ExpressionVisitor>();
return queryable.InterceptWith(visitors.ToArray());
}

public static IQueryable<T> InterceptWith<T>(this IQueryable<T> source, params ExpressionVisitor[] visitors)
{
return new QueryTranslator<T>(source, visitors);
}
}
}

69 changes: 69 additions & 0 deletions src/Cosmonaut/Interception/QueryInterceptor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace Cosmonaut.Interception
{
public class QueryInterceptor<TEntity> : ExpressionVisitor, IQueryInterceptor where TEntity : class
{
public Type Type { get; }

public bool Applied { get; private set; }

private readonly Expression<Func<TEntity, bool>> _predicate;
private readonly ConstantExpression _constant;

public QueryInterceptor(Expression<Func<TEntity, bool>> interceptor)
{
if (interceptor == null)
{
throw new ArgumentException(nameof(interceptor));
}

Type = typeof(TEntity);
_predicate = interceptor;
}

public override Expression Visit(Expression node)
{
if (!(node is ConstantExpression constant))
{
return base.Visit(node);
}

if (!(constant.Value is IQueryable<TEntity>))
{
return base.Visit(node);
}

var method = GetLinqWhere();

Applied = true;

return Expression.Call(method, constant, _predicate);
}

private MethodInfo GetLinqWhere()
{
var method = typeof(Queryable).GetRuntimeMethods()
.Where(x => x.Name == nameof(Queryable.Where))
.Select(x => x.MakeGenericMethod(new[] { typeof(TEntity) }))
.Single(methodInfo =>
{
var parameters = methodInfo.GetParameters();

if (parameters.Count() == 2
&& parameters[0].ParameterType == typeof(IQueryable<TEntity>)
&& parameters[1].ParameterType == typeof(Expression<Func<TEntity, bool>>))
{
return true;
}

return false;
});

return method;
}
}
}
77 changes: 77 additions & 0 deletions src/Cosmonaut/Interception/QueryTranslation/QueryTranslator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;

namespace Cosmonaut.Interception.QueryTranslation
{
internal class QueryTranslator<T> : IOrderedQueryable<T>
{
private readonly Expression _expression;
private readonly QueryTranslatorProviderAsync _provider;

/// <summary>
/// Initializes a new instance of the <see cref="QueryTranslator{T}"/> class.
/// </summary>
/// <param name="source">The source.</param>
/// <param name="visitors">The visitors.</param>
public QueryTranslator(IQueryable source, IEnumerable<ExpressionVisitor> visitors)
{
_expression = Expression.Constant(this);
_provider = new QueryTranslatorProviderAsync(source, visitors);
}

/// <summary>
/// Initializes a new instance of the <see cref="QueryTranslator{T}"/> class.
/// </summary>
/// <param name="source">The source.</param>
/// <param name="expression">The expression.</param>
/// <param name="visitors">The visitors.</param>
public QueryTranslator(IQueryable source, Expression expression, IEnumerable<ExpressionVisitor> visitors)
{
_expression = expression;
_provider = new QueryTranslatorProviderAsync(source, visitors);
}

/// <summary>
/// Returns an enumerator that iterates through the collection.
/// </summary>
/// <returns>
/// A <see cref="T:System.Collections.Generic.IEnumerator`1" /> that can be used to iterate through the collection.
/// </returns>
public IEnumerator<T> GetEnumerator()
{
return ((IEnumerable<T>)_provider.ExecuteEnumerable(_expression)).GetEnumerator();
}

/// <summary>
/// Returns an enumerator that iterates through a collection.
/// </summary>
/// <returns>
/// An <see cref="T:System.Collections.IEnumerator" /> object that can be used to iterate through the collection.
/// </returns>
IEnumerator IEnumerable.GetEnumerator()
{
return _provider.ExecuteEnumerable(_expression).GetEnumerator();
}

/// <summary>
/// Gets the type of the element(s) that are returned when the expression tree associated with this instance of <see cref="T:System.Linq.IQueryable" /> is executed.
/// </summary>
/// <returns>A <see cref="T:System.Type" /> that represents the type of the element(s) that are returned when the expression tree associated with this object is executed.</returns>
public Type ElementType => typeof(T);

/// <summary>
/// Gets the expression tree that is associated with the instance of <see cref="T:System.Linq.IQueryable" />.
/// </summary>
/// <returns>The <see cref="T:System.Linq.Expressions.Expression" /> that is associated with this instance of <see cref="T:System.Linq.IQueryable" />.</returns>
public Expression Expression => _expression;

/// <summary>
/// Gets the query provider that is associated with this data source.
/// </summary>
/// <returns>The <see cref="T:System.Linq.IQueryProvider" /> that is associated with this data source.</returns>
public IQueryProvider Provider => _provider;
}
}
Loading