diff --git a/src/PowderCoating.Application/Interfaces/IStripeService.cs b/src/PowderCoating.Application/Interfaces/IStripeService.cs index a4d52b1..a8a21a6 100644 --- a/src/PowderCoating.Application/Interfaces/IStripeService.cs +++ b/src/PowderCoating.Application/Interfaces/IStripeService.cs @@ -4,6 +4,7 @@ public interface IStripeService { Task CreateCheckoutSessionAsync(int companyId, int newPlan, bool isAnnual, string successUrl, string cancelUrl); Task CreateRegistrationCheckoutSessionAsync(int plan, bool isAnnual, string email, string companyName, string successUrl, string cancelUrl); + Task IsRegistrationCheckoutPaidAsync(string sessionId); Task FulfillCheckoutAsync(string sessionId); Task FulfillRegistrationCheckoutAsync(string sessionId, int companyId, int plan); Task SyncSubscriptionAsync(int companyId); diff --git a/src/PowderCoating.Infrastructure/Services/StripeService.cs b/src/PowderCoating.Infrastructure/Services/StripeService.cs index ac37789..1557f42 100644 --- a/src/PowderCoating.Infrastructure/Services/StripeService.cs +++ b/src/PowderCoating.Infrastructure/Services/StripeService.cs @@ -195,6 +195,49 @@ public class StripeService : IStripeService return session.Url; } + /// + /// Verifies that the supplied Stripe Checkout session belongs to the registration flow and has + /// reached the paid/complete state. Returns false for any missing/invalid/unpaid session + /// so the caller can safely stop before creating any local company or user records. + /// + public async Task IsRegistrationCheckoutPaidAsync(string sessionId) + { + try + { + var sessionService = new SessionService(); + var session = await sessionService.GetAsync(sessionId); + + if (!session.Metadata.TryGetValue("registration", out var isRegistration) || + !string.Equals(isRegistration, "true", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogWarning( + "Registration checkout validation failed for session {SessionId}: missing registration metadata", + sessionId); + return false; + } + + var isPaidAndComplete = session.PaymentStatus == "paid" && session.Status == "complete"; + if (!isPaidAndComplete) + { + _logger.LogWarning( + "Registration checkout validation failed for session {SessionId}: paymentStatus={PaymentStatus}, status={Status}", + sessionId, session.PaymentStatus, session.Status); + } + + return isPaidAndComplete; + } + catch (StripeException ex) + { + _logger.LogWarning(ex, "Stripe rejected registration checkout validation for session {SessionId}", sessionId); + return false; + } + catch (Exception ex) + { + _logger.LogError(ex, "Unexpected error validating registration checkout session {SessionId}", sessionId); + return false; + } + } + /// /// Finalizes a registration checkout after payment is confirmed. Called from the /// registration success redirect (not the webhook path). Retrieves the session from Stripe diff --git a/src/PowderCoating.Web/Controllers/RegistrationController.cs b/src/PowderCoating.Web/Controllers/RegistrationController.cs index 33f2667..f78fd6f 100644 --- a/src/PowderCoating.Web/Controllers/RegistrationController.cs +++ b/src/PowderCoating.Web/Controllers/RegistrationController.cs @@ -262,11 +262,10 @@ public class RegistrationController : Controller /// Stripe return URL handler called after the customer completes payment in Stripe Checkout. /// Looks up the PendingRegistrationSession by the opaque reg_token to recover the /// registration data that could not be stored in a session cookie (which may not survive the - /// Stripe redirect on some browsers/devices). Marks the session as completed before creating the - /// account to prevent duplicate submissions if the user hits reload. Contains a second open-cap - /// check in case the tenant limit was reached between the user starting and completing checkout. - /// Calls FulfillRegistrationCheckoutAsync to link the Stripe subscription (sets the real - /// SubscriptionEndDate); failure is non-fatal — dates can be corrected manually. + /// Stripe redirect on some browsers/devices). The Stripe session is re-validated before any + /// local company or user is created, and the pending session is only left marked completed once + /// registration succeeds. On recoverable failures the completion flag is released so the same + /// success URL can be retried instead of forcing manual support intervention. /// [HttpGet] public async Task PaymentSuccess(string? session_id, string? reg_token) @@ -281,99 +280,123 @@ public class RegistrationController : Controller } var pendingSession = await _db.PendingRegistrationSessions - .FirstOrDefaultAsync(p => p.Token == reg_token && !p.IsCompleted); + .FirstOrDefaultAsync(p => p.Token == reg_token); if (pendingSession == null) { - TempData["Error"] = "Your registration session was not found or has already been completed. Please fill in your details again."; + TempData["Error"] = "Your registration session was not found. Please fill in your details again."; return RedirectToAction(nameof(Index)); } - // Map DB record to the internal record used below var pending = new PendingRegistration( pendingSession.CompanyName, pendingSession.CompanyPhone, pendingSession.FirstName, pendingSession.LastName, pendingSession.Email, pendingSession.Plan, pendingSession.IsAnnual); - // Mark completed to prevent duplicate submissions + var existingUser = await _userManager.FindByEmailAsync(pending.Email); + + if (pendingSession.IsCompleted) + { + if (existingUser != null) + { + await SignInExistingRegistrationUserAsync(existingUser); + return RedirectToAction(nameof(Welcome)); + } + + TempData["Error"] = "Your registration was already submitted, but we couldn't finish signing you in. Please contact support with reference: " + session_id; + return RedirectToAction(nameof(Index)); + } + + if (!await _stripeService.IsRegistrationCheckoutPaidAsync(session_id)) + { + TempData["Error"] = "We couldn't verify a completed payment for this registration session yet. Please try the link again in a moment, or contact support if the issue persists."; + return RedirectToAction(nameof(Index)); + } + + var keepSessionCompleted = false; pendingSession.IsCompleted = true; await _db.SaveChangesAsync(); - // Guard against race condition: re-check capacity after Stripe redirect - if (!await IsRegistrationOpenAsync()) - { - TempData["Error"] = "Registration is currently closed. Your payment has been received but no account was created. Please contact support."; - return RedirectToAction(nameof(Index)); - } - - // Guard against race condition (duplicate submission) - if (await _userManager.FindByEmailAsync(pending.Email) != null) - { - // Account already exists — just sign them in and go to dashboard - var existingUser = await _userManager.FindByEmailAsync(pending.Email); - if (existingUser != null) - { - existingUser.LastLoginDate = DateTime.UtcNow; - await _userManager.UpdateAsync(existingUser); - await _signInManager.SignInAsync(existingUser, isPersistent: false); - } - return RedirectToAction("Index", "Dashboard"); - } - - var companyCode = await GenerateUniqueCompanyCodeAsync(pending.CompanyName); - var company = new Company - { - CompanyName = pending.CompanyName, - CompanyCode = companyCode, - Phone = pending.CompanyPhone, - PrimaryContactEmail = pending.Email, - PrimaryContactName = $"{pending.FirstName} {pending.LastName}", - SubscriptionPlan = pending.Plan, - SubscriptionStatus = SubscriptionStatus.Active, - SubscriptionStartDate = DateTime.UtcNow, - SubscriptionEndDate = DateTime.UtcNow.AddDays(1), // Stripe fulfillment sets real date below - IsActive = true, - CreatedAt = DateTime.UtcNow - }; - - await _unitOfWork.Companies.AddAsync(company); - await _unitOfWork.CompleteAsync(); - - var tempPassword = GenerateTemporaryPassword(); - var user = BuildUser(pending.Email, pending.FirstName, pending.LastName, company.Id); - - var createResult = await _userManager.CreateAsync(user, tempPassword); - if (!createResult.Succeeded) - { - await _unitOfWork.Companies.DeleteAsync(company); - await _unitOfWork.CompleteAsync(); - - _logger.LogError("Failed to create user after payment for {Email}: {Errors}", - pending.Email, string.Join(", ", createResult.Errors.Select(e => e.Description))); - - TempData["Error"] = "Your payment was received but we encountered an error creating your account. " + - "Please contact support with reference: " + session_id; - return RedirectToAction(nameof(Index)); - } - - // Link the Stripe subscription (sets real SubscriptionEndDate) try { - await _stripeService.FulfillRegistrationCheckoutAsync(session_id, company.Id, pending.Plan); + // Guard against race condition: re-check capacity after Stripe redirect + if (!await IsRegistrationOpenAsync()) + { + TempData["Error"] = "Registration is currently closed. Your payment has been received but no account was created. Please contact support."; + return RedirectToAction(nameof(Index)); + } + + // Recover gracefully if a prior attempt already created the user. + if (existingUser != null) + { + keepSessionCompleted = true; + await SignInExistingRegistrationUserAsync(existingUser); + return RedirectToAction(nameof(Welcome)); + } + + var companyCode = await GenerateUniqueCompanyCodeAsync(pending.CompanyName); + var company = new Company + { + CompanyName = pending.CompanyName, + CompanyCode = companyCode, + Phone = pending.CompanyPhone, + PrimaryContactEmail = pending.Email, + PrimaryContactName = $"{pending.FirstName} {pending.LastName}", + SubscriptionPlan = pending.Plan, + SubscriptionStatus = SubscriptionStatus.Active, + SubscriptionStartDate = DateTime.UtcNow, + SubscriptionEndDate = DateTime.UtcNow.AddDays(1), // Stripe fulfillment sets real date below + IsActive = true, + CreatedAt = DateTime.UtcNow + }; + + await _unitOfWork.Companies.AddAsync(company); + await _unitOfWork.CompleteAsync(); + + var tempPassword = GenerateTemporaryPassword(); + var user = BuildUser(pending.Email, pending.FirstName, pending.LastName, company.Id); + + var createResult = await _userManager.CreateAsync(user, tempPassword); + if (!createResult.Succeeded) + { + await _unitOfWork.Companies.DeleteAsync(company); + await _unitOfWork.CompleteAsync(); + + _logger.LogError("Failed to create user after payment for {Email}: {Errors}", + pending.Email, string.Join(", ", createResult.Errors.Select(e => e.Description))); + + TempData["Error"] = "Your payment was received, but we couldn't finish creating your account. Please try the success link again, or contact support with reference: " + session_id; + return RedirectToAction(nameof(Index)); + } + + // Link the Stripe subscription (sets real SubscriptionEndDate) + try + { + await _stripeService.FulfillRegistrationCheckoutAsync(session_id, company.Id, pending.Plan); + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to fulfill registration checkout {SessionId} for company {CompanyId}", + session_id, company.Id); + // Non-fatal — subscription dates can be synced manually later + } + + await _userManager.AddClaimAsync(user, new Claim("MustChangePassword", "true")); + await FinalizeRegistrationAsync(user, company, pending.Plan); + _ = SendWelcomeEmailAsync(pending.Email, pending.FirstName, tempPassword, pending.Plan, + null, $"{Request.Scheme}://{Request.Host}"); + + keepSessionCompleted = true; + return RedirectToAction(nameof(Welcome)); } - catch (Exception ex) + finally { - _logger.LogError(ex, "Failed to fulfill registration checkout {SessionId} for company {CompanyId}", - session_id, company.Id); - // Non-fatal — subscription dates can be synced manually later + if (!keepSessionCompleted) + { + pendingSession.IsCompleted = false; + await _db.SaveChangesAsync(); + } } - - await _userManager.AddClaimAsync(user, new Claim("MustChangePassword", "true")); - await FinalizeRegistrationAsync(user, company, pending.Plan); - _ = SendWelcomeEmailAsync(pending.Email, pending.FirstName, tempPassword, pending.Plan, - null, $"{Request.Scheme}://{Request.Host}"); - - return RedirectToAction(nameof(Welcome)); } /// @@ -480,6 +503,18 @@ public class RegistrationController : Controller await _signInManager.SignInAsync(user, isPersistent: false); } + /// + /// Signs in an already-created registration user on idempotent success-link retries. + /// This is intentionally narrower than because all + /// registration side effects should already have happened on the original successful run. + /// + private async Task SignInExistingRegistrationUserAsync(ApplicationUser user) + { + user.LastLoginDate = DateTime.UtcNow; + await _userManager.UpdateAsync(user); + await _signInManager.SignInAsync(user, isPersistent: false); + } + /// /// Renders the post-registration welcome page. Requires authentication (the user was just signed /// in by ). Determines whether the account is on a free diff --git a/tests/PowderCoating.UnitTests/GlobalUsings.cs b/tests/PowderCoating.UnitTests/GlobalUsings.cs new file mode 100644 index 0000000..c802f44 --- /dev/null +++ b/tests/PowderCoating.UnitTests/GlobalUsings.cs @@ -0,0 +1 @@ +global using Xunit; diff --git a/tests/PowderCoating.UnitTests/PowderCoating.UnitTests.csproj b/tests/PowderCoating.UnitTests/PowderCoating.UnitTests.csproj index 5c870ca..0e1e74f 100644 --- a/tests/PowderCoating.UnitTests/PowderCoating.UnitTests.csproj +++ b/tests/PowderCoating.UnitTests/PowderCoating.UnitTests.csproj @@ -9,6 +9,7 @@ + @@ -26,6 +27,7 @@ + diff --git a/tests/PowderCoating.UnitTests/PricingCalculationServiceTests.cs b/tests/PowderCoating.UnitTests/PricingCalculationServiceTests.cs new file mode 100644 index 0000000..8866f13 --- /dev/null +++ b/tests/PowderCoating.UnitTests/PricingCalculationServiceTests.cs @@ -0,0 +1,244 @@ +using Microsoft.Extensions.Logging; +using Moq; +using PowderCoating.Application.DTOs.Quote; +using PowderCoating.Application.Services; +using PowderCoating.Core.Entities; +using PowderCoating.Core.Interfaces; +using Xunit; + +namespace PowderCoating.UnitTests; + +public class PricingCalculationServiceTests +{ + [Fact] + public async Task CalculateCoatPriceAsync_CustomPowder_ChargesFullOrderQuantity() + { + var costs = CreateOperatingCosts(); + var unitOfWork = CreateUnitOfWorkMock(costs); + var tenantContext = new Mock(); + tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false); + + var service = new PricingCalculationService( + unitOfWork.Object, + Mock.Of>(), + new MeasurementConversionService(), + tenantContext.Object); + + var coat = new CreateQuoteItemCoatDto + { + CoatName = "Custom Red", + PowderCostPerLb = 10m, + PowderToOrder = 3m, + CoverageSqFtPerLb = 30m, + TransferEfficiency = 65m + }; + + var result = await service.CalculateCoatPriceAsync( + coat, + itemSurfaceAreaSqFt: 5m, + quantity: 2m, + coatIndex: 0, + estimatedMinutesBase: 15, + companyId: 1); + + Assert.Equal(30m, result.CoatMaterialCost); + Assert.Equal(30m, result.CoatLaborCost); + Assert.Equal(60m, result.CoatTotalCost); + } + + [Fact] + public async Task CalculateQuoteItemPriceAsync_LaborItem_UsesStandardLaborRate() + { + var costs = CreateOperatingCosts(); + costs.StandardLaborRate = 80m; + + var unitOfWork = CreateUnitOfWorkMock(costs); + var tenantContext = new Mock(); + tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false); + + var service = new PricingCalculationService( + unitOfWork.Object, + Mock.Of>(), + new MeasurementConversionService(), + tenantContext.Object); + + var item = new CreateQuoteItemDto + { + Description = "Shop labor", + IsLaborItem = true, + Quantity = 2.5m + }; + + var result = await service.CalculateQuoteItemPriceAsync(item, companyId: 1); + + Assert.Equal(0m, result.MaterialCost); + Assert.Equal(200m, result.LaborCost); + Assert.Equal(80m, result.UnitPrice); + Assert.Equal(200m, result.TotalPrice); + } + + [Fact] + public async Task CalculateQuoteItemPriceAsync_AiItem_UsesManualUnitPriceWithoutAdditionalCosts() + { + var unitOfWork = CreateUnitOfWorkMock(CreateOperatingCosts()); + var tenantContext = new Mock(); + tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false); + + var service = new PricingCalculationService( + unitOfWork.Object, + Mock.Of>(), + new MeasurementConversionService(), + tenantContext.Object); + + var item = new CreateQuoteItemDto + { + Description = "AI wheel estimate", + IsAiItem = true, + ManualUnitPrice = 123m, + Quantity = 2m + }; + + var result = await service.CalculateQuoteItemPriceAsync(item, companyId: 1); + + Assert.Equal(0m, result.MaterialCost); + Assert.Equal(0m, result.LaborCost); + Assert.Equal(123m, result.UnitPrice); + Assert.Equal(246m, result.TotalPrice); + } + + [Fact] + public async Task CalculateQuoteTotalsAsync_AppliesTierDiscount_QuoteDiscount_RushFee_AndTax() + { + var costs = CreateOperatingCosts(); + costs.StandardLaborRate = 100m; + costs.ShopSuppliesRate = 10m; + costs.RushChargeType = "Percentage"; + costs.RushChargePercentage = 20m; + costs.TaxPercent = 5m; + costs.OvenOperatingCostPerHour = 0m; + costs.MonthlyRent = 0m; + costs.MonthlyUtilities = 0m; + + var customerRepo = new Mock>(); + customerRepo + .Setup(x => x.FindAsync(It.IsAny>>(), false, It.IsAny>[]>())) + .ReturnsAsync(new[] + { + new Customer { Id = 1, CompanyId = 1, PricingTierId = 10 } + }); + + var pricingTierRepo = new Mock>(); + pricingTierRepo + .Setup(x => x.GetByIdAsync(10, false, It.IsAny>[]>())) + .ReturnsAsync(new PricingTier { Id = 10, CompanyId = 1, DiscountPercent = 10m }); + + var unitOfWork = CreateUnitOfWorkMock(costs); + unitOfWork.SetupGet(x => x.Customers).Returns(customerRepo.Object); + unitOfWork.SetupGet(x => x.PricingTiers).Returns(pricingTierRepo.Object); + + var tenantContext = new Mock(); + tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false); + + var service = new PricingCalculationService( + unitOfWork.Object, + Mock.Of>(), + new MeasurementConversionService(), + tenantContext.Object); + + var items = new List + { + new() + { + Description = "Labor item", + IsLaborItem = true, + Quantity = 2m + } + }; + + var result = await service.CalculateQuoteTotalsAsync( + items, + companyId: 1, + customerId: 1, + discountType: "FixedAmount", + discountValue: 5m, + isRushJob: true); + + Assert.Equal(200m, result.ItemsSubtotal); + Assert.Equal(20m, result.ShopSuppliesAmount); + Assert.Equal(220m, result.SubtotalBeforeDiscount); + Assert.Equal(22m, result.PricingTierDiscountAmount); + Assert.Equal(5m, result.QuoteDiscountAmount); + Assert.Equal(193m, result.SubtotalAfterDiscount); + Assert.Equal(38.6m, result.RushFee); + Assert.Equal(11.58m, result.TaxAmount); + Assert.Equal(243.18m, result.Total); + } + + private static Mock CreateUnitOfWorkMock(CompanyOperatingCosts costs) + { + var unitOfWork = new Mock(); + + var companyOperatingCostsRepo = new Mock>(); + companyOperatingCostsRepo + .Setup(x => x.FindAsync(It.IsAny>>(), false, It.IsAny>[]>())) + .ReturnsAsync(new[] { costs }); + + var inventoryRepo = new Mock>(); + inventoryRepo + .Setup(x => x.GetByIdAsync(It.IsAny(), false, It.IsAny>[]>())) + .ReturnsAsync((InventoryItem?)null); + + var catalogRepo = new Mock>(); + catalogRepo + .Setup(x => x.GetByIdAsync(It.IsAny(), false, It.IsAny>[]>())) + .ReturnsAsync((CatalogItem?)null); + + var customerRepo = new Mock>(); + customerRepo + .Setup(x => x.FindAsync(It.IsAny>>(), false, It.IsAny>[]>())) + .ReturnsAsync(Array.Empty()); + + var pricingTierRepo = new Mock>(); + pricingTierRepo + .Setup(x => x.GetByIdAsync(It.IsAny(), false, It.IsAny>[]>())) + .ReturnsAsync((PricingTier?)null); + + unitOfWork.SetupGet(x => x.CompanyOperatingCosts).Returns(companyOperatingCostsRepo.Object); + unitOfWork.SetupGet(x => x.InventoryItems).Returns(inventoryRepo.Object); + unitOfWork.SetupGet(x => x.CatalogItems).Returns(catalogRepo.Object); + unitOfWork.SetupGet(x => x.Customers).Returns(customerRepo.Object); + unitOfWork.SetupGet(x => x.PricingTiers).Returns(pricingTierRepo.Object); + + return unitOfWork; + } + + private static CompanyOperatingCosts CreateOperatingCosts() + { + return new CompanyOperatingCosts + { + Id = 1, + CompanyId = 1, + StandardLaborRate = 60m, + AdditionalCoatLaborPercent = 50m, + OvenOperatingCostPerHour = 25m, + SandblasterCostPerHour = 20m, + CoatingBoothCostPerHour = 10m, + PowderCoatingCostPerSqFt = 1m, + PricingMode = PowderCoating.Core.Enums.PricingMode.MarkupOnMaterial, + GeneralMarkupPercentage = 20m, + TargetMarginPercent = 40m, + TaxPercent = 5m, + ShopSuppliesRate = 10m, + DefaultOvenCycleMinutes = 60, + RushChargeType = "Percentage", + RushChargePercentage = 15m, + RushChargeFixedAmount = 50m, + ShopMinimumCharge = 0m, + ComplexitySimplePercent = 0m, + ComplexityModeratePercent = 5m, + ComplexityComplexPercent = 15m, + ComplexityExtremePercent = 25m, + MonthlyBillableHours = 160 + }; + } +} diff --git a/tests/PowderCoating.UnitTests/RegistrationControllerTests.cs b/tests/PowderCoating.UnitTests/RegistrationControllerTests.cs new file mode 100644 index 0000000..60cea60 --- /dev/null +++ b/tests/PowderCoating.UnitTests/RegistrationControllerTests.cs @@ -0,0 +1,191 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Identity; +using Microsoft.AspNetCore.Mvc.ViewFeatures; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging; +using Moq; +using PowderCoating.Application.Interfaces; +using PowderCoating.Core.Entities; +using PowderCoating.Infrastructure.Data; +using PowderCoating.Infrastructure.Repositories; +using PowderCoating.Web.Controllers; +using Xunit; + +namespace PowderCoating.UnitTests; + +public class RegistrationControllerTests +{ + [Fact] + public async Task PaymentSuccess_WhenStripeSessionIsNotPaid_DoesNotBurnPendingSession() + { + await using var context = CreateContext(); + context.PendingRegistrationSessions.Add(CreatePendingSession("token-1", "owner@example.com")); + await context.SaveChangesAsync(); + + var stripeService = new Mock(); + stripeService.Setup(x => x.IsRegistrationCheckoutPaidAsync("sess_unpaid")).ReturnsAsync(false); + + var controller = CreateController(context, stripeService: stripeService); + + var result = await controller.PaymentSuccess("sess_unpaid", "token-1"); + + var redirect = Assert.IsType(result); + Assert.Equal("Index", redirect.ActionName); + Assert.False((await context.PendingRegistrationSessions.SingleAsync()).IsCompleted); + Assert.Contains("couldn't verify a completed payment", controller.TempData["Error"]?.ToString()); + } + + [Fact] + public async Task PaymentSuccess_WhenUserCreationFails_ReleasesPendingSessionAndDeletesCompany() + { + await using var context = CreateContext(); + context.PendingRegistrationSessions.Add(CreatePendingSession("token-2", "owner2@example.com")); + await context.SaveChangesAsync(); + + var userManager = CreateUserManagerMock(); + userManager.Setup(x => x.FindByEmailAsync("owner2@example.com")).ReturnsAsync((ApplicationUser?)null); + userManager.Setup(x => x.CreateAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(IdentityResult.Failed(new IdentityError { Description = "boom" })); + + var stripeService = new Mock(); + stripeService.Setup(x => x.IsRegistrationCheckoutPaidAsync("sess_paid")).ReturnsAsync(true); + + var controller = CreateController(context, userManager, stripeService: stripeService); + + var result = await controller.PaymentSuccess("sess_paid", "token-2"); + + var redirect = Assert.IsType(result); + Assert.Equal("Index", redirect.ActionName); + Assert.False((await context.PendingRegistrationSessions.SingleAsync()).IsCompleted); + Assert.Empty(context.Companies); + Assert.Contains("Please try the success link again", controller.TempData["Error"]?.ToString()); + } + + [Fact] + public async Task PaymentSuccess_WhenSessionAlreadyCompletedAndUserExists_SignsUserInAndRedirectsToWelcome() + { + await using var context = CreateContext(); + context.PendingRegistrationSessions.Add(CreatePendingSession("token-3", "owner3@example.com", isCompleted: true)); + await context.SaveChangesAsync(); + + var existingUser = new ApplicationUser + { + Id = "user-3", + Email = "owner3@example.com", + UserName = "owner3@example.com", + FirstName = "Terry", + LastName = "Tenant", + CompanyId = 3 + }; + + var userManager = CreateUserManagerMock(); + userManager.Setup(x => x.FindByEmailAsync("owner3@example.com")).ReturnsAsync(existingUser); + userManager.Setup(x => x.UpdateAsync(existingUser)).ReturnsAsync(IdentityResult.Success); + + var signInManager = CreateSignInManagerMock(userManager.Object); + signInManager.Setup(x => x.SignInAsync(existingUser, false, null)).Returns(Task.CompletedTask).Verifiable(); + + var controller = CreateController(context, userManager, signInManager.Object); + + var result = await controller.PaymentSuccess("sess_complete", "token-3"); + + var redirect = Assert.IsType(result); + Assert.Equal("Welcome", redirect.ActionName); + signInManager.Verify(x => x.SignInAsync(existingUser, false, null), Times.Once); + Assert.True((await context.PendingRegistrationSessions.SingleAsync()).IsCompleted); + } + + private static RegistrationController CreateController( + ApplicationDbContext context, + Mock>? userManager = null, + SignInManager? signInManager = null, + Mock? stripeService = null) + { + var unitOfWork = new UnitOfWork(context); + var userManagerMock = userManager ?? CreateUserManagerMock(); + var signInManagerInstance = signInManager ?? CreateSignInManagerMock(userManagerMock.Object).Object; + + var platformSettings = new Mock(); + platformSettings.Setup(x => x.GetAsync(It.IsAny())).ReturnsAsync((string?)null); + + var controller = new RegistrationController( + unitOfWork, + context, + userManagerMock.Object, + signInManagerInstance, + Mock.Of(), + Mock.Of(), + Mock.Of(), + platformSettings.Object, + (stripeService ?? new Mock()).Object, + Mock.Of(), + Mock.Of>()); + + var httpContext = new DefaultHttpContext(); + controller.ControllerContext = new Microsoft.AspNetCore.Mvc.ControllerContext + { + HttpContext = httpContext + }; + controller.TempData = new TempDataDictionary(httpContext, Mock.Of()); + + return controller; + } + + private static Mock> CreateUserManagerMock() + { + var store = new Mock>(); + return new Mock>( + store.Object, + null!, + null!, + null!, + null!, + null!, + null!, + null!, + null!); + } + + private static Mock> CreateSignInManagerMock(UserManager userManager) + { + var contextAccessor = new Mock(); + contextAccessor.Setup(x => x.HttpContext).Returns(new DefaultHttpContext()); + + var claimsFactory = new Mock>(); + + return new Mock>( + userManager, + contextAccessor.Object, + claimsFactory.Object, + null!, + null!, + null!, + null!); + } + + private static ApplicationDbContext CreateContext() + { + var options = new DbContextOptionsBuilder() + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .Options; + + return new ApplicationDbContext(options); + } + + private static PendingRegistrationSession CreatePendingSession(string token, string email, bool isCompleted = false) + { + return new PendingRegistrationSession + { + Token = token, + CompanyName = "Retry Co", + CompanyPhone = "555-0100", + FirstName = "Pat", + LastName = "Owner", + Email = email, + Plan = 1, + IsAnnual = false, + IsCompleted = isCompleted, + CreatedAt = DateTime.UtcNow + }; + } +} diff --git a/tests/PowderCoating.UnitTests/SubscriptionServiceTests.cs b/tests/PowderCoating.UnitTests/SubscriptionServiceTests.cs new file mode 100644 index 0000000..5e4f3a6 --- /dev/null +++ b/tests/PowderCoating.UnitTests/SubscriptionServiceTests.cs @@ -0,0 +1,221 @@ +using Microsoft.EntityFrameworkCore; +using PowderCoating.Core.Entities; +using PowderCoating.Core.Enums; +using PowderCoating.Infrastructure.Data; +using PowderCoating.Infrastructure.Repositories; +using PowderCoating.Infrastructure.Services; +using Xunit; + +namespace PowderCoating.UnitTests; + +public class SubscriptionServiceTests +{ + [Fact] + public async Task GetUserCountAsync_PrefersCompanyOverrideOverPlanDefault() + { + await using var context = CreateContext(); + SeedCompanyAndPlan(context, companyId: 7, plan: 1, maxUsers: 3); + var company = context.Companies.Local.Single(c => c.Id == 7); + company.MaxUsersOverride = 7; + + context.Users.AddRange( + new ApplicationUser { Id = "u1", CompanyId = 7, UserName = "u1", Email = "u1@example.com", FirstName = "A", LastName = "One", IsActive = true }, + new ApplicationUser { Id = "u2", CompanyId = 7, UserName = "u2", Email = "u2@example.com", FirstName = "B", LastName = "Two", IsActive = true }); + await context.SaveChangesAsync(); + + var service = new SubscriptionService(new UnitOfWork(context), context); + + var (used, max) = await service.GetUserCountAsync(7); + + Assert.Equal(2, used); + Assert.Equal(7, max); + } + + [Fact] + public async Task GetJobCountAsync_ExcludesTerminalStatuses() + { + await using var context = CreateContext(); + SeedCompanyAndPlan(context, companyId: 8, plan: 2, maxActiveJobs: 50); + SeedJobStatuses(context, 8); + context.Jobs.AddRange( + new Job { Id = 1, CompanyId = 8, JobNumber = "JOB-1", CustomerId = 1, Description = "Active", JobStatusId = 1, JobPriorityId = 1 }, + new Job { Id = 2, CompanyId = 8, JobNumber = "JOB-2", CustomerId = 1, Description = "Done", JobStatusId = 2, JobPriorityId = 1 }, + new Job { Id = 3, CompanyId = 8, JobNumber = "JOB-3", CustomerId = 1, Description = "Delivered", JobStatusId = 3, JobPriorityId = 1 }); + await context.SaveChangesAsync(); + + var service = new SubscriptionService(new UnitOfWork(context), context); + + var (used, max) = await service.GetJobCountAsync(8); + + Assert.Equal(1, used); + Assert.Equal(50, max); + } + + [Fact] + public async Task GetQuoteCountAsync_CountsOnlyCurrentMonth() + { + await using var context = CreateContext(); + SeedCompanyAndPlan(context, companyId: 9, plan: 3, maxQuotes: 5); + var currentQuote = new Quote + { + Id = 1, + CompanyId = 9, + QuoteNumber = "Q-001", + QuoteStatusId = 1 + }; + var oldQuote = new Quote + { + Id = 2, + CompanyId = 9, + QuoteNumber = "Q-OLD", + QuoteStatusId = 1 + }; + context.Quotes.AddRange(currentQuote, oldQuote); + await context.SaveChangesAsync(); + + oldQuote.CreatedAt = DateTime.UtcNow.AddMonths(-1); + await context.SaveChangesAsync(); + + var service = new SubscriptionService(new UnitOfWork(context), context); + + var (used, max) = await service.GetQuoteCountAsync(9); + + Assert.Equal(1, used); + Assert.Equal(5, max); + } + + [Fact] + public async Task CanAddCustomerAsync_CompedCompany_BypassesPlanLimits() + { + await using var context = CreateContext(); + SeedCompanyAndPlan(context, companyId: 10, plan: 4, maxCustomers: 0); + var company = await context.Companies.FindAsync(10); + company!.IsComped = true; + context.Customers.Add(new Customer { Id = 1, CompanyId = 10, CompanyName = "Customer A" }); + await context.SaveChangesAsync(); + + var service = new SubscriptionService(new UnitOfWork(context), context); + + var allowed = await service.CanAddCustomerAsync(10); + + Assert.True(allowed); + } + + [Fact] + public async Task CanUseAiPhotoQuoteAsync_RequiresFeatureEnabledAndQuotaAvailable() + { + await using var context = CreateContext(); + SeedCompanyAndPlan(context, companyId: 11, plan: 5, maxAiPhotoQuotesPerMonth: 2, allowAiPhotoQuotes: true); + context.AiItemPredictions.Add(new AiItemPrediction { Id = 1, CompanyId = 11, CreatedAt = DateTime.UtcNow.AddDays(-1) }); + await context.SaveChangesAsync(); + + var service = new SubscriptionService(new UnitOfWork(context), context); + + var allowed = await service.CanUseAiPhotoQuoteAsync(11); + + Assert.True(allowed); + } + + [Fact] + public async Task CanUseAiPhotoQuoteAsync_ReturnsFalse_WhenPlanDisablesFeature() + { + await using var context = CreateContext(); + SeedCompanyAndPlan(context, companyId: 12, plan: 6, maxAiPhotoQuotesPerMonth: 10, allowAiPhotoQuotes: false); + await context.SaveChangesAsync(); + + var service = new SubscriptionService(new UnitOfWork(context), context); + + var allowed = await service.CanUseAiPhotoQuoteAsync(12); + + Assert.False(allowed); + } + + private static ApplicationDbContext CreateContext() + { + var options = new DbContextOptionsBuilder() + .UseInMemoryDatabase(Guid.NewGuid().ToString()) + .Options; + + return new ApplicationDbContext(options); + } + + private static void SeedCompanyAndPlan( + ApplicationDbContext context, + int companyId, + int plan, + int maxUsers = -1, + int maxActiveJobs = -1, + int maxCustomers = -1, + int maxQuotes = -1, + int maxAiPhotoQuotesPerMonth = -1, + bool allowAiPhotoQuotes = true) + { + context.Companies.Add(new Company + { + Id = companyId, + CompanyId = companyId, + CompanyName = $"Company {companyId}", + PrimaryContactName = "Owner", + PrimaryContactEmail = $"owner{companyId}@example.com", + SubscriptionPlan = plan, + SubscriptionStatus = SubscriptionStatus.Active, + IsActive = true + }); + + context.SubscriptionPlanConfigs.Add(new SubscriptionPlanConfig + { + Id = companyId, + CompanyId = 0, + Plan = plan, + DisplayName = $"Plan {plan}", + IsActive = true, + MaxUsers = maxUsers, + MaxActiveJobs = maxActiveJobs, + MaxCustomers = maxCustomers, + MaxQuotes = maxQuotes, + MaxAiPhotoQuotesPerMonth = maxAiPhotoQuotesPerMonth, + AllowAiPhotoQuotes = allowAiPhotoQuotes + }); + + context.JobPriorityLookups.Add(new JobPriorityLookup + { + Id = companyId, + CompanyId = companyId, + PriorityCode = "NORMAL", + DisplayName = "Normal", + DisplayOrder = 1 + }); + } + + private static void SeedJobStatuses(ApplicationDbContext context, int companyId) + { + context.JobStatusLookups.AddRange( + new JobStatusLookup + { + Id = 1, + CompanyId = companyId, + StatusCode = "Pending", + DisplayName = "Pending", + DisplayOrder = 1, + IsTerminalStatus = false + }, + new JobStatusLookup + { + Id = 2, + CompanyId = companyId, + StatusCode = "Completed", + DisplayName = "Completed", + DisplayOrder = 2, + IsTerminalStatus = true + }, + new JobStatusLookup + { + Id = 3, + CompanyId = companyId, + StatusCode = "Delivered", + DisplayName = "Delivered", + DisplayOrder = 3, + IsTerminalStatus = true + }); + } +}