using System.Security.Claims; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Identity; using Microsoft.EntityFrameworkCore; using Moq; using PowderCoating.Core.Entities; using PowderCoating.Infrastructure.Data; using PowderCoating.Infrastructure.Services; namespace PowderCoating.UnitTests; public class TenantContextTests { [Fact] public void GetCurrentCompanyId_WhenUnauthenticated_ReturnsNull() { using var context = CreateContext(); var userManager = CreateUserManagerMock(); var accessor = CreateHttpContextAccessor(new ClaimsPrincipal(new ClaimsIdentity())); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var companyId = tenantContext.GetCurrentCompanyId(); Assert.Null(companyId); } [Fact] public void GetCurrentCompanyId_WhenSuperAdminIsImpersonating_ReturnsSessionOverride() { using var context = CreateContext(); var userManager = CreateUserManagerMock(); var session = new TestSession(); session.SetInt32("ImpersonatingCompanyId", 42); var accessor = CreateHttpContextAccessor( CreatePrincipal(isAuthenticated: true, name: "admin@example.com", roles: ["SuperAdmin"]), session); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var companyId = tenantContext.GetCurrentCompanyId(); Assert.Equal(42, companyId); } [Fact] public void GetCurrentCompanyId_PrefersCompanyClaim() { using var context = CreateContext(); var userManager = CreateUserManagerMock(); userManager.Setup(x => x.Users).Returns(Enumerable.Empty().AsQueryable()); var accessor = CreateHttpContextAccessor( CreatePrincipal(isAuthenticated: true, name: "user@example.com", companyIdClaim: 9)); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var companyId = tenantContext.GetCurrentCompanyId(); Assert.Equal(9, companyId); } [Fact] public async Task GetCurrentCompanyId_WhenClaimMissing_FallsBackToUserLookup() { await using var context = CreateContext(); context.Users.Add(new ApplicationUser { Id = "user-1", UserName = "legacy@example.com", Email = "legacy@example.com", FirstName = "Legacy", LastName = "User", CompanyId = 17 }); await context.SaveChangesAsync(); var userManager = CreateUserManagerMock(); userManager.Setup(x => x.Users).Returns(context.Users); var accessor = CreateHttpContextAccessor( CreatePrincipal(isAuthenticated: true, name: "legacy@example.com")); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var companyId = tenantContext.GetCurrentCompanyId(); Assert.Equal(17, companyId); } [Fact] public void IsPlatformAdmin_ReturnsTrue_ForSuperAdminWithoutTenantScope() { using var context = CreateContext(); var userManager = CreateUserManagerMock(); userManager.Setup(x => x.Users).Returns(Enumerable.Empty().AsQueryable()); var accessor = CreateHttpContextAccessor( CreatePrincipal(isAuthenticated: true, roles: ["SuperAdmin"])); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var isPlatformAdmin = tenantContext.IsPlatformAdmin(); Assert.True(isPlatformAdmin); } [Fact] public void IsPlatformAdmin_ReturnsFalse_ForSuperAdminImpersonatingCompany() { using var context = CreateContext(); var userManager = CreateUserManagerMock(); var session = new TestSession(); session.SetInt32("ImpersonatingCompanyId", 2); var accessor = CreateHttpContextAccessor( CreatePrincipal(isAuthenticated: true, name: "admin@example.com", roles: ["SuperAdmin"]), session); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var isPlatformAdmin = tenantContext.IsPlatformAdmin(); Assert.False(isPlatformAdmin); } [Fact] public async Task UseMetricSystemAsync_ReturnsStoredPreference() { await using var context = CreateContext(); context.CompanyPreferences.Add(new CompanyPreferences { Id = 1, CompanyId = 25, UseMetricSystem = true }); await context.SaveChangesAsync(); var userManager = CreateUserManagerMock(); var accessor = CreateHttpContextAccessor( CreatePrincipal(isAuthenticated: true, name: "metric@example.com", companyIdClaim: 25)); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var useMetric = await tenantContext.UseMetricSystemAsync(); Assert.True(useMetric); } [Fact] public async Task GetCurrentCompanyAsync_ReturnsCompanyFromUserManager() { await using var context = CreateContext(); var company = new Company { Id = 31, CompanyId = 31, CompanyName = "Current Co", PrimaryContactName = "Owner", PrimaryContactEmail = "owner@example.com" }; var principal = CreatePrincipal(isAuthenticated: true, name: "current@example.com", companyIdClaim: 31); var user = new ApplicationUser { Id = "user-31", UserName = "current@example.com", Email = "current@example.com", FirstName = "Current", LastName = "User", CompanyId = 31, Company = company }; var userManager = CreateUserManagerMock(); userManager.Setup(x => x.GetUserAsync(principal)).ReturnsAsync(user); var accessor = CreateHttpContextAccessor(principal); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var currentCompany = await tenantContext.GetCurrentCompanyAsync(); Assert.NotNull(currentCompany); Assert.Equal("Current Co", currentCompany!.CompanyName); } [Fact] public void IsPlatformAdmin_ReturnsTrue_ForSuperAdminOnCompanyOne() { using var context = CreateContext(); var userManager = CreateUserManagerMock(); var accessor = CreateHttpContextAccessor( CreatePrincipal(isAuthenticated: true, name: "platform@example.com", companyIdClaim: 1, roles: ["SuperAdmin"])); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var isPlatformAdmin = tenantContext.IsPlatformAdmin(); Assert.True(isPlatformAdmin); } [Fact] public async Task UseMetricSystemAsync_WhenNoCompanyContext_ReturnsFalse() { await using var context = CreateContext(); var userManager = CreateUserManagerMock(); var accessor = CreateHttpContextAccessor(CreatePrincipal(isAuthenticated: true, name: "nocompany@example.com")); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var useMetric = await tenantContext.UseMetricSystemAsync(); Assert.False(useMetric); } [Fact] public async Task GetCurrentCompanyAsync_WhenUnauthenticated_ReturnsNull() { await using var context = CreateContext(); var userManager = CreateUserManagerMock(); var accessor = CreateHttpContextAccessor(new ClaimsPrincipal(new ClaimsIdentity())); var tenantContext = new TenantContext(accessor.Object, userManager.Object, context); var currentCompany = await tenantContext.GetCurrentCompanyAsync(); Assert.Null(currentCompany); } private static Mock CreateHttpContextAccessor(ClaimsPrincipal principal, ISession? session = null) { var httpContext = new Mock(); httpContext.SetupGet(x => x.User).Returns(principal); httpContext.SetupGet(x => x.Session).Returns(session ?? new TestSession()); var accessor = new Mock(); accessor.SetupGet(x => x.HttpContext).Returns(httpContext.Object); return accessor; } private static ClaimsPrincipal CreatePrincipal( bool isAuthenticated, string? name = null, int? companyIdClaim = null, string[]? roles = null) { if (!isAuthenticated) { return new ClaimsPrincipal(new ClaimsIdentity()); } var claims = new List(); if (!string.IsNullOrWhiteSpace(name)) { claims.Add(new Claim(ClaimTypes.Name, name)); } if (companyIdClaim.HasValue) { claims.Add(new Claim("CompanyId", companyIdClaim.Value.ToString())); } foreach (var role in roles ?? []) { claims.Add(new Claim(ClaimTypes.Role, role)); } var identity = new ClaimsIdentity(claims, "TestAuth", ClaimTypes.Name, ClaimTypes.Role); return new ClaimsPrincipal(identity); } private static Mock> CreateUserManagerMock() { var store = new Mock>(); return new Mock>( store.Object, null!, null!, null!, null!, null!, null!, null!, null!); } private static ApplicationDbContext CreateContext() { var options = new DbContextOptionsBuilder() .UseInMemoryDatabase(Guid.NewGuid().ToString()) .Options; return new ApplicationDbContext(options); } private sealed class TestSession : ISession { private readonly Dictionary _values = new(); public IEnumerable Keys => _values.Keys; public string Id => "test-session"; public bool IsAvailable => true; public void Clear() => _values.Clear(); public Task CommitAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; public Task LoadAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; public void Remove(string key) => _values.Remove(key); public void Set(string key, byte[] value) => _values[key] = value; public bool TryGetValue(string key, out byte[] value) => _values.TryGetValue(key, out value!); } }