Harden paid registration flow and add unit tests

This commit is contained in:
2026-04-24 21:10:28 -04:00
parent 4153acf3aa
commit 27ac793f62
8 changed files with 817 additions and 79 deletions
@@ -4,6 +4,7 @@ public interface IStripeService
{
Task<string> CreateCheckoutSessionAsync(int companyId, int newPlan, bool isAnnual, string successUrl, string cancelUrl);
Task<string> CreateRegistrationCheckoutSessionAsync(int plan, bool isAnnual, string email, string companyName, string successUrl, string cancelUrl);
Task<bool> IsRegistrationCheckoutPaidAsync(string sessionId);
Task FulfillCheckoutAsync(string sessionId);
Task FulfillRegistrationCheckoutAsync(string sessionId, int companyId, int plan);
Task SyncSubscriptionAsync(int companyId);
@@ -195,6 +195,49 @@ public class StripeService : IStripeService
return session.Url;
}
/// <summary>
/// Verifies that the supplied Stripe Checkout session belongs to the registration flow and has
/// reached the paid/complete state. Returns <c>false</c> for any missing/invalid/unpaid session
/// so the caller can safely stop before creating any local company or user records.
/// </summary>
public async Task<bool> IsRegistrationCheckoutPaidAsync(string sessionId)
{
try
{
var sessionService = new SessionService();
var session = await sessionService.GetAsync(sessionId);
if (!session.Metadata.TryGetValue("registration", out var isRegistration) ||
!string.Equals(isRegistration, "true", StringComparison.OrdinalIgnoreCase))
{
_logger.LogWarning(
"Registration checkout validation failed for session {SessionId}: missing registration metadata",
sessionId);
return false;
}
var isPaidAndComplete = session.PaymentStatus == "paid" && session.Status == "complete";
if (!isPaidAndComplete)
{
_logger.LogWarning(
"Registration checkout validation failed for session {SessionId}: paymentStatus={PaymentStatus}, status={Status}",
sessionId, session.PaymentStatus, session.Status);
}
return isPaidAndComplete;
}
catch (StripeException ex)
{
_logger.LogWarning(ex, "Stripe rejected registration checkout validation for session {SessionId}", sessionId);
return false;
}
catch (Exception ex)
{
_logger.LogError(ex, "Unexpected error validating registration checkout session {SessionId}", sessionId);
return false;
}
}
/// <summary>
/// Finalizes a registration checkout after payment is confirmed. Called from the
/// registration success redirect (not the webhook path). Retrieves the session from Stripe
@@ -262,11 +262,10 @@ public class RegistrationController : Controller
/// Stripe return URL handler called after the customer completes payment in Stripe Checkout.
/// Looks up the <c>PendingRegistrationSession</c> by the opaque <c>reg_token</c> to recover the
/// registration data that could not be stored in a session cookie (which may not survive the
/// Stripe redirect on some browsers/devices). Marks the session as completed before creating the
/// account to prevent duplicate submissions if the user hits reload. Contains a second open-cap
/// check in case the tenant limit was reached between the user starting and completing checkout.
/// Calls <c>FulfillRegistrationCheckoutAsync</c> to link the Stripe subscription (sets the real
/// <c>SubscriptionEndDate</c>); failure is non-fatal — dates can be corrected manually.
/// Stripe redirect on some browsers/devices). The Stripe session is re-validated before any
/// local company or user is created, and the pending session is only left marked completed once
/// registration succeeds. On recoverable failures the completion flag is released so the same
/// success URL can be retried instead of forcing manual support intervention.
/// </summary>
[HttpGet]
public async Task<IActionResult> PaymentSuccess(string? session_id, string? reg_token)
@@ -281,99 +280,123 @@ public class RegistrationController : Controller
}
var pendingSession = await _db.PendingRegistrationSessions
.FirstOrDefaultAsync(p => p.Token == reg_token && !p.IsCompleted);
.FirstOrDefaultAsync(p => p.Token == reg_token);
if (pendingSession == null)
{
TempData["Error"] = "Your registration session was not found or has already been completed. Please fill in your details again.";
TempData["Error"] = "Your registration session was not found. Please fill in your details again.";
return RedirectToAction(nameof(Index));
}
// Map DB record to the internal record used below
var pending = new PendingRegistration(
pendingSession.CompanyName, pendingSession.CompanyPhone,
pendingSession.FirstName, pendingSession.LastName,
pendingSession.Email, pendingSession.Plan, pendingSession.IsAnnual);
// Mark completed to prevent duplicate submissions
var existingUser = await _userManager.FindByEmailAsync(pending.Email);
if (pendingSession.IsCompleted)
{
if (existingUser != null)
{
await SignInExistingRegistrationUserAsync(existingUser);
return RedirectToAction(nameof(Welcome));
}
TempData["Error"] = "Your registration was already submitted, but we couldn't finish signing you in. Please contact support with reference: " + session_id;
return RedirectToAction(nameof(Index));
}
if (!await _stripeService.IsRegistrationCheckoutPaidAsync(session_id))
{
TempData["Error"] = "We couldn't verify a completed payment for this registration session yet. Please try the link again in a moment, or contact support if the issue persists.";
return RedirectToAction(nameof(Index));
}
var keepSessionCompleted = false;
pendingSession.IsCompleted = true;
await _db.SaveChangesAsync();
// Guard against race condition: re-check capacity after Stripe redirect
if (!await IsRegistrationOpenAsync())
{
TempData["Error"] = "Registration is currently closed. Your payment has been received but no account was created. Please contact support.";
return RedirectToAction(nameof(Index));
}
// Guard against race condition (duplicate submission)
if (await _userManager.FindByEmailAsync(pending.Email) != null)
{
// Account already exists — just sign them in and go to dashboard
var existingUser = await _userManager.FindByEmailAsync(pending.Email);
if (existingUser != null)
{
existingUser.LastLoginDate = DateTime.UtcNow;
await _userManager.UpdateAsync(existingUser);
await _signInManager.SignInAsync(existingUser, isPersistent: false);
}
return RedirectToAction("Index", "Dashboard");
}
var companyCode = await GenerateUniqueCompanyCodeAsync(pending.CompanyName);
var company = new Company
{
CompanyName = pending.CompanyName,
CompanyCode = companyCode,
Phone = pending.CompanyPhone,
PrimaryContactEmail = pending.Email,
PrimaryContactName = $"{pending.FirstName} {pending.LastName}",
SubscriptionPlan = pending.Plan,
SubscriptionStatus = SubscriptionStatus.Active,
SubscriptionStartDate = DateTime.UtcNow,
SubscriptionEndDate = DateTime.UtcNow.AddDays(1), // Stripe fulfillment sets real date below
IsActive = true,
CreatedAt = DateTime.UtcNow
};
await _unitOfWork.Companies.AddAsync(company);
await _unitOfWork.CompleteAsync();
var tempPassword = GenerateTemporaryPassword();
var user = BuildUser(pending.Email, pending.FirstName, pending.LastName, company.Id);
var createResult = await _userManager.CreateAsync(user, tempPassword);
if (!createResult.Succeeded)
{
await _unitOfWork.Companies.DeleteAsync(company);
await _unitOfWork.CompleteAsync();
_logger.LogError("Failed to create user after payment for {Email}: {Errors}",
pending.Email, string.Join(", ", createResult.Errors.Select(e => e.Description)));
TempData["Error"] = "Your payment was received but we encountered an error creating your account. " +
"Please contact support with reference: " + session_id;
return RedirectToAction(nameof(Index));
}
// Link the Stripe subscription (sets real SubscriptionEndDate)
try
{
await _stripeService.FulfillRegistrationCheckoutAsync(session_id, company.Id, pending.Plan);
// Guard against race condition: re-check capacity after Stripe redirect
if (!await IsRegistrationOpenAsync())
{
TempData["Error"] = "Registration is currently closed. Your payment has been received but no account was created. Please contact support.";
return RedirectToAction(nameof(Index));
}
// Recover gracefully if a prior attempt already created the user.
if (existingUser != null)
{
keepSessionCompleted = true;
await SignInExistingRegistrationUserAsync(existingUser);
return RedirectToAction(nameof(Welcome));
}
var companyCode = await GenerateUniqueCompanyCodeAsync(pending.CompanyName);
var company = new Company
{
CompanyName = pending.CompanyName,
CompanyCode = companyCode,
Phone = pending.CompanyPhone,
PrimaryContactEmail = pending.Email,
PrimaryContactName = $"{pending.FirstName} {pending.LastName}",
SubscriptionPlan = pending.Plan,
SubscriptionStatus = SubscriptionStatus.Active,
SubscriptionStartDate = DateTime.UtcNow,
SubscriptionEndDate = DateTime.UtcNow.AddDays(1), // Stripe fulfillment sets real date below
IsActive = true,
CreatedAt = DateTime.UtcNow
};
await _unitOfWork.Companies.AddAsync(company);
await _unitOfWork.CompleteAsync();
var tempPassword = GenerateTemporaryPassword();
var user = BuildUser(pending.Email, pending.FirstName, pending.LastName, company.Id);
var createResult = await _userManager.CreateAsync(user, tempPassword);
if (!createResult.Succeeded)
{
await _unitOfWork.Companies.DeleteAsync(company);
await _unitOfWork.CompleteAsync();
_logger.LogError("Failed to create user after payment for {Email}: {Errors}",
pending.Email, string.Join(", ", createResult.Errors.Select(e => e.Description)));
TempData["Error"] = "Your payment was received, but we couldn't finish creating your account. Please try the success link again, or contact support with reference: " + session_id;
return RedirectToAction(nameof(Index));
}
// Link the Stripe subscription (sets real SubscriptionEndDate)
try
{
await _stripeService.FulfillRegistrationCheckoutAsync(session_id, company.Id, pending.Plan);
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to fulfill registration checkout {SessionId} for company {CompanyId}",
session_id, company.Id);
// Non-fatal — subscription dates can be synced manually later
}
await _userManager.AddClaimAsync(user, new Claim("MustChangePassword", "true"));
await FinalizeRegistrationAsync(user, company, pending.Plan);
_ = SendWelcomeEmailAsync(pending.Email, pending.FirstName, tempPassword, pending.Plan,
null, $"{Request.Scheme}://{Request.Host}");
keepSessionCompleted = true;
return RedirectToAction(nameof(Welcome));
}
catch (Exception ex)
finally
{
_logger.LogError(ex, "Failed to fulfill registration checkout {SessionId} for company {CompanyId}",
session_id, company.Id);
// Non-fatal — subscription dates can be synced manually later
if (!keepSessionCompleted)
{
pendingSession.IsCompleted = false;
await _db.SaveChangesAsync();
}
}
await _userManager.AddClaimAsync(user, new Claim("MustChangePassword", "true"));
await FinalizeRegistrationAsync(user, company, pending.Plan);
_ = SendWelcomeEmailAsync(pending.Email, pending.FirstName, tempPassword, pending.Plan,
null, $"{Request.Scheme}://{Request.Host}");
return RedirectToAction(nameof(Welcome));
}
/// <summary>
@@ -480,6 +503,18 @@ public class RegistrationController : Controller
await _signInManager.SignInAsync(user, isPersistent: false);
}
/// <summary>
/// Signs in an already-created registration user on idempotent success-link retries.
/// This is intentionally narrower than <see cref="FinalizeRegistrationAsync"/> because all
/// registration side effects should already have happened on the original successful run.
/// </summary>
private async Task SignInExistingRegistrationUserAsync(ApplicationUser user)
{
user.LastLoginDate = DateTime.UtcNow;
await _userManager.UpdateAsync(user);
await _signInManager.SignInAsync(user, isPersistent: false);
}
/// <summary>
/// Renders the post-registration welcome page. Requires authentication (the user was just signed
/// in by <see cref="FinalizeRegistrationAsync"/>). Determines whether the account is on a free
@@ -0,0 +1 @@
global using Xunit;
@@ -9,6 +9,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore.InMemory" Version="8.0.11" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageReference Include="Moq" Version="4.20.72" />
<PackageReference Include="xunit" Version="2.9.2" />
@@ -26,6 +27,7 @@
<ProjectReference Include="..\..\src\PowderCoating.Core\PowderCoating.Core.csproj" />
<ProjectReference Include="..\..\src\PowderCoating.Application\PowderCoating.Application.csproj" />
<ProjectReference Include="..\..\src\PowderCoating.Infrastructure\PowderCoating.Infrastructure.csproj" />
<ProjectReference Include="..\..\src\PowderCoating.Web\PowderCoating.Web.csproj" />
</ItemGroup>
</Project>
@@ -0,0 +1,244 @@
using Microsoft.Extensions.Logging;
using Moq;
using PowderCoating.Application.DTOs.Quote;
using PowderCoating.Application.Services;
using PowderCoating.Core.Entities;
using PowderCoating.Core.Interfaces;
using Xunit;
namespace PowderCoating.UnitTests;
public class PricingCalculationServiceTests
{
[Fact]
public async Task CalculateCoatPriceAsync_CustomPowder_ChargesFullOrderQuantity()
{
var costs = CreateOperatingCosts();
var unitOfWork = CreateUnitOfWorkMock(costs);
var tenantContext = new Mock<ITenantContext>();
tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false);
var service = new PricingCalculationService(
unitOfWork.Object,
Mock.Of<ILogger<PricingCalculationService>>(),
new MeasurementConversionService(),
tenantContext.Object);
var coat = new CreateQuoteItemCoatDto
{
CoatName = "Custom Red",
PowderCostPerLb = 10m,
PowderToOrder = 3m,
CoverageSqFtPerLb = 30m,
TransferEfficiency = 65m
};
var result = await service.CalculateCoatPriceAsync(
coat,
itemSurfaceAreaSqFt: 5m,
quantity: 2m,
coatIndex: 0,
estimatedMinutesBase: 15,
companyId: 1);
Assert.Equal(30m, result.CoatMaterialCost);
Assert.Equal(30m, result.CoatLaborCost);
Assert.Equal(60m, result.CoatTotalCost);
}
[Fact]
public async Task CalculateQuoteItemPriceAsync_LaborItem_UsesStandardLaborRate()
{
var costs = CreateOperatingCosts();
costs.StandardLaborRate = 80m;
var unitOfWork = CreateUnitOfWorkMock(costs);
var tenantContext = new Mock<ITenantContext>();
tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false);
var service = new PricingCalculationService(
unitOfWork.Object,
Mock.Of<ILogger<PricingCalculationService>>(),
new MeasurementConversionService(),
tenantContext.Object);
var item = new CreateQuoteItemDto
{
Description = "Shop labor",
IsLaborItem = true,
Quantity = 2.5m
};
var result = await service.CalculateQuoteItemPriceAsync(item, companyId: 1);
Assert.Equal(0m, result.MaterialCost);
Assert.Equal(200m, result.LaborCost);
Assert.Equal(80m, result.UnitPrice);
Assert.Equal(200m, result.TotalPrice);
}
[Fact]
public async Task CalculateQuoteItemPriceAsync_AiItem_UsesManualUnitPriceWithoutAdditionalCosts()
{
var unitOfWork = CreateUnitOfWorkMock(CreateOperatingCosts());
var tenantContext = new Mock<ITenantContext>();
tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false);
var service = new PricingCalculationService(
unitOfWork.Object,
Mock.Of<ILogger<PricingCalculationService>>(),
new MeasurementConversionService(),
tenantContext.Object);
var item = new CreateQuoteItemDto
{
Description = "AI wheel estimate",
IsAiItem = true,
ManualUnitPrice = 123m,
Quantity = 2m
};
var result = await service.CalculateQuoteItemPriceAsync(item, companyId: 1);
Assert.Equal(0m, result.MaterialCost);
Assert.Equal(0m, result.LaborCost);
Assert.Equal(123m, result.UnitPrice);
Assert.Equal(246m, result.TotalPrice);
}
[Fact]
public async Task CalculateQuoteTotalsAsync_AppliesTierDiscount_QuoteDiscount_RushFee_AndTax()
{
var costs = CreateOperatingCosts();
costs.StandardLaborRate = 100m;
costs.ShopSuppliesRate = 10m;
costs.RushChargeType = "Percentage";
costs.RushChargePercentage = 20m;
costs.TaxPercent = 5m;
costs.OvenOperatingCostPerHour = 0m;
costs.MonthlyRent = 0m;
costs.MonthlyUtilities = 0m;
var customerRepo = new Mock<IRepository<Customer>>();
customerRepo
.Setup(x => x.FindAsync(It.IsAny<System.Linq.Expressions.Expression<Func<Customer, bool>>>(), false, It.IsAny<System.Linq.Expressions.Expression<Func<Customer, object>>[]>()))
.ReturnsAsync(new[]
{
new Customer { Id = 1, CompanyId = 1, PricingTierId = 10 }
});
var pricingTierRepo = new Mock<IRepository<PricingTier>>();
pricingTierRepo
.Setup(x => x.GetByIdAsync(10, false, It.IsAny<System.Linq.Expressions.Expression<Func<PricingTier, object>>[]>()))
.ReturnsAsync(new PricingTier { Id = 10, CompanyId = 1, DiscountPercent = 10m });
var unitOfWork = CreateUnitOfWorkMock(costs);
unitOfWork.SetupGet(x => x.Customers).Returns(customerRepo.Object);
unitOfWork.SetupGet(x => x.PricingTiers).Returns(pricingTierRepo.Object);
var tenantContext = new Mock<ITenantContext>();
tenantContext.Setup(x => x.UseMetricSystemAsync()).ReturnsAsync(false);
var service = new PricingCalculationService(
unitOfWork.Object,
Mock.Of<ILogger<PricingCalculationService>>(),
new MeasurementConversionService(),
tenantContext.Object);
var items = new List<CreateQuoteItemDto>
{
new()
{
Description = "Labor item",
IsLaborItem = true,
Quantity = 2m
}
};
var result = await service.CalculateQuoteTotalsAsync(
items,
companyId: 1,
customerId: 1,
discountType: "FixedAmount",
discountValue: 5m,
isRushJob: true);
Assert.Equal(200m, result.ItemsSubtotal);
Assert.Equal(20m, result.ShopSuppliesAmount);
Assert.Equal(220m, result.SubtotalBeforeDiscount);
Assert.Equal(22m, result.PricingTierDiscountAmount);
Assert.Equal(5m, result.QuoteDiscountAmount);
Assert.Equal(193m, result.SubtotalAfterDiscount);
Assert.Equal(38.6m, result.RushFee);
Assert.Equal(11.58m, result.TaxAmount);
Assert.Equal(243.18m, result.Total);
}
private static Mock<IUnitOfWork> CreateUnitOfWorkMock(CompanyOperatingCosts costs)
{
var unitOfWork = new Mock<IUnitOfWork>();
var companyOperatingCostsRepo = new Mock<IRepository<CompanyOperatingCosts>>();
companyOperatingCostsRepo
.Setup(x => x.FindAsync(It.IsAny<System.Linq.Expressions.Expression<Func<CompanyOperatingCosts, bool>>>(), false, It.IsAny<System.Linq.Expressions.Expression<Func<CompanyOperatingCosts, object>>[]>()))
.ReturnsAsync(new[] { costs });
var inventoryRepo = new Mock<IRepository<InventoryItem>>();
inventoryRepo
.Setup(x => x.GetByIdAsync(It.IsAny<int>(), false, It.IsAny<System.Linq.Expressions.Expression<Func<InventoryItem, object>>[]>()))
.ReturnsAsync((InventoryItem?)null);
var catalogRepo = new Mock<IRepository<CatalogItem>>();
catalogRepo
.Setup(x => x.GetByIdAsync(It.IsAny<int>(), false, It.IsAny<System.Linq.Expressions.Expression<Func<CatalogItem, object>>[]>()))
.ReturnsAsync((CatalogItem?)null);
var customerRepo = new Mock<IRepository<Customer>>();
customerRepo
.Setup(x => x.FindAsync(It.IsAny<System.Linq.Expressions.Expression<Func<Customer, bool>>>(), false, It.IsAny<System.Linq.Expressions.Expression<Func<Customer, object>>[]>()))
.ReturnsAsync(Array.Empty<Customer>());
var pricingTierRepo = new Mock<IRepository<PricingTier>>();
pricingTierRepo
.Setup(x => x.GetByIdAsync(It.IsAny<int>(), false, It.IsAny<System.Linq.Expressions.Expression<Func<PricingTier, object>>[]>()))
.ReturnsAsync((PricingTier?)null);
unitOfWork.SetupGet(x => x.CompanyOperatingCosts).Returns(companyOperatingCostsRepo.Object);
unitOfWork.SetupGet(x => x.InventoryItems).Returns(inventoryRepo.Object);
unitOfWork.SetupGet(x => x.CatalogItems).Returns(catalogRepo.Object);
unitOfWork.SetupGet(x => x.Customers).Returns(customerRepo.Object);
unitOfWork.SetupGet(x => x.PricingTiers).Returns(pricingTierRepo.Object);
return unitOfWork;
}
private static CompanyOperatingCosts CreateOperatingCosts()
{
return new CompanyOperatingCosts
{
Id = 1,
CompanyId = 1,
StandardLaborRate = 60m,
AdditionalCoatLaborPercent = 50m,
OvenOperatingCostPerHour = 25m,
SandblasterCostPerHour = 20m,
CoatingBoothCostPerHour = 10m,
PowderCoatingCostPerSqFt = 1m,
PricingMode = PowderCoating.Core.Enums.PricingMode.MarkupOnMaterial,
GeneralMarkupPercentage = 20m,
TargetMarginPercent = 40m,
TaxPercent = 5m,
ShopSuppliesRate = 10m,
DefaultOvenCycleMinutes = 60,
RushChargeType = "Percentage",
RushChargePercentage = 15m,
RushChargeFixedAmount = 50m,
ShopMinimumCharge = 0m,
ComplexitySimplePercent = 0m,
ComplexityModeratePercent = 5m,
ComplexityComplexPercent = 15m,
ComplexityExtremePercent = 25m,
MonthlyBillableHours = 160
};
}
}
@@ -0,0 +1,191 @@
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Mvc.ViewFeatures;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging;
using Moq;
using PowderCoating.Application.Interfaces;
using PowderCoating.Core.Entities;
using PowderCoating.Infrastructure.Data;
using PowderCoating.Infrastructure.Repositories;
using PowderCoating.Web.Controllers;
using Xunit;
namespace PowderCoating.UnitTests;
public class RegistrationControllerTests
{
[Fact]
public async Task PaymentSuccess_WhenStripeSessionIsNotPaid_DoesNotBurnPendingSession()
{
await using var context = CreateContext();
context.PendingRegistrationSessions.Add(CreatePendingSession("token-1", "owner@example.com"));
await context.SaveChangesAsync();
var stripeService = new Mock<IStripeService>();
stripeService.Setup(x => x.IsRegistrationCheckoutPaidAsync("sess_unpaid")).ReturnsAsync(false);
var controller = CreateController(context, stripeService: stripeService);
var result = await controller.PaymentSuccess("sess_unpaid", "token-1");
var redirect = Assert.IsType<Microsoft.AspNetCore.Mvc.RedirectToActionResult>(result);
Assert.Equal("Index", redirect.ActionName);
Assert.False((await context.PendingRegistrationSessions.SingleAsync()).IsCompleted);
Assert.Contains("couldn't verify a completed payment", controller.TempData["Error"]?.ToString());
}
[Fact]
public async Task PaymentSuccess_WhenUserCreationFails_ReleasesPendingSessionAndDeletesCompany()
{
await using var context = CreateContext();
context.PendingRegistrationSessions.Add(CreatePendingSession("token-2", "owner2@example.com"));
await context.SaveChangesAsync();
var userManager = CreateUserManagerMock();
userManager.Setup(x => x.FindByEmailAsync("owner2@example.com")).ReturnsAsync((ApplicationUser?)null);
userManager.Setup(x => x.CreateAsync(It.IsAny<ApplicationUser>(), It.IsAny<string>()))
.ReturnsAsync(IdentityResult.Failed(new IdentityError { Description = "boom" }));
var stripeService = new Mock<IStripeService>();
stripeService.Setup(x => x.IsRegistrationCheckoutPaidAsync("sess_paid")).ReturnsAsync(true);
var controller = CreateController(context, userManager, stripeService: stripeService);
var result = await controller.PaymentSuccess("sess_paid", "token-2");
var redirect = Assert.IsType<Microsoft.AspNetCore.Mvc.RedirectToActionResult>(result);
Assert.Equal("Index", redirect.ActionName);
Assert.False((await context.PendingRegistrationSessions.SingleAsync()).IsCompleted);
Assert.Empty(context.Companies);
Assert.Contains("Please try the success link again", controller.TempData["Error"]?.ToString());
}
[Fact]
public async Task PaymentSuccess_WhenSessionAlreadyCompletedAndUserExists_SignsUserInAndRedirectsToWelcome()
{
await using var context = CreateContext();
context.PendingRegistrationSessions.Add(CreatePendingSession("token-3", "owner3@example.com", isCompleted: true));
await context.SaveChangesAsync();
var existingUser = new ApplicationUser
{
Id = "user-3",
Email = "owner3@example.com",
UserName = "owner3@example.com",
FirstName = "Terry",
LastName = "Tenant",
CompanyId = 3
};
var userManager = CreateUserManagerMock();
userManager.Setup(x => x.FindByEmailAsync("owner3@example.com")).ReturnsAsync(existingUser);
userManager.Setup(x => x.UpdateAsync(existingUser)).ReturnsAsync(IdentityResult.Success);
var signInManager = CreateSignInManagerMock(userManager.Object);
signInManager.Setup(x => x.SignInAsync(existingUser, false, null)).Returns(Task.CompletedTask).Verifiable();
var controller = CreateController(context, userManager, signInManager.Object);
var result = await controller.PaymentSuccess("sess_complete", "token-3");
var redirect = Assert.IsType<Microsoft.AspNetCore.Mvc.RedirectToActionResult>(result);
Assert.Equal("Welcome", redirect.ActionName);
signInManager.Verify(x => x.SignInAsync(existingUser, false, null), Times.Once);
Assert.True((await context.PendingRegistrationSessions.SingleAsync()).IsCompleted);
}
private static RegistrationController CreateController(
ApplicationDbContext context,
Mock<UserManager<ApplicationUser>>? userManager = null,
SignInManager<ApplicationUser>? signInManager = null,
Mock<IStripeService>? stripeService = null)
{
var unitOfWork = new UnitOfWork(context);
var userManagerMock = userManager ?? CreateUserManagerMock();
var signInManagerInstance = signInManager ?? CreateSignInManagerMock(userManagerMock.Object).Object;
var platformSettings = new Mock<IPlatformSettingsService>();
platformSettings.Setup(x => x.GetAsync(It.IsAny<string>())).ReturnsAsync((string?)null);
var controller = new RegistrationController(
unitOfWork,
context,
userManagerMock.Object,
signInManagerInstance,
Mock.Of<ISeedDataService>(),
Mock.Of<IAdminNotificationService>(),
Mock.Of<IInAppNotificationService>(),
platformSettings.Object,
(stripeService ?? new Mock<IStripeService>()).Object,
Mock.Of<IEmailService>(),
Mock.Of<ILogger<RegistrationController>>());
var httpContext = new DefaultHttpContext();
controller.ControllerContext = new Microsoft.AspNetCore.Mvc.ControllerContext
{
HttpContext = httpContext
};
controller.TempData = new TempDataDictionary(httpContext, Mock.Of<ITempDataProvider>());
return controller;
}
private static Mock<UserManager<ApplicationUser>> CreateUserManagerMock()
{
var store = new Mock<IUserStore<ApplicationUser>>();
return new Mock<UserManager<ApplicationUser>>(
store.Object,
null!,
null!,
null!,
null!,
null!,
null!,
null!,
null!);
}
private static Mock<SignInManager<ApplicationUser>> CreateSignInManagerMock(UserManager<ApplicationUser> userManager)
{
var contextAccessor = new Mock<IHttpContextAccessor>();
contextAccessor.Setup(x => x.HttpContext).Returns(new DefaultHttpContext());
var claimsFactory = new Mock<IUserClaimsPrincipalFactory<ApplicationUser>>();
return new Mock<SignInManager<ApplicationUser>>(
userManager,
contextAccessor.Object,
claimsFactory.Object,
null!,
null!,
null!,
null!);
}
private static ApplicationDbContext CreateContext()
{
var options = new DbContextOptionsBuilder<ApplicationDbContext>()
.UseInMemoryDatabase(Guid.NewGuid().ToString())
.Options;
return new ApplicationDbContext(options);
}
private static PendingRegistrationSession CreatePendingSession(string token, string email, bool isCompleted = false)
{
return new PendingRegistrationSession
{
Token = token,
CompanyName = "Retry Co",
CompanyPhone = "555-0100",
FirstName = "Pat",
LastName = "Owner",
Email = email,
Plan = 1,
IsAnnual = false,
IsCompleted = isCompleted,
CreatedAt = DateTime.UtcNow
};
}
}
@@ -0,0 +1,221 @@
using Microsoft.EntityFrameworkCore;
using PowderCoating.Core.Entities;
using PowderCoating.Core.Enums;
using PowderCoating.Infrastructure.Data;
using PowderCoating.Infrastructure.Repositories;
using PowderCoating.Infrastructure.Services;
using Xunit;
namespace PowderCoating.UnitTests;
public class SubscriptionServiceTests
{
[Fact]
public async Task GetUserCountAsync_PrefersCompanyOverrideOverPlanDefault()
{
await using var context = CreateContext();
SeedCompanyAndPlan(context, companyId: 7, plan: 1, maxUsers: 3);
var company = context.Companies.Local.Single(c => c.Id == 7);
company.MaxUsersOverride = 7;
context.Users.AddRange(
new ApplicationUser { Id = "u1", CompanyId = 7, UserName = "u1", Email = "u1@example.com", FirstName = "A", LastName = "One", IsActive = true },
new ApplicationUser { Id = "u2", CompanyId = 7, UserName = "u2", Email = "u2@example.com", FirstName = "B", LastName = "Two", IsActive = true });
await context.SaveChangesAsync();
var service = new SubscriptionService(new UnitOfWork(context), context);
var (used, max) = await service.GetUserCountAsync(7);
Assert.Equal(2, used);
Assert.Equal(7, max);
}
[Fact]
public async Task GetJobCountAsync_ExcludesTerminalStatuses()
{
await using var context = CreateContext();
SeedCompanyAndPlan(context, companyId: 8, plan: 2, maxActiveJobs: 50);
SeedJobStatuses(context, 8);
context.Jobs.AddRange(
new Job { Id = 1, CompanyId = 8, JobNumber = "JOB-1", CustomerId = 1, Description = "Active", JobStatusId = 1, JobPriorityId = 1 },
new Job { Id = 2, CompanyId = 8, JobNumber = "JOB-2", CustomerId = 1, Description = "Done", JobStatusId = 2, JobPriorityId = 1 },
new Job { Id = 3, CompanyId = 8, JobNumber = "JOB-3", CustomerId = 1, Description = "Delivered", JobStatusId = 3, JobPriorityId = 1 });
await context.SaveChangesAsync();
var service = new SubscriptionService(new UnitOfWork(context), context);
var (used, max) = await service.GetJobCountAsync(8);
Assert.Equal(1, used);
Assert.Equal(50, max);
}
[Fact]
public async Task GetQuoteCountAsync_CountsOnlyCurrentMonth()
{
await using var context = CreateContext();
SeedCompanyAndPlan(context, companyId: 9, plan: 3, maxQuotes: 5);
var currentQuote = new Quote
{
Id = 1,
CompanyId = 9,
QuoteNumber = "Q-001",
QuoteStatusId = 1
};
var oldQuote = new Quote
{
Id = 2,
CompanyId = 9,
QuoteNumber = "Q-OLD",
QuoteStatusId = 1
};
context.Quotes.AddRange(currentQuote, oldQuote);
await context.SaveChangesAsync();
oldQuote.CreatedAt = DateTime.UtcNow.AddMonths(-1);
await context.SaveChangesAsync();
var service = new SubscriptionService(new UnitOfWork(context), context);
var (used, max) = await service.GetQuoteCountAsync(9);
Assert.Equal(1, used);
Assert.Equal(5, max);
}
[Fact]
public async Task CanAddCustomerAsync_CompedCompany_BypassesPlanLimits()
{
await using var context = CreateContext();
SeedCompanyAndPlan(context, companyId: 10, plan: 4, maxCustomers: 0);
var company = await context.Companies.FindAsync(10);
company!.IsComped = true;
context.Customers.Add(new Customer { Id = 1, CompanyId = 10, CompanyName = "Customer A" });
await context.SaveChangesAsync();
var service = new SubscriptionService(new UnitOfWork(context), context);
var allowed = await service.CanAddCustomerAsync(10);
Assert.True(allowed);
}
[Fact]
public async Task CanUseAiPhotoQuoteAsync_RequiresFeatureEnabledAndQuotaAvailable()
{
await using var context = CreateContext();
SeedCompanyAndPlan(context, companyId: 11, plan: 5, maxAiPhotoQuotesPerMonth: 2, allowAiPhotoQuotes: true);
context.AiItemPredictions.Add(new AiItemPrediction { Id = 1, CompanyId = 11, CreatedAt = DateTime.UtcNow.AddDays(-1) });
await context.SaveChangesAsync();
var service = new SubscriptionService(new UnitOfWork(context), context);
var allowed = await service.CanUseAiPhotoQuoteAsync(11);
Assert.True(allowed);
}
[Fact]
public async Task CanUseAiPhotoQuoteAsync_ReturnsFalse_WhenPlanDisablesFeature()
{
await using var context = CreateContext();
SeedCompanyAndPlan(context, companyId: 12, plan: 6, maxAiPhotoQuotesPerMonth: 10, allowAiPhotoQuotes: false);
await context.SaveChangesAsync();
var service = new SubscriptionService(new UnitOfWork(context), context);
var allowed = await service.CanUseAiPhotoQuoteAsync(12);
Assert.False(allowed);
}
private static ApplicationDbContext CreateContext()
{
var options = new DbContextOptionsBuilder<ApplicationDbContext>()
.UseInMemoryDatabase(Guid.NewGuid().ToString())
.Options;
return new ApplicationDbContext(options);
}
private static void SeedCompanyAndPlan(
ApplicationDbContext context,
int companyId,
int plan,
int maxUsers = -1,
int maxActiveJobs = -1,
int maxCustomers = -1,
int maxQuotes = -1,
int maxAiPhotoQuotesPerMonth = -1,
bool allowAiPhotoQuotes = true)
{
context.Companies.Add(new Company
{
Id = companyId,
CompanyId = companyId,
CompanyName = $"Company {companyId}",
PrimaryContactName = "Owner",
PrimaryContactEmail = $"owner{companyId}@example.com",
SubscriptionPlan = plan,
SubscriptionStatus = SubscriptionStatus.Active,
IsActive = true
});
context.SubscriptionPlanConfigs.Add(new SubscriptionPlanConfig
{
Id = companyId,
CompanyId = 0,
Plan = plan,
DisplayName = $"Plan {plan}",
IsActive = true,
MaxUsers = maxUsers,
MaxActiveJobs = maxActiveJobs,
MaxCustomers = maxCustomers,
MaxQuotes = maxQuotes,
MaxAiPhotoQuotesPerMonth = maxAiPhotoQuotesPerMonth,
AllowAiPhotoQuotes = allowAiPhotoQuotes
});
context.JobPriorityLookups.Add(new JobPriorityLookup
{
Id = companyId,
CompanyId = companyId,
PriorityCode = "NORMAL",
DisplayName = "Normal",
DisplayOrder = 1
});
}
private static void SeedJobStatuses(ApplicationDbContext context, int companyId)
{
context.JobStatusLookups.AddRange(
new JobStatusLookup
{
Id = 1,
CompanyId = companyId,
StatusCode = "Pending",
DisplayName = "Pending",
DisplayOrder = 1,
IsTerminalStatus = false
},
new JobStatusLookup
{
Id = 2,
CompanyId = companyId,
StatusCode = "Completed",
DisplayName = "Completed",
DisplayOrder = 2,
IsTerminalStatus = true
},
new JobStatusLookup
{
Id = 3,
CompanyId = companyId,
StatusCode = "Delivered",
DisplayName = "Delivered",
DisplayOrder = 3,
IsTerminalStatus = true
});
}
}