a fork of iceshrimp.net but a tweaked frontend to my personal liking. waow
fediverse social-media social iceshrimp fedi
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

[sln] Switch to extension blocks

+1656 -1560
+10 -7
Iceshrimp.Backend/Controllers/Shared/Attributes/LinkPaginationAttribute.cs
··· 89 89 { 90 90 private const string Key = "link-pagination"; 91 91 92 - internal static void SetPaginationData(this HttpContext ctx, IEnumerable<IIdentifiable> entities) 92 + extension(HttpContext ctx) 93 93 { 94 - ctx.Items.Add(Key, entities); 95 - } 94 + internal void SetPaginationData(IEnumerable<IIdentifiable> entities) 95 + { 96 + ctx.Items.Add(Key, entities); 97 + } 96 98 97 - public static IEnumerable<IIdentifiable>? GetPaginationData(this HttpContext ctx) 98 - { 99 - ctx.Items.TryGetValue(Key, out var entities); 100 - return entities as IEnumerable<IIdentifiable>; 99 + public IEnumerable<IIdentifiable>? GetPaginationData() 100 + { 101 + ctx.Items.TryGetValue(Key, out var entities); 102 + return entities as IEnumerable<IIdentifiable>; 103 + } 101 104 } 102 105 } 103 106
+41 -38
Iceshrimp.Backend/Core/Extensions/ArrayDestructuringExtensions.cs
··· 2 2 3 3 public static class ArrayDestructuringExtensions 4 4 { 5 - public static void Deconstruct<T>(this T[] array, out T item1) 5 + extension<T>(T[] array) 6 6 { 7 - if (array.Length != 1) 8 - throw new Exception("This deconstructor only takes arrays of length 1"); 7 + public void Deconstruct(out T item1) 8 + { 9 + if (array.Length != 1) 10 + throw new Exception("This deconstructor only takes arrays of length 1"); 9 11 10 - item1 = array[0]; 11 - } 12 + item1 = array[0]; 13 + } 12 14 13 - public static void Deconstruct<T>(this T[] array, out T item1, out T item2) 14 - { 15 - if (array.Length != 2) 16 - throw new Exception("This deconstructor only takes arrays of length 2"); 15 + public void Deconstruct(out T item1, out T item2) 16 + { 17 + if (array.Length != 2) 18 + throw new Exception("This deconstructor only takes arrays of length 2"); 17 19 18 - item1 = array[0]; 19 - item2 = array[1]; 20 - } 20 + item1 = array[0]; 21 + item2 = array[1]; 22 + } 21 23 22 - public static void Deconstruct<T>(this T[] array, out T item1, out T item2, out T item3) 23 - { 24 - if (array.Length != 3) 25 - throw new Exception("This deconstructor only takes arrays of length 3"); 24 + public void Deconstruct(out T item1, out T item2, out T item3) 25 + { 26 + if (array.Length != 3) 27 + throw new Exception("This deconstructor only takes arrays of length 3"); 26 28 27 - item1 = array[0]; 28 - item2 = array[1]; 29 - item3 = array[2]; 30 - } 29 + item1 = array[0]; 30 + item2 = array[1]; 31 + item3 = array[2]; 32 + } 31 33 32 - public static void Deconstruct<T>(this T[] array, out T item1, out T item2, out T item3, out T item4) 33 - { 34 - if (array.Length != 4) 35 - throw new Exception("This deconstructor only takes arrays of length 4"); 34 + public void Deconstruct(out T item1, out T item2, out T item3, out T item4) 35 + { 36 + if (array.Length != 4) 37 + throw new Exception("This deconstructor only takes arrays of length 4"); 36 38 37 - item1 = array[0]; 38 - item2 = array[1]; 39 - item3 = array[2]; 40 - item4 = array[3]; 41 - } 39 + item1 = array[0]; 40 + item2 = array[1]; 41 + item3 = array[2]; 42 + item4 = array[3]; 43 + } 42 44 43 - public static void Deconstruct<T>(this T[] array, out T item1, out T item2, out T item3, out T item4, out T item5) 44 - { 45 - if (array.Length != 5) 46 - throw new Exception("This deconstructor only takes arrays of length 5"); 45 + public void Deconstruct(out T item1, out T item2, out T item3, out T item4, out T item5) 46 + { 47 + if (array.Length != 5) 48 + throw new Exception("This deconstructor only takes arrays of length 5"); 47 49 48 - item1 = array[0]; 49 - item2 = array[1]; 50 - item3 = array[2]; 51 - item4 = array[3]; 52 - item5 = array[4]; 50 + item1 = array[0]; 51 + item2 = array[1]; 52 + item3 = array[2]; 53 + item4 = array[3]; 54 + item5 = array[4]; 55 + } 53 56 } 54 57 }
+4 -12
Iceshrimp.Backend/Core/Extensions/DateTimeExtensions.cs
··· 2 2 3 3 public static class DateTimeExtensions 4 4 { 5 - public static string ToStringIso8601Like(this DateTime dateTime) 6 - { 7 - return dateTime.ToString("yyyy'-'MM'-'dd'T'HH':'mm':'ss'.'fffK"); 8 - } 9 - 10 - public static string ToDisplayString(this DateTime dateTime) 11 - { 12 - return dateTime.ToString("yyyy'-'MM'-'dd' 'HH':'mm"); 13 - } 14 - 15 - public static string ToDisplayStringTz(this DateTime dateTime) 5 + extension(DateTime dateTime) 16 6 { 17 - return dateTime.ToString("yyyy'-'MM'-'dd' 'HH':'mm':'sszz"); 7 + public string ToStringIso8601Like() => dateTime.ToString("yyyy'-'MM'-'dd'T'HH':'mm':'ss'.'fffK"); 8 + public string ToDisplayString() => dateTime.ToString("yyyy'-'MM'-'dd' 'HH':'mm"); 9 + public string ToDisplayStringTz() => dateTime.ToString("yyyy'-'MM'-'dd' 'HH':'mm':'sszz"); 18 10 } 19 11 }
+18 -12
Iceshrimp.Backend/Core/Extensions/EnumerableExtensions.cs
··· 1 + using System.Diagnostics.CodeAnalysis; 2 + 1 3 namespace Iceshrimp.Backend.Core.Extensions; 2 4 3 5 public static class EnumerableExtensions ··· 37 39 foreach (var task in tasks) await task; 38 40 } 39 41 40 - public static bool IsDisjoint<T>(this IEnumerable<T> x, IEnumerable<T> y) 42 + extension<T>(IEnumerable<T> x) 41 43 { 42 - return x.All(item => !y.Contains(item)); 43 - } 44 + public bool IsDisjoint(IEnumerable<T> y) 45 + { 46 + return x.All(item => !y.Contains(item)); 47 + } 44 48 45 - public static bool Intersects<T>(this IEnumerable<T> x, IEnumerable<T> y) 46 - { 47 - return x.Any(y.Contains); 48 - } 49 + public bool Intersects(IEnumerable<T> y) 50 + { 51 + return x.Any(y.Contains); 52 + } 49 53 50 - public static bool IsEquivalent<T>(this IEnumerable<T> x, IEnumerable<T> y) 51 - { 52 - var xArray = x as T[] ?? x.ToArray(); 53 - var yArray = y as T[] ?? y.ToArray(); 54 - return xArray.Length == yArray.Length && xArray.All(yArray.Contains); 54 + public bool IsEquivalent(IEnumerable<T> y) 55 + { 56 + var xArray = x as T[] ?? x.ToArray(); 57 + var yArray = y as T[] ?? y.ToArray(); 58 + return xArray.Length == yArray.Length && xArray.All(yArray.Contains); 59 + } 55 60 } 56 61 62 + [SuppressMessage("ReSharper", "MoveToExtensionBlock", Justification = "Nullability does not match")] 57 63 public static IEnumerable<T> NotNull<T>(this IEnumerable<T?> @enum) => @enum.OfType<T>(); 58 64 59 65 public static IEnumerable<T> StructNotNull<T>(this IEnumerable<T?> @enum) where T : struct =>
+15 -12
Iceshrimp.Backend/Core/Extensions/HttpClientExtensions.cs
··· 4 4 { 5 5 private static readonly HttpRequestOptionsKey<bool?> AutoRedirectOptionsKey = new("RequestAutoRedirect"); 6 6 7 - public static HttpRequestMessage DisableAutoRedirects(this HttpRequestMessage request) 7 + extension(HttpRequestMessage request) 8 8 { 9 - request.SetAutoRedirect(false); 10 - return request; 11 - } 9 + public HttpRequestMessage DisableAutoRedirects() 10 + { 11 + request.SetAutoRedirect(false); 12 + return request; 13 + } 12 14 13 - private static void SetAutoRedirect(this HttpRequestMessage request, bool autoRedirect) 14 - { 15 - request.Options.Set(AutoRedirectOptionsKey, autoRedirect); 16 - } 15 + private void SetAutoRedirect(bool autoRedirect) 16 + { 17 + request.Options.Set(AutoRedirectOptionsKey, autoRedirect); 18 + } 17 19 18 - public static bool? GetAutoRedirect(this HttpRequestMessage request) 19 - { 20 - request.Options.TryGetValue(AutoRedirectOptionsKey, out var value); 21 - return value; 20 + public bool? GetAutoRedirect() 21 + { 22 + request.Options.TryGetValue(AutoRedirectOptionsKey, out var value); 23 + return value; 24 + } 22 25 } 23 26 24 27 public static HttpMessageHandler? GetMostInnerHandler(this HttpMessageHandler? self)
+27 -24
Iceshrimp.Backend/Core/Extensions/HttpContextExtensions.cs
··· 9 9 10 10 public static partial class HttpContextExtensions 11 11 { 12 - public static PaginationWrapper<TData> CreatePaginationWrapper<TData>( 13 - this HttpContext ctx, PaginationQuery query, IEnumerable<IIdentifiable> paginationData, TData data 14 - ) 12 + extension(HttpContext ctx) 15 13 { 16 - var attr = ctx.GetEndpoint()?.Metadata.GetMetadata<RestPaginationAttribute>(); 17 - if (attr == null) throw new Exception("Route doesn't have a RestPaginationAttribute"); 14 + public PaginationWrapper<TData> CreatePaginationWrapper<TData>( 15 + PaginationQuery query, IEnumerable<IIdentifiable> paginationData, TData data 16 + ) 17 + { 18 + var attr = ctx.GetEndpoint()?.Metadata.GetMetadata<RestPaginationAttribute>(); 19 + if (attr == null) throw new Exception("Route doesn't have a RestPaginationAttribute"); 18 20 19 - var limit = Math.Min(query.Limit ?? attr.DefaultLimit, attr.MaxLimit); 20 - if (limit < 1) throw GracefulException.BadRequest("Limit cannot be less than 1"); 21 + var limit = Math.Min(query.Limit ?? attr.DefaultLimit, attr.MaxLimit); 22 + if (limit < 1) throw GracefulException.BadRequest("Limit cannot be less than 1"); 21 23 22 - var ids = paginationData.Select(p => p.Id).ToList(); 23 - if (query.MinId != null) ids.Reverse(); 24 + var ids = paginationData.Select(p => p.Id).ToList(); 25 + if (query.MinId != null) ids.Reverse(); 24 26 25 - var next = ids.Count >= limit ? new QueryBuilder { { "max_id", ids.Last() } } : null; 26 - var prev = ids.Count > 0 ? new QueryBuilder { { "min_id", ids.First() } } : null; 27 + var next = ids.Count >= limit ? new QueryBuilder { { "max_id", ids.Last() } } : null; 28 + var prev = ids.Count > 0 ? new QueryBuilder { { "min_id", ids.First() } } : null; 27 29 28 - var links = new PaginationData 29 - { 30 - Limit = limit, 31 - Next = next?.ToQueryString().ToString(), 32 - Prev = prev?.ToQueryString().ToString() 33 - }; 30 + var links = new PaginationData 31 + { 32 + Limit = limit, 33 + Next = next?.ToQueryString().ToString(), 34 + Prev = prev?.ToQueryString().ToString() 35 + }; 34 36 35 - return new PaginationWrapper<TData> { Data = data, Links = links }; 36 - } 37 + return new PaginationWrapper<TData> { Data = data, Links = links }; 38 + } 37 39 38 - public static PaginationWrapper<TData> CreatePaginationWrapper<TData>( 39 - this HttpContext ctx, PaginationQuery query, TData data 40 - ) where TData : IEnumerable<IIdentifiable> 41 - { 42 - return CreatePaginationWrapper(ctx, query, data, data); 40 + public PaginationWrapper<TData> CreatePaginationWrapper<TData>( 41 + PaginationQuery query, TData data 42 + ) where TData : IEnumerable<IIdentifiable> 43 + { 44 + return CreatePaginationWrapper(ctx, query, data, data); 45 + } 43 46 } 44 47 } 45 48
+14 -11
Iceshrimp.Backend/Core/Extensions/HttpResponseExtensions.cs
··· 4 4 5 5 public static class HttpResponseExtensions 6 6 { 7 - public static bool IsClientError(this HttpResponseMessage res) 8 - => res.StatusCode is >= HttpStatusCode.BadRequest and <= (HttpStatusCode)499; 7 + extension(HttpResponseMessage res) 8 + { 9 + public bool IsClientError() 10 + => res.StatusCode is >= HttpStatusCode.BadRequest and <= (HttpStatusCode)499; 9 11 10 - public static bool IsRetryableClientError(this HttpResponseMessage res) 11 - => res.StatusCode is HttpStatusCode.TooManyRequests; 12 + public bool IsRetryableClientError() 13 + => res.StatusCode is HttpStatusCode.TooManyRequests; 12 14 13 - public static void EnsureSuccessStatusCode( 14 - this HttpResponseMessage res, bool excludeClientErrors, Func<Exception> exceptionFactory 15 - ) 16 - { 17 - if (excludeClientErrors && res.IsClientError() && !res.IsRetryableClientError()) 18 - throw exceptionFactory(); 19 - res.EnsureSuccessStatusCode(); 15 + public void EnsureSuccessStatusCode( 16 + bool excludeClientErrors, Func<Exception> exceptionFactory 17 + ) 18 + { 19 + if (excludeClientErrors && res.IsClientError() && !res.IsRetryableClientError()) 20 + throw exceptionFactory(); 21 + res.EnsureSuccessStatusCode(); 22 + } 20 23 } 21 24 }
+10 -7
Iceshrimp.Backend/Core/Extensions/IPAddressExtensions.cs
··· 5 5 6 6 public static class IPAddressExtensions 7 7 { 8 - public static bool IsLoopback(this IPAddress address) => IPAddress.IsLoopback(address); 8 + extension(IPAddress address) 9 + { 10 + public bool IsLoopback() => IPAddress.IsLoopback(address); 9 11 10 - public static bool IsLocalIPv6(this IPAddress address) => address.AddressFamily == AddressFamily.InterNetworkV6 && 11 - (address.IsIPv6LinkLocal || 12 - address.IsIPv6SiteLocal || 13 - address.IsIPv6UniqueLocal); 12 + public bool IsLocalIPv6() => address.AddressFamily == AddressFamily.InterNetworkV6 && 13 + (address.IsIPv6LinkLocal || 14 + address.IsIPv6SiteLocal || 15 + address.IsIPv6UniqueLocal); 14 16 15 - public static bool IsLocalIPv4(this IPAddress address) => address.AddressFamily == AddressFamily.InterNetwork && 16 - IsPrivateIPv4(address.GetAddressBytes()); 17 + public bool IsLocalIPv4() => address.AddressFamily == AddressFamily.InterNetwork && 18 + IsPrivateIPv4(address.GetAddressBytes()); 19 + } 17 20 18 21 private static bool IsPrivateIPv4(byte[] ipv4Bytes) 19 22 {
+43 -40
Iceshrimp.Backend/Core/Extensions/ListDestructuringExtensions.cs
··· 2 2 3 3 public static class ListDestructuringExtensions 4 4 { 5 - public static void Deconstruct<T>(this IList<T> list, out T item1) 5 + extension<T>(IList<T> list) 6 6 { 7 - if (list.Count != 1) 8 - throw new Exception("This deconstructor only takes lists of length 1"); 7 + public void Deconstruct(out T item1) 8 + { 9 + if (list.Count != 1) 10 + throw new Exception("This deconstructor only takes lists of length 1"); 9 11 10 - item1 = list[0]; 11 - } 12 + item1 = list[0]; 13 + } 12 14 13 - public static void Deconstruct<T>(this IList<T> list, out T item1, out T item2) 14 - { 15 - if (list.Count != 2) 16 - throw new Exception("This deconstructor only takes lists of length 2"); 15 + public void Deconstruct(out T item1, out T item2) 16 + { 17 + if (list.Count != 2) 18 + throw new Exception("This deconstructor only takes lists of length 2"); 17 19 18 - item1 = list[0]; 19 - item2 = list[1]; 20 - } 20 + item1 = list[0]; 21 + item2 = list[1]; 22 + } 21 23 22 - public static void Deconstruct<T>(this IList<T> list, out T item1, out T item2, out T item3) 23 - { 24 - if (list.Count != 3) 25 - throw new Exception("This deconstructor only takes lists of length 3"); 24 + public void Deconstruct(out T item1, out T item2, out T item3) 25 + { 26 + if (list.Count != 3) 27 + throw new Exception("This deconstructor only takes lists of length 3"); 26 28 27 - item1 = list[0]; 28 - item2 = list[1]; 29 - item3 = list[2]; 30 - } 29 + item1 = list[0]; 30 + item2 = list[1]; 31 + item3 = list[2]; 32 + } 31 33 32 - public static void Deconstruct<T>(this IList<T> list, out T item1, out T item2, out T item3, out T item4) 33 - { 34 - if (list.Count != 4) 35 - throw new Exception("This deconstructor only takes lists of length 4"); 34 + public void Deconstruct(out T item1, out T item2, out T item3, out T item4) 35 + { 36 + if (list.Count != 4) 37 + throw new Exception("This deconstructor only takes lists of length 4"); 36 38 37 - item1 = list[0]; 38 - item2 = list[1]; 39 - item3 = list[2]; 40 - item4 = list[3]; 41 - } 39 + item1 = list[0]; 40 + item2 = list[1]; 41 + item3 = list[2]; 42 + item4 = list[3]; 43 + } 42 44 43 - public static void Deconstruct<T>( 44 - this IList<T> list, out T item1, out T item2, out T item3, out T item4, out T item5 45 - ) 46 - { 47 - if (list.Count != 5) 48 - throw new Exception("This deconstructor only takes lists of length 5"); 45 + public void Deconstruct( 46 + out T item1, out T item2, out T item3, out T item4, out T item5 47 + ) 48 + { 49 + if (list.Count != 5) 50 + throw new Exception("This deconstructor only takes lists of length 5"); 49 51 50 - item1 = list[0]; 51 - item2 = list[1]; 52 - item3 = list[2]; 53 - item4 = list[3]; 54 - item5 = list[4]; 52 + item1 = list[0]; 53 + item2 = list[1]; 54 + item3 = list[2]; 55 + item4 = list[3]; 56 + item5 = list[4]; 57 + } 55 58 } 56 59 }
+59 -56
Iceshrimp.Backend/Core/Extensions/MvcBuilderExtensions.cs
··· 26 26 }); 27 27 } 28 28 29 - public static IMvcBuilder AddMultiFormatter(this IMvcBuilder builder) 29 + extension(IMvcBuilder builder) 30 30 { 31 - builder.Services.AddOptions<MvcOptions>() 32 - .PostConfigure<IOptions<JsonOptions>, IOptions<MvcNewtonsoftJsonOptions>, ArrayPool<char>, 33 - ObjectPoolProvider, ILoggerFactory>((opts, jsonOpts, _, _, _, loggerFactory) => 34 - { 35 - // We need to re-add these one since .AddNewtonsoftJson() removes them 36 - if (!opts.InputFormatters.OfType<SystemTextJsonInputFormatter>().Any()) 31 + public IMvcBuilder AddMultiFormatter() 32 + { 33 + builder.Services.AddOptions<MvcOptions>() 34 + .PostConfigure<IOptions<JsonOptions>, IOptions<MvcNewtonsoftJsonOptions>, ArrayPool<char>, 35 + ObjectPoolProvider, ILoggerFactory>((opts, jsonOpts, _, _, _, loggerFactory) => 37 36 { 38 - var systemInputLogger = loggerFactory.CreateLogger<SystemTextJsonInputFormatter>(); 39 - // We need to set this, otherwise characters like '+' will be escaped in responses 40 - jsonOpts.Value.JsonSerializerOptions.Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping; 41 - opts.InputFormatters.Add(new SystemTextJsonInputFormatter(jsonOpts.Value, systemInputLogger)); 42 - } 37 + // We need to re-add these one since .AddNewtonsoftJson() removes them 38 + if (!opts.InputFormatters.OfType<SystemTextJsonInputFormatter>().Any()) 39 + { 40 + var systemInputLogger = loggerFactory.CreateLogger<SystemTextJsonInputFormatter>(); 41 + // We need to set this, otherwise characters like '+' will be escaped in responses 42 + jsonOpts.Value.JsonSerializerOptions.Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping; 43 + opts.InputFormatters.Add(new SystemTextJsonInputFormatter(jsonOpts.Value, systemInputLogger)); 44 + } 43 45 44 - if (!opts.OutputFormatters.OfType<SystemTextJsonOutputFormatter>().Any()) 45 - opts.OutputFormatters.Add(new SystemTextJsonOutputFormatter(jsonOpts.Value 46 - .JsonSerializerOptions)); 46 + if (!opts.OutputFormatters.OfType<SystemTextJsonOutputFormatter>().Any()) 47 + opts.OutputFormatters.Add(new SystemTextJsonOutputFormatter(jsonOpts.Value 48 + .JsonSerializerOptions)); 47 49 48 - opts.InputFormatters.Insert(0, new JsonInputMultiFormatter()); 49 - opts.OutputFormatters.Insert(0, new JsonOutputMultiFormatter()); 50 - opts.OutputFormatters.Add(new XmlSerializerOutputFormatter()); 51 - }); 50 + opts.InputFormatters.Insert(0, new JsonInputMultiFormatter()); 51 + opts.OutputFormatters.Insert(0, new JsonOutputMultiFormatter()); 52 + opts.OutputFormatters.Add(new XmlSerializerOutputFormatter()); 53 + }); 52 54 53 - return builder; 54 - } 55 + return builder; 56 + } 55 57 56 - public static IMvcBuilder ConfigureNewtonsoftJson(this IMvcBuilder builder) 57 - { 58 - JsonConvert.DefaultSettings = () => new JsonSerializerSettings 58 + public IMvcBuilder ConfigureNewtonsoftJson() 59 59 { 60 - DateTimeZoneHandling = DateTimeZoneHandling.Utc 61 - }; 60 + JsonConvert.DefaultSettings = () => new JsonSerializerSettings 61 + { 62 + DateTimeZoneHandling = DateTimeZoneHandling.Utc 63 + }; 62 64 63 - return builder; 64 - } 65 + return builder; 66 + } 65 67 66 - public static IMvcBuilder AddModelBindingProviders(this IMvcBuilder builder) 67 - { 68 - builder.Services.AddOptions<MvcOptions>() 69 - .PostConfigure(options => { options.ModelBinderProviders.AddHybridBindingProvider(); }); 68 + public IMvcBuilder AddModelBindingProviders() 69 + { 70 + builder.Services.AddOptions<MvcOptions>() 71 + .PostConfigure(options => { options.ModelBinderProviders.AddHybridBindingProvider(); }); 70 72 71 - return builder; 72 - } 73 + return builder; 74 + } 73 75 74 - public static IMvcBuilder AddValueProviderFactories(this IMvcBuilder builder) 75 - { 76 - builder.Services.AddOptions<MvcOptions>() 77 - .PostConfigure(options => 78 - { 79 - options.ValueProviderFactories.Add(new JQueryQueryStringValueProviderFactory()); 80 - }); 76 + public IMvcBuilder AddValueProviderFactories() 77 + { 78 + builder.Services.AddOptions<MvcOptions>() 79 + .PostConfigure(options => 80 + { 81 + options.ValueProviderFactories.Add(new JQueryQueryStringValueProviderFactory()); 82 + }); 81 83 82 - return builder; 83 - } 84 + return builder; 85 + } 84 86 85 - public static IMvcBuilder AddApiBehaviorOptions(this IMvcBuilder builder) 86 - { 87 - builder.ConfigureApiBehaviorOptions(o => 87 + public IMvcBuilder AddApiBehaviorOptions() 88 88 { 89 - o.InvalidModelStateResponseFactory = actionContext => 89 + builder.ConfigureApiBehaviorOptions(o => 90 90 { 91 - var details = new ValidationProblemDetails(actionContext.ModelState); 91 + o.InvalidModelStateResponseFactory = actionContext => 92 + { 93 + var details = new ValidationProblemDetails(actionContext.ModelState); 92 94 93 - var status = (HttpStatusCode?)details.Status ?? HttpStatusCode.BadRequest; 94 - var message = details.Title ?? "One or more validation errors occurred."; 95 - if (details.Detail != null) 96 - message += $" - {details.Detail}"; 95 + var status = (HttpStatusCode?)details.Status ?? HttpStatusCode.BadRequest; 96 + var message = details.Title ?? "One or more validation errors occurred."; 97 + if (details.Detail != null) 98 + message += $" - {details.Detail}"; 97 99 98 - throw new ValidationException(status, status.ToString(), message, details.Errors); 99 - }; 100 - }); 100 + throw new ValidationException(status, status.ToString(), message, details.Errors); 101 + }; 102 + }); 101 103 102 - return builder; 104 + return builder; 105 + } 103 106 } 104 107 } 105 108
+271 -266
Iceshrimp.Backend/Core/Extensions/QueryableExtensions.cs
··· 18 18 19 19 public static class QueryableExtensions 20 20 { 21 - public static IQueryable<T> Paginate<T>( 22 - this IQueryable<T> query, 23 - MastodonPaginationQuery pq, 24 - int defaultLimit, 25 - int maxLimit 26 - ) where T : IIdentifiable 21 + extension<T>(IQueryable<T> query) where T : IIdentifiable 27 22 { 28 - if (pq.Limit is < 1) 29 - throw GracefulException.BadRequest("Limit cannot be less than 1"); 23 + public IQueryable<T> Paginate( 24 + MastodonPaginationQuery pq, 25 + int defaultLimit, 26 + int maxLimit 27 + ) 28 + { 29 + if (pq.Limit is < 1) 30 + throw GracefulException.BadRequest("Limit cannot be less than 1"); 30 31 31 - if (pq is { SinceId: not null, MinId: not null }) 32 - throw GracefulException.BadRequest("Can't use sinceId and minId params simultaneously"); 32 + if (pq is { SinceId: not null, MinId: not null }) 33 + throw GracefulException.BadRequest("Can't use sinceId and minId params simultaneously"); 33 34 34 - // @formatter:off 35 - query = pq switch 35 + // @formatter:off 36 + query = pq switch 37 + { 38 + { SinceId: not null, MaxId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.SinceId) && p.Id.IsLessThan(pq.MaxId)) 39 + .OrderByDescending(p => p.Id), 40 + { MinId: not null, MaxId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId) && p.Id.IsLessThan(pq.MaxId)) 41 + .OrderBy(p => p.Id), 42 + { SinceId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.SinceId)) 43 + .OrderByDescending(p => p.Id), 44 + { MinId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId)) 45 + .OrderBy(p => p.Id), 46 + { MaxId: not null } => query.Where(p => p.Id.IsLessThan(pq.MaxId)) 47 + .OrderByDescending(p => p.Id), 48 + _ => query.OrderByDescending(p => p.Id) 49 + }; 50 + // @formatter:on 51 + 52 + return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 53 + } 54 + 55 + public IQueryable<T> Paginate( 56 + Expression<Func<T, string>> predicate, 57 + MastodonPaginationQuery pq, 58 + int defaultLimit, 59 + int maxLimit 60 + ) 36 61 { 37 - { SinceId: not null, MaxId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.SinceId) && p.Id.IsLessThan(pq.MaxId)) 38 - .OrderByDescending(p => p.Id), 39 - { MinId: not null, MaxId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId) && p.Id.IsLessThan(pq.MaxId)) 40 - .OrderBy(p => p.Id), 41 - { SinceId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.SinceId)) 42 - .OrderByDescending(p => p.Id), 43 - { MinId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId)) 44 - .OrderBy(p => p.Id), 45 - { MaxId: not null } => query.Where(p => p.Id.IsLessThan(pq.MaxId)) 46 - .OrderByDescending(p => p.Id), 47 - _ => query.OrderByDescending(p => p.Id) 48 - }; 49 - // @formatter:on 62 + if (pq.Limit is < 1) 63 + throw GracefulException.BadRequest("Limit cannot be less than 1"); 50 64 51 - return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 52 - } 65 + if (pq is { SinceId: not null, MinId: not null }) 66 + throw GracefulException.BadRequest("Can't use sinceId and minId params simultaneously"); 53 67 54 - public static IQueryable<T> Paginate<T>( 55 - this IQueryable<T> query, 56 - Expression<Func<T, string>> predicate, 57 - MastodonPaginationQuery pq, 58 - int defaultLimit, 59 - int maxLimit 60 - ) where T : IIdentifiable 61 - { 62 - if (pq.Limit is < 1) 63 - throw GracefulException.BadRequest("Limit cannot be less than 1"); 68 + // @formatter:off 69 + query = pq switch 70 + { 71 + { SinceId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.SinceId) && id.IsLessThan(pq.MaxId))) 72 + .OrderByDescending(predicate), 73 + { MinId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.MinId) && id.IsLessThan(pq.MaxId))) 74 + .OrderBy(predicate), 75 + { SinceId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.SinceId))) 76 + .OrderByDescending(predicate), 77 + { MinId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.MinId))) 78 + .OrderBy(predicate), 79 + { MaxId: not null } => query.Where(predicate.Compose(id => id.IsLessThan(pq.MaxId))) 80 + .OrderByDescending(predicate), 81 + _ => query.OrderByDescending(predicate) 82 + }; 83 + // @formatter:on 64 84 65 - if (pq is { SinceId: not null, MinId: not null }) 66 - throw GracefulException.BadRequest("Can't use sinceId and minId params simultaneously"); 85 + return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 86 + } 67 87 68 - // @formatter:off 69 - query = pq switch 88 + public IQueryable<T> Paginate( 89 + Expression<Func<T, long>> predicate, 90 + MastodonPaginationQuery pq, 91 + int defaultLimit, 92 + int maxLimit 93 + ) 70 94 { 71 - { SinceId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.SinceId) && id.IsLessThan(pq.MaxId))) 72 - .OrderByDescending(predicate), 73 - { MinId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.MinId) && id.IsLessThan(pq.MaxId))) 74 - .OrderBy(predicate), 75 - { SinceId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.SinceId))) 76 - .OrderByDescending(predicate), 77 - { MinId: not null } => query.Where(predicate.Compose(id => id.IsGreaterThan(pq.MinId))) 78 - .OrderBy(predicate), 79 - { MaxId: not null } => query.Where(predicate.Compose(id => id.IsLessThan(pq.MaxId))) 80 - .OrderByDescending(predicate), 81 - _ => query.OrderByDescending(predicate) 82 - }; 83 - // @formatter:on 95 + if (pq.Limit is < 1) 96 + throw GracefulException.BadRequest("Limit cannot be less than 1"); 84 97 85 - return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 86 - } 98 + if (pq is { SinceId: not null, MinId: not null }) 99 + throw GracefulException.BadRequest("Can't use sinceId and minId params simultaneously"); 87 100 88 - public static IQueryable<T> Paginate<T>( 89 - this IQueryable<T> query, 90 - Expression<Func<T, long>> predicate, 91 - MastodonPaginationQuery pq, 92 - int defaultLimit, 93 - int maxLimit 94 - ) where T : IIdentifiable 95 - { 96 - if (pq.Limit is < 1) 97 - throw GracefulException.BadRequest("Limit cannot be less than 1"); 101 + long? sinceId = null; 102 + long? minId = null; 103 + long? maxId = null; 98 104 99 - if (pq is { SinceId: not null, MinId: not null }) 100 - throw GracefulException.BadRequest("Can't use sinceId and minId params simultaneously"); 105 + if (pq.SinceId != null) 106 + { 107 + if (!long.TryParse(pq.SinceId, out var res)) 108 + throw GracefulException.BadRequest("sinceId must be an integer"); 109 + sinceId = res; 110 + } 101 111 102 - long? sinceId = null; 103 - long? minId = null; 104 - long? maxId = null; 112 + if (pq.MinId != null) 113 + { 114 + if (!long.TryParse(pq.MinId, out var res)) 115 + throw GracefulException.BadRequest("minId must be an integer"); 116 + minId = res; 117 + } 105 118 106 - if (pq.SinceId != null) 107 - { 108 - if (!long.TryParse(pq.SinceId, out var res)) 109 - throw GracefulException.BadRequest("sinceId must be an integer"); 110 - sinceId = res; 111 - } 119 + if (pq.MaxId != null) 120 + { 121 + if (!long.TryParse(pq.MaxId, out var res)) 122 + throw GracefulException.BadRequest("maxId must be an integer"); 123 + maxId = res; 124 + } 112 125 113 - if (pq.MinId != null) 114 - { 115 - if (!long.TryParse(pq.MinId, out var res)) 116 - throw GracefulException.BadRequest("minId must be an integer"); 117 - minId = res; 118 - } 126 + // @formatter:off 127 + query = pq switch 128 + { 129 + { SinceId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id > sinceId && id < maxId)) 130 + .OrderByDescending(predicate), 131 + { MinId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id > minId && id < maxId)) 132 + .OrderBy(predicate), 133 + { SinceId: not null } => query.Where(predicate.Compose(id => id > sinceId)) 134 + .OrderByDescending(predicate), 135 + { MinId: not null } => query.Where(predicate.Compose(id => id > minId)) 136 + .OrderBy(predicate), 137 + { MaxId: not null } => query.Where(predicate.Compose(id => id < maxId)) 138 + .OrderByDescending(predicate), 139 + _ => query.OrderByDescending(predicate) 140 + }; 141 + // @formatter:on 119 142 120 - if (pq.MaxId != null) 121 - { 122 - if (!long.TryParse(pq.MaxId, out var res)) 123 - throw GracefulException.BadRequest("maxId must be an integer"); 124 - maxId = res; 143 + return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 125 144 } 126 145 127 - // @formatter:off 128 - query = pq switch 146 + public IQueryable<T> Paginate( 147 + PaginationQuery pq, 148 + int defaultLimit, 149 + int maxLimit 150 + ) 129 151 { 130 - { SinceId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id > sinceId && id < maxId)) 131 - .OrderByDescending(predicate), 132 - { MinId: not null, MaxId: not null } => query.Where(predicate.Compose(id => id > minId && id < maxId)) 133 - .OrderBy(predicate), 134 - { SinceId: not null } => query.Where(predicate.Compose(id => id > sinceId)) 135 - .OrderByDescending(predicate), 136 - { MinId: not null } => query.Where(predicate.Compose(id => id > minId)) 137 - .OrderBy(predicate), 138 - { MaxId: not null } => query.Where(predicate.Compose(id => id < maxId)) 139 - .OrderByDescending(predicate), 140 - _ => query.OrderByDescending(predicate) 141 - }; 142 - // @formatter:on 152 + if (pq.Limit is < 1) 153 + throw GracefulException.BadRequest("Limit cannot be less than 1"); 143 154 144 - return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 145 - } 155 + // @formatter:off 156 + query = pq switch 157 + { 158 + { MinId: not null, MaxId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId) && p.Id.IsLessThan(pq.MaxId)) 159 + .OrderBy(p => p.Id), 160 + { MinId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId)) 161 + .OrderBy(p => p.Id), 162 + { MaxId: not null } => query.Where(p => p.Id.IsLessThan(pq.MaxId)) 163 + .OrderByDescending(p => p.Id), 164 + _ => query.OrderByDescending(p => p.Id) 165 + }; 166 + // @formatter:on 146 167 147 - public static IQueryable<T> Paginate<T>( 148 - this IQueryable<T> query, 149 - PaginationQuery pq, 150 - int defaultLimit, 151 - int maxLimit 152 - ) where T : IIdentifiable 153 - { 154 - if (pq.Limit is < 1) 155 - throw GracefulException.BadRequest("Limit cannot be less than 1"); 168 + return query.Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 169 + } 156 170 157 - // @formatter:off 158 - query = pq switch 171 + public IQueryable<T> Paginate( 172 + MastodonPaginationQuery pq, 173 + ControllerContext context 174 + ) 159 175 { 160 - { MinId: not null, MaxId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId) && p.Id.IsLessThan(pq.MaxId)) 161 - .OrderBy(p => p.Id), 162 - { MinId: not null } => query.Where(p => p.Id.IsGreaterThan(pq.MinId)) 163 - .OrderBy(p => p.Id), 164 - { MaxId: not null } => query.Where(p => p.Id.IsLessThan(pq.MaxId)) 165 - .OrderByDescending(p => p.Id), 166 - _ => query.OrderByDescending(p => p.Id) 167 - }; 168 - // @formatter:on 176 + var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 177 + if (attr == null) 178 + throw new Exception("Route doesn't have a IPaginationAttribute"); 169 179 170 - return query.Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 171 - } 180 + return Paginate(query, pq, attr.DefaultLimit, attr.MaxLimit); 181 + } 172 182 173 - public static IQueryable<T> Paginate<T>( 174 - this IQueryable<T> query, 175 - MastodonPaginationQuery pq, 176 - ControllerContext context 177 - ) where T : IIdentifiable 178 - { 179 - var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 180 - if (attr == null) 181 - throw new Exception("Route doesn't have a IPaginationAttribute"); 183 + public IQueryable<T> PaginateByOffset( 184 + MastodonPaginationQuery pq, 185 + int defaultLimit, 186 + int maxLimit 187 + ) 188 + { 189 + if (pq.Limit is < 1) 190 + throw GracefulException.BadRequest("Limit cannot be less than 1"); 182 191 183 - return Paginate(query, pq, attr.DefaultLimit, attr.MaxLimit); 184 - } 192 + return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 193 + } 185 194 186 - public static IQueryable<T> PaginateByOffset<T>( 187 - this IQueryable<T> query, 188 - MastodonPaginationQuery pq, 189 - int defaultLimit, 190 - int maxLimit 191 - ) where T : IIdentifiable 192 - { 193 - if (pq.Limit is < 1) 194 - throw GracefulException.BadRequest("Limit cannot be less than 1"); 195 + public IQueryable<T> PaginateByOffset( 196 + MastodonPaginationQuery pq, 197 + ControllerContext context 198 + ) 199 + { 200 + var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 201 + if (attr == null) 202 + throw new Exception("Route doesn't have a IPaginationAttribute"); 195 203 196 - return query.Skip(pq.Offset ?? 0).Take(Math.Min(pq.Limit ?? defaultLimit, maxLimit)); 197 - } 204 + return PaginateByOffset(query, pq, attr.DefaultLimit, attr.MaxLimit); 205 + } 198 206 199 - public static IQueryable<T> PaginateByOffset<T>( 200 - this IQueryable<T> query, 201 - MastodonPaginationQuery pq, 202 - ControllerContext context 203 - ) where T : IIdentifiable 204 - { 205 - var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 206 - if (attr == null) 207 - throw new Exception("Route doesn't have a IPaginationAttribute"); 207 + public IQueryable<T> Paginate( 208 + Expression<Func<T, string>> predicate, 209 + MastodonPaginationQuery pq, 210 + ControllerContext context 211 + ) 212 + { 213 + var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 214 + if (attr == null) 215 + throw new Exception("Route doesn't have a IPaginationAttribute"); 208 216 209 - return PaginateByOffset(query, pq, attr.DefaultLimit, attr.MaxLimit); 210 - } 217 + return Paginate(query, predicate, pq, attr.DefaultLimit, attr.MaxLimit); 218 + } 211 219 212 - public static IQueryable<T> Paginate<T>( 213 - this IQueryable<T> query, 214 - Expression<Func<T, string>> predicate, 215 - MastodonPaginationQuery pq, 216 - ControllerContext context 217 - ) where T : IIdentifiable 218 - { 219 - var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 220 - if (attr == null) 221 - throw new Exception("Route doesn't have a IPaginationAttribute"); 220 + public IQueryable<T> Paginate( 221 + Expression<Func<T, long>> predicate, 222 + MastodonPaginationQuery pq, 223 + ControllerContext context 224 + ) 225 + { 226 + var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 227 + if (attr == null) 228 + throw new Exception("Route doesn't have a IPaginationAttribute"); 222 229 223 - return Paginate(query, predicate, pq, attr.DefaultLimit, attr.MaxLimit); 224 - } 230 + return Paginate(query, predicate, pq, attr.DefaultLimit, attr.MaxLimit); 231 + } 225 232 226 - public static IQueryable<T> Paginate<T>( 227 - this IQueryable<T> query, 228 - Expression<Func<T, long>> predicate, 229 - MastodonPaginationQuery pq, 230 - ControllerContext context 231 - ) where T : IIdentifiable 232 - { 233 - var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 234 - if (attr == null) 235 - throw new Exception("Route doesn't have a IPaginationAttribute"); 233 + public IQueryable<T> Paginate( 234 + PaginationQuery pq, 235 + ControllerContext context 236 + ) 237 + { 238 + var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 239 + if (attr == null) 240 + throw new Exception("Route doesn't have a IPaginationAttribute"); 236 241 237 - return Paginate(query, predicate, pq, attr.DefaultLimit, attr.MaxLimit); 238 - } 242 + return Paginate(query, pq, attr.DefaultLimit, attr.MaxLimit); 243 + } 239 244 240 - public static IQueryable<T> Paginate<T>( 241 - this IQueryable<T> query, 242 - PaginationQuery pq, 243 - ControllerContext context 244 - ) where T : IIdentifiable 245 - { 246 - var attr = context.HttpContext.GetEndpoint()?.Metadata.GetMetadata<IPaginationAttribute>(); 247 - if (attr == null) 248 - throw new Exception("Route doesn't have a IPaginationAttribute"); 249 - 250 - return Paginate(query, pq, attr.DefaultLimit, attr.MaxLimit); 245 + public IQueryable<EntityWrapper<TResult>> Wrap<TResult>( 246 + Expression<Func<T, TResult>> predicate 247 + ) 248 + { 249 + return query.Select(p => new EntityWrapper<TResult> { Id = p.Id, Entity = predicate.Compile().Invoke(p) }); 250 + } 251 251 } 252 252 253 - public static IQueryable<EntityWrapper<TResult>> Wrap<TSource, TResult>( 254 - this IQueryable<TSource> query, Expression<Func<TSource, TResult>> predicate 255 - ) where TSource : IIdentifiable 253 + extension(IQueryable<Note> query) 256 254 { 257 - return query.Select(p => new EntityWrapper<TResult> { Id = p.Id, Entity = predicate.Compile().Invoke(p) }); 258 - } 259 - 260 - public static IQueryable<Note> HasVisibility(this IQueryable<Note> query, Note.NoteVisibility visibility) 261 - { 262 - return query.Where(note => note.Visibility == visibility); 263 - } 255 + public IQueryable<Note> HasVisibility(Note.NoteVisibility visibility) 256 + { 257 + return query.Where(note => note.Visibility == visibility); 258 + } 264 259 265 - public static IQueryable<Note> FilterByUser(this IQueryable<Note> query, User user) 266 - { 267 - return query.Where(note => note.User == user); 268 - } 260 + public IQueryable<Note> FilterByUser(User user) 261 + { 262 + return query.Where(note => note.User == user); 263 + } 269 264 270 - public static IQueryable<Note> FilterByUser(this IQueryable<Note> query, string? userId) 271 - { 272 - return userId != null ? query.Where(note => note.UserId == userId) : query; 273 - } 265 + public IQueryable<Note> FilterByUser(string? userId) 266 + { 267 + return userId != null ? query.Where(note => note.UserId == userId) : query; 268 + } 274 269 275 - public static IQueryable<Note> EnsureVisibleFor(this IQueryable<Note> query, User? user) 276 - { 277 - return user == null 278 - ? query.Where(note => note.VisibilityIsPublicOrHome && !note.LocalOnly) 279 - : query.Where(note => note.IsVisibleFor(user)); 270 + public IQueryable<Note> EnsureVisibleFor(User? user) 271 + { 272 + return user == null 273 + ? query.Where(note => note.VisibilityIsPublicOrHome && !note.LocalOnly) 274 + : query.Where(note => note.IsVisibleFor(user)); 275 + } 280 276 } 281 277 282 278 public static IQueryable<TSource> EnsureNoteVisibilityFor<TSource>( ··· 288 284 : predicate.Compose(p => p == null || p.IsVisibleFor(user))); 289 285 } 290 286 291 - public static IQueryable<Note> PrecomputeVisibilities(this IQueryable<Note> query, User? user) 287 + extension(IQueryable<Note> query) 292 288 { 293 - return query.Select(p => p.WithPrecomputedVisibilities(p.Reply != null && p.Reply.IsVisibleFor(user), 294 - p.Renote != null && p.Renote.IsVisibleFor(user), 295 - p.Renote != null && 296 - p.Renote.Renote != null && 297 - p.Renote.Renote.IsVisibleFor(user))); 298 - } 289 + public IQueryable<Note> PrecomputeVisibilities(User? user) 290 + { 291 + return query.Select(p => p.WithPrecomputedVisibilities(p.Reply != null && p.Reply.IsVisibleFor(user), 292 + p.Renote != null && p.Renote.IsVisibleFor(user), 293 + p.Renote != null && 294 + p.Renote.Renote != null && 295 + p.Renote.Renote.IsVisibleFor(user))); 296 + } 299 297 300 - public static IQueryable<Note> PrecomputeNoteContextVisibilities(this IQueryable<Note> query, User? user) 301 - { 302 - return query.Select(p => p.WithPrecomputedVisibilities(p.Reply != null && p.Reply.IsVisibleFor(user), 303 - p.Renote != null && p.Renote.IsVisibleFor(user), 304 - p.Renote != null && 305 - false)); 298 + public IQueryable<Note> PrecomputeNoteContextVisibilities(User? user) 299 + { 300 + return query.Select(p => p.WithPrecomputedVisibilities(p.Reply != null && p.Reply.IsVisibleFor(user), 301 + p.Renote != null && p.Renote.IsVisibleFor(user), 302 + p.Renote != null && 303 + false)); 304 + } 306 305 } 307 306 308 307 public static IQueryable<Notification> PrecomputeNoteVisibilities(this IQueryable<Notification> query, User user) ··· 338 337 return query.Where(p => !hidden.Contains(p.NotifierId) && (p.Note == null || !hidden.Contains(p.Note.Id))); 339 338 } 340 339 341 - public static IQueryable<Note> FilterHiddenConversations(this IQueryable<Note> query, User user, DatabaseContext db) 340 + extension(IQueryable<Note> query) 342 341 { 343 - //TODO: handle muted instances 342 + public IQueryable<Note> FilterHiddenConversations(User user, DatabaseContext db) 343 + { 344 + //TODO: handle muted instances 344 345 345 - var blocks = db.Blockings.Where(i => i.Blocker == user).Select(p => p.BlockeeId); 346 - var mutes = db.Mutings.Where(i => i.Muter == user).Select(p => p.MuteeId); 347 - var hidden = blocks.Concat(mutes); 346 + var blocks = db.Blockings.Where(i => i.Blocker == user).Select(p => p.BlockeeId); 347 + var mutes = db.Mutings.Where(i => i.Muter == user).Select(p => p.MuteeId); 348 + var hidden = blocks.Concat(mutes); 348 349 349 - return query.Where(p => p.VisibleUserIds.IsDisjoint(hidden)); 350 - } 350 + return query.Where(p => p.VisibleUserIds.IsDisjoint(hidden)); 351 + } 351 352 352 - public static IQueryable<Note> FilterMutedThreads(this IQueryable<Note> query, User user, DatabaseContext db) 353 - { 354 - return query.Where(p => p.User == user || 355 - !db.NoteThreadMutings.Any(m => m.User == user && m.ThreadId == p.ThreadId)); 353 + public IQueryable<Note> FilterMutedThreads(User user, DatabaseContext db) 354 + { 355 + return query.Where(p => p.User == user || 356 + !db.NoteThreadMutings.Any(m => m.User == user && m.ThreadId == p.ThreadId)); 357 + } 356 358 } 357 359 358 360 public static IQueryable<Notification> FilterMutedThreads( ··· 568 570 return query; 569 571 } 570 572 571 - public static IQueryable<Note> FilterByPublicTimelineRequest( 572 - this IQueryable<Note> query, TimelineSchemas.PublicTimelineRequest request, DatabaseContext db 573 - ) 573 + extension(IQueryable<Note> query) 574 574 { 575 - if (request.OnlyLocal) 576 - query = query.Where(p => p.UserHost == null); 577 - if (request.OnlyRemote) 578 - query = query.Where(p => p.UserHost != null); 579 - if (request.OnlyMedia) 580 - query = query.Where(p => p.FileIds.Count != 0); 581 - if (request.Bubble) 582 - query = query.Where(p => p.UserHost == null || db.BubbleInstances.Any(i => i.Host == p.UserHost)); 575 + public IQueryable<Note> FilterByPublicTimelineRequest( 576 + TimelineSchemas.PublicTimelineRequest request, DatabaseContext db 577 + ) 578 + { 579 + if (request.OnlyLocal) 580 + query = query.Where(p => p.UserHost == null); 581 + if (request.OnlyRemote) 582 + query = query.Where(p => p.UserHost != null); 583 + if (request.OnlyMedia) 584 + query = query.Where(p => p.FileIds.Count != 0); 585 + if (request.Bubble) 586 + query = query.Where(p => p.UserHost == null || db.BubbleInstances.Any(i => i.Host == p.UserHost)); 583 587 584 - return query; 585 - } 588 + return query; 589 + } 586 590 587 - public static IQueryable<Note> FilterByHashtagTimelineRequest( 588 - this IQueryable<Note> query, TimelineSchemas.HashtagTimelineRequest request, DatabaseContext db 589 - ) 590 - { 591 - if (request.Any.Count > 0) 592 - query = query.Where(p => request.Any.Any(t => p.Tags.Contains(t))); 593 - if (request.All.Count > 0) 594 - query = query.Where(p => request.All.All(t => p.Tags.Contains(t))); 595 - if (request.None.Count > 0) 596 - query = query.Where(p => request.None.All(t => !p.Tags.Contains(t))); 591 + public IQueryable<Note> FilterByHashtagTimelineRequest( 592 + TimelineSchemas.HashtagTimelineRequest request, DatabaseContext db 593 + ) 594 + { 595 + if (request.Any.Count > 0) 596 + query = query.Where(p => request.Any.Any(t => p.Tags.Contains(t))); 597 + if (request.All.Count > 0) 598 + query = query.Where(p => request.All.All(t => p.Tags.Contains(t))); 599 + if (request.None.Count > 0) 600 + query = query.Where(p => request.None.All(t => !p.Tags.Contains(t))); 597 601 598 - return query.FilterByPublicTimelineRequest(request, db); 602 + return query.FilterByPublicTimelineRequest(request, db); 603 + } 599 604 } 600 605 601 606 #pragma warning disable CS8602 // Dereference of a possibly null reference.
+155 -152
Iceshrimp.Backend/Core/Extensions/QueryableFtsExtensions.cs
··· 64 64 internal static string? LocalDomainCheck(string? host, Config.InstanceSection config) => 65 65 host == null || host == config.WebDomain || host == config.AccountDomain ? null : host; 66 66 67 - private static IQueryable<Note> ApplyAfterFilter(this IQueryable<Note> query, AfterFilter filter) 68 - => query.Where(p => p.CreatedAt >= filter.Value.ToDateTime(TimeOnly.MinValue).ToUniversalTime()); 67 + extension(IQueryable<Note> query) 68 + { 69 + private IQueryable<Note> ApplyAfterFilter(AfterFilter filter) 70 + => query.Where(p => p.CreatedAt >= filter.Value.ToDateTime(TimeOnly.MinValue).ToUniversalTime()); 69 71 70 - private static IQueryable<Note> ApplyBeforeFilter(this IQueryable<Note> query, BeforeFilter filter) 71 - => query.Where(p => p.CreatedAt < filter.Value.ToDateTime(TimeOnly.MinValue).ToUniversalTime()); 72 + private IQueryable<Note> ApplyBeforeFilter(BeforeFilter filter) 73 + => query.Where(p => p.CreatedAt < filter.Value.ToDateTime(TimeOnly.MinValue).ToUniversalTime()); 72 74 73 - private static IQueryable<Note> ApplyWordFilter( 74 - this IQueryable<Note> query, WordFilter filter, CaseFilterType caseSensitivity, MatchFilterType matchType 75 - ) => query.Where(p => p.FtsQueryPreEscaped(PreEscapeFtsQuery(filter.Value, matchType), filter.Negated, 76 - caseSensitivity, matchType)); 75 + private IQueryable<Note> ApplyWordFilter( 76 + WordFilter filter, CaseFilterType caseSensitivity, MatchFilterType matchType 77 + ) => query.Where(p => p.FtsQueryPreEscaped(PreEscapeFtsQuery(filter.Value, matchType), filter.Negated, 78 + caseSensitivity, matchType)); 77 79 78 - private static IQueryable<Note> ApplyCwFilter( 79 - this IQueryable<Note> query, CwFilter filter, CaseFilterType caseSensitivity, MatchFilterType matchType 80 - ) => query.Where(p => FtsQueryPreEscapedMatch(p.Cw, PreEscapeFtsQuery(filter.Value, matchType), filter.Negated, 81 - caseSensitivity, matchType)); 80 + private IQueryable<Note> ApplyCwFilter( 81 + CwFilter filter, CaseFilterType caseSensitivity, MatchFilterType matchType 82 + ) => query.Where(p => FtsQueryPreEscapedMatch(p.Cw, PreEscapeFtsQuery(filter.Value, matchType), filter.Negated, 83 + caseSensitivity, matchType)); 82 84 83 - private static IQueryable<Note> ApplyMultiWordFilter( 84 - this IQueryable<Note> query, MultiWordFilter filter, CaseFilterType caseSensitivity, MatchFilterType matchType 85 - ) => filter.Negated 86 - ? query.Where(p => !p.FtsQueryOneOf(filter.Values, caseSensitivity, matchType)) 87 - : query.Where(p => p.FtsQueryOneOf(filter.Values, caseSensitivity, matchType)); 85 + private IQueryable<Note> ApplyMultiWordFilter( 86 + MultiWordFilter filter, CaseFilterType caseSensitivity, MatchFilterType matchType 87 + ) => filter.Negated 88 + ? query.Where(p => !p.FtsQueryOneOf(filter.Values, caseSensitivity, matchType)) 89 + : query.Where(p => p.FtsQueryOneOf(filter.Values, caseSensitivity, matchType)); 88 90 89 - private static IQueryable<Note> ApplyFromFilters( 90 - this IQueryable<Note> query, List<FromFilter> filters, Config.InstanceSection config, DatabaseContext db 91 - ) 92 - { 93 - if (filters.Count == 0) return query; 94 - var expr = ExpressionExtensions.False<Note>(); 95 - expr = filters.Aggregate(expr, (current, filter) => current 96 - .Or(p => p.User.UserSubqueryMatches(filter.Value, filter.Negated, config, db))); 97 - return query.Where(expr); 98 - } 91 + private IQueryable<Note> ApplyFromFilters( 92 + List<FromFilter> filters, Config.InstanceSection config, DatabaseContext db 93 + ) 94 + { 95 + if (filters.Count == 0) return query; 96 + var expr = ExpressionExtensions.False<Note>(); 97 + expr = filters.Aggregate(expr, (current, filter) => current 98 + .Or(p => p.User.UserSubqueryMatches(filter.Value, filter.Negated, config, db))); 99 + return query.Where(expr); 100 + } 99 101 100 - private static IQueryable<Note> ApplyInstanceFilter( 101 - this IQueryable<Note> query, InstanceFilter filter, Config.InstanceSection config 102 - ) => query.Where(p => filter.Negated 103 - ? p.UserHost != LocalDomainCheck(filter.Value, config) 104 - : p.UserHost == LocalDomainCheck(filter.Value, config)); 102 + private IQueryable<Note> ApplyInstanceFilter( 103 + InstanceFilter filter, Config.InstanceSection config 104 + ) => query.Where(p => filter.Negated 105 + ? p.UserHost != LocalDomainCheck(filter.Value, config) 106 + : p.UserHost == LocalDomainCheck(filter.Value, config)); 105 107 106 - private static IQueryable<Note> ApplyMentionFilter( 107 - this IQueryable<Note> query, MentionFilter filter, Config.InstanceSection config, DatabaseContext db 108 - ) => query.Where(p => p.Mentions.UserSubqueryContains(filter.Value, filter.Negated, config, db)); 108 + private IQueryable<Note> ApplyMentionFilter( 109 + MentionFilter filter, Config.InstanceSection config, DatabaseContext db 110 + ) => query.Where(p => p.Mentions.UserSubqueryContains(filter.Value, filter.Negated, config, db)); 109 111 110 - private static IQueryable<Note> ApplyReplyFilter( 111 - this IQueryable<Note> query, ReplyFilter filter, Config.InstanceSection config, DatabaseContext db 112 - ) => query.Where(p => p.Reply != null 113 - && p.Reply.User.UserSubqueryMatches(filter.Value, filter.Negated, config, db)); 112 + private IQueryable<Note> ApplyReplyFilter( 113 + ReplyFilter filter, Config.InstanceSection config, DatabaseContext db 114 + ) => query.Where(p => p.Reply != null 115 + && p.Reply.User.UserSubqueryMatches(filter.Value, filter.Negated, config, db)); 114 116 115 - private static IQueryable<Note> ApplyInFilter( 116 - this IQueryable<Note> query, InFilter filter, User user, DatabaseContext db 117 - ) 118 - { 119 - return filter.Value switch 117 + private IQueryable<Note> ApplyInFilter( 118 + InFilter filter, User user, DatabaseContext db 119 + ) 120 120 { 121 - InFilterType.Likes => query.ApplyInLikesFilter(user, filter.Negated, db), 122 - InFilterType.Bookmarks => query.ApplyInBookmarksFilter(user, filter.Negated, db), 123 - InFilterType.Reactions => query.ApplyInReactionsFilter(user, filter.Negated, db), 124 - InFilterType.Interactions => query.ApplyInInteractionsFilter(user, filter.Negated, db), 125 - _ => throw new ArgumentOutOfRangeException(nameof(filter), filter.Value, null) 126 - }; 127 - } 121 + return filter.Value switch 122 + { 123 + InFilterType.Likes => query.ApplyInLikesFilter(user, filter.Negated, db), 124 + InFilterType.Bookmarks => query.ApplyInBookmarksFilter(user, filter.Negated, db), 125 + InFilterType.Reactions => query.ApplyInReactionsFilter(user, filter.Negated, db), 126 + InFilterType.Interactions => query.ApplyInInteractionsFilter(user, filter.Negated, db), 127 + _ => throw new ArgumentOutOfRangeException(nameof(filter), filter.Value, null) 128 + }; 129 + } 128 130 129 - [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 130 - private static IQueryable<Note> ApplyInBookmarksFilter( 131 - this IQueryable<Note> query, User user, bool negated, DatabaseContext db 132 - ) => query.Where(p => negated 133 - ? !db.Users.First(u => u == user).HasBookmarked(p) 134 - : db.Users.First(u => u == user).HasBookmarked(p)); 131 + [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 132 + private IQueryable<Note> ApplyInBookmarksFilter( 133 + User user, bool negated, DatabaseContext db 134 + ) => query.Where(p => negated 135 + ? !db.Users.First(u => u == user).HasBookmarked(p) 136 + : db.Users.First(u => u == user).HasBookmarked(p)); 135 137 136 - [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 137 - private static IQueryable<Note> ApplyInLikesFilter( 138 - this IQueryable<Note> query, User user, bool negated, DatabaseContext db 139 - ) => query.Where(p => negated 140 - ? !db.Users.First(u => u == user).HasLiked(p) 141 - : db.Users.First(u => u == user).HasLiked(p)); 138 + [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 139 + private IQueryable<Note> ApplyInLikesFilter( 140 + User user, bool negated, DatabaseContext db 141 + ) => query.Where(p => negated 142 + ? !db.Users.First(u => u == user).HasLiked(p) 143 + : db.Users.First(u => u == user).HasLiked(p)); 142 144 143 - [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 144 - private static IQueryable<Note> ApplyInReactionsFilter( 145 - this IQueryable<Note> query, User user, bool negated, DatabaseContext db 146 - ) => query.Where(p => negated 147 - ? !db.Users.First(u => u == user).HasReacted(p) 148 - : db.Users.First(u => u == user).HasReacted(p)); 145 + [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 146 + private IQueryable<Note> ApplyInReactionsFilter( 147 + User user, bool negated, DatabaseContext db 148 + ) => query.Where(p => negated 149 + ? !db.Users.First(u => u == user).HasReacted(p) 150 + : db.Users.First(u => u == user).HasReacted(p)); 149 151 150 - [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 151 - private static IQueryable<Note> ApplyInInteractionsFilter( 152 - this IQueryable<Note> query, User user, bool negated, DatabaseContext db 153 - ) => query.Where(p => negated 154 - ? !db.Users.First(u => u == user).HasInteractedWith(p) 155 - : db.Users.First(u => u == user).HasInteractedWith(p)); 152 + [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 153 + private IQueryable<Note> ApplyInInteractionsFilter( 154 + User user, bool negated, DatabaseContext db 155 + ) => query.Where(p => negated 156 + ? !db.Users.First(u => u == user).HasInteractedWith(p) 157 + : db.Users.First(u => u == user).HasInteractedWith(p)); 156 158 157 - private static IQueryable<Note> ApplyMiscFilter(this IQueryable<Note> query, MiscFilter filter, User user) 158 - { 159 - return filter.Value switch 159 + private IQueryable<Note> ApplyMiscFilter(MiscFilter filter, User user) 160 160 { 161 - MiscFilterType.Followers => query.ApplyFollowersFilter(user, filter.Negated), 162 - MiscFilterType.Following => query.ApplyFollowingFilter(user, filter.Negated), 163 - MiscFilterType.Renotes => query.ApplyBoostsFilter(filter.Negated), 164 - MiscFilterType.Replies => query.ApplyRepliesFilter(filter.Negated), 165 - _ => throw new ArgumentOutOfRangeException(nameof(filter)) 166 - }; 167 - } 161 + return filter.Value switch 162 + { 163 + MiscFilterType.Followers => query.ApplyFollowersFilter(user, filter.Negated), 164 + MiscFilterType.Following => query.ApplyFollowingFilter(user, filter.Negated), 165 + MiscFilterType.Renotes => query.ApplyBoostsFilter(filter.Negated), 166 + MiscFilterType.Replies => query.ApplyRepliesFilter(filter.Negated), 167 + _ => throw new ArgumentOutOfRangeException(nameof(filter)) 168 + }; 169 + } 168 170 169 - private static IQueryable<Note> ApplyVisibilityFilter(this IQueryable<Note> query, VisibilityFilter filter) 170 - { 171 - if (filter.Value is VisibilityFilterType.Local) 172 - return query.Where(p => p.LocalOnly == !filter.Negated); 173 - 174 - var visibility = filter.Value switch 171 + private IQueryable<Note> ApplyVisibilityFilter(VisibilityFilter filter) 175 172 { 176 - VisibilityFilterType.Public => Note.NoteVisibility.Public, 177 - VisibilityFilterType.Home => Note.NoteVisibility.Home, 178 - VisibilityFilterType.Followers => Note.NoteVisibility.Followers, 179 - VisibilityFilterType.Specified => Note.NoteVisibility.Specified, 180 - _ => throw new ArgumentOutOfRangeException() 181 - }; 173 + if (filter.Value is VisibilityFilterType.Local) 174 + return query.Where(p => p.LocalOnly == !filter.Negated); 182 175 183 - return filter.Negated 184 - ? query.Where(p => p.Visibility != visibility) 185 - : query.Where(p => p.Visibility == visibility); 186 - } 176 + var visibility = filter.Value switch 177 + { 178 + VisibilityFilterType.Public => Note.NoteVisibility.Public, 179 + VisibilityFilterType.Home => Note.NoteVisibility.Home, 180 + VisibilityFilterType.Followers => Note.NoteVisibility.Followers, 181 + VisibilityFilterType.Specified => Note.NoteVisibility.Specified, 182 + _ => throw new ArgumentOutOfRangeException() 183 + }; 187 184 188 - [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 189 - private static IQueryable<Note> ApplyFollowersFilter(this IQueryable<Note> query, User user, bool negated) 190 - => query.Where(p => negated ? !p.User.IsFollowing(user) : p.User.IsFollowing(user)); 185 + return filter.Negated 186 + ? query.Where(p => p.Visibility != visibility) 187 + : query.Where(p => p.Visibility == visibility); 188 + } 191 189 192 - [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 193 - private static IQueryable<Note> ApplyFollowingFilter(this IQueryable<Note> query, User user, bool negated) 194 - => query.Where(p => negated ? !p.User.IsFollowedBy(user) : p.User.IsFollowedBy(user)); 190 + [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 191 + private IQueryable<Note> ApplyFollowersFilter(User user, bool negated) 192 + => query.Where(p => negated ? !p.User.IsFollowing(user) : p.User.IsFollowing(user)); 195 193 196 - private static IQueryable<Note> ApplyRepliesFilter(this IQueryable<Note> query, bool negated) 197 - => query.Where(p => negated ? p.Reply == null : p.Reply != null); 194 + [SuppressMessage("ReSharper", "EntityFramework.UnsupportedServerSideFunctionCall", Justification = "Projectables")] 195 + private IQueryable<Note> ApplyFollowingFilter(User user, bool negated) 196 + => query.Where(p => negated ? !p.User.IsFollowedBy(user) : p.User.IsFollowedBy(user)); 198 197 199 - private static IQueryable<Note> ApplyBoostsFilter(this IQueryable<Note> query, bool negated) 200 - => query.Where(p => negated ? !p.IsPureRenote : p.IsPureRenote); 198 + private IQueryable<Note> ApplyRepliesFilter(bool negated) 199 + => query.Where(p => negated ? p.Reply == null : p.Reply != null); 201 200 202 - private static IQueryable<Note> ApplyAttachmentFilter(this IQueryable<Note> query, AttachmentFilter filter) 203 - => filter.Negated ? query.ApplyNegatedAttachmentFilter(filter) : query.ApplyRegularAttachmentFilter(filter); 201 + private IQueryable<Note> ApplyBoostsFilter(bool negated) 202 + => query.Where(p => negated ? !p.IsPureRenote : p.IsPureRenote); 204 203 205 - private static IQueryable<Note> ApplyRegularAttachmentFilter(this IQueryable<Note> query, AttachmentFilter filter) 206 - { 207 - if (filter.Value is AttachmentFilterType.Media) 208 - return query.Where(p => p.AttachedFileTypes.Count != 0); 209 - if (filter.Value is AttachmentFilterType.Poll) 210 - return query.Where(p => p.HasPoll); 204 + private IQueryable<Note> ApplyAttachmentFilter(AttachmentFilter filter) 205 + => filter.Negated ? query.ApplyNegatedAttachmentFilter(filter) : query.ApplyRegularAttachmentFilter(filter); 211 206 212 - if (filter.Value is AttachmentFilterType.Image or AttachmentFilterType.Video or AttachmentFilterType.Audio) 207 + private IQueryable<Note> ApplyRegularAttachmentFilter(AttachmentFilter filter) 213 208 { 214 - return query.Where(p => p.AttachedFileTypes.Count != 0 215 - && EF.Functions.ILike(p.RawAttachments, GetAttachmentILikeQuery(filter.Value))); 216 - } 209 + if (filter.Value is AttachmentFilterType.Media) 210 + return query.Where(p => p.AttachedFileTypes.Count != 0); 211 + if (filter.Value is AttachmentFilterType.Poll) 212 + return query.Where(p => p.HasPoll); 217 213 218 - if (filter.Value is AttachmentFilterType.File) 219 - { 220 - return query.Where(p => p.AttachedFileTypes.Count != 0 221 - && (!EF.Functions.ILike(p.RawAttachments, 222 - GetAttachmentILikeQuery(AttachmentFilterType.Image)) 223 - || !EF.Functions.ILike(p.RawAttachments, 224 - GetAttachmentILikeQuery(AttachmentFilterType.Video)) 225 - || !EF.Functions.ILike(p.RawAttachments, 226 - GetAttachmentILikeQuery(AttachmentFilterType.Audio)))); 227 - } 214 + if (filter.Value is AttachmentFilterType.Image or AttachmentFilterType.Video or AttachmentFilterType.Audio) 215 + { 216 + return query.Where(p => p.AttachedFileTypes.Count != 0 217 + && EF.Functions.ILike(p.RawAttachments, GetAttachmentILikeQuery(filter.Value))); 218 + } 228 219 229 - throw new ArgumentOutOfRangeException(nameof(filter), filter.Value, null); 230 - } 220 + if (filter.Value is AttachmentFilterType.File) 221 + { 222 + return query.Where(p => p.AttachedFileTypes.Count != 0 223 + && (!EF.Functions.ILike(p.RawAttachments, 224 + GetAttachmentILikeQuery(AttachmentFilterType.Image)) 225 + || !EF.Functions.ILike(p.RawAttachments, 226 + GetAttachmentILikeQuery(AttachmentFilterType.Video)) 227 + || !EF.Functions.ILike(p.RawAttachments, 228 + GetAttachmentILikeQuery(AttachmentFilterType.Audio)))); 229 + } 231 230 232 - private static IQueryable<Note> ApplyNegatedAttachmentFilter(this IQueryable<Note> query, AttachmentFilter filter) 233 - { 234 - if (filter.Value is AttachmentFilterType.Media) 235 - return query.Where(p => p.AttachedFileTypes.Count == 0); 236 - if (filter.Value is AttachmentFilterType.Poll) 237 - return query.Where(p => !p.HasPoll); 238 - if (filter.Value is AttachmentFilterType.Image or AttachmentFilterType.Video or AttachmentFilterType.Audio) 239 - return query.Where(p => !EF.Functions.ILike(p.RawAttachments, GetAttachmentILikeQuery(filter.Value))); 231 + throw new ArgumentOutOfRangeException(nameof(filter), filter.Value, null); 232 + } 240 233 241 - if (filter.Value is AttachmentFilterType.File) 234 + private IQueryable<Note> ApplyNegatedAttachmentFilter(AttachmentFilter filter) 242 235 { 243 - return query.Where(p => EF.Functions 244 - .ILike(p.RawAttachments, GetAttachmentILikeQuery(AttachmentFilterType.Image)) 245 - || EF.Functions 246 - .ILike(p.RawAttachments, GetAttachmentILikeQuery(AttachmentFilterType.Video)) 247 - || EF.Functions 248 - .ILike(p.RawAttachments, GetAttachmentILikeQuery(AttachmentFilterType.Audio))); 249 - } 236 + if (filter.Value is AttachmentFilterType.Media) 237 + return query.Where(p => p.AttachedFileTypes.Count == 0); 238 + if (filter.Value is AttachmentFilterType.Poll) 239 + return query.Where(p => !p.HasPoll); 240 + if (filter.Value is AttachmentFilterType.Image or AttachmentFilterType.Video or AttachmentFilterType.Audio) 241 + return query.Where(p => !EF.Functions.ILike(p.RawAttachments, GetAttachmentILikeQuery(filter.Value))); 250 242 251 - throw new ArgumentOutOfRangeException(nameof(filter), filter.Value, null); 243 + if (filter.Value is AttachmentFilterType.File) 244 + { 245 + return query.Where(p => EF.Functions 246 + .ILike(p.RawAttachments, GetAttachmentILikeQuery(AttachmentFilterType.Image)) 247 + || EF.Functions 248 + .ILike(p.RawAttachments, GetAttachmentILikeQuery(AttachmentFilterType.Video)) 249 + || EF.Functions 250 + .ILike(p.RawAttachments, GetAttachmentILikeQuery(AttachmentFilterType.Audio))); 251 + } 252 + 253 + throw new ArgumentOutOfRangeException(nameof(filter), filter.Value, null); 254 + } 252 255 } 253 256 254 257 [SuppressMessage("ReSharper", "MemberCanBePrivate.Global",
+33 -30
Iceshrimp.Backend/Core/Extensions/QueryableTimelineExtensions.cs
··· 15 15 16 16 private const string Prefix = "following-query-heuristic"; 17 17 18 - public static IQueryable<Note> FilterByFollowingAndOwn( 19 - this IQueryable<Note> query, User user, DatabaseContext db, int heuristic 20 - ) 18 + extension(IQueryable<Note> query) 21 19 { 22 - var q = heuristic < Cutoff 23 - ? query.Where(FollowingAndOwnLowFreqExpr(user, db)) 24 - : query.Where(note => note.User == user || note.User.IsFollowedBy(user)); 20 + public IQueryable<Note> FilterByFollowingAndOwn( 21 + User user, DatabaseContext db, int heuristic 22 + ) 23 + { 24 + var q = heuristic < Cutoff 25 + ? query.Where(FollowingAndOwnLowFreqExpr(user, db)) 26 + : query.Where(note => note.User == user || note.User.IsFollowedBy(user)); 25 27 26 - if (user.UserSettings?.HideRepliesNotFollowing == true) 27 - q = q.Where(note => note.User == user || note.MastoReplyUser == null || note.MastoReplyUser.IsFollowedBy(user)); 28 + if (user.UserSettings?.HideRepliesNotFollowing == true) 29 + q = q.Where(note => note.User == user || note.MastoReplyUser == null || note.MastoReplyUser.IsFollowedBy(user)); 28 30 29 - return q; 30 - } 31 + return q; 32 + } 31 33 32 - public static IQueryable<Note> FilterByPublicFollowingAndOwn( 33 - this IQueryable<Note> query, User user, DatabaseContext db, int heuristic 34 - ) 35 - { 36 - return heuristic < Cutoff 37 - ? query.Where(FollowingAndOwnLowFreqExpr(user, db).Or(p => p.Visibility == Note.NoteVisibility.Public)) 38 - : query.Where(note => note.Visibility == Note.NoteVisibility.Public 39 - || note.User == user 40 - || note.User.IsFollowedBy(user)); 41 - } 34 + public IQueryable<Note> FilterByPublicFollowingAndOwn( 35 + User user, DatabaseContext db, int heuristic 36 + ) 37 + { 38 + return heuristic < Cutoff 39 + ? query.Where(FollowingAndOwnLowFreqExpr(user, db).Or(p => p.Visibility == Note.NoteVisibility.Public)) 40 + : query.Where(note => note.Visibility == Note.NoteVisibility.Public 41 + || note.User == user 42 + || note.User.IsFollowedBy(user)); 43 + } 42 44 43 - public static IQueryable<Note> FilterByFollowingOwnAndLocal( 44 - this IQueryable<Note> query, User user, DatabaseContext db, int heuristic 45 - ) 46 - { 47 - return heuristic < Cutoff 48 - ? query.Where(FollowingAndOwnLowFreqExpr(user, db) 49 - .Or(p => p.UserHost == null && p.Visibility == Note.NoteVisibility.Public)) 50 - : query.Where(note => note.User == user 51 - || note.User.IsFollowedBy(user) 52 - || (note.UserHost == null && note.Visibility == Note.NoteVisibility.Public)); 45 + public IQueryable<Note> FilterByFollowingOwnAndLocal( 46 + User user, DatabaseContext db, int heuristic 47 + ) 48 + { 49 + return heuristic < Cutoff 50 + ? query.Where(FollowingAndOwnLowFreqExpr(user, db) 51 + .Or(p => p.UserHost == null && p.Visibility == Note.NoteVisibility.Public)) 52 + : query.Where(note => note.User == user 53 + || note.User.IsFollowedBy(user) 54 + || (note.UserHost == null && note.Visibility == Note.NoteVisibility.Public)); 55 + } 53 56 } 54 57 55 58 private static Expression<Func<Note,bool>> FollowingAndOwnLowFreqExpr(User user, DatabaseContext db)
+279 -273
Iceshrimp.Backend/Core/Extensions/ServiceExtensions.cs
··· 29 29 30 30 public static class ServiceExtensions 31 31 { 32 - public static void AddServices(this IServiceCollection services, IConfiguration configuration) 32 + extension(IServiceCollection services) 33 33 { 34 - var config = configuration.Get<Config>() ?? throw new Exception("Failed to read storage config section"); 34 + public void AddServices(IConfiguration configuration) 35 + { 36 + var config = configuration.Get<Config>() ?? throw new Exception("Failed to read storage config section"); 35 37 36 - var serviceTypes = PluginLoader 37 - .Assemblies.Prepend(Assembly.GetExecutingAssembly()) 38 - .SelectMany(AssemblyLoader.GetImplementationsOfInterface<IService>) 39 - .OrderBy(type => type.GetInterfaceProperty<IService, int?>(nameof(IService.Priority)) ?? 0) 40 - .ToArray(); 38 + var serviceTypes = PluginLoader 39 + .Assemblies.Prepend(Assembly.GetExecutingAssembly()) 40 + .SelectMany(AssemblyLoader.GetImplementationsOfInterface<IService>) 41 + .OrderBy(type => type.GetInterfaceProperty<IService, int?>(nameof(IService.Priority)) ?? 0) 42 + .ToArray(); 41 43 42 - foreach (var type in serviceTypes) 43 - { 44 - if (type.GetInterfaceProperty<IService, ServiceLifetime?>(nameof(IService.Lifetime)) is not { } lifetime) 45 - continue; 44 + foreach (var type in serviceTypes) 45 + { 46 + if (type.GetInterfaceProperty<IService, ServiceLifetime?>(nameof(IService.Lifetime)) is not { } lifetime) 47 + continue; 46 48 47 - if (type.GetInterface(nameof(IConditionalService)) != null) 48 - if (type.CallInterfaceMethod(nameof(IConditionalService.Predicate), config) is not true) 49 - continue; 49 + if (type.GetInterface(nameof(IConditionalService)) != null) 50 + if (type.CallInterfaceMethod(nameof(IConditionalService.Predicate), config) is not true) 51 + continue; 50 52 51 - var serviceType = type.GetInterfaceProperty<IService, Type>(nameof(IService.ServiceType)) ?? type; 52 - services.Add(new ServiceDescriptor(serviceType, type, lifetime)); 53 - } 53 + var serviceType = type.GetInterfaceProperty<IService, Type>(nameof(IService.ServiceType)) ?? type; 54 + services.Add(new ServiceDescriptor(serviceType, type, lifetime)); 55 + } 54 56 55 - var hostedServiceTypes = PluginLoader 56 - .Assemblies.Prepend(Assembly.GetExecutingAssembly()) 57 - .SelectMany(AssemblyLoader.GetImplementationsOfInterface<IHostedService>) 58 - .ToArray(); 57 + var hostedServiceTypes = PluginLoader 58 + .Assemblies.Prepend(Assembly.GetExecutingAssembly()) 59 + .SelectMany(AssemblyLoader.GetImplementationsOfInterface<IHostedService>) 60 + .ToArray(); 59 61 60 - foreach (var type in hostedServiceTypes) 61 - { 62 - if (type.GetInterface(nameof(IService)) == null) 63 - services.Add(new ServiceDescriptor(type, type, ServiceLifetime.Singleton)); 62 + foreach (var type in hostedServiceTypes) 63 + { 64 + if (type.GetInterface(nameof(IService)) == null) 65 + services.Add(new ServiceDescriptor(type, type, ServiceLifetime.Singleton)); 64 66 65 - services.Add(new ServiceDescriptor(typeof(IHostedService), provider => provider.GetRequiredService(type), 66 - ServiceLifetime.Singleton)); 67 + services.Add(new ServiceDescriptor(typeof(IHostedService), provider => provider.GetRequiredService(type), 68 + ServiceLifetime.Singleton)); 69 + } 67 70 } 68 - } 69 71 70 - public static void AddMiddleware(this IServiceCollection services) 71 - { 72 - var types = PluginLoader 73 - .Assemblies.Prepend(Assembly.GetExecutingAssembly()) 74 - .SelectMany(p => AssemblyLoader.GetImplementationsOfInterface(p, typeof(IMiddlewareService))); 75 - 76 - foreach (var type in types) 72 + public void AddMiddleware() 77 73 { 78 - if (type.GetProperty(nameof(IMiddlewareService.Lifetime))?.GetValue(null) is not ServiceLifetime lifetime) 79 - continue; 74 + var types = PluginLoader 75 + .Assemblies.Prepend(Assembly.GetExecutingAssembly()) 76 + .SelectMany(p => AssemblyLoader.GetImplementationsOfInterface(p, typeof(IMiddlewareService))); 80 77 81 - services.Add(new ServiceDescriptor(type, type, lifetime)); 78 + foreach (var type in types) 79 + { 80 + if (type.GetProperty(nameof(IMiddlewareService.Lifetime))?.GetValue(null) is not ServiceLifetime lifetime) 81 + continue; 82 + 83 + services.Add(new ServiceDescriptor(type, type, lifetime)); 84 + } 82 85 } 83 - } 84 86 85 - public static void ConfigureServices(this IServiceCollection services, IConfiguration configuration) 86 - { 87 - // @formatter:off 88 - services.ConfigureWithValidation<Config>(configuration) 89 - .ConfigureWithValidation<Config.InstanceSection>(configuration, "Instance") 90 - .ConfigureWithValidation<Config.SecuritySection>(configuration, "Security") 91 - .ConfigureWithValidation<Config.NetworkSection>(configuration, "Network") 92 - .ConfigureWithValidation<Config.PerformanceSection>(configuration, "Performance") 93 - .ConfigureWithValidation<Config.QueueConcurrencySection>(configuration, "Performance:QueueConcurrency") 94 - .ConfigureWithValidation<Config.BackfillSection>(configuration, "Backfill") 95 - .ConfigureWithValidation<Config.BackfillRepliesSection>(configuration, "Backfill:Replies") 96 - .ConfigureWithValidation<Config.BackfillUserSection>(configuration, "Backfill:User") 97 - .ConfigureWithValidation<Config.QueueSection>(configuration, "Queue") 98 - .ConfigureWithValidation<Config.JobRetentionSection>(configuration, "Queue:JobRetention") 99 - .ConfigureWithValidation<Config.DatabaseSection>(configuration, "Database") 100 - .ConfigureWithValidation<Config.StorageSection>(configuration, "Storage") 101 - .ConfigureWithValidation<Config.LocalStorageSection>(configuration, "Storage:Local") 102 - .ConfigureWithValidation<Config.ObjectStorageSection>(configuration, "Storage:ObjectStorage") 103 - .ConfigureWithValidation<Config.MediaProcessingSection>(configuration, "Storage:MediaProcessing") 104 - .ConfigureWithValidation<Config.ImagePipelineSection>(configuration, "Storage:MediaProcessing:ImagePipeline") 105 - .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Original:Local") 106 - .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Original:Remote") 107 - .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Thumbnail:Local") 108 - .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Thumbnail:Remote") 109 - .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Public:Local") 110 - .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Public:Remote") 111 - .ConfigureWithValidation<Config.OpenTelemetrySection>(configuration, "OpenTelemetry"); 112 - // @formatter:on 113 - 114 - services.Configure<JsonOptions>(options => 87 + public void ConfigureServices(IConfiguration configuration) 115 88 { 116 - options.SerializerOptions.PropertyNamingPolicy = JsonSerialization.Options.PropertyNamingPolicy; 117 - foreach (var converter in JsonSerialization.Options.Converters) 118 - options.SerializerOptions.Converters.Add(converter); 119 - }); 89 + // @formatter:off 90 + services.ConfigureWithValidation<Config>(configuration) 91 + .ConfigureWithValidation<Config.InstanceSection>(configuration, "Instance") 92 + .ConfigureWithValidation<Config.SecuritySection>(configuration, "Security") 93 + .ConfigureWithValidation<Config.NetworkSection>(configuration, "Network") 94 + .ConfigureWithValidation<Config.PerformanceSection>(configuration, "Performance") 95 + .ConfigureWithValidation<Config.QueueConcurrencySection>(configuration, "Performance:QueueConcurrency") 96 + .ConfigureWithValidation<Config.BackfillSection>(configuration, "Backfill") 97 + .ConfigureWithValidation<Config.BackfillRepliesSection>(configuration, "Backfill:Replies") 98 + .ConfigureWithValidation<Config.BackfillUserSection>(configuration, "Backfill:User") 99 + .ConfigureWithValidation<Config.QueueSection>(configuration, "Queue") 100 + .ConfigureWithValidation<Config.JobRetentionSection>(configuration, "Queue:JobRetention") 101 + .ConfigureWithValidation<Config.DatabaseSection>(configuration, "Database") 102 + .ConfigureWithValidation<Config.StorageSection>(configuration, "Storage") 103 + .ConfigureWithValidation<Config.LocalStorageSection>(configuration, "Storage:Local") 104 + .ConfigureWithValidation<Config.ObjectStorageSection>(configuration, "Storage:ObjectStorage") 105 + .ConfigureWithValidation<Config.MediaProcessingSection>(configuration, "Storage:MediaProcessing") 106 + .ConfigureWithValidation<Config.ImagePipelineSection>(configuration, "Storage:MediaProcessing:ImagePipeline") 107 + .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Original:Local") 108 + .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Original:Remote") 109 + .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Thumbnail:Local") 110 + .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Thumbnail:Remote") 111 + .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Public:Local") 112 + .ConfigureWithValidation<Config.ImageFormatConfiguration>(configuration, "Storage:MediaProcessing:ImagePipeline:Public:Remote") 113 + .ConfigureWithValidation<Config.OpenTelemetrySection>(configuration, "OpenTelemetry"); 114 + // @formatter:on 120 115 121 - services.Configure<Microsoft.AspNetCore.Mvc.JsonOptions>(options => 122 - { 123 - options.JsonSerializerOptions.PropertyNamingPolicy = JsonSerialization.Options.PropertyNamingPolicy; 124 - options.JsonSerializerOptions.MaxDepth = 256; 125 - foreach (var converter in JsonSerialization.Options.Converters) 126 - options.JsonSerializerOptions.Converters.Add(converter); 127 - }); 116 + services.Configure<JsonOptions>(options => 117 + { 118 + options.SerializerOptions.PropertyNamingPolicy = JsonSerialization.Options.PropertyNamingPolicy; 119 + foreach (var converter in JsonSerialization.Options.Converters) 120 + options.SerializerOptions.Converters.Add(converter); 121 + }); 128 122 129 - services.PostConfigure<RazorComponentsServiceOptions>(BlazorSsrHandoffMiddleware.DisableBlazorJsInitializers); 130 - } 123 + services.Configure<Microsoft.AspNetCore.Mvc.JsonOptions>(options => 124 + { 125 + options.JsonSerializerOptions.PropertyNamingPolicy = JsonSerialization.Options.PropertyNamingPolicy; 126 + options.JsonSerializerOptions.MaxDepth = 256; 127 + foreach (var converter in JsonSerialization.Options.Converters) 128 + options.JsonSerializerOptions.Converters.Add(converter); 129 + }); 131 130 132 - private static IServiceCollection ConfigureWithValidation<T>( 133 - this IServiceCollection services, IConfiguration config 134 - ) where T : class 135 - { 136 - services.AddOptionsWithValidateOnStart<T>() 137 - .Bind(config) 138 - .ValidateDataAnnotations(); 131 + services.PostConfigure<RazorComponentsServiceOptions>(BlazorSsrHandoffMiddleware.DisableBlazorJsInitializers); 132 + } 139 133 140 - return services; 141 - } 134 + private IServiceCollection ConfigureWithValidation<T>( 135 + IConfiguration config 136 + ) where T : class 137 + { 138 + services.AddOptionsWithValidateOnStart<T>() 139 + .Bind(config) 140 + .ValidateDataAnnotations(); 142 141 143 - private static IServiceCollection ConfigureWithValidation<T>( 144 - this IServiceCollection services, IConfiguration config, string name 145 - ) where T : class 146 - { 147 - services.AddOptionsWithValidateOnStart<T>() 148 - .Bind(config.GetSection(name)) 149 - .ValidateDataAnnotations(); 142 + return services; 143 + } 150 144 151 - return services; 152 - } 145 + private IServiceCollection ConfigureWithValidation<T>( 146 + IConfiguration config, string name 147 + ) where T : class 148 + { 149 + services.AddOptionsWithValidateOnStart<T>() 150 + .Bind(config.GetSection(name)) 151 + .ValidateDataAnnotations(); 153 152 154 - public static void AddDatabaseContext(this IServiceCollection services, IConfiguration configuration) 155 - { 156 - var config = configuration.GetSection("Database").Get<Config.DatabaseSection>() ?? 157 - throw new Exception("Failed to initialize database: Failed to load configuration"); 153 + return services; 154 + } 158 155 159 - var dataSource = DatabaseContext.GetDataSource(config); 160 - services.AddDbContext<DatabaseContext>(options => { DatabaseContext.Configure(options, dataSource, config); }); 161 - services.AddKeyedDatabaseContext<DatabaseContext>("cache"); 162 - services.AddDataProtection() 163 - .PersistKeysToDbContextAsync<DatabaseContext>() 164 - .UseCryptographicAlgorithms(new AuthenticatedEncryptorConfiguration 165 - { 166 - EncryptionAlgorithm = EncryptionAlgorithm.AES_256_CBC, 167 - ValidationAlgorithm = ValidationAlgorithm.HMACSHA256 168 - }); 169 - } 156 + public void AddDatabaseContext(IConfiguration configuration) 157 + { 158 + var config = configuration.GetSection("Database").Get<Config.DatabaseSection>() ?? 159 + throw new Exception("Failed to initialize database: Failed to load configuration"); 170 160 171 - private static void AddKeyedDatabaseContext<T>( 172 - this IServiceCollection services, string key, ServiceLifetime contextLifetime = ServiceLifetime.Scoped 173 - ) where T : DbContext 174 - { 175 - services.TryAdd(new ServiceDescriptor(typeof(T), key, typeof(T), contextLifetime)); 176 - } 161 + var dataSource = DatabaseContext.GetDataSource(config); 162 + services.AddDbContext<DatabaseContext>(options => { DatabaseContext.Configure(options, dataSource, config); }); 163 + services.AddKeyedDatabaseContext<DatabaseContext>("cache"); 164 + services.AddDataProtection() 165 + .PersistKeysToDbContextAsync<DatabaseContext>() 166 + .UseCryptographicAlgorithms(new AuthenticatedEncryptorConfiguration 167 + { 168 + EncryptionAlgorithm = EncryptionAlgorithm.AES_256_CBC, 169 + ValidationAlgorithm = ValidationAlgorithm.HMACSHA256 170 + }); 171 + } 177 172 178 - public static void AddOpenApiWithOptions(this IServiceCollection services) 179 - { 180 - services.AddEndpointsApiExplorer(); 181 - services.AddSwaggerGen(options => 173 + private void AddKeyedDatabaseContext<T>( 174 + string key, ServiceLifetime contextLifetime = ServiceLifetime.Scoped 175 + ) where T : DbContext 182 176 { 183 - options.SupportNonNullableReferenceTypes(); 177 + services.TryAdd(new ServiceDescriptor(typeof(T), key, typeof(T), contextLifetime)); 178 + } 184 179 185 - var version = new Config.InstanceSection().Version; 186 - options.SwaggerDoc("iceshrimp", new OpenApiInfo { Title = "Iceshrimp.NET", Version = version }); 187 - options.SwaggerDoc("federation", new OpenApiInfo { Title = "Federation", Version = version }); 188 - options.SwaggerDoc("mastodon", new OpenApiInfo { Title = "Mastodon", Version = version }); 180 + public void AddOpenApiWithOptions() 181 + { 182 + services.AddEndpointsApiExplorer(); 183 + services.AddSwaggerGen(options => 184 + { 185 + options.SupportNonNullableReferenceTypes(); 189 186 190 - options.AddSecurityDefinition("iceshrimp", 191 - new OpenApiSecurityScheme 192 - { 193 - Name = "Authorization token", 194 - In = ParameterLocation.Header, 195 - Type = SecuritySchemeType.Http, 196 - Scheme = "bearer" 197 - }); 198 - options.AddSecurityDefinition("mastodon", 199 - new OpenApiSecurityScheme 200 - { 201 - Name = "Authorization token", 202 - In = ParameterLocation.Header, 203 - Type = SecuritySchemeType.Http, 204 - Scheme = "bearer" 205 - }); 187 + var version = new Config.InstanceSection().Version; 188 + options.SwaggerDoc("iceshrimp", new OpenApiInfo { Title = "Iceshrimp.NET", Version = version }); 189 + options.SwaggerDoc("federation", new OpenApiInfo { Title = "Federation", Version = version }); 190 + options.SwaggerDoc("mastodon", new OpenApiInfo { Title = "Mastodon", Version = version }); 191 + 192 + options.AddSecurityDefinition("iceshrimp", 193 + new OpenApiSecurityScheme 194 + { 195 + Name = "Authorization token", 196 + In = ParameterLocation.Header, 197 + Type = SecuritySchemeType.Http, 198 + Scheme = "bearer" 199 + }); 200 + options.AddSecurityDefinition("mastodon", 201 + new OpenApiSecurityScheme 202 + { 203 + Name = "Authorization token", 204 + In = ParameterLocation.Header, 205 + Type = SecuritySchemeType.Http, 206 + Scheme = "bearer" 207 + }); 206 208 207 - options.AddFilters(); 208 - }); 209 - } 209 + options.AddFilters(); 210 + }); 211 + } 210 212 211 - public static void AddSlidingWindowRateLimiter(this IServiceCollection services) 212 - { 213 - //TODO: rate limit status headers - maybe switch to https://github.com/stefanprodan/AspNetCoreRateLimit? 214 - //TODO: alternatively just write our own 215 - services.AddRateLimiter(options => 213 + public void AddSlidingWindowRateLimiter() 216 214 { 217 - var sliding = new SlidingWindowRateLimiterOptions 215 + //TODO: rate limit status headers - maybe switch to https://github.com/stefanprodan/AspNetCoreRateLimit? 216 + //TODO: alternatively just write our own 217 + services.AddRateLimiter(options => 218 218 { 219 - PermitLimit = 500, 220 - SegmentsPerWindow = 60, 221 - Window = TimeSpan.FromSeconds(60), 222 - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 223 - QueueLimit = 0 224 - }; 219 + var sliding = new SlidingWindowRateLimiterOptions 220 + { 221 + PermitLimit = 500, 222 + SegmentsPerWindow = 60, 223 + Window = TimeSpan.FromSeconds(60), 224 + QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 225 + QueueLimit = 0 226 + }; 225 227 226 - var auth = new SlidingWindowRateLimiterOptions 227 - { 228 - PermitLimit = 10, 229 - SegmentsPerWindow = 60, 230 - Window = TimeSpan.FromSeconds(60), 231 - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 232 - QueueLimit = 0 233 - }; 228 + var auth = new SlidingWindowRateLimiterOptions 229 + { 230 + PermitLimit = 10, 231 + SegmentsPerWindow = 60, 232 + Window = TimeSpan.FromSeconds(60), 233 + QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 234 + QueueLimit = 0 235 + }; 234 236 235 - var strict = new SlidingWindowRateLimiterOptions 236 - { 237 - PermitLimit = 3, 238 - SegmentsPerWindow = 60, 239 - Window = TimeSpan.FromSeconds(60), 240 - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 241 - QueueLimit = 0 242 - }; 237 + var strict = new SlidingWindowRateLimiterOptions 238 + { 239 + PermitLimit = 3, 240 + SegmentsPerWindow = 60, 241 + Window = TimeSpan.FromSeconds(60), 242 + QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 243 + QueueLimit = 0 244 + }; 243 245 244 - var imports = new SlidingWindowRateLimiterOptions 245 - { 246 - PermitLimit = 2, 247 - SegmentsPerWindow = 30, 248 - Window = TimeSpan.FromMinutes(30), 249 - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 250 - QueueLimit = 0 251 - }; 246 + var imports = new SlidingWindowRateLimiterOptions 247 + { 248 + PermitLimit = 2, 249 + SegmentsPerWindow = 30, 250 + Window = TimeSpan.FromMinutes(30), 251 + QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 252 + QueueLimit = 0 253 + }; 252 254 253 - var proxy = new SlidingWindowRateLimiterOptions 254 - { 255 - PermitLimit = 10, 256 - SegmentsPerWindow = 10, 257 - Window = TimeSpan.FromSeconds(10), 258 - QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 259 - QueueLimit = 0 260 - }; 255 + var proxy = new SlidingWindowRateLimiterOptions 256 + { 257 + PermitLimit = 10, 258 + SegmentsPerWindow = 10, 259 + Window = TimeSpan.FromSeconds(10), 260 + QueueProcessingOrder = QueueProcessingOrder.OldestFirst, 261 + QueueLimit = 0 262 + }; 261 263 262 - // @formatter:off 263 - options.AddPolicy("sliding", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(false),_ => sliding)); 264 - options.AddPolicy("auth", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(false), _ => auth)); 265 - options.AddPolicy("strict", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(true), _ => strict)); 266 - options.AddPolicy("imports", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(true), _ => imports)); 267 - options.AddPolicy("proxy", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(true), _ => proxy)); 268 - // @formatter:on 264 + // @formatter:off 265 + options.AddPolicy("sliding", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(false),_ => sliding)); 266 + options.AddPolicy("auth", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(false), _ => auth)); 267 + options.AddPolicy("strict", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(true), _ => strict)); 268 + options.AddPolicy("imports", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(true), _ => imports)); 269 + options.AddPolicy("proxy", ctx => RateLimitPartition.GetSlidingWindowLimiter(ctx.GetRateLimitPartition(true), _ => proxy)); 270 + // @formatter:on 269 271 270 - options.OnRejected = async (context, token) => 271 - { 272 - context.HttpContext.Response.StatusCode = 429; 273 - context.HttpContext.Response.ContentType = "application/json"; 274 - var res = new ErrorResponse(new Exception()) 272 + options.OnRejected = async (context, token) => 275 273 { 276 - Error = "Too Many Requests", 277 - StatusCode = 429, 278 - RequestId = context.HttpContext.TraceIdentifier 274 + context.HttpContext.Response.StatusCode = 429; 275 + context.HttpContext.Response.ContentType = "application/json"; 276 + var res = new ErrorResponse(new Exception()) 277 + { 278 + Error = "Too Many Requests", 279 + StatusCode = 429, 280 + RequestId = context.HttpContext.TraceIdentifier 281 + }; 282 + await context.HttpContext.Response.WriteAsJsonAsync(res, token); 279 283 }; 280 - await context.HttpContext.Response.WriteAsJsonAsync(res, token); 281 - }; 282 - }); 283 - } 284 + }); 285 + } 284 286 285 - public static void AddCorsPolicies(this IServiceCollection services) 286 - { 287 - services.AddCors(options => 287 + public void AddCorsPolicies() 288 288 { 289 - options.AddPolicy("well-known", policy => 289 + services.AddCors(options => 290 290 { 291 - policy.WithOrigins("*") 292 - .WithMethods("GET") 293 - .WithHeaders("Accept") 294 - .WithExposedHeaders("Vary"); 295 - }); 296 - options.AddPolicy("drive", policy => 297 - { 298 - policy.WithOrigins("*") 299 - .WithMethods("GET", "HEAD"); 300 - }); 301 - options.AddPolicy("mastodon", policy => 302 - { 303 - policy.WithOrigins("*") 304 - .WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT") 305 - .WithHeaders("Authorization", "Content-Type", "Idempotency-Key") 306 - .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); 307 - }); 308 - options.AddPolicy("fallback", policy => 309 - { 310 - policy.WithOrigins("*") 311 - .WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT") 312 - .WithHeaders("Authorization", "Content-Type", "Idempotency-Key") 313 - .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); 291 + options.AddPolicy("well-known", policy => 292 + { 293 + policy.WithOrigins("*") 294 + .WithMethods("GET") 295 + .WithHeaders("Accept") 296 + .WithExposedHeaders("Vary"); 297 + }); 298 + options.AddPolicy("drive", policy => 299 + { 300 + policy.WithOrigins("*") 301 + .WithMethods("GET", "HEAD"); 302 + }); 303 + options.AddPolicy("mastodon", policy => 304 + { 305 + policy.WithOrigins("*") 306 + .WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT") 307 + .WithHeaders("Authorization", "Content-Type", "Idempotency-Key") 308 + .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); 309 + }); 310 + options.AddPolicy("fallback", policy => 311 + { 312 + policy.WithOrigins("*") 313 + .WithMethods("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "CONNECT") 314 + .WithHeaders("Authorization", "Content-Type", "Idempotency-Key") 315 + .WithExposedHeaders("Link", "Connection", "Sec-Websocket-Accept", "Upgrade"); 316 + }); 314 317 }); 315 - }); 316 - } 318 + } 317 319 318 - public static void AddAuthorizationPolicies(this IServiceCollection services) 319 - { 320 - services.AddAuthorizationBuilder() 321 - .AddPolicy("HubAuthorization", policy => 322 - { 323 - policy.Requirements.Add(new HubAuthorizationRequirement()); 324 - policy.AuthenticationSchemes = ["HubAuthenticationScheme"]; 325 - }); 320 + public void AddAuthorizationPolicies() 321 + { 322 + services.AddAuthorizationBuilder() 323 + .AddPolicy("HubAuthorization", policy => 324 + { 325 + policy.Requirements.Add(new HubAuthorizationRequirement()); 326 + policy.AuthenticationSchemes = ["HubAuthenticationScheme"]; 327 + }); 326 328 327 - services.AddAuthentication(options => 328 - { 329 - options.AddScheme<HubAuthenticationHandler>("HubAuthenticationScheme", null); 329 + services.AddAuthentication(options => 330 + { 331 + options.AddScheme<HubAuthenticationHandler>("HubAuthenticationScheme", null); 330 332 331 - // Add a stub authentication handler to bypass strange ASP.NET Core >=7.0 defaults 332 - // Ref: https://github.com/dotnet/aspnetcore/issues/44661 333 - options.AddScheme<IAuthenticationHandler>("StubAuthenticationHandler", null); 334 - }); 335 - } 333 + // Add a stub authentication handler to bypass strange ASP.NET Core >=7.0 defaults 334 + // Ref: https://github.com/dotnet/aspnetcore/issues/44661 335 + options.AddScheme<IAuthenticationHandler>("StubAuthenticationHandler", null); 336 + }); 337 + } 336 338 337 - public static void AddOutputCacheWithOptions(this IServiceCollection services) 338 - { 339 - services.AddOutputCache(options => 339 + public void AddOutputCacheWithOptions() 340 340 { 341 - options.AddPolicy("conditional", o => o.With(ctx => ctx.HttpContext.ShouldCacheOutput())); 342 - options.AddPolicy("federation", o => o.SetVaryByHeader("Accept").Expire(TimeSpan.FromSeconds(60))); 343 - options.DefaultExpirationTimeSpan = TimeSpan.FromDays(365); 344 - }); 341 + services.AddOutputCache(options => 342 + { 343 + options.AddPolicy("conditional", o => o.With(ctx => ctx.HttpContext.ShouldCacheOutput())); 344 + options.AddPolicy("federation", o => o.SetVaryByHeader("Accept").Expire(TimeSpan.FromSeconds(60))); 345 + options.DefaultExpirationTimeSpan = TimeSpan.FromDays(365); 346 + }); 347 + } 345 348 } 346 349 } 347 350 ··· 349 352 { 350 353 private const string CacheKey = "shouldCache"; 351 354 352 - public static string GetRateLimitPartition(this HttpContext ctx, bool includeRoute) => 353 - (includeRoute ? ctx.Request.Path.ToString() + "#" : "") + (GetRateLimitPartitionInternal(ctx) ?? ""); 355 + extension(HttpContext ctx) 356 + { 357 + public string GetRateLimitPartition(bool includeRoute) => 358 + (includeRoute ? ctx.Request.Path.ToString() + "#" : "") + (GetRateLimitPartitionInternal(ctx) ?? ""); 354 359 355 - private static string? GetRateLimitPartitionInternal(this HttpContext ctx) => 356 - ctx.GetUser()?.Id ?? 357 - ctx.Request.Headers["X-Forwarded-For"].FirstOrDefault() ?? 358 - ctx.Connection.RemoteIpAddress?.ToString(); 360 + private string? GetRateLimitPartitionInternal() => 361 + ctx.GetUser()?.Id ?? 362 + ctx.Request.Headers["X-Forwarded-For"].FirstOrDefault() ?? 363 + ctx.Connection.RemoteIpAddress?.ToString(); 359 364 360 - public static void CacheOutput(this HttpContext ctx) => ctx.Items[CacheKey] = true; 365 + public void CacheOutput() => ctx.Items[CacheKey] = true; 361 366 362 - public static bool ShouldCacheOutput(this HttpContext ctx) => 363 - ctx.Items.TryGetValue(CacheKey, out var s) && s is true; 367 + public bool ShouldCacheOutput() => 368 + ctx.Items.TryGetValue(CacheKey, out var s) && s is true; 369 + } 364 370 } 365 371 366 372 public interface IService
+46 -44
Iceshrimp.Backend/Core/Extensions/StreamExtensions.cs
··· 4 4 5 5 public static class StreamExtensions 6 6 { 7 - public static async Task CopyToAsync( 8 - this Stream source, Stream destination, long? maxLength, CancellationToken cancellationToken 9 - ) 7 + extension(Stream source) 10 8 { 11 - var buffer = ArrayPool<byte>.Shared.Rent(81920); 12 - try 9 + public async Task CopyToAsync( 10 + Stream destination, long? maxLength, CancellationToken cancellationToken 11 + ) 13 12 { 14 - int bytesRead; 15 - var totalBytesRead = 0L; 16 - while ((maxLength == null || totalBytesRead <= maxLength) && (bytesRead = await DoReadAsync()) != 0) 13 + var buffer = ArrayPool<byte>.Shared.Rent(81920); 14 + try 15 + { 16 + int bytesRead; 17 + var totalBytesRead = 0L; 18 + while ((maxLength == null || totalBytesRead <= maxLength) && (bytesRead = await DoReadAsync()) != 0) 19 + { 20 + totalBytesRead += bytesRead; 21 + await destination.WriteAsync(new ReadOnlyMemory<byte>(buffer, 0, bytesRead), cancellationToken); 22 + } 23 + } 24 + finally 17 25 { 18 - totalBytesRead += bytesRead; 19 - await destination.WriteAsync(new ReadOnlyMemory<byte>(buffer, 0, bytesRead), cancellationToken); 26 + ArrayPool<byte>.Shared.Return(buffer); 20 27 } 21 - } 22 - finally 23 - { 24 - ArrayPool<byte>.Shared.Return(buffer); 28 + 29 + return; 30 + 31 + ValueTask<int> DoReadAsync() => source.ReadAsync(new Memory<byte>(buffer), cancellationToken); 25 32 } 26 33 27 - return; 34 + /// <summary> 35 + /// We can't trust the Content-Length header, and it might be null. 36 + /// This makes sure that we only ever read up to maxLength into memory. 37 + /// </summary> 38 + /// <param name="maxLength">The maximum length to buffer (null = unlimited)</param> 39 + /// <param name="contentLength">The content length, if known</param> 40 + /// <param name="token">A CancellationToken, if applicable</param> 41 + /// <returns>Either a buffered MemoryStream, or Stream.Null</returns> 42 + public async Task<Stream> GetSafeStreamOrNullAsync( 43 + long? maxLength, long? contentLength, CancellationToken token = default 44 + ) 45 + { 46 + if (maxLength is 0) return Stream.Null; 47 + if (contentLength > maxLength) return Stream.Null; 28 48 29 - ValueTask<int> DoReadAsync() => source.ReadAsync(new Memory<byte>(buffer), cancellationToken); 30 - } 31 - 32 - /// <summary> 33 - /// We can't trust the Content-Length header, and it might be null. 34 - /// This makes sure that we only ever read up to maxLength into memory. 35 - /// </summary> 36 - /// <param name="stream">The response content stream</param> 37 - /// <param name="maxLength">The maximum length to buffer (null = unlimited)</param> 38 - /// <param name="contentLength">The content length, if known</param> 39 - /// <param name="token">A CancellationToken, if applicable</param> 40 - /// <returns>Either a buffered MemoryStream, or Stream.Null</returns> 41 - public static async Task<Stream> GetSafeStreamOrNullAsync( 42 - this Stream stream, long? maxLength, long? contentLength, CancellationToken token = default 43 - ) 44 - { 45 - if (maxLength is 0) return Stream.Null; 46 - if (contentLength > maxLength) return Stream.Null; 49 + MemoryStream buf = new(); 50 + if (contentLength < maxLength) 51 + maxLength = contentLength.Value; 47 52 48 - MemoryStream buf = new(); 49 - if (contentLength < maxLength) 50 - maxLength = contentLength.Value; 53 + await source.CopyToAsync(buf, maxLength, token); 54 + if (maxLength == null || buf.Length <= maxLength) 55 + { 56 + buf.Seek(0, SeekOrigin.Begin); 57 + return buf; 58 + } 51 59 52 - await stream.CopyToAsync(buf, maxLength, token); 53 - if (maxLength == null || buf.Length <= maxLength) 54 - { 55 - buf.Seek(0, SeekOrigin.Begin); 56 - return buf; 60 + await buf.DisposeAsync(); 61 + return Stream.Null; 57 62 } 58 - 59 - await buf.DisposeAsync(); 60 - return Stream.Null; 61 63 } 62 64 }
+49 -39
Iceshrimp.Backend/Core/Extensions/StringExtensions.cs
··· 10 10 { 11 11 private static readonly IdnMapping IdnMapping = new(); 12 12 13 - public static bool EqualsInvariant(this string? s1, string? s2) => 14 - string.Equals(s1, s2, StringComparison.InvariantCulture); 15 - 16 - public static bool EqualsIgnoreCase(this string? s1, string s2) => 17 - string.Equals(s1, s2, StringComparison.InvariantCultureIgnoreCase); 18 - 19 - public static string Truncate(this string target, int maxLength) 20 - { 21 - return target[..Math.Min(target.Length, maxLength)]; 22 - } 23 - 24 - public static string TruncateEllipsis(this string target, int maxLength) 13 + extension(string? s1) 25 14 { 26 - if (target.Length <= maxLength) return target; 27 - return target[..(maxLength-3)] + "..."; 28 - } 15 + public bool EqualsInvariant(string? s2) => 16 + string.Equals(s1, s2, StringComparison.InvariantCulture); 29 17 30 - private static string ToPunycode(this string target) 31 - { 32 - return target.Length > 0 ? IdnMapping.GetAscii(target) : target; 18 + public bool EqualsIgnoreCase(string s2) => 19 + string.Equals(s1, s2, StringComparison.InvariantCultureIgnoreCase); 33 20 } 34 21 35 - public static string ToPunycodeLower(this string target) 22 + extension(string target) 36 23 { 37 - return ToPunycode(target).ToLowerInvariant(); 38 - } 24 + public string Truncate(int maxLength) 25 + { 26 + return target[..Math.Min(target.Length, maxLength)]; 27 + } 39 28 40 - public static string FromPunycode(this string target) 41 - { 42 - return IdnMapping.GetUnicode(target); 43 - } 29 + public string TruncateEllipsis(int maxLength) 30 + { 31 + if (target.Length <= maxLength) return target; 32 + return target[..(maxLength-3)] + "..."; 33 + } 44 34 45 - public static string ToTitleCase(this string input) => input switch 46 - { 47 - null => throw new ArgumentNullException(nameof(input)), 48 - "" => throw new ArgumentException(@$"{nameof(input)} cannot be empty", nameof(input)), 49 - _ => string.Concat(input[0].ToString().ToUpper(), input.AsSpan(1)) 50 - }; 35 + private string ToPunycode() 36 + { 37 + return target.Length > 0 ? IdnMapping.GetAscii(target) : target; 38 + } 51 39 52 - public static string UrlEncode(this string input) => UrlEncoder.Default.Encode(input); 40 + public string ToPunycodeLower() 41 + { 42 + return ToPunycode(target).ToLowerInvariant(); 43 + } 44 + 45 + public string FromPunycode() 46 + { 47 + return IdnMapping.GetUnicode(target); 48 + } 49 + 50 + public string ToTitleCase() => target switch 51 + { 52 + null => throw new ArgumentNullException(nameof(target)), 53 + "" => throw new ArgumentException(@$"{nameof(target)} cannot be empty", nameof(target)), 54 + _ => string.Concat(target[0].ToString().ToUpper(), target.AsSpan(1)) 55 + }; 56 + 57 + public string UrlEncode() => UrlEncoder.Default.Encode(target); 58 + } 53 59 } 54 60 55 - [SuppressMessage("ReSharper", "StringCompareToIsCultureSpecific")] 61 + [SuppressMessage("ReSharper", "StringCompareToIsCultureSpecific", Justification = "SQL")] 62 + [SuppressMessage("ReSharper", "ConvertToExtensionBlock", Justification = "Projectables")] 56 63 public static class ProjectableStringExtensions 57 64 { 58 65 [Projectable] ··· 72 79 { 73 80 private const char NewLineLf = '\n'; 74 81 75 - /// <summary> 76 - /// Equivalent to .AppendLine, but always uses \n instead of Environment.NewLine 77 - /// </summary> 78 - public static StringBuilder AppendLineLf(this StringBuilder sb, string? value) 82 + extension(StringBuilder sb) 79 83 { 80 - sb.Append(value); 81 - return sb.Append(NewLineLf); 84 + /// <summary> 85 + /// Equivalent to .AppendLine, but always uses \n instead of Environment.NewLine 86 + /// </summary> 87 + public StringBuilder AppendLineLf(string? value) 88 + { 89 + sb.Append(value); 90 + return sb.Append(NewLineLf); 91 + } 82 92 } 83 93 }
+100 -79
Iceshrimp.Backend/Core/Extensions/TaskExtensions.cs
··· 5 5 [SuppressMessage("ReSharper", "InconsistentNaming")] 6 6 public static class TaskExtensions 7 7 { 8 - public static async Task SafeWaitAsync(this Task task, TimeSpan timeSpan) 8 + extension(Task task) 9 9 { 10 - try 10 + public async Task SafeWaitAsync(TimeSpan timeSpan) 11 11 { 12 - await task.WaitAsync(timeSpan); 12 + try 13 + { 14 + await task.WaitAsync(timeSpan); 15 + } 16 + catch (TimeoutException) 17 + { 18 + // ignored 19 + } 13 20 } 14 - catch (TimeoutException) 21 + 22 + public async Task SafeWaitAsync(CancellationToken token) 15 23 { 16 - // ignored 24 + try 25 + { 26 + await task.WaitAsync(token); 27 + } 28 + catch (TaskCanceledException) 29 + { 30 + // ignored 31 + } 17 32 } 18 - } 19 33 20 - public static async Task SafeWaitAsync(this Task task, CancellationToken token) 21 - { 22 - try 34 + public async Task SafeWaitAsync(CancellationToken token, Action action) 23 35 { 24 - await task.WaitAsync(token); 36 + try 37 + { 38 + await task.WaitAsync(token); 39 + } 40 + catch (TaskCanceledException) 41 + { 42 + action(); 43 + } 25 44 } 26 - catch (TaskCanceledException) 45 + 46 + public async Task SafeWaitAsync(CancellationToken token, Func<Task> action) 27 47 { 28 - // ignored 48 + try 49 + { 50 + await task.WaitAsync(token); 51 + } 52 + catch (TaskCanceledException) 53 + { 54 + await action(); 55 + } 29 56 } 30 57 } 31 58 32 - public static async Task SafeWaitAsync(this Task task, CancellationToken token, Action action) 59 + extension(Func<Task> factory) 33 60 { 34 - try 35 - { 36 - await task.WaitAsync(token); 37 - } 38 - catch (TaskCanceledException) 39 - { 40 - action(); 41 - } 61 + public List<Task> QueueMany(int n) => 62 + Enumerable.Range(0, n).Select(_ => factory()).ToList(); 42 63 } 43 64 44 - public static async Task SafeWaitAsync(this Task task, CancellationToken token, Func<Task> action) 65 + extension<T>(Task<IEnumerable<T>> task) 45 66 { 46 - try 67 + public async Task<List<T>> ToListAsync() 47 68 { 48 - await task.WaitAsync(token); 69 + return (await task).ToList(); 49 70 } 50 - catch (TaskCanceledException) 71 + 72 + public async Task<T[]> ToArrayAsync() 51 73 { 52 - await action(); 74 + return (await task).ToArray(); 53 75 } 54 - } 55 76 56 - public static List<Task> QueueMany(this Func<Task> factory, int n) => 57 - Enumerable.Range(0, n).Select(_ => factory()).ToList(); 58 - 59 - public static async Task<List<T>> ToListAsync<T>(this Task<IEnumerable<T>> task) 60 - { 61 - return (await task).ToList(); 77 + public async Task<T?> FirstOrDefaultAsync() 78 + { 79 + return (await task).FirstOrDefault(); 80 + } 62 81 } 63 82 64 - public static async Task<T[]> ToArrayAsync<T>(this Task<IEnumerable<T>> task) 83 + extension(Task task) 65 84 { 66 - return (await task).ToArray(); 67 - } 85 + public async Task ContinueWithResult(Action continuation) 86 + { 87 + await task; 88 + continuation(); 89 + } 68 90 69 - public static async Task<T?> FirstOrDefaultAsync<T>(this Task<IEnumerable<T>> task) 70 - { 71 - return (await task).FirstOrDefault(); 91 + public async Task<TNewResult> ContinueWithResult<TNewResult>( 92 + Func<TNewResult> continuation 93 + ) 94 + { 95 + await task; 96 + return continuation(); 97 + } 72 98 } 73 99 74 - public static async Task ContinueWithResult(this Task task, Action continuation) 100 + extension<TResult>(Task<TResult> task) 75 101 { 76 - await task; 77 - continuation(); 78 - } 102 + public async Task ContinueWithResult(Action<TResult> continuation) 103 + { 104 + continuation(await task); 105 + } 79 106 80 - public static async Task<TNewResult> ContinueWithResult<TNewResult>( 81 - this Task task, Func<TNewResult> continuation 82 - ) 83 - { 84 - await task; 85 - return continuation(); 107 + public async Task<TNewResult> ContinueWithResult<TNewResult>( 108 + Func<TResult, TNewResult> continuation 109 + ) 110 + { 111 + return continuation(await task); 112 + } 86 113 } 87 114 88 - public static async Task ContinueWithResult<TResult>(this Task<TResult> task, Action<TResult> continuation) 115 + extension(Task task) 89 116 { 90 - continuation(await task); 91 - } 117 + public async Task ContinueWithResult(Func<Task> continuation) 118 + { 119 + await task; 120 + await continuation(); 121 + } 92 122 93 - public static async Task<TNewResult> ContinueWithResult<TResult, TNewResult>( 94 - this Task<TResult> task, Func<TResult, TNewResult> continuation 95 - ) 96 - { 97 - return continuation(await task); 98 - } 99 - 100 - public static async Task ContinueWithResult(this Task task, Func<Task> continuation) 101 - { 102 - await task; 103 - await continuation(); 104 - } 105 - 106 - public static async Task<TNewResult> ContinueWithResult<TNewResult>( 107 - this Task task, Func<Task<TNewResult>> continuation 108 - ) 109 - { 110 - await task; 111 - return await continuation(); 123 + public async Task<TNewResult> ContinueWithResult<TNewResult>( 124 + Func<Task<TNewResult>> continuation 125 + ) 126 + { 127 + await task; 128 + return await continuation(); 129 + } 112 130 } 113 131 114 - public static async Task ContinueWithResult<TResult>(this Task<TResult> task, Func<TResult, Task> continuation) 132 + extension<TResult>(Task<TResult> task) 115 133 { 116 - await continuation(await task); 117 - } 134 + public async Task ContinueWithResult(Func<TResult, Task> continuation) 135 + { 136 + await continuation(await task); 137 + } 118 138 119 - public static async Task<TNewResult> ContinueWithResult<TResult, TNewResult>( 120 - this Task<TResult> task, Func<TResult, Task<TNewResult>> continuation 121 - ) 122 - { 123 - return await continuation(await task); 139 + public async Task<TNewResult> ContinueWithResult<TNewResult>( 140 + Func<TResult, Task<TNewResult>> continuation 141 + ) 142 + { 143 + return await continuation(await task); 144 + } 124 145 } 125 146 }
+22 -19
Iceshrimp.Backend/Core/Extensions/TimeSpanExtensions.cs
··· 6 6 private static readonly long Minutes = TimeSpan.FromHours(1).Ticks; 7 7 private static readonly long Hours = TimeSpan.FromDays(1).Ticks; 8 8 9 - public static long GetTotalMilliseconds(this TimeSpan timeSpan) => Convert.ToInt64(timeSpan.TotalMilliseconds); 9 + extension(TimeSpan timeSpan) 10 + { 11 + public long GetTotalMilliseconds() => Convert.ToInt64(timeSpan.TotalMilliseconds); 10 12 11 - public static string ToDisplayString(this TimeSpan timeSpan, bool singleNumber = true) 12 - { 13 - if (timeSpan.Ticks < Seconds) 13 + public string ToDisplayString(bool singleNumber = true) 14 14 { 15 - var seconds = (int)timeSpan.TotalSeconds; 16 - return seconds == 1 ? singleNumber ? "1 second" : "second" : $"{seconds} seconds"; 17 - } 15 + if (timeSpan.Ticks < Seconds) 16 + { 17 + var seconds = (int)timeSpan.TotalSeconds; 18 + return seconds == 1 ? singleNumber ? "1 second" : "second" : $"{seconds} seconds"; 19 + } 20 + 21 + if (timeSpan.Ticks < Minutes) 22 + { 23 + var minutes = (int)timeSpan.TotalMinutes; 24 + return minutes == 1 ? singleNumber ? "1 minute" : "minute" : $"{minutes} minutes"; 25 + } 18 26 19 - if (timeSpan.Ticks < Minutes) 20 - { 21 - var minutes = (int)timeSpan.TotalMinutes; 22 - return minutes == 1 ? singleNumber ? "1 minute" : "minute" : $"{minutes} minutes"; 23 - } 27 + if (timeSpan.Ticks < Hours) 28 + { 29 + var hours = (int)timeSpan.TotalHours; 30 + return hours == 1 ? singleNumber ? "1 hour" : "hour" : $"{hours} hours"; 31 + } 24 32 25 - if (timeSpan.Ticks < Hours) 26 - { 27 - var hours = (int)timeSpan.TotalHours; 28 - return hours == 1 ? singleNumber ? "1 hour" : "hour" : $"{hours} hours"; 33 + var days = (int)timeSpan.TotalDays; 34 + return days == 1 ? singleNumber ? "1 day" : "day" : $"{days} days"; 29 35 } 30 - 31 - var days = (int)timeSpan.TotalDays; 32 - return days == 1 ? singleNumber ? "1 day" : "day" : $"{days} days"; 33 36 } 34 37 }
+343 -336
Iceshrimp.Backend/Core/Extensions/WebApplicationExtensions.cs
··· 20 20 21 21 public static class WebApplicationExtensions 22 22 { 23 - public static IApplicationBuilder UseCustomMiddleware(this IApplicationBuilder app) 23 + extension(IApplicationBuilder app) 24 24 { 25 - // Caution: make sure these are in the correct order 26 - return app.UseMiddleware<RequestDurationMiddleware>() 27 - .UseMiddleware<ErrorHandlerMiddleware>() 28 - .UseMiddleware<RequestVerificationMiddleware>() 29 - .UseMiddleware<RequestBufferingMiddleware>() 30 - .UseMiddleware<AuthenticationMiddleware>() 31 - .UseRateLimiter() 32 - .UseMiddleware<AuthorizationMiddleware>() 33 - .UseMiddleware<FederationSemaphoreMiddleware>() 34 - .UseMiddleware<AuthorizedFetchMiddleware>() 35 - .UseMiddleware<InboxValidationMiddleware>() 36 - .UseOutputCache() 37 - .UseMiddleware<BlazorSsrHandoffMiddleware>(); 38 - } 39 - 40 - // Prevents conditional middleware from being invoked on non-matching requests 41 - private static IApplicationBuilder UseMiddleware<T>(this IApplicationBuilder app) where T : IConditionalMiddleware 42 - => app.UseWhen(T.Predicate, builder => UseMiddlewareExtensions.UseMiddleware<T>(builder)); 43 - 44 - public static IApplicationBuilder UseOpenApiWithOptions(this WebApplication app) 45 - { 46 - app.MapSwagger("/openapi/{documentName}.{extension:regex(^(json|ya?ml)$)}", 47 - o => o.OpenApiVersion = OpenApiSpecVersion.OpenApi3_1) 48 - .CacheOutput(p => p.Expire(TimeSpan.FromHours(12))); 49 - 50 - app.UseSwaggerUI(options => 25 + public IApplicationBuilder UseCustomMiddleware() 51 26 { 52 - options.DocumentTitle = "Iceshrimp API documentation"; 53 - options.SwaggerEndpoint("/openapi/iceshrimp.json", "Iceshrimp.NET"); 54 - options.SwaggerEndpoint("/openapi/federation.json", "Federation"); 55 - options.SwaggerEndpoint("/openapi/mastodon.json", "Mastodon"); 56 - options.InjectStylesheet("/css/swagger.css"); 57 - options.EnablePersistAuthorization(); 58 - options.EnableTryItOutByDefault(); 59 - options.DisplayRequestDuration(); 60 - options.DefaultModelsExpandDepth(-1); // Hide "Schemas" section 61 - options.ConfigObject.AdditionalItems.Add("tagsSorter", "alpha"); // Sort tags alphabetically 62 - }); 27 + // Caution: make sure these are in the correct order 28 + return app.UseMiddleware<RequestDurationMiddleware>() 29 + .UseMiddleware<ErrorHandlerMiddleware>() 30 + .UseMiddleware<RequestVerificationMiddleware>() 31 + .UseMiddleware<RequestBufferingMiddleware>() 32 + .UseMiddleware<AuthenticationMiddleware>() 33 + .UseRateLimiter() 34 + .UseMiddleware<AuthorizationMiddleware>() 35 + .UseMiddleware<FederationSemaphoreMiddleware>() 36 + .UseMiddleware<AuthorizedFetchMiddleware>() 37 + .UseMiddleware<InboxValidationMiddleware>() 38 + .UseOutputCache() 39 + .UseMiddleware<BlazorSsrHandoffMiddleware>(); 40 + } 63 41 64 - app.MapScalarApiReference("/scalar", options => 65 - { 66 - options.WithTitle("Iceshrimp API documentation") 67 - .AddDocument("iceshrimp", "Iceshrimp.NET") 68 - .AddDocument("federation", "Federation") 69 - .AddDocument("mastodon", "Mastodon") 70 - .WithOpenApiRoutePattern("/openapi/{documentName}.json") 71 - .HideModels() 72 - .EnablePersistentAuthentication() 73 - .WithCustomCss(""" 74 - .open-api-client-button, .darklight-reference > .flex > .text-sm { display: none !important; } 75 - .darklight-reference > .flex > button > div:nth-child(1) { height: 14px !important; } 76 - .darklight-reference { padding: 15px 14px !important; } 77 - """); 78 - }); 79 - 80 - return app; 42 + // Prevents conditional middleware from being invoked on non-matching requests 43 + private IApplicationBuilder UseMiddleware<T>() where T : IConditionalMiddleware 44 + => app.UseWhen(T.Predicate, builder => UseMiddlewareExtensions.UseMiddleware<T>(builder)); 81 45 } 82 46 83 - public static void MapFrontendRoutes(this WebApplication app, string page) 84 - { 85 - app.MapFallbackToPage(page).WithOrder(int.MaxValue - 2); 86 - app.MapFallbackToPage("/@{user}", page).WithOrder(int.MaxValue - 1); 87 - app.MapFallbackToPage("/@{user}@{host}", page); 88 - } 89 47 90 - public static async Task<Config.InstanceSection> InitializeAsync(this WebApplication app, string[] args) 48 + extension(WebApplication app) 91 49 { 92 - var instanceConfig = app.Configuration.GetSection("Instance").Get<Config.InstanceSection>() ?? 93 - throw new Exception("Failed to read Instance config section"); 94 - 95 - app.Logger.LogInformation("Iceshrimp.NET v{version}, codename \"{codename}\" ({domain})", 96 - instanceConfig.Version, instanceConfig.Codename, instanceConfig.AccountDomain); 50 + public IApplicationBuilder UseOpenApiWithOptions() 51 + { 52 + app.MapSwagger("/openapi/{documentName}.{extension:regex(^(json|ya?ml)$)}", 53 + o => o.OpenApiVersion = OpenApiSpecVersion.OpenApi3_1) 54 + .CacheOutput(p => p.Expire(TimeSpan.FromHours(12))); 97 55 98 - await using var scope = app.Services.CreateAsyncScope(); 99 - var provider = scope.ServiceProvider; 56 + app.UseSwaggerUI(options => 57 + { 58 + options.DocumentTitle = "Iceshrimp API documentation"; 59 + options.SwaggerEndpoint("/openapi/iceshrimp.json", "Iceshrimp.NET"); 60 + options.SwaggerEndpoint("/openapi/federation.json", "Federation"); 61 + options.SwaggerEndpoint("/openapi/mastodon.json", "Mastodon"); 62 + options.InjectStylesheet("/css/swagger.css"); 63 + options.EnablePersistAuthorization(); 64 + options.EnableTryItOutByDefault(); 65 + options.DisplayRequestDuration(); 66 + options.DefaultModelsExpandDepth(-1); // Hide "Schemas" section 67 + options.ConfigObject.AdditionalItems.Add("tagsSorter", "alpha"); // Sort tags alphabetically 68 + }); 100 69 101 - var config = (ConfigurationManager)app.Configuration; 102 - var files = config.Sources.OfType<IniConfigurationSource>().Select(p => p.Path); 103 - app.Logger.LogDebug("Loaded configuration files: \n* {files}", string.Join("\n* ", files)); 70 + app.MapScalarApiReference("/scalar", options => 71 + { 72 + options.WithTitle("Iceshrimp API documentation") 73 + .AddDocument("iceshrimp", "Iceshrimp.NET") 74 + .AddDocument("federation", "Federation") 75 + .AddDocument("mastodon", "Mastodon") 76 + .WithOpenApiRoutePattern("/openapi/{documentName}.json") 77 + .HideModels() 78 + .EnablePersistentAuthentication() 79 + .WithCustomCss(""" 80 + .open-api-client-button, .darklight-reference > .flex > .text-sm { display: none !important; } 81 + .darklight-reference > .flex > button > div:nth-child(1) { height: 14px !important; } 82 + .darklight-reference { padding: 15px 14px !important; } 83 + """); 84 + }); 104 85 105 - try 106 - { 107 - app.Logger.LogInformation("Validating configuration..."); 108 - provider.GetRequiredService<IStartupValidator>().Validate(); 109 - } 110 - catch (OptionsValidationException e) 111 - { 112 - app.Logger.LogCritical("Failed to validate configuration: {error}", e.Message); 113 - Environment.Exit(1); 86 + return app; 114 87 } 115 88 116 - if (app.Environment.IsDevelopment()) 89 + public void MapFrontendRoutes(string page) 117 90 { 118 - app.Logger.LogWarning("The hosting environment is set to Development."); 119 - app.Logger.LogWarning("This application will not validate the Host header for incoming requests."); 120 - app.Logger.LogWarning("If this is not a local development instance, please set the environment to Production."); 91 + app.MapFallbackToPage(page).WithOrder(int.MaxValue - 2); 92 + app.MapFallbackToPage("/@{user}", page).WithOrder(int.MaxValue - 1); 93 + app.MapFallbackToPage("/@{user}@{host}", page); 121 94 } 122 95 123 - await using var db = provider.GetService<DatabaseContext>(); 124 - if (db == null) 96 + public async Task<Config.InstanceSection> InitializeAsync(string[] args) 125 97 { 126 - app.Logger.LogCritical("Failed to initialize database context"); 127 - Environment.Exit(1); 128 - } 98 + var instanceConfig = app.Configuration.GetSection("Instance").Get<Config.InstanceSection>() ?? 99 + throw new Exception("Failed to read Instance config section"); 129 100 130 - app.Logger.LogInformation("Verifying database connection..."); 101 + app.Logger.LogInformation("Iceshrimp.NET v{version}, codename \"{codename}\" ({domain})", 102 + instanceConfig.Version, instanceConfig.Codename, instanceConfig.AccountDomain); 103 + 104 + await using var scope = app.Services.CreateAsyncScope(); 105 + var provider = scope.ServiceProvider; 106 + 107 + var config = (ConfigurationManager)app.Configuration; 108 + var files = config.Sources.OfType<IniConfigurationSource>().Select(p => p.Path); 109 + app.Logger.LogDebug("Loaded configuration files: \n* {files}", string.Join("\n* ", files)); 131 110 132 - if (!await db.Database.CanConnectAsync()) 133 - { 134 111 try 135 112 { 136 - await db.Database.OpenConnectionAsync(); 137 - await db.Database.CloseConnectionAsync(); 113 + app.Logger.LogInformation("Validating configuration..."); 114 + provider.GetRequiredService<IStartupValidator>().Validate(); 138 115 } 139 - catch (Exception e) 116 + catch (OptionsValidationException e) 140 117 { 141 - app.Logger.LogCritical("Failed to connect to database. Please make sure your configuration is correct."); 142 - app.Logger.LogError("Additional information: {e}", e); 118 + app.Logger.LogCritical("Failed to validate configuration: {error}", e.Message); 143 119 Environment.Exit(1); 144 120 } 145 121 146 - app.Logger.LogCritical("Failed to connect to database. Please make sure your configuration is correct."); 147 - Environment.Exit(1); 148 - } 149 - 150 - var unknownMigrations = (await db.Database.GetAppliedMigrationsAsync()) 151 - .Except(db.Database.GetMigrations()) 152 - .ToList(); 153 - 154 - if (unknownMigrations.Count > 0) 155 - { 156 - app.Logger.LogCritical("Database has {Count} unknown migrations applied, refusing to continue startup.", unknownMigrations.Count); 157 - app.Logger.LogCritical("If you tried to downgrade, make sure you know what you are doing and revert all migrations applied since the version you are downgrading to."); 158 - app.Logger.LogCritical("Unknown Migrations: {}", string.Join(", ", unknownMigrations)); 159 - Environment.Exit(1); 160 - } 161 - 162 - // @formatter:off 163 - var pendingMigration = (await db.Database.GetPendingMigrationsAsync()).FirstOrDefault(); 164 - if (args.Contains("--migrate-from-js")) 165 - { 166 - app.Logger.LogInformation("Initializing migration assistant..."); 167 - var initialMigration = typeof(Initial).GetCustomAttribute<MigrationAttribute>()?.Id; 168 - if (pendingMigration != initialMigration || await db.IsDatabaseEmptyAsync()) 122 + if (app.Environment.IsDevelopment()) 169 123 { 170 - app.Logger.LogCritical("Database does not appear to be an iceshrimp-js database."); 171 - Environment.Exit(1); 124 + app.Logger.LogWarning("The hosting environment is set to Development."); 125 + app.Logger.LogWarning("This application will not validate the Host header for incoming requests."); 126 + app.Logger.LogWarning("If this is not a local development instance, please set the environment to Production."); 172 127 } 173 - else if (!args.Contains("--i-reverted-any-extra-migrations") || 174 - !args.Contains("--i-made-a-database-backup") || 175 - !args.Contains("--i-understand-that-this-is-a-one-way-operation")) 128 + 129 + await using var db = provider.GetService<DatabaseContext>(); 130 + if (db == null) 176 131 { 177 - app.Logger.LogCritical("Missing confirmation argument(s), please follow the instructions on https://iceshrimp.net/help/migrate exactly."); 132 + app.Logger.LogCritical("Failed to initialize database context"); 178 133 Environment.Exit(1); 179 134 } 180 - else 135 + 136 + app.Logger.LogInformation("Verifying database connection..."); 137 + 138 + if (!await db.Database.CanConnectAsync()) 181 139 { 182 - app.Logger.LogInformation("Applying initial migration..."); 183 140 try 184 141 { 185 - await db.Database.ExecuteSqlAsync(new MigrationAssistant().InitialMigration); 142 + await db.Database.OpenConnectionAsync(); 143 + await db.Database.CloseConnectionAsync(); 186 144 } 187 145 catch (Exception e) 188 146 { 189 - app.Logger.LogCritical("Failed to apply initial migration: {error}", e); 190 - app.Logger.LogCritical("Manual intervention required, please follow the instructions on https://iceshrimp.net/help/migrate for more information."); 147 + app.Logger.LogCritical("Failed to connect to database. Please make sure your configuration is correct."); 148 + app.Logger.LogError("Additional information: {e}", e); 191 149 Environment.Exit(1); 192 150 } 193 151 194 - app.Logger.LogInformation("Successfully applied the initial migration."); 195 - app.Logger.LogInformation("Please follow the instructions on https://iceshrimp.net/help/migrate to validate the database schema."); 196 - Environment.Exit(0); 197 - } 198 - } 199 - // @formatter:on 200 - 201 - if (pendingMigration != null) 202 - { 203 - var initialMigration = typeof(Initial).GetCustomAttribute<MigrationAttribute>()?.Id; 204 - if (pendingMigration == initialMigration && !await db.IsDatabaseEmptyAsync()) 205 - { 206 - app.Logger.LogCritical("Initial migration is pending but database is not empty."); 207 - app.Logger.LogCritical("If you are trying to migrate from iceshrimp-js, please follow the instructions on https://iceshrimp.net/help/migrate."); 152 + app.Logger.LogCritical("Failed to connect to database. Please make sure your configuration is correct."); 208 153 Environment.Exit(1); 209 154 } 210 155 211 - if (args.Contains("--migrate") || args.Contains("--migrate-and-start")) 156 + var unknownMigrations = (await db.Database.GetAppliedMigrationsAsync()) 157 + .Except(db.Database.GetMigrations()) 158 + .ToList(); 159 + 160 + if (unknownMigrations.Count > 0) 212 161 { 213 - app.Logger.LogInformation("Running migrations..."); 214 - db.Database.SetCommandTimeout(0); 215 - await db.Database.MigrateAsync(); 216 - db.Database.SetCommandTimeout(30); 217 - if (args.Contains("--migrate")) Environment.Exit(0); 218 - } 219 - else 220 - { 221 - app.Logger.LogCritical("Database has pending migrations, please restart with --migrate or --migrate-and-start"); 162 + app.Logger.LogCritical("Database has {Count} unknown migrations applied, refusing to continue startup.", unknownMigrations.Count); 163 + app.Logger.LogCritical("If you tried to downgrade, make sure you know what you are doing and revert all migrations applied since the version you are downgrading to."); 164 + app.Logger.LogCritical("Unknown Migrations: {}", string.Join(", ", unknownMigrations)); 222 165 Environment.Exit(1); 223 166 } 224 - } 225 - else if (args.Contains("--migrate") || args.Contains("--migrate-and-start")) 226 - { 227 - app.Logger.LogInformation("No migrations are pending."); 228 - if (args.Contains("--migrate")) Environment.Exit(0); 229 - } 230 167 231 - if (args.Contains("--recompute-counters")) 232 - { 233 - app.Logger.LogInformation("Recomputing note, user & instance counters, this will take a while..."); 234 - var maintenanceSvc = provider.GetRequiredService<DatabaseMaintenanceService>(); 235 - await maintenanceSvc.RecomputeNoteCountersAsync(); 236 - await maintenanceSvc.RecomputeUserCountersAsync(); 237 - await maintenanceSvc.RecomputeInstanceCountersAsync(); 238 - Environment.Exit(0); 239 - } 168 + // @formatter:off 169 + var pendingMigration = (await db.Database.GetPendingMigrationsAsync()).FirstOrDefault(); 170 + if (args.Contains("--migrate-from-js")) 171 + { 172 + app.Logger.LogInformation("Initializing migration assistant..."); 173 + var initialMigration = typeof(Initial).GetCustomAttribute<MigrationAttribute>()?.Id; 174 + if (pendingMigration != initialMigration || await db.IsDatabaseEmptyAsync()) 175 + { 176 + app.Logger.LogCritical("Database does not appear to be an iceshrimp-js database."); 177 + Environment.Exit(1); 178 + } 179 + else if (!args.Contains("--i-reverted-any-extra-migrations") || 180 + !args.Contains("--i-made-a-database-backup") || 181 + !args.Contains("--i-understand-that-this-is-a-one-way-operation")) 182 + { 183 + app.Logger.LogCritical("Missing confirmation argument(s), please follow the instructions on https://iceshrimp.net/help/migrate exactly."); 184 + Environment.Exit(1); 185 + } 186 + else 187 + { 188 + app.Logger.LogInformation("Applying initial migration..."); 189 + try 190 + { 191 + await db.Database.ExecuteSqlAsync(new MigrationAssistant().InitialMigration); 192 + } 193 + catch (Exception e) 194 + { 195 + app.Logger.LogCritical("Failed to apply initial migration: {error}", e); 196 + app.Logger.LogCritical("Manual intervention required, please follow the instructions on https://iceshrimp.net/help/migrate for more information."); 197 + Environment.Exit(1); 198 + } 240 199 241 - if (args.Contains("--migrate-storage")) 242 - { 243 - app.Logger.LogInformation("Migrating files to object storage, this will take a while..."); 244 - db.Database.SetCommandTimeout(0); 245 - await provider.GetRequiredService<StorageMaintenanceService>() 246 - .MigrateLocalFilesAsync(args.Contains("--purge")); 247 - Environment.Exit(0); 248 - } 249 - 250 - if (args.Contains("--fixup-media")) 251 - { 252 - await provider.GetRequiredService<StorageMaintenanceService>().FixupMediaAsync(args.Contains("--dry-run")); 253 - Environment.Exit(0); 254 - } 255 - 256 - if (args.Contains("--cleanup-storage")) 257 - { 258 - await provider.GetRequiredService<StorageMaintenanceService>() 259 - .CleanupStorageAsync(args.Contains("--dry-run")); 260 - Environment.Exit(0); 261 - } 262 - 263 - string[] userMgmtCommands = 264 - [ 265 - "--create-user", "--create-admin-user", "--reset-password", "--grant-admin", "--revoke-admin" 266 - ]; 267 - 268 - if (args.FirstOrDefault(userMgmtCommands.Contains) is { } cmd) 269 - { 270 - if (args is not [not null, var username]) 271 - { 272 - app.Logger.LogError("Invalid syntax. Usage: {cmd} <username>", cmd); 273 - Environment.Exit(1); 274 - return null!; 200 + app.Logger.LogInformation("Successfully applied the initial migration."); 201 + app.Logger.LogInformation("Please follow the instructions on https://iceshrimp.net/help/migrate to validate the database schema."); 202 + Environment.Exit(0); 203 + } 275 204 } 205 + // @formatter:on 276 206 277 - if (cmd is "--create-user" or "--create-admin-user") 207 + if (pendingMigration != null) 278 208 { 279 - var password = CryptographyHelpers.GenerateRandomString(16); 280 - app.Logger.LogInformation("Creating user {username}...", username); 281 - var userSvc = provider.GetRequiredService<UserService>(); 282 - await userSvc.CreateLocalUserAsync(username, password, null, force: true); 209 + var initialMigration = typeof(Initial).GetCustomAttribute<MigrationAttribute>()?.Id; 210 + if (pendingMigration == initialMigration && !await db.IsDatabaseEmptyAsync()) 211 + { 212 + app.Logger.LogCritical("Initial migration is pending but database is not empty."); 213 + app.Logger.LogCritical("If you are trying to migrate from iceshrimp-js, please follow the instructions on https://iceshrimp.net/help/migrate."); 214 + Environment.Exit(1); 215 + } 283 216 284 - if (args[0] is "--create-admin-user") 217 + if (args.Contains("--migrate") || args.Contains("--migrate-and-start")) 285 218 { 286 - await db.Users 287 - .Where(p => p.Username == username && p.Host == null) 288 - .ExecuteUpdateAsync(p => p.SetProperty(i => i.IsAdmin, true)); 289 - 290 - app.Logger.LogInformation("Successfully created admin user."); 219 + app.Logger.LogInformation("Running migrations..."); 220 + db.Database.SetCommandTimeout(0); 221 + await db.Database.MigrateAsync(); 222 + db.Database.SetCommandTimeout(30); 223 + if (args.Contains("--migrate")) Environment.Exit(0); 291 224 } 292 225 else 293 226 { 294 - app.Logger.LogInformation("Successfully created user."); 227 + app.Logger.LogCritical("Database has pending migrations, please restart with --migrate or --migrate-and-start"); 228 + Environment.Exit(1); 295 229 } 230 + } 231 + else if (args.Contains("--migrate") || args.Contains("--migrate-and-start")) 232 + { 233 + app.Logger.LogInformation("No migrations are pending."); 234 + if (args.Contains("--migrate")) Environment.Exit(0); 235 + } 296 236 297 - app.Logger.LogInformation("Username: {username}", username); 298 - app.Logger.LogInformation("Password: {password}", password); 237 + if (args.Contains("--recompute-counters")) 238 + { 239 + app.Logger.LogInformation("Recomputing note, user & instance counters, this will take a while..."); 240 + var maintenanceSvc = provider.GetRequiredService<DatabaseMaintenanceService>(); 241 + await maintenanceSvc.RecomputeNoteCountersAsync(); 242 + await maintenanceSvc.RecomputeUserCountersAsync(); 243 + await maintenanceSvc.RecomputeInstanceCountersAsync(); 299 244 Environment.Exit(0); 300 245 } 301 246 302 - if (cmd is "--reset-password") 247 + if (args.Contains("--migrate-storage")) 303 248 { 304 - var settings = await db.UserSettings 305 - .FirstOrDefaultAsync(p => p.User.UsernameLower == username.ToLowerInvariant()); 249 + app.Logger.LogInformation("Migrating files to object storage, this will take a while..."); 250 + db.Database.SetCommandTimeout(0); 251 + await provider.GetRequiredService<StorageMaintenanceService>() 252 + .MigrateLocalFilesAsync(args.Contains("--purge")); 253 + Environment.Exit(0); 254 + } 306 255 307 - if (settings == null) 308 - { 309 - app.Logger.LogError("User {username} not found.", username); 310 - Environment.Exit(1); 311 - } 312 - 313 - app.Logger.LogInformation("Resetting password for user {username}...", username); 314 - 315 - var password = CryptographyHelpers.GenerateRandomString(16); 316 - settings.Password = AuthHelpers.HashPassword(password); 317 - await db.SaveChangesAsync(); 256 + if (args.Contains("--fixup-media")) 257 + { 258 + await provider.GetRequiredService<StorageMaintenanceService>().FixupMediaAsync(args.Contains("--dry-run")); 259 + Environment.Exit(0); 260 + } 318 261 319 - app.Logger.LogInformation("Password for user {username} was reset to: {password}", username, password); 262 + if (args.Contains("--cleanup-storage")) 263 + { 264 + await provider.GetRequiredService<StorageMaintenanceService>() 265 + .CleanupStorageAsync(args.Contains("--dry-run")); 320 266 Environment.Exit(0); 321 267 } 322 268 323 - if (cmd is "--grant-admin") 269 + string[] userMgmtCommands = 270 + [ 271 + "--create-user", "--create-admin-user", "--reset-password", "--grant-admin", "--revoke-admin" 272 + ]; 273 + 274 + if (args.FirstOrDefault(userMgmtCommands.Contains) is { } cmd) 324 275 { 325 - var user = await db.Users.FirstOrDefaultAsync(p => p.UsernameLower == username.ToLowerInvariant() && p.Host == null); 326 - if (user == null) 276 + if (args is not [not null, var username]) 327 277 { 328 - app.Logger.LogError("User {username} not found.", username); 278 + app.Logger.LogError("Invalid syntax. Usage: {cmd} <username>", cmd); 329 279 Environment.Exit(1); 280 + return null!; 330 281 } 331 - else 282 + 283 + if (cmd is "--create-user" or "--create-admin-user") 284 + { 285 + var password = CryptographyHelpers.GenerateRandomString(16); 286 + app.Logger.LogInformation("Creating user {username}...", username); 287 + var userSvc = provider.GetRequiredService<UserService>(); 288 + await userSvc.CreateLocalUserAsync(username, password, null, force: true); 289 + 290 + if (args[0] is "--create-admin-user") 291 + { 292 + await db.Users 293 + .Where(p => p.Username == username && p.Host == null) 294 + .ExecuteUpdateAsync(p => p.SetProperty(i => i.IsAdmin, true)); 295 + 296 + app.Logger.LogInformation("Successfully created admin user."); 297 + } 298 + else 299 + { 300 + app.Logger.LogInformation("Successfully created user."); 301 + } 302 + 303 + app.Logger.LogInformation("Username: {username}", username); 304 + app.Logger.LogInformation("Password: {password}", password); 305 + Environment.Exit(0); 306 + } 307 + 308 + if (cmd is "--reset-password") 332 309 { 333 - user.IsAdmin = true; 310 + var settings = await db.UserSettings 311 + .FirstOrDefaultAsync(p => p.User.UsernameLower == username.ToLowerInvariant()); 312 + 313 + if (settings == null) 314 + { 315 + app.Logger.LogError("User {username} not found.", username); 316 + Environment.Exit(1); 317 + } 318 + 319 + app.Logger.LogInformation("Resetting password for user {username}...", username); 320 + 321 + var password = CryptographyHelpers.GenerateRandomString(16); 322 + settings.Password = AuthHelpers.HashPassword(password); 334 323 await db.SaveChangesAsync(); 335 - app.Logger.LogInformation("Granted admin privileges to user {username}.", username); 324 + 325 + app.Logger.LogInformation("Password for user {username} was reset to: {password}", username, password); 336 326 Environment.Exit(0); 337 327 } 328 + 329 + if (cmd is "--grant-admin") 330 + { 331 + var user = await db.Users.FirstOrDefaultAsync(p => p.UsernameLower == username.ToLowerInvariant() && p.Host == null); 332 + if (user == null) 333 + { 334 + app.Logger.LogError("User {username} not found.", username); 335 + Environment.Exit(1); 336 + } 337 + else 338 + { 339 + user.IsAdmin = true; 340 + await db.SaveChangesAsync(); 341 + app.Logger.LogInformation("Granted admin privileges to user {username}.", username); 342 + Environment.Exit(0); 343 + } 344 + } 345 + 346 + if (cmd is "--revoke-admin") 347 + { 348 + var user = await db.Users.FirstOrDefaultAsync(p => p.UsernameLower == username.ToLowerInvariant() && p.Host == null); 349 + if (user == null) 350 + { 351 + app.Logger.LogError("User {username} not found.", username); 352 + Environment.Exit(1); 353 + } 354 + else 355 + { 356 + user.IsAdmin = false; 357 + await db.SaveChangesAsync(); 358 + app.Logger.LogInformation("Revoked admin privileges of user {username}.", username); 359 + Environment.Exit(0); 360 + } 361 + } 338 362 } 339 363 340 - if (cmd is "--revoke-admin") 364 + var storageConfig = app.Configuration.GetSection("Storage").Get<Config.StorageSection>() ?? 365 + throw new Exception("Failed to read Storage config section"); 366 + 367 + if (storageConfig.Provider == Enums.FileStorage.Local) 341 368 { 342 - var user = await db.Users.FirstOrDefaultAsync(p => p.UsernameLower == username.ToLowerInvariant() && p.Host == null); 343 - if (user == null) 369 + if (string.IsNullOrWhiteSpace(storageConfig.Local?.Path) || !Directory.Exists(storageConfig.Local.Path)) 344 370 { 345 - app.Logger.LogError("User {username} not found.", username); 371 + app.Logger.LogCritical("Local storage path does not exist"); 346 372 Environment.Exit(1); 347 373 } 348 374 else 349 375 { 350 - user.IsAdmin = false; 351 - await db.SaveChangesAsync(); 352 - app.Logger.LogInformation("Revoked admin privileges of user {username}.", username); 353 - Environment.Exit(0); 376 + try 377 + { 378 + var path = Path.Combine(storageConfig.Local.Path, Path.GetRandomFileName()); 379 + 380 + await using var fs = File.Create(path, 1, FileOptions.DeleteOnClose); 381 + } 382 + catch 383 + { 384 + app.Logger.LogCritical("Local storage path is not accessible or not writable"); 385 + Environment.Exit(1); 386 + } 354 387 } 355 388 } 356 - } 357 - 358 - var storageConfig = app.Configuration.GetSection("Storage").Get<Config.StorageSection>() ?? 359 - throw new Exception("Failed to read Storage config section"); 360 - 361 - if (storageConfig.Provider == Enums.FileStorage.Local) 362 - { 363 - if (string.IsNullOrWhiteSpace(storageConfig.Local?.Path) || !Directory.Exists(storageConfig.Local.Path)) 389 + else if (storageConfig.Provider == Enums.FileStorage.ObjectStorage) 364 390 { 365 - app.Logger.LogCritical("Local storage path does not exist"); 366 - Environment.Exit(1); 367 - } 368 - else 369 - { 391 + app.Logger.LogInformation("Verifying object storage configuration..."); 392 + var svc = provider.GetRequiredService<ObjectStorageService>(); 370 393 try 371 394 { 372 - var path = Path.Combine(storageConfig.Local.Path, Path.GetRandomFileName()); 373 - 374 - await using var fs = File.Create(path, 1, FileOptions.DeleteOnClose); 395 + await svc.VerifyCredentialsAsync(); 375 396 } 376 - catch 397 + catch (Exception e) 377 398 { 378 - app.Logger.LogCritical("Local storage path is not accessible or not writable"); 399 + app.Logger.LogCritical("Failed to initialize object storage: {message}", e.Message); 379 400 Environment.Exit(1); 380 401 } 381 402 } 382 - } 383 - else if (storageConfig.Provider == Enums.FileStorage.ObjectStorage) 384 - { 385 - app.Logger.LogInformation("Verifying object storage configuration..."); 386 - var svc = provider.GetRequiredService<ObjectStorageService>(); 403 + 404 + var tempPath = Environment.GetEnvironmentVariable("ASPNETCORE_TEMP") ?? Path.GetTempPath(); 387 405 try 388 406 { 389 - await svc.VerifyCredentialsAsync(); 407 + await using var stream = File.Create(Path.Combine(tempPath, ".iceshrimp-test"), 1, FileOptions.DeleteOnClose); 390 408 } 391 - catch (Exception e) 409 + catch 392 410 { 393 - app.Logger.LogCritical("Failed to initialize object storage: {message}", e.Message); 411 + app.Logger.LogCritical("Temporary directory {dir} is not writable. Please adjust permissions or set the ASPNETCORE_TEMP environment variable to a writable directory.", 412 + tempPath); 394 413 Environment.Exit(1); 395 414 } 396 - } 397 415 398 - var tempPath = Environment.GetEnvironmentVariable("ASPNETCORE_TEMP") ?? Path.GetTempPath(); 399 - try 400 - { 401 - await using var stream = File.Create(Path.Combine(tempPath, ".iceshrimp-test"), 1, FileOptions.DeleteOnClose); 402 - } 403 - catch 404 - { 405 - app.Logger.LogCritical("Temporary directory {dir} is not writable. Please adjust permissions or set the ASPNETCORE_TEMP environment variable to a writable directory.", 406 - tempPath); 407 - Environment.Exit(1); 416 + app.Logger.LogInformation("Initializing VAPID keys..."); 417 + var meta = provider.GetRequiredService<MetaService>(); 418 + await meta.EnsureSetAsync([MetaEntity.VapidPublicKey, MetaEntity.VapidPrivateKey], () => 419 + { 420 + var keypair = VapidHelper.GenerateVapidKeys(); 421 + return [keypair.PublicKey, keypair.PrivateKey]; 422 + }); 423 + 424 + app.Logger.LogInformation("Warming up meta cache..."); 425 + await meta.WarmupCacheAsync(); 426 + 427 + // Initialize image processing 428 + provider.GetRequiredService<ImageProcessor>(); 429 + 430 + return instanceConfig; 408 431 } 409 432 410 - app.Logger.LogInformation("Initializing VAPID keys..."); 411 - var meta = provider.GetRequiredService<MetaService>(); 412 - await meta.EnsureSetAsync([MetaEntity.VapidPublicKey, MetaEntity.VapidPrivateKey], () => 433 + public void SetKestrelUnixSocketPermissions() 413 434 { 414 - var keypair = VapidHelper.GenerateVapidKeys(); 415 - return [keypair.PublicKey, keypair.PrivateKey]; 416 - }); 435 + var config = app.Configuration.GetSection("Instance").Get<Config.InstanceSection>() 436 + ?? throw new Exception("Failed to read instance config"); 437 + if (config.ListenSocket == null) return; 438 + using var scope = app.Services.CreateScope(); 439 + var logger = scope.ServiceProvider.GetRequiredService<ILoggerFactory>() 440 + .CreateLogger("Microsoft.Hosting.Lifetime"); 417 441 418 - app.Logger.LogInformation("Warming up meta cache..."); 419 - await meta.WarmupCacheAsync(); 420 - 421 - // Initialize image processing 422 - provider.GetRequiredService<ImageProcessor>(); 423 - 424 - return instanceConfig; 425 - } 442 + if (!OperatingSystem.IsLinux() && !OperatingSystem.IsMacOS() && !OperatingSystem.IsFreeBSD()) 443 + throw new Exception("Can't set unix socket permissions on a non-UNIX system"); 426 444 427 - public static void SetKestrelUnixSocketPermissions(this WebApplication app) 428 - { 429 - var config = app.Configuration.GetSection("Instance").Get<Config.InstanceSection>() 430 - ?? throw new Exception("Failed to read instance config"); 431 - if (config.ListenSocket == null) return; 432 - using var scope = app.Services.CreateScope(); 433 - var logger = scope.ServiceProvider.GetRequiredService<ILoggerFactory>() 434 - .CreateLogger("Microsoft.Hosting.Lifetime"); 445 + int perms; 446 + try 447 + { 448 + perms = Convert.ToInt32(config.ListenSocketPerms, 8); 449 + } 450 + catch 451 + { 452 + logger.LogError("Failed to set Kestrel unix socket permissions to {SocketPerms}: failed to parse octal digits", 453 + config.ListenSocketPerms); 454 + Environment.Exit(1); 455 + return; 456 + } 435 457 436 - if (!OperatingSystem.IsLinux() && !OperatingSystem.IsMacOS() && !OperatingSystem.IsFreeBSD()) 437 - throw new Exception("Can't set unix socket permissions on a non-UNIX system"); 458 + var exitCode = chmod(config.ListenSocket, perms); 459 + if (exitCode < 0) 460 + { 461 + logger.LogError("Failed to set Kestrel unix socket permissions to {SocketPerms}, return code: {ExitCode}", 462 + config.ListenSocketPerms, exitCode); 463 + } 464 + else 465 + { 466 + logger.LogInformation("Kestrel unix socket permissions were set to {SocketPerms}", 467 + config.ListenSocketPerms); 468 + } 438 469 439 - int perms; 440 - try 441 - { 442 - perms = Convert.ToInt32(config.ListenSocketPerms, 8); 443 - } 444 - catch 445 - { 446 - logger.LogError("Failed to set Kestrel unix socket permissions to {SocketPerms}: failed to parse octal digits", 447 - config.ListenSocketPerms); 448 - Environment.Exit(1); 449 470 return; 450 - } 451 471 452 - var exitCode = chmod(config.ListenSocket, perms); 453 - if (exitCode < 0) 454 - { 455 - logger.LogError("Failed to set Kestrel unix socket permissions to {SocketPerms}, return code: {ExitCode}", 456 - config.ListenSocketPerms, exitCode); 472 + [DllImport("libc")] 473 + static extern int chmod(string pathname, int mode); 457 474 } 458 - else 459 - { 460 - logger.LogInformation("Kestrel unix socket permissions were set to {SocketPerms}", 461 - config.ListenSocketPerms); 462 - } 463 - 464 - return; 465 - 466 - [DllImport("libc")] 467 - static extern int chmod(string pathname, int mode); 468 475 } 469 476 } 470 477
+17 -14
Iceshrimp.Backend/Core/Federation/ActivityStreams/LdHelpers.cs
··· 102 102 return result; 103 103 } 104 104 105 - public static async Task<string> SignAndCompactAsync(this ASActivity activity, UserKeypair keypair) 105 + extension(ASActivity activity) 106 106 { 107 - var expanded = Expand(activity) ?? throw new Exception("Failed to expand activity"); 108 - var signed = await LdSignature.SignAsync(expanded, keypair.PrivateKey, 109 - activity.Actor?.PublicKey?.Id ?? $"{activity.Actor!.Id}#main-key") ?? 110 - throw new Exception("Failed to sign activity"); 111 - var compacted = Compact(signed) ?? throw new Exception("Failed to compact signed activity"); 112 - var payload = JsonConvert.SerializeObject(compacted, JsonSerializerSettings); 107 + public async Task<string> SignAndCompactAsync(UserKeypair keypair) 108 + { 109 + var expanded = Expand(activity) ?? throw new Exception("Failed to expand activity"); 110 + var signed = await LdSignature.SignAsync(expanded, keypair.PrivateKey, 111 + activity.Actor?.PublicKey?.Id ?? $"{activity.Actor!.Id}#main-key") ?? 112 + throw new Exception("Failed to sign activity"); 113 + var compacted = Compact(signed) ?? throw new Exception("Failed to compact signed activity"); 114 + var payload = JsonConvert.SerializeObject(compacted, JsonSerializerSettings); 113 115 114 - return payload; 115 - } 116 + return payload; 117 + } 116 118 117 - public static string CompactToPayload(this ASActivity activity) 118 - { 119 - var compacted = Compact(activity) ?? throw new Exception("Failed to compact signed activity"); 120 - var payload = JsonConvert.SerializeObject(compacted, JsonSerializerSettings); 119 + public string CompactToPayload() 120 + { 121 + var compacted = Compact(activity) ?? throw new Exception("Failed to compact signed activity"); 122 + var payload = JsonConvert.SerializeObject(compacted, JsonSerializerSettings); 121 123 122 - return payload; 124 + return payload; 125 + } 123 126 } 124 127 125 128 public static JObject Compact(this ASObject obj)
+43 -40
Iceshrimp.Backend/Core/Middleware/AuthenticationMiddleware.cs
··· 159 159 private const string MastodonKey = "masto-session"; 160 160 private const string HideFooterKey = "hide-login-footer"; 161 161 162 - internal static void SetSession(this HttpContext ctx, Session session) 162 + extension(HttpContext ctx) 163 163 { 164 - ctx.Items.Add(Key, session); 165 - } 164 + internal void SetSession(Session session) 165 + { 166 + ctx.Items.Add(Key, session); 167 + } 166 168 167 - public static Session? GetSession(this HttpContext ctx) 168 - { 169 - ctx.Items.TryGetValue(Key, out var session); 170 - return session as Session; 171 - } 169 + public Session? GetSession() 170 + { 171 + ctx.Items.TryGetValue(Key, out var session); 172 + return session as Session; 173 + } 172 174 173 - public static Session GetSessionOrFail(this HttpContext ctx) 174 - { 175 - return ctx.GetSession() ?? throw new Exception("Failed to get session from HttpContext"); 176 - } 175 + public Session GetSessionOrFail() 176 + { 177 + return ctx.GetSession() ?? throw new Exception("Failed to get session from HttpContext"); 178 + } 179 + 180 + internal void SetOauthToken(OauthToken session) 181 + { 182 + ctx.Items.Add(MastodonKey, session); 183 + } 177 184 178 - internal static void SetOauthToken(this HttpContext ctx, OauthToken session) 179 - { 180 - ctx.Items.Add(MastodonKey, session); 181 - } 185 + public OauthToken? GetOauthToken() 186 + { 187 + ctx.Items.TryGetValue(MastodonKey, out var session); 188 + return session as OauthToken; 189 + } 182 190 183 - public static OauthToken? GetOauthToken(this HttpContext ctx) 184 - { 185 - ctx.Items.TryGetValue(MastodonKey, out var session); 186 - return session as OauthToken; 187 - } 191 + //TODO: Is it faster to check for the MastodonApiControllerAttribute here? 192 + public User? GetUser() 193 + { 194 + if (ctx.Items.TryGetValue(Key, out var session)) 195 + return (session as Session)?.User; 196 + return ctx.Items.TryGetValue(MastodonKey, out var token) 197 + ? (token as OauthToken)?.User 198 + : null; 199 + } 188 200 189 - //TODO: Is it faster to check for the MastodonApiControllerAttribute here? 190 - public static User? GetUser(this HttpContext ctx) 191 - { 192 - if (ctx.Items.TryGetValue(Key, out var session)) 193 - return (session as Session)?.User; 194 - return ctx.Items.TryGetValue(MastodonKey, out var token) 195 - ? (token as OauthToken)?.User 196 - : null; 197 - } 201 + public User GetUserOrFail() 202 + { 203 + return ctx.GetUser() ?? throw new Exception("Failed to get user from HttpContext"); 204 + } 198 205 199 - public static User GetUserOrFail(this HttpContext ctx) 200 - { 201 - return ctx.GetUser() ?? throw new Exception("Failed to get user from HttpContext"); 202 - } 206 + public bool ShouldHideFooter() 207 + { 208 + ctx.Items.TryGetValue(HideFooterKey, out var auth); 209 + return auth is true; 210 + } 203 211 204 - public static bool ShouldHideFooter(this HttpContext ctx) 205 - { 206 - ctx.Items.TryGetValue(HideFooterKey, out var auth); 207 - return auth is true; 212 + public void HideFooter() => ctx.Items.Add(HideFooterKey, true); 208 213 } 209 - 210 - public static void HideFooter(this HttpContext ctx) => ctx.Items.Add(HideFooterKey, true); 211 214 }
+10 -7
Iceshrimp.Backend/Core/Middleware/AuthorizedFetchMiddleware.cs
··· 171 171 { 172 172 private const string ActorKey = "auth-fetch-user"; 173 173 174 - internal static void SetActor(this HttpContext ctx, User actor) 174 + extension(HttpContext ctx) 175 175 { 176 - ctx.Items.Add(ActorKey, actor); 177 - } 176 + internal void SetActor(User actor) 177 + { 178 + ctx.Items.Add(ActorKey, actor); 179 + } 178 180 179 - public static User? GetActor(this HttpContext ctx) 180 - { 181 - ctx.Items.TryGetValue(ActorKey, out var actor); 182 - return actor as User; 181 + public User? GetActor() 182 + { 183 + ctx.Items.TryGetValue(ActorKey, out var actor); 184 + return actor as User; 185 + } 183 186 } 184 187 }
+26 -23
Iceshrimp.Backend/Core/Queues/PreDeliverQueue.cs
··· 163 163 164 164 file static class QueryableExtensions 165 165 { 166 - public static IQueryable<InboxQueryResult> SkipDeadInstances( 167 - this IQueryable<InboxQueryResult> query, ASActivity activity, DatabaseContext db 168 - ) 166 + extension(IQueryable<InboxQueryResult> query) 169 167 { 170 - return activity is ASFollow 171 - ? query.Where(user => !db.Instances.Any(p => p.Host == user.Host && p.IsSuspended)) 172 - : query.Where(user => !db.Instances.Any(p => p.Host == user.Host && 173 - ((p.IsNotResponding && 174 - p.LastCommunicatedAt < 175 - DateTime.UtcNow - TimeSpan.FromDays(7)) || 176 - p.IsSuspended))); 177 - } 168 + public IQueryable<InboxQueryResult> SkipDeadInstances( 169 + ASActivity activity, DatabaseContext db 170 + ) 171 + { 172 + return activity is ASFollow 173 + ? query.Where(user => !db.Instances.Any(p => p.Host == user.Host && p.IsSuspended)) 174 + : query.Where(user => !db.Instances.Any(p => p.Host == user.Host && 175 + ((p.IsNotResponding && 176 + p.LastCommunicatedAt < 177 + DateTime.UtcNow - TimeSpan.FromDays(7)) || 178 + p.IsSuspended))); 179 + } 178 180 179 - public static IQueryable<InboxQueryResult> SkipBlockedInstances( 180 - this IQueryable<InboxQueryResult> query, Enums.FederationMode mode, DatabaseContext db 181 - ) 182 - { 183 - // @formatter:off 184 - Expression<Func<InboxQueryResult, bool>> expr = mode switch 181 + public IQueryable<InboxQueryResult> SkipBlockedInstances( 182 + Enums.FederationMode mode, DatabaseContext db 183 + ) 185 184 { 186 - Enums.FederationMode.BlockList => u => u.Host == null || !db.BlockedInstances.Any(p => u.Host == p.Host || u.Host.EndsWith("." + p.Host)), 187 - Enums.FederationMode.AllowList => u => u.Host == null || db.AllowedInstances.Any(p => u.Host == p.Host || u.Host.EndsWith("." + p.Host)), 185 + // @formatter:off 186 + Expression<Func<InboxQueryResult, bool>> expr = mode switch 187 + { 188 + Enums.FederationMode.BlockList => u => u.Host == null || !db.BlockedInstances.Any(p => u.Host == p.Host || u.Host.EndsWith("." + p.Host)), 189 + Enums.FederationMode.AllowList => u => u.Host == null || db.AllowedInstances.Any(p => u.Host == p.Host || u.Host.EndsWith("." + p.Host)), 188 190 189 - _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, null) 190 - }; 191 - // @formatter:on 191 + _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, null) 192 + }; 193 + // @formatter:on 192 194 193 - return query.Where(expr); 195 + return query.Where(expr); 196 + } 194 197 } 195 198 } 196 199
+21 -19
Iceshrimp.Frontend/Core/InMemoryLogger/InMemoryLoggerExtension.cs
··· 5 5 6 6 internal static class InMemoryLoggerExtension 7 7 { 8 - public static void AddInMemoryLogger( 9 - this ILoggingBuilder builder, IConfiguration configuration 10 - ) 8 + extension(ILoggingBuilder builder) 11 9 { 12 - builder.AddConfiguration(); 13 - builder.Services.AddOptionsWithValidateOnStart<InMemoryLoggerConfiguration>() 14 - .Bind(configuration.GetSection("InMemoryLogger")); 15 - LoggerProviderOptions 16 - .RegisterProviderOptions<InMemoryLoggerConfiguration, InMemoryLoggerProvider>(builder.Services); 17 - builder.Services.TryAddSingleton<InMemoryLogService>(); 18 - builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<ILoggerProvider, InMemoryLoggerProvider>()); 19 - } 10 + public void AddInMemoryLogger( 11 + IConfiguration configuration 12 + ) 13 + { 14 + builder.AddConfiguration(); 15 + builder.Services.AddOptionsWithValidateOnStart<InMemoryLoggerConfiguration>() 16 + .Bind(configuration.GetSection("InMemoryLogger")); 17 + LoggerProviderOptions 18 + .RegisterProviderOptions<InMemoryLoggerConfiguration, InMemoryLoggerProvider>(builder.Services); 19 + builder.Services.TryAddSingleton<InMemoryLogService>(); 20 + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<ILoggerProvider, InMemoryLoggerProvider>()); 21 + } 20 22 21 - public static void AddInMemoryLogger( 22 - this ILoggingBuilder builder, 23 - Action<InMemoryLoggerConfiguration> configure 24 - ) 25 - { 26 - builder.Services.TryAddSingleton<InMemoryLogService>(); 27 - builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<ILoggerProvider, InMemoryLoggerProvider>()); 28 - builder.Services.Configure(configure); 23 + public void AddInMemoryLogger( 24 + Action<InMemoryLoggerConfiguration> configure 25 + ) 26 + { 27 + builder.Services.TryAddSingleton<InMemoryLogService>(); 28 + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton<ILoggerProvider, InMemoryLoggerProvider>()); 29 + builder.Services.Configure(configure); 30 + } 29 31 } 30 32 }