marco.pms.api/Marco.Pms.Services/Service/RefreshTokenService.cs

293 lines
12 KiB
C#

using Marco.Pms.DataAccess.Data;
using Marco.Pms.Model.Authentication;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.IdentityModel.Tokens;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Text;
#nullable disable
namespace MarcoBMS.Services.Service
{
public class RefreshTokenService
{
private readonly ApplicationDbContext _context;
private readonly IMemoryCache _cache; // For optional JWT blacklisting
private readonly ILoggingService _logger;
public RefreshTokenService(ApplicationDbContext context, IMemoryCache cache, ILoggingService logger)
{
_context = context;
_cache = cache;
_logger = logger;
}
public string GenerateJwtTokenWithOrganization(string username, Guid organizationId, JwtSettings _jwtSettings)
{
// Custom claims
var claims = new List<Claim>
{
new Claim(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()),
new Claim(JwtRegisteredClaimNames.Sub, username),
new Claim("OrganizationId", organizationId.ToString()), // Add TenantId claim
new Claim(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()) };
var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(_jwtSettings.Key));
var creds = new SigningCredentials(key, SecurityAlgorithms.HmacSha256);
var token = new JwtSecurityToken(
issuer: _jwtSettings.Issuer,
audience: _jwtSettings.Audience,
claims: claims,
expires: DateTime.UtcNow.AddMinutes(_jwtSettings.ExpiresInMinutes),
signingCredentials: creds);
return new JwtSecurityTokenHandler().WriteToken(token);
}
public async Task<string> CreateRefreshTokenWithOrganization(string userId, Guid organizationId, JwtSettings jwtSettings)
{
try
{
var claims = new[]
{
new Claim(ClaimTypes.NameIdentifier, userId),
new Claim("OrganizationId", organizationId.ToString()),
new Claim("token_type", "refresh")
};
var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(jwtSettings.Key));
var credentials = new SigningCredentials(key, SecurityAlgorithms.HmacSha256Signature);
var tokenDescriptor = new SecurityTokenDescriptor
{
Subject = new ClaimsIdentity(claims),
Expires = DateTime.UtcNow.AddDays(jwtSettings.RefreshTokenExpiresInDays),
Issuer = jwtSettings.Issuer,
Audience = jwtSettings.Audience,
SigningCredentials = credentials
};
var tokenHandler = new JwtSecurityTokenHandler();
var refreshTokenString = tokenHandler.WriteToken(tokenHandler.CreateToken(tokenDescriptor));
var refreshToken = new RefreshToken
{
Token = refreshTokenString,
UserId = userId,
ExpiryDate = DateTime.UtcNow.AddDays(jwtSettings.RefreshTokenExpiresInDays),
IsRevoked = false
};
// Check if the record exists
var existingEntity = await _context.RefreshTokens.AnyAsync(c => c.UserId == userId);
if (existingEntity) { _context.RefreshTokens.Update(refreshToken); }
else
{
_context.RefreshTokens.Add(refreshToken);
}
await _context.SaveChangesAsync();
return refreshTokenString;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error occured while creating new JWT token for user {UserId}", userId);
throw;
}
}
public string GenerateJwtToken(string username, Guid tenantId, JwtSettings _jwtSettings)
{
// Custom claims
var claims = new List<Claim>
{
new Claim(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()),
new Claim(JwtRegisteredClaimNames.Sub, username),
new Claim("TenantId", tenantId.ToString()), // Add TenantId claim
new Claim(JwtRegisteredClaimNames.Jti, Guid.NewGuid().ToString()) };
var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(_jwtSettings.Key));
var creds = new SigningCredentials(key, SecurityAlgorithms.HmacSha256);
var token = new JwtSecurityToken(
issuer: _jwtSettings.Issuer,
audience: _jwtSettings.Audience,
claims: claims,
expires: DateTime.UtcNow.AddMinutes(_jwtSettings.ExpiresInMinutes),
signingCredentials: creds);
return new JwtSecurityTokenHandler().WriteToken(token);
}
public async Task<string> CreateRefreshToken(string userId, string tenantId, JwtSettings jwtSettings)
{
try
{
var claims = new[]
{
new Claim(ClaimTypes.NameIdentifier, userId),
new Claim("TenantId", tenantId),
new Claim("token_type", "refresh")
};
var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(jwtSettings.Key));
var credentials = new SigningCredentials(key, SecurityAlgorithms.HmacSha256Signature);
var tokenDescriptor = new SecurityTokenDescriptor
{
Subject = new ClaimsIdentity(claims),
Expires = DateTime.UtcNow.AddDays(jwtSettings.RefreshTokenExpiresInDays),
Issuer = jwtSettings.Issuer,
Audience = jwtSettings.Audience,
SigningCredentials = credentials
};
var tokenHandler = new JwtSecurityTokenHandler();
var refreshTokenString = tokenHandler.WriteToken(tokenHandler.CreateToken(tokenDescriptor));
var refreshToken = new RefreshToken
{
Token = refreshTokenString,
UserId = userId,
ExpiryDate = DateTime.UtcNow.AddDays(jwtSettings.RefreshTokenExpiresInDays),
IsRevoked = false
};
// Check if the record exists
var existingEntity = await _context.RefreshTokens.AnyAsync(c => c.UserId == userId);
if (existingEntity) { _context.RefreshTokens.Update(refreshToken); }
else
{
_context.RefreshTokens.Add(refreshToken);
}
await _context.SaveChangesAsync();
return refreshTokenString;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error occured while creating new JWT token for user {UserId}", userId);
throw;
}
}
public string CreateMPINToken(string userId, string organizationId, JwtSettings jwtSettings)
{
try
{
var claims = new[]
{
new Claim(ClaimTypes.NameIdentifier, userId),
new Claim("OrganizationId", organizationId),
new Claim("token_type", "mpin")
};
var key = new SymmetricSecurityKey(Encoding.UTF8.GetBytes(jwtSettings.Key));
var creds = new SigningCredentials(key, SecurityAlgorithms.HmacSha256Signature);
var tokenDescriptor = new SecurityTokenDescriptor
{
Subject = new ClaimsIdentity(claims),
Issuer = jwtSettings.Issuer,
Audience = jwtSettings.Audience,
SigningCredentials = creds
// No 'Expires' means the token won't expire
};
var tokenHandler = new JwtSecurityTokenHandler();
var MPINToken = tokenHandler.WriteToken(tokenHandler.CreateToken(tokenDescriptor));
return MPINToken;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error creating MPIN token for userId: {UserId}, organizationId: {OrganizationId}", userId, organizationId);
throw;
}
}
public async Task<RefreshToken> GetRefreshToken(string token)
{
return await _context.RefreshTokens.FirstOrDefaultAsync(rt => rt.Token == token && !rt.IsRevoked && !rt.IsUsed) ?? new RefreshToken();
}
public async Task MarkRefreshTokenAsUsed(RefreshToken refreshToken)
{
refreshToken.IsUsed = true;
_context.RefreshTokens.Update(refreshToken);
await _context.SaveChangesAsync();
}
public async Task RevokeRefreshToken(RefreshToken refreshToken)
{
refreshToken.IsRevoked = true;
refreshToken.RevokedAt = DateTime.UtcNow;
_context.RefreshTokens.Update(refreshToken);
await _context.SaveChangesAsync();
}
// Revoke refresh token
public async Task<bool> RevokeRefreshTokenAsync(string refreshToken)
{
var token = await _context.RefreshTokens.FirstOrDefaultAsync(t => t.Token == refreshToken);
if (token == null || token.IsRevoked || token.ExpiryDate <= DateTime.UtcNow)
return false;
token.IsRevoked = true;
token.RevokedAt = DateTime.UtcNow;
await _context.SaveChangesAsync();
return true;
}
// Optional: Blacklist JWT token
public Task BlacklistJwtTokenAsync(string jwtToken)
{
// Store the JWT token in memory cache with its expiry
var jwtExpiry = GetJwtExpiry(jwtToken);
if (jwtExpiry.HasValue)
{
_cache.Set(jwtToken, true, jwtExpiry.Value - DateTime.UtcNow);
}
return Task.CompletedTask;
}
private DateTime? GetJwtExpiry(string token)
{
var handler = new JwtSecurityTokenHandler();
var jwtToken = handler.ReadToken(token) as JwtSecurityToken;
return jwtToken?.ValidTo;
}
public ClaimsPrincipal ValidateToken(string token, JwtSettings jwtSettings)
{
var tokenHandler = new JwtSecurityTokenHandler();
var key = System.Text.Encoding.ASCII.GetBytes(jwtSettings.Key);
var validationParameters = new TokenValidationParameters
{
ValidateIssuerSigningKey = true,
IssuerSigningKey = new SymmetricSecurityKey(key),
ValidateIssuer = true,
ValidIssuer = jwtSettings.Issuer,
ValidateAudience = true,
ValidAudience = jwtSettings.Audience,
ValidateLifetime = false, // Disable lifetime validation (ignores expiration)
ClockSkew = TimeSpan.Zero // Optional: Remove time skew buffer
};
try
{
var principal = tokenHandler.ValidateToken(token, validationParameters, out SecurityToken validatedToken);
return principal;
}
catch (Exception ex)
{
// Token is invalid
_logger.LogError(ex, "Token validation failed");
return null;
}
}
}
}