From c08503d2f3e16efaa4c29129ad0f00e0b0eb8220 Mon Sep 17 00:00:00 2001 From: LittleSheep Date: Sun, 9 Nov 2025 01:46:24 +0800 Subject: [PATCH] :recycle: Refactored files service --- .../Storage/CloudFileUnusedRecyclingJob.cs | 21 +- DysonNetwork.Drive/Storage/FileController.cs | 285 +++++++++++------- .../Storage/FileExpirationJob.cs | 68 +++-- .../Storage/FileReferenceService.cs | 64 +++- DysonNetwork.Drive/Storage/FileService.cs | 136 ++++++--- .../Storage/FileUploadController.cs | 265 +++++++++------- 6 files changed, 528 insertions(+), 311 deletions(-) diff --git a/DysonNetwork.Drive/Storage/CloudFileUnusedRecyclingJob.cs b/DysonNetwork.Drive/Storage/CloudFileUnusedRecyclingJob.cs index f2ae3a9..e3af918 100644 --- a/DysonNetwork.Drive/Storage/CloudFileUnusedRecyclingJob.cs +++ b/DysonNetwork.Drive/Storage/CloudFileUnusedRecyclingJob.cs @@ -6,7 +6,6 @@ namespace DysonNetwork.Drive.Storage; public class CloudFileUnusedRecyclingJob( AppDatabase db, - FileReferenceService fileRefService, ILogger logger, IConfiguration configuration ) @@ -80,15 +79,15 @@ public class CloudFileUnusedRecyclingJob( processedCount += fileBatch.Count; lastProcessedId = fileBatch.Last(); - // Get all relevant file references for this batch - var fileReferences = await fileRefService.GetReferencesAsync(fileBatch); - - // Filter to find files that have no references or all expired references - var filesToMark = fileBatch.Where(fileId => - !fileReferences.TryGetValue(fileId, out var references) || - references.Count == 0 || - references.All(r => r.ExpiredAt.HasValue && r.ExpiredAt.Value <= now) - ).ToList(); + // Optimized query: Find files that have no references OR all references are expired + // This replaces the memory-intensive approach of loading all references + var filesToMark = await db.Files + .Where(f => fileBatch.Contains(f.Id)) + .Where(f => !db.FileReferences.Any(r => r.FileId == f.Id) || // No references at all + !db.FileReferences.Any(r => r.FileId == f.Id && // OR has references but all are expired + (r.ExpiredAt == null || r.ExpiredAt > now))) + .Select(f => f.Id) + .ToListAsync(); if (filesToMark.Count > 0) { @@ -120,4 +119,4 @@ public class CloudFileUnusedRecyclingJob( logger.LogInformation("Completed marking {MarkedCount} files for recycling", markedCount); } -} \ No newline at end of file +} diff --git a/DysonNetwork.Drive/Storage/FileController.cs b/DysonNetwork.Drive/Storage/FileController.cs index ac43f4a..1a9097d 100644 --- a/DysonNetwork.Drive/Storage/FileController.cs +++ b/DysonNetwork.Drive/Storage/FileController.cs @@ -29,137 +29,200 @@ public class FileController( [FromQuery] string? passcode = null ) { - // Support the file extension for client side data recognize - string? fileExtension = null; - if (id.Contains('.')) - { - var splitId = id.Split('.'); - id = splitId.First(); - fileExtension = splitId.Last(); - } - - var file = await fs.GetFileAsync(id); + var (fileId, fileExtension) = ParseFileId(id); + var file = await fs.GetFileAsync(fileId); if (file is null) return NotFound("File not found."); + var accessResult = await ValidateFileAccess(file, passcode); + if (accessResult is not null) return accessResult; + + // Handle direct storage URL redirect + if (!string.IsNullOrWhiteSpace(file.StorageUrl)) + return Redirect(file.StorageUrl); + + // Handle files not yet uploaded to remote storage + if (file.UploadedAt is null) + return await ServeLocalFile(file); + + // Handle uploaded files + return await ServeRemoteFile(file, fileExtension, download, original, thumbnail, overrideMimeType); + } + + private (string fileId, string? extension) ParseFileId(string id) + { + if (!id.Contains('.')) return (id, null); + + var parts = id.Split('.'); + return (parts.First(), parts.Last()); + } + + private async Task ValidateFileAccess(SnCloudFile file, string? passcode) + { if (file.Bundle is not null && !file.Bundle.VerifyPasscode(passcode)) return StatusCode(StatusCodes.Status403Forbidden, "The passcode is incorrect."); + return null; + } - if (!string.IsNullOrWhiteSpace(file.StorageUrl)) return Redirect(file.StorageUrl); - - if (file.UploadedAt is null) + private async Task ServeLocalFile(SnCloudFile file) + { + // Try temp storage first + var tempFilePath = Path.Combine(Path.GetTempPath(), file.Id); + if (System.IO.File.Exists(tempFilePath)) { - // File is not yet uploaded to remote storage. Try to serve from local temp storage. - var tempFilePath = Path.Combine(Path.GetTempPath(), file.Id); - if (System.IO.File.Exists(tempFilePath)) - { - if (file.IsEncrypted) - { - return StatusCode(StatusCodes.Status403Forbidden, "Encrypted files cannot be accessed before they are processed and stored."); - } - return PhysicalFile(tempFilePath, file.MimeType ?? "application/octet-stream", file.Name, enableRangeProcessing: true); - } - - // Fallback for tus uploads that are not processed yet. - var tusStorePath = configuration.GetValue("Tus:StorePath"); - if (!string.IsNullOrEmpty(tusStorePath)) - { - var tusFilePath = Path.Combine(env.ContentRootPath, tusStorePath, file.Id); - if (System.IO.File.Exists(tusFilePath)) - { - return PhysicalFile(tusFilePath, file.MimeType ?? "application/octet-stream", file.Name, enableRangeProcessing: true); - } - } + if (file.IsEncrypted) + return StatusCode(StatusCodes.Status403Forbidden, "Encrypted files cannot be accessed before they are processed and stored."); - return StatusCode(StatusCodes.Status400BadRequest, "File is being processed. Please try again later."); + return PhysicalFile(tempFilePath, file.MimeType ?? "application/octet-stream", file.Name, enableRangeProcessing: true); } + // Fallback for tus uploads + var tusStorePath = configuration.GetValue("Tus:StorePath"); + if (!string.IsNullOrEmpty(tusStorePath)) + { + var tusFilePath = Path.Combine(env.ContentRootPath, tusStorePath, file.Id); + if (System.IO.File.Exists(tusFilePath)) + { + return PhysicalFile(tusFilePath, file.MimeType ?? "application/octet-stream", file.Name, enableRangeProcessing: true); + } + } + + return StatusCode(StatusCodes.Status400BadRequest, "File is being processed. Please try again later."); + } + + private async Task ServeRemoteFile( + SnCloudFile file, + string? fileExtension, + bool download, + bool original, + bool thumbnail, + string? overrideMimeType + ) + { if (!file.PoolId.HasValue) return StatusCode(StatusCodes.Status500InternalServerError, "File is in an inconsistent state: uploaded but no pool ID."); var pool = await fs.GetPoolAsync(file.PoolId.Value); if (pool is null) return StatusCode(StatusCodes.Status410Gone, "The pool of the file no longer exists or not accessible."); + + if (!pool.PolicyConfig.AllowAnonymous && HttpContext.Items["CurrentUser"] is not Account) + return Unauthorized(); + var dest = pool.StorageConfig; + var fileName = BuildRemoteFileName(file, original, thumbnail); - if (!pool.PolicyConfig.AllowAnonymous) - if (HttpContext.Items["CurrentUser"] is not Account currentUser) - return Unauthorized(); - // TODO: Provide ability to add access log + // Try proxy redirects first + var proxyResult = TryProxyRedirect(file, dest, fileName); + if (proxyResult is not null) return proxyResult; + // Handle signed URLs + if (dest.EnableSigned) + return await CreateSignedUrl(file, dest, fileName, fileExtension, download, overrideMimeType); + + // Fallback to direct S3 endpoint + var protocol = dest.EnableSsl ? "https" : "http"; + return Redirect($"{protocol}://{dest.Endpoint}/{dest.Bucket}/{fileName}"); + } + + private string BuildRemoteFileName(SnCloudFile file, bool original, bool thumbnail) + { var fileName = string.IsNullOrWhiteSpace(file.StorageId) ? file.Id : file.StorageId; - switch (thumbnail) + if (thumbnail) { - case true when file.HasThumbnail: - fileName += ".thumbnail"; - break; - case true when !file.HasThumbnail: - return NotFound(); + if (!file.HasThumbnail) throw new InvalidOperationException("Thumbnail not available"); + fileName += ".thumbnail"; + } + else if (!original && file.HasCompression) + { + fileName += ".compressed"; } - if (!original && file.HasCompression) - fileName += ".compressed"; + return fileName; + } + private ActionResult? TryProxyRedirect(SnCloudFile file, RemoteStorageConfig dest, string fileName) + { if (dest.ImageProxy is not null && (file.MimeType?.StartsWith("image/") ?? false)) { - var proxyUrl = dest.ImageProxy; - var baseUri = new Uri(proxyUrl.EndsWith('/') ? proxyUrl : $"{proxyUrl}/"); - var fullUri = new Uri(baseUri, fileName); - return Redirect(fullUri.ToString()); + return Redirect(BuildProxyUrl(dest.ImageProxy, fileName)); } if (dest.AccessProxy is not null) { - var proxyUrl = dest.AccessProxy; - var baseUri = new Uri(proxyUrl.EndsWith('/') ? proxyUrl : $"{proxyUrl}/"); - var fullUri = new Uri(baseUri, fileName); - return Redirect(fullUri.ToString()); + return Redirect(BuildProxyUrl(dest.AccessProxy, fileName)); } - if (dest.EnableSigned) + return null; + } + + private string BuildProxyUrl(string proxyUrl, string fileName) + { + var baseUri = new Uri(proxyUrl.EndsWith('/') ? proxyUrl : $"{proxyUrl}/"); + var fullUri = new Uri(baseUri, fileName); + return fullUri.ToString(); + } + + private async Task CreateSignedUrl( + SnCloudFile file, + RemoteStorageConfig dest, + string fileName, + string? fileExtension, + bool download, + string? overrideMimeType + ) + { + var client = fs.CreateMinioClient(dest); + if (client is null) + return BadRequest("Failed to configure client for remote destination, file got an invalid storage remote."); + + var headers = BuildSignedUrlHeaders(file, fileExtension, overrideMimeType, download); + + var openUrl = await client.PresignedGetObjectAsync( + new PresignedGetObjectArgs() + .WithBucket(dest.Bucket) + .WithObject(fileName) + .WithExpiry(3600) + .WithHeaders(headers) + ); + + return Redirect(openUrl); + } + + private Dictionary BuildSignedUrlHeaders( + SnCloudFile file, + string? fileExtension, + string? overrideMimeType, + bool download + ) + { + var headers = new Dictionary(); + + string? contentType = null; + if (fileExtension is not null && MimeTypes.TryGetMimeType(fileExtension, out var mimeType)) { - var client = fs.CreateMinioClient(dest); - if (client is null) - return BadRequest( - "Failed to configure client for remote destination, file got an invalid storage remote." - ); - - var headers = new Dictionary(); - if (fileExtension is not null) - { - if (MimeTypes.TryGetMimeType(fileExtension, out var mimeType)) - headers.Add("Response-Content-Type", mimeType); - } - else if (overrideMimeType is not null) - { - headers.Add("Response-Content-Type", overrideMimeType); - } - else if (file.MimeType is not null && !file.MimeType!.EndsWith("unknown")) - { - headers.Add("Response-Content-Type", file.MimeType); - } - - if (download) - { - headers.Add("Response-Content-Disposition", $"attachment; filename=\"{file.Name}\""); - } - - var bucket = dest.Bucket; - var openUrl = await client.PresignedGetObjectAsync( - new PresignedGetObjectArgs() - .WithBucket(bucket) - .WithObject(fileName) - .WithExpiry(3600) - .WithHeaders(headers) - ); - - return Redirect(openUrl); + contentType = mimeType; + } + else if (overrideMimeType is not null) + { + contentType = overrideMimeType; + } + else if (file.MimeType is not null && !file.MimeType.EndsWith("unknown")) + { + contentType = file.MimeType; } - // Fallback redirect to the S3 endpoint (public read) - var protocol = dest.EnableSsl ? "https" : "http"; - // Use the path bucket lookup mode - return Redirect($"{protocol}://{dest.Endpoint}/{dest.Bucket}/{fileName}"); + if (contentType is not null) + { + headers.Add("Response-Content-Type", contentType); + } + + if (download) + { + headers.Add("Response-Content-Disposition", $"attachment; filename=\"{file.Name}\""); + } + + return headers; } [HttpGet("{id}/info")] @@ -175,14 +238,7 @@ public class FileController( [HttpPatch("{id}/name")] public async Task> UpdateFileName(string id, [FromBody] string name) { - if (HttpContext.Items["CurrentUser"] is not Account currentUser) return Unauthorized(); - var accountId = Guid.Parse(currentUser.Id); - var file = await db.Files.FirstOrDefaultAsync(f => f.Id == id && f.AccountId == accountId); - if (file is null) return NotFound(); - file.Name = name; - await db.SaveChangesAsync(); - await fs._PurgeCacheAsync(file.Id); - return file; + return await UpdateFileProperty(id, file => file.Name = name); } public class MarkFileRequest @@ -194,27 +250,28 @@ public class FileController( [HttpPut("{id}/marks")] public async Task> MarkFile(string id, [FromBody] MarkFileRequest request) { - if (HttpContext.Items["CurrentUser"] is not Account currentUser) return Unauthorized(); - var accountId = Guid.Parse(currentUser.Id); - var file = await db.Files.FirstOrDefaultAsync(f => f.Id == id && f.AccountId == accountId); - if (file is null) return NotFound(); - file.SensitiveMarks = request.SensitiveMarks; - await db.SaveChangesAsync(); - await fs._PurgeCacheAsync(file.Id); - return file; + return await UpdateFileProperty(id, file => file.SensitiveMarks = request.SensitiveMarks); } [Authorize] [HttpPut("{id}/meta")] public async Task> UpdateFileMeta(string id, [FromBody] Dictionary meta) + { + return await UpdateFileProperty(id, file => file.UserMeta = meta); + } + + private async Task> UpdateFileProperty(string fileId, Action updateAction) { if (HttpContext.Items["CurrentUser"] is not Account currentUser) return Unauthorized(); var accountId = Guid.Parse(currentUser.Id); - var file = await db.Files.FirstOrDefaultAsync(f => f.Id == id && f.AccountId == accountId); + + var file = await db.Files.FirstOrDefaultAsync(f => f.Id == fileId && f.AccountId == accountId); if (file is null) return NotFound(); - file.UserMeta = meta; + + updateAction(file); await db.SaveChangesAsync(); await fs._PurgeCacheAsync(file.Id); + return file; } diff --git a/DysonNetwork.Drive/Storage/FileExpirationJob.cs b/DysonNetwork.Drive/Storage/FileExpirationJob.cs index f53aedb..53aa2e2 100644 --- a/DysonNetwork.Drive/Storage/FileExpirationJob.cs +++ b/DysonNetwork.Drive/Storage/FileExpirationJob.cs @@ -10,53 +10,59 @@ namespace DysonNetwork.Drive.Storage; public class FileExpirationJob(AppDatabase db, FileService fileService, ILogger logger) : IJob { public async Task Execute(IJobExecutionContext context) - { + { var now = SystemClock.Instance.GetCurrentInstant(); logger.LogInformation("Running file reference expiration job at {now}", now); - // Find all expired references - var expiredReferences = await db.FileReferences + // Delete expired references in bulk and get affected file IDs + var affectedFileIds = await db.FileReferences .Where(r => r.ExpiredAt < now && r.ExpiredAt != null) + .Select(r => r.FileId) + .Distinct() .ToListAsync(); - if (!expiredReferences.Any()) + if (!affectedFileIds.Any()) { logger.LogInformation("No expired file references found"); return; } - logger.LogInformation("Found {count} expired file references", expiredReferences.Count); + logger.LogInformation("Found expired references for {count} files", affectedFileIds.Count); - // Get unique file IDs - var fileIds = expiredReferences.Select(r => r.FileId).Distinct().ToList(); - var filesAndReferenceCount = new Dictionary(); + // Delete expired references in bulk + var deletedReferencesCount = await db.FileReferences + .Where(r => r.ExpiredAt < now && r.ExpiredAt != null) + .ExecuteDeleteAsync(); - // Delete expired references - db.FileReferences.RemoveRange(expiredReferences); - await db.SaveChangesAsync(); + logger.LogInformation("Deleted {count} expired file references", deletedReferencesCount); - // Check remaining references for each file - foreach (var fileId in fileIds) - { - var remainingReferences = await db.FileReferences - .Where(r => r.FileId == fileId) - .CountAsync(); + // Find files that now have no remaining references (bulk operation) + var filesToDelete = await db.Files + .Where(f => affectedFileIds.Contains(f.Id)) + .Where(f => !db.FileReferences.Any(r => r.FileId == f.Id)) + .Select(f => f.Id) + .ToListAsync(); - filesAndReferenceCount[fileId] = remainingReferences; + if (filesToDelete.Any()) + { + logger.LogInformation("Deleting {count} files that have no remaining references", filesToDelete.Count); - // If no references remain, delete the file - if (remainingReferences == 0) - { - var file = await db.Files.FirstOrDefaultAsync(f => f.Id == fileId); - if (file == null) continue; - logger.LogInformation("Deleting file {fileId} as all references have expired", fileId); - await fileService.DeleteFileAsync(file); - } - else - { - // Just purge the cache - await fileService._PurgeCacheAsync(fileId); - } + // Get files for deletion + var files = await db.Files + .Where(f => filesToDelete.Contains(f.Id)) + .ToListAsync(); + + // Delete files and their data in parallel + var deleteTasks = files.Select(f => fileService.DeleteFileAsync(f)); + await Task.WhenAll(deleteTasks); + } + + // Purge cache for files that still have references + var filesWithRemainingRefs = affectedFileIds.Except(filesToDelete).ToList(); + if (filesWithRemainingRefs.Any()) + { + var cachePurgeTasks = filesWithRemainingRefs.Select(fileService._PurgeCacheAsync); + await Task.WhenAll(cachePurgeTasks); } logger.LogInformation("Completed file reference expiration job"); diff --git a/DysonNetwork.Drive/Storage/FileReferenceService.cs b/DysonNetwork.Drive/Storage/FileReferenceService.cs index f171d78..258ed65 100644 --- a/DysonNetwork.Drive/Storage/FileReferenceService.cs +++ b/DysonNetwork.Drive/Storage/FileReferenceService.cs @@ -90,13 +90,45 @@ public class FileReferenceService(AppDatabase db, FileService fileService, ICach return references; } - public async Task>> GetReferencesAsync(IEnumerable fileId) + public async Task>> GetReferencesAsync(IEnumerable fileIds) { - var references = await db.FileReferences - .Where(r => fileId.Contains(r.FileId)) - .GroupBy(r => r.FileId) - .ToDictionaryAsync(r => r.Key, r => r.ToList()); - return references; + var fileIdList = fileIds.ToList(); + var result = new Dictionary>(); + + // Check cache for each file ID + var uncachedFileIds = new List(); + foreach (var fileId in fileIdList) + { + var cacheKey = $"{CacheKeyPrefix}list:{fileId}"; + var cachedReferences = await cache.GetAsync>(cacheKey); + if (cachedReferences is not null) + { + result[fileId] = cachedReferences; + } + else + { + uncachedFileIds.Add(fileId); + } + } + + // Fetch uncached references from database + if (uncachedFileIds.Any()) + { + var dbReferences = await db.FileReferences + .Where(r => uncachedFileIds.Contains(r.FileId)) + .GroupBy(r => r.FileId) + .ToDictionaryAsync(r => r.Key, r => r.ToList()); + + // Cache the results + foreach (var kvp in dbReferences) + { + var cacheKey = $"{CacheKeyPrefix}list:{kvp.Key}"; + await cache.SetAsync(cacheKey, kvp.Value, CacheDuration); + result[kvp.Key] = kvp.Value; + } + } + + return result; } /// @@ -150,9 +182,19 @@ public class FileReferenceService(AppDatabase db, FileService fileService, ICach /// A list of file references with the specified usage public async Task> GetUsageReferencesAsync(string usage) { - return await db.FileReferences + var cacheKey = $"{CacheKeyPrefix}usage:{usage}"; + + var cachedReferences = await cache.GetAsync>(cacheKey); + if (cachedReferences is not null) + return cachedReferences; + + var references = await db.FileReferences .Where(r => r.Usage == usage) .ToListAsync(); + + await cache.SetAsync(cacheKey, references, CacheDuration); + + return references; } /// @@ -209,8 +251,9 @@ public class FileReferenceService(AppDatabase db, FileService fileService, ICach public async Task DeleteResourceReferencesBatchAsync(IEnumerable resourceIds, string? usage = null) { + var resourceIdList = resourceIds.ToList(); var references = await db.FileReferences - .Where(r => resourceIds.Contains(r.ResourceId)) + .Where(r => resourceIdList.Contains(r.ResourceId)) .If(usage != null, q => q.Where(q => q.Usage == usage)) .ToListAsync(); @@ -222,8 +265,9 @@ public class FileReferenceService(AppDatabase db, FileService fileService, ICach db.FileReferences.RemoveRange(references); var deletedCount = await db.SaveChangesAsync(); - // Purge caches + // Purge caches for files and resources var tasks = fileIds.Select(fileService._PurgeCacheAsync).ToList(); + tasks.AddRange(resourceIdList.Select(PurgeCacheForResourceAsync)); await Task.WhenAll(tasks); return deletedCount; @@ -473,4 +517,4 @@ public class FileReferenceService(AppDatabase db, FileService fileService, ICach return await SetReferenceExpirationAsync(referenceId, expiredAt); } -} \ No newline at end of file +} diff --git a/DysonNetwork.Drive/Storage/FileService.cs b/DysonNetwork.Drive/Storage/FileService.cs index 901f782..583a866 100644 --- a/DysonNetwork.Drive/Storage/FileService.cs +++ b/DysonNetwork.Drive/Storage/FileService.cs @@ -99,30 +99,75 @@ public class FileService( ) { var accountId = Guid.Parse(account.Id); + var pool = await ValidateAndGetPoolAsync(filePool); + var bundle = await ValidateAndGetBundleAsync(fileBundleId, accountId); + var finalExpiredAt = CalculateFinalExpiration(expiredAt, pool, bundle); + var (managedTempPath, fileSize, finalContentType) = await PrepareFileAsync(fileId, filePath, fileName, contentType); + + var file = CreateFileObject(fileId, fileName, finalContentType, fileSize, finalExpiredAt, bundle, accountId); + + if (!pool.PolicyConfig.NoMetadata) + { + await ExtractMetadataAsync(file, managedTempPath); + } + + var (processingPath, isTempFile) = await ProcessEncryptionAsync(fileId, managedTempPath, encryptPassword, pool, file); + + file.Hash = await HashFileAsync(processingPath); + + await SaveFileToDatabaseAsync(file); + + await PublishFileUploadedEventAsync(file, pool, processingPath, isTempFile); + + return file; + } + + private async Task ValidateAndGetPoolAsync(string filePool) + { var pool = await GetPoolAsync(Guid.Parse(filePool)); if (pool is null) throw new InvalidOperationException("Pool not found"); + return pool; + } + private async Task ValidateAndGetBundleAsync(string? fileBundleId, Guid accountId) + { + if (fileBundleId is null) return null; + + var bundle = await GetBundleAsync(Guid.Parse(fileBundleId), accountId); + if (bundle is null) throw new InvalidOperationException("Bundle not found"); + + return bundle; + } + + private Instant? CalculateFinalExpiration(Instant? expiredAt, FilePool pool, SnFileBundle? bundle) + { + var finalExpiredAt = expiredAt; + + // Apply pool expiration policy if (pool.StorageConfig.Expiration is not null && expiredAt.HasValue) { var expectedExpiration = SystemClock.Instance.GetCurrentInstant() - expiredAt.Value; var effectiveExpiration = pool.StorageConfig.Expiration < expectedExpiration ? pool.StorageConfig.Expiration : expectedExpiration; - expiredAt = SystemClock.Instance.GetCurrentInstant() + effectiveExpiration; - } - - var bundle = fileBundleId is not null - ? await GetBundleAsync(Guid.Parse(fileBundleId), accountId) - : null; - if (fileBundleId is not null && bundle is null) - { - throw new InvalidOperationException("Bundle not found"); + finalExpiredAt = SystemClock.Instance.GetCurrentInstant() + effectiveExpiration; } + // Bundle expiration takes precedence if (bundle?.ExpiredAt != null) - expiredAt = bundle.ExpiredAt.Value; + finalExpiredAt = bundle.ExpiredAt.Value; + return finalExpiredAt; + } + + private async Task<(string tempPath, long fileSize, string contentType)> PrepareFileAsync( + string fileId, + string filePath, + string fileName, + string? contentType + ) + { var managedTempPath = Path.Combine(Path.GetTempPath(), fileId); File.Copy(filePath, managedTempPath, true); @@ -131,49 +176,66 @@ public class FileService( var finalContentType = contentType ?? (!fileName.Contains('.') ? "application/octet-stream" : MimeTypes.GetMimeType(fileName)); - var file = new SnCloudFile + return (managedTempPath, fileSize, finalContentType); + } + + private SnCloudFile CreateFileObject( + string fileId, + string fileName, + string contentType, + long fileSize, + Instant? expiredAt, + SnFileBundle? bundle, + Guid accountId + ) + { + return new SnCloudFile { Id = fileId, Name = fileName, - MimeType = finalContentType, + MimeType = contentType, Size = fileSize, ExpiredAt = expiredAt, BundleId = bundle?.Id, - AccountId = Guid.Parse(account.Id), + AccountId = accountId, }; + } - if (!pool.PolicyConfig.NoMetadata) - { - await ExtractMetadataAsync(file, managedTempPath); - } + private async Task<(string processingPath, bool isTempFile)> ProcessEncryptionAsync( + string fileId, + string managedTempPath, + string? encryptPassword, + FilePool pool, + SnCloudFile file + ) + { + if (string.IsNullOrWhiteSpace(encryptPassword)) + return (managedTempPath, true); - string processingPath = managedTempPath; - bool isTempFile = true; + if (!pool.PolicyConfig.AllowEncryption) + throw new InvalidOperationException("Encryption is not allowed in this pool"); - if (!string.IsNullOrWhiteSpace(encryptPassword)) - { - if (!pool.PolicyConfig.AllowEncryption) - throw new InvalidOperationException("Encryption is not allowed in this pool"); + var encryptedPath = Path.Combine(Path.GetTempPath(), $"{fileId}.encrypted"); + FileEncryptor.EncryptFile(managedTempPath, encryptedPath, encryptPassword); - var encryptedPath = Path.Combine(Path.GetTempPath(), $"{fileId}.encrypted"); - FileEncryptor.EncryptFile(managedTempPath, encryptedPath, encryptPassword); + File.Delete(managedTempPath); - File.Delete(managedTempPath); + file.IsEncrypted = true; + file.MimeType = "application/octet-stream"; + file.Size = new FileInfo(encryptedPath).Length; - processingPath = encryptedPath; - - file.IsEncrypted = true; - file.MimeType = "application/octet-stream"; - file.Size = new FileInfo(processingPath).Length; - } - - file.Hash = await HashFileAsync(processingPath); + return (encryptedPath, true); + } + private async Task SaveFileToDatabaseAsync(SnCloudFile file) + { db.Files.Add(file); await db.SaveChangesAsync(); - file.StorageId ??= file.Id; + } + private async Task PublishFileUploadedEventAsync(SnCloudFile file, FilePool pool, string processingPath, bool isTempFile) + { var js = nats.CreateJetStreamContext(); await js.PublishAsync( FileUploadedEvent.Type, @@ -186,8 +248,6 @@ public class FileService( isTempFile) ).ToByteArray() ); - - return file; } private async Task ExtractMetadataAsync(SnCloudFile file, string filePath) @@ -724,4 +784,4 @@ file class UpdatableCloudFile(SnCloudFile file) .SetProperty(f => f.UserMeta, userMeta) .SetProperty(f => f.IsMarkedRecycle, IsMarkedRecycle); } -} \ No newline at end of file +} diff --git a/DysonNetwork.Drive/Storage/FileUploadController.cs b/DysonNetwork.Drive/Storage/FileUploadController.cs index 367c418..157b933 100644 --- a/DysonNetwork.Drive/Storage/FileUploadController.cs +++ b/DysonNetwork.Drive/Storage/FileUploadController.cs @@ -4,6 +4,7 @@ using DysonNetwork.Drive.Billing; using DysonNetwork.Drive.Storage.Model; using DysonNetwork.Shared.Auth; using DysonNetwork.Shared.Http; +using DysonNetwork.Shared.Models; using DysonNetwork.Shared.Proto; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Mvc; @@ -32,46 +33,82 @@ public class FileUploadController( [HttpPost("create")] public async Task CreateUploadTask([FromBody] CreateUploadTaskRequest request) { - if (HttpContext.Items["CurrentUser"] is not Account currentUser) - { + var currentUser = HttpContext.Items["CurrentUser"] as Account; + if (currentUser is null) return new ObjectResult(ApiError.Unauthorized()) { StatusCode = 401 }; - } - if (!currentUser.IsSuperuser) - { - var allowed = await permission.HasPermissionAsync(new HasPermissionRequest - { Actor = $"user:{currentUser.Id}", Area = "global", Key = "files.create" }); - if (!allowed.HasPermission) - { - return new ObjectResult(ApiError.Unauthorized(forbidden: true)) { StatusCode = 403 }; - } - } + var permissionCheck = await ValidateUserPermissions(currentUser); + if (permissionCheck is not null) return permissionCheck; request.PoolId ??= Guid.Parse(configuration["Storage:PreferredRemote"]!); var pool = await fileService.GetPoolAsync(request.PoolId.Value); if (pool is null) - { return new ObjectResult(ApiError.NotFound("Pool")) { StatusCode = 404 }; - } - if (pool.PolicyConfig.RequirePrivilege is > 0) + var poolValidation = await ValidatePoolAccess(currentUser, pool, request); + if (poolValidation is not null) return poolValidation; + + var policyValidation = ValidatePoolPolicy(pool.PolicyConfig, request); + if (policyValidation is not null) return policyValidation; + + var quotaValidation = await ValidateQuota(currentUser, pool, request.FileSize); + if (quotaValidation is not null) return quotaValidation; + + EnsureTempDirectoryExists(); + + // Check if a file with the same hash already exists + var existingFile = await db.Files.FirstOrDefaultAsync(f => f.Hash == request.Hash); + if (existingFile != null) { - var privilege = - currentUser.PerkSubscription is null ? 0 : - PerkSubscriptionPrivilege.GetPrivilegeFromIdentifier(currentUser.PerkSubscription.Identifier); - if (privilege < pool.PolicyConfig.RequirePrivilege) + return Ok(new CreateUploadTaskResponse { - return new ObjectResult(ApiError.Unauthorized( - $"You need Stellar Program tier {pool.PolicyConfig.RequirePrivilege} to use pool {pool.Name}, you are tier {privilege}", - forbidden: true)) - { - StatusCode = 403 - }; - } + FileExists = true, + File = existingFile + }); } - var policy = pool.PolicyConfig; + var (taskId, task) = await CreateUploadTaskInternal(request); + return Ok(new CreateUploadTaskResponse + { + FileExists = false, + TaskId = taskId, + ChunkSize = task.ChunkSize, + ChunksCount = task.ChunksCount + }); + } + + private async Task ValidateUserPermissions(Account currentUser) + { + if (currentUser.IsSuperuser) return null; + + var allowed = await permission.HasPermissionAsync(new HasPermissionRequest + { Actor = $"user:{currentUser.Id}", Area = "global", Key = "files.create" }); + + return allowed.HasPermission ? null : + new ObjectResult(ApiError.Unauthorized(forbidden: true)) { StatusCode = 403 }; + } + + private async Task ValidatePoolAccess(Account currentUser, FilePool pool, CreateUploadTaskRequest request) + { + if (pool.PolicyConfig.RequirePrivilege <= 0) return null; + + var privilege = currentUser.PerkSubscription is null ? 0 : + PerkSubscriptionPrivilege.GetPrivilegeFromIdentifier(currentUser.PerkSubscription.Identifier); + + if (privilege < pool.PolicyConfig.RequirePrivilege) + { + return new ObjectResult(ApiError.Unauthorized( + $"You need Stellar Program tier {pool.PolicyConfig.RequirePrivilege} to use pool {pool.Name}, you are tier {privilege}", + forbidden: true)) + { StatusCode = 403 }; + } + + return null; + } + + private IActionResult? ValidatePoolPolicy(PolicyConfig policy, CreateUploadTaskRequest request) + { if (!policy.AllowEncryption && !string.IsNullOrEmpty(request.EncryptPassword)) { return new ObjectResult(ApiError.Unauthorized("File encryption is not allowed in this pool", true)) @@ -103,8 +140,7 @@ public class FileUploadController( if (!foundMatch) { return new ObjectResult( - ApiError.Unauthorized($"Content type {request.ContentType} is not allowed by the pool's policy", - true)) + ApiError.Unauthorized($"Content type {request.ContentType} is not allowed by the pool's policy", true)) { StatusCode = 403 }; } } @@ -112,42 +148,41 @@ public class FileUploadController( if (policy.MaxFileSize is not null && request.FileSize > policy.MaxFileSize) { return new ObjectResult(ApiError.Unauthorized( - $"File size {request.FileSize} is larger than the pool's maximum file size {policy.MaxFileSize}", - true)) - { - StatusCode = 403 - }; - } - - var (ok, billableUnit, quota) = await quotaService.IsFileAcceptable( - Guid.Parse(currentUser.Id), - pool.BillingConfig.CostMultiplier ?? 1.0, - request.FileSize - ); - if (!ok) - { - return new ObjectResult( - ApiError.Unauthorized($"File size {billableUnit} MiB is exceeded the user's quota {quota} MiB", - true)) + $"File size {request.FileSize} is larger than the pool's maximum file size {policy.MaxFileSize}", true)) { StatusCode = 403 }; } + return null; + } + + private async Task ValidateQuota(Account currentUser, FilePool pool, long fileSize) + { + var (ok, billableUnit, quota) = await quotaService.IsFileAcceptable( + Guid.Parse(currentUser.Id), + pool.BillingConfig.CostMultiplier ?? 1.0, + fileSize + ); + + if (!ok) + { + return new ObjectResult( + ApiError.Unauthorized($"File size {billableUnit} MiB is exceeded the user's quota {quota} MiB", true)) + { StatusCode = 403 }; + } + + return null; + } + + private void EnsureTempDirectoryExists() + { if (!Directory.Exists(_tempPath)) { Directory.CreateDirectory(_tempPath); } + } - // Check if a file with the same hash already exists - var existingFile = await db.Files.FirstOrDefaultAsync(f => f.Hash == request.Hash); - if (existingFile != null) - { - return Ok(new CreateUploadTaskResponse - { - FileExists = true, - File = existingFile - }); - } - + private async Task<(string taskId, UploadTask task)> CreateUploadTaskInternal(CreateUploadTaskRequest request) + { var taskId = await Nanoid.GenerateAsync(); var taskPath = Path.Combine(_tempPath, taskId); Directory.CreateDirectory(taskPath); @@ -171,14 +206,7 @@ public class FileUploadController( }; await System.IO.File.WriteAllTextAsync(Path.Combine(taskPath, "task.json"), JsonSerializer.Serialize(task)); - - return Ok(new CreateUploadTaskResponse - { - FileExists = false, - TaskId = taskId, - ChunkSize = chunkSize, - ChunksCount = chunksCount - }); + return (taskId, task); } public class UploadChunkRequest @@ -211,68 +239,91 @@ public class FileUploadController( { var taskPath = Path.Combine(_tempPath, taskId); if (!Directory.Exists(taskPath)) - { return new ObjectResult(ApiError.NotFound("Upload task")) { StatusCode = 404 }; - } var taskJsonPath = Path.Combine(taskPath, "task.json"); if (!System.IO.File.Exists(taskJsonPath)) - { return new ObjectResult(ApiError.NotFound("Upload task metadata")) { StatusCode = 404 }; - } var task = JsonSerializer.Deserialize(await System.IO.File.ReadAllTextAsync(taskJsonPath)); if (task == null) - { return new ObjectResult(new ApiError { Code = "BAD_REQUEST", Message = "Invalid task metadata.", Status = 400 }) { StatusCode = 400 }; - } + + var currentUser = HttpContext.Items["CurrentUser"] as Account; + if (currentUser is null) + return new ObjectResult(ApiError.Unauthorized()) { StatusCode = 401 }; var mergedFilePath = Path.Combine(_tempPath, taskId + ".tmp"); - await using (var mergedStream = new FileStream(mergedFilePath, FileMode.Create)) + + try { - for (var i = 0; i < task.ChunksCount; i++) + await MergeChunks(taskPath, mergedFilePath, task.ChunksCount); + + var fileId = await Nanoid.GenerateAsync(); + var cloudFile = await fileService.ProcessNewFileAsync( + currentUser, + fileId, + task.PoolId.ToString(), + task.BundleId?.ToString(), + mergedFilePath, + task.FileName, + task.ContentType, + task.EncryptPassword, + task.ExpiredAt + ); + + return Ok(cloudFile); + } + catch (Exception) + { + // Log the error and clean up + // (Assuming you have a logger - you might want to inject ILogger) + await CleanupTempFiles(taskPath, mergedFilePath); + return new ObjectResult(new ApiError { - var chunkPath = Path.Combine(taskPath, $"{i}.chunk"); - if (!System.IO.File.Exists(chunkPath)) - { - // Clean up partially uploaded file - mergedStream.Close(); - System.IO.File.Delete(mergedFilePath); - Directory.Delete(taskPath, true); - return new ObjectResult(new ApiError - { Code = "CHUNK_MISSING", Message = $"Chunk {i} is missing.", Status = 400 }) - { StatusCode = 400 }; - } - - await using var chunkStream = new FileStream(chunkPath, FileMode.Open); - await chunkStream.CopyToAsync(mergedStream); - } + Code = "UPLOAD_FAILED", + Message = "Failed to complete file upload.", + Status = 500 + }) { StatusCode = 500 }; } - - if (HttpContext.Items["CurrentUser"] is not Account currentUser) + finally { - return new ObjectResult(ApiError.Unauthorized()) { StatusCode = 401 }; + // Always clean up temp files + await CleanupTempFiles(taskPath, mergedFilePath); } + } - var fileId = await Nanoid.GenerateAsync(); + private async Task MergeChunks(string taskPath, string mergedFilePath, int chunksCount) + { + await using var mergedStream = new FileStream(mergedFilePath, FileMode.Create); - var cloudFile = await fileService.ProcessNewFileAsync( - currentUser, - fileId, - task.PoolId.ToString(), - task.BundleId?.ToString(), - mergedFilePath, - task.FileName, - task.ContentType, - task.EncryptPassword, - task.ExpiredAt - ); + for (var i = 0; i < chunksCount; i++) + { + var chunkPath = Path.Combine(taskPath, $"{i}.chunk"); + if (!System.IO.File.Exists(chunkPath)) + { + throw new InvalidOperationException($"Chunk {i} is missing."); + } - // Clean up - Directory.Delete(taskPath, true); - System.IO.File.Delete(mergedFilePath); + await using var chunkStream = new FileStream(chunkPath, FileMode.Open); + await chunkStream.CopyToAsync(mergedStream); + } + } - return Ok(cloudFile); + private async Task CleanupTempFiles(string taskPath, string mergedFilePath) + { + try + { + if (Directory.Exists(taskPath)) + Directory.Delete(taskPath, true); + + if (System.IO.File.Exists(mergedFilePath)) + System.IO.File.Delete(mergedFilePath); + } + catch + { + // Ignore cleanup errors to avoid masking the original exception + } } }