Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Osara-B committed Dec 19, 2024
1 parent 1d8ffdd commit 8313f3f
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, WSO2 LLC. (http://www.wso2.com).
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
Expand Down Expand Up @@ -87,12 +87,14 @@
import static org.wso2.carbon.identity.core.dao.SAMLSSOServiceProviderConstants.SAML2TableColumns.SP_ID;

import static java.time.ZoneOffset.UTC;
/**
* Implementation of the SAMLSSOServiceProviderDAO interface for JDBC-based persistence.
*/

public class JDBCSAMLSSOServiceProviderDAOImpl implements SAMLSSOServiceProviderDAO {

private static final Calendar CALENDAR = Calendar.getInstance(TimeZone.getTimeZone(UTC));
private static final Log log = LogFactory.getLog(JDBCSAMLSSOServiceProviderDAOImpl.class);
private int tenantId;
private static final Log LOG = LogFactory.getLog(JDBCSAMLSSOServiceProviderDAOImpl.class);
private static final String CERTIFICATE_PROPERTY_NAME = "CERTIFICATE";
private static final String QUERY_TO_GET_APPLICATION_CERTIFICATE_ID = "SELECT " +
"META.VALUE FROM SP_INBOUND_AUTH INBOUND, SP_APP SP, SP_METADATA META WHERE SP.ID = INBOUND.APP_ID AND " +
Expand All @@ -107,24 +109,24 @@ public JDBCSAMLSSOServiceProviderDAOImpl() {

@Override
public boolean addServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO, int tenantId) throws IdentityException {
this.tenantId = tenantId;

validateServiceProvider(serviceProviderDO);
try {
if (processIsServiceProviderExists(serviceProviderDO.getIssuer())) {
if (processIsServiceProviderExists(serviceProviderDO.getIssuer(),tenantId)) {
debugLog(serviceProviderInfo(serviceProviderDO) + " already exists.");
return false;
}
NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
namedJdbcTemplate.withTransaction(template -> {
processAddServiceProvider(serviceProviderDO);
processAddSPProperties(serviceProviderDO);
processAddServiceProvider(serviceProviderDO, tenantId);
processAddSPProperties(serviceProviderDO, tenantId);
return null;
});
debugLog(serviceProviderInfo(serviceProviderDO) + " is added successfully.");
return true;
} catch (TransactionException | DataAccessException e) {
String msg = "Error while adding " + serviceProviderInfo(serviceProviderDO);
log.error(msg, e);
LOG.error(msg, e);
throw new IdentityException(msg, e);
}
}
Expand All @@ -133,41 +135,39 @@ public boolean addServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO, in
public boolean updateServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO, String currentIssuer, int tenantId)
throws IdentityException {

this.tenantId = tenantId;
validateServiceProvider(serviceProviderDO);
String newIssuer = serviceProviderDO.getIssuer();
boolean isIssuerUpdated = !StringUtils.equals(currentIssuer, newIssuer);

try {
if (isIssuerUpdated && processIsServiceProviderExists(newIssuer)) {
if (isIssuerUpdated && processIsServiceProviderExists(newIssuer, tenantId)) {
debugLog(serviceProviderInfo(serviceProviderDO) + " already exists.");
return false;
}
int serviceProviderId = processGetServiceProviderId(currentIssuer);
int serviceProviderId = processGetServiceProviderId(currentIssuer, tenantId);
NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
namedJdbcTemplate.withTransaction(template -> {
processUpdateServiceProvider(serviceProviderDO, serviceProviderId);
processUpdateServiceProvider(serviceProviderDO, serviceProviderId, tenantId);
processUpdateSPProperties(serviceProviderDO, serviceProviderId);
return null;
});
debugLog(serviceProviderInfo(serviceProviderDO) + " is updated successfully.");
return true;
} catch (TransactionException | DataAccessException e) {
String msg = "Error while updating " + serviceProviderInfo(serviceProviderDO);
log.error(msg, e);
LOG.error(msg, e);
throw new IdentityException(msg, e);
}
}

@Override
public SAMLSSOServiceProviderDO[] getServiceProviders(int tenantId) throws IdentityException {

this.tenantId = tenantId;
List<SAMLSSOServiceProviderDO> serviceProvidersList;
try {
serviceProvidersList = processGetServiceProviders();
serviceProvidersList = processGetServiceProviders(tenantId);
} catch (DataAccessException e) {
log.error("Error reading Service Providers", e);
LOG.error("Error reading Service Providers", e);
throw new IdentityException("Error reading Service Providers", e);
}
return serviceProvidersList.toArray(new SAMLSSOServiceProviderDO[0]);
Expand All @@ -176,33 +176,31 @@ public SAMLSSOServiceProviderDO[] getServiceProviders(int tenantId) throws Ident
@Override
public boolean removeServiceProvider(String issuer, int tenantId) throws IdentityException {

this.tenantId = tenantId;
if (issuer == null || StringUtils.isEmpty(issuer.trim())) {
throw new IllegalArgumentException("Trying to delete issuer \'" + issuer + "\'");
}
try {
if (!processIsServiceProviderExists(issuer)) {
if (!processIsServiceProviderExists(issuer, tenantId)) {
debugLog("Service Provider with issuer " + issuer + " does not exist.");
return false;
}
processDeleteServiceProvider(issuer);
processDeleteServiceProvider(issuer, tenantId);
return true;
} catch (DataAccessException e) {
String msg = "Error removing the service provider with name: " + issuer;
log.error(msg, e);
LOG.error(msg, e);
throw new IdentityException(msg, e);
}
}

@Override
public SAMLSSOServiceProviderDO getServiceProvider(String issuer, int tenantId) throws IdentityException {

this.tenantId = tenantId;
SAMLSSOServiceProviderDO serviceProviderDO = null;

try {
if (isServiceProviderExists(issuer, tenantId)) {
serviceProviderDO = processGetServiceProvider(issuer);
serviceProviderDO = processGetServiceProvider(issuer, tenantId);
}
} catch (DataAccessException e) {
throw IdentityException.error(String.format(
Expand Down Expand Up @@ -233,12 +231,12 @@ public SAMLSSOServiceProviderDO getServiceProvider(String issuer, int tenantId)
@Override
public boolean isServiceProviderExists(String issuer, int tenantId) throws IdentityException {

this.tenantId = tenantId;

try {
return processIsServiceProviderExists(issuer);
return processIsServiceProviderExists(issuer, tenantId);
} catch (DataAccessException e) {
String msg = "Error while checking existence of Service Provider with issuer: " + issuer;
log.error(msg, e);
LOG.error(msg, e);
throw new IdentityException(msg, e);
}
}
Expand All @@ -247,41 +245,41 @@ public boolean isServiceProviderExists(String issuer, int tenantId) throws Ident
public SAMLSSOServiceProviderDO uploadServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO, int tenantId)
throws IdentityException {

this.tenantId = tenantId;

validateServiceProvider(serviceProviderDO);
if (serviceProviderDO.getDefaultAssertionConsumerUrl() == null) {
throw new IdentityException("No default assertion consumer URL provided for service provider :" +
serviceProviderDO.getIssuer());
}

try {
if (processIsServiceProviderExists(serviceProviderDO.getIssuer())) {
if (processIsServiceProviderExists(serviceProviderDO.getIssuer(), tenantId)) {
debugLog(serviceProviderInfo(serviceProviderDO) + " already exists.");
throw new IdentityException(serviceProviderInfo(serviceProviderDO) + " already exists.");
}
NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
namedJdbcTemplate.withTransaction(template -> {
processAddServiceProvider(serviceProviderDO);
processAddSPProperties(serviceProviderDO);
processAddServiceProvider(serviceProviderDO, tenantId);
processAddSPProperties(serviceProviderDO, tenantId);
return null;
});
debugLog(serviceProviderInfo(serviceProviderDO) + " is added successfully.");
return serviceProviderDO;
} catch (TransactionException | DataAccessException e) {
String msg = "Error while adding " + serviceProviderInfo(serviceProviderDO);
log.error(msg, e);
LOG.error(msg, e);
throw new IdentityException(msg, e);
}
}

private void debugLog(String message) {

if (log.isDebugEnabled()) {
log.debug(message);
if (LOG.isDebugEnabled()) {
LOG.debug(message);
}
}

private boolean processIsServiceProviderExists(String issuer) throws DataAccessException {
private boolean processIsServiceProviderExists(String issuer, int tenantId) throws DataAccessException {

NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
Integer serviceProviderId =
Expand Down Expand Up @@ -392,7 +390,7 @@ private void addProperties(int serviceProviderId, SAMLSSOServiceProviderDO servi
}

private void setUpdateServiceProviderParameters(NamedPreparedStatement statement,
SAMLSSOServiceProviderDO serviceProviderDO) throws SQLException {
SAMLSSOServiceProviderDO serviceProviderDO, int tenantId) throws SQLException {

statement.setInt(TENANT_ID, tenantId);
statement.setString(ISSUER, serviceProviderDO.getIssuer());
Expand Down Expand Up @@ -428,7 +426,7 @@ private void setUpdateServiceProviderParameters(NamedPreparedStatement statement
}

private void setServiceProviderParameters(NamedPreparedStatement statement,
SAMLSSOServiceProviderDO serviceProviderDO) throws SQLException {
SAMLSSOServiceProviderDO serviceProviderDO, int tenantId) throws SQLException {
Timestamp currentTime = new Timestamp(new Date().getTime());
statement.setInt(TENANT_ID, tenantId);
statement.setString(ISSUER, serviceProviderDO.getIssuer());
Expand Down Expand Up @@ -465,7 +463,7 @@ private void setServiceProviderParameters(NamedPreparedStatement statement,
statement.setTimeStamp(UPDATED_AT, currentTime, CALENDAR);
}

private int processGetServiceProviderId(String issuer) throws DataAccessException {
private int processGetServiceProviderId(String issuer, int tenantId) throws DataAccessException {

NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
Integer serviceProviderId =
Expand All @@ -480,18 +478,18 @@ private int processGetServiceProviderId(String issuer) throws DataAccessExceptio
return serviceProviderId.intValue();
}

private void processAddServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO) throws DataAccessException {
private void processAddServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO,int tenantId) throws DataAccessException {

NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
namedJdbcTemplate.executeInsert(SAMLSSOServiceProviderConstants.SQLQueries.ADD_SAML2_SSO_CONFIG,
namedPreparedStatement -> setServiceProviderParameters(namedPreparedStatement, serviceProviderDO),
namedPreparedStatement -> setServiceProviderParameters(namedPreparedStatement, serviceProviderDO, tenantId),
serviceProviderDO, false);
}

private void processAddSPProperties(SAMLSSOServiceProviderDO serviceProviderDO) throws DataAccessException {
private void processAddSPProperties(SAMLSSOServiceProviderDO serviceProviderDO, int tenantId) throws DataAccessException {

List<ServiceProviderProperty> properties = serviceProviderDO.getMultiValuedProperties();
int serviceProviderId = processGetServiceProviderId(serviceProviderDO.getIssuer());
int serviceProviderId = processGetServiceProviderId(serviceProviderDO.getIssuer(), tenantId);

NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();

Expand All @@ -506,14 +504,14 @@ private void processAddSPProperties(SAMLSSOServiceProviderDO serviceProviderDO)
}), serviceProviderDO);
}

private void processUpdateServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO, int serviceProviderId)
private void processUpdateServiceProvider(SAMLSSOServiceProviderDO serviceProviderDO, int serviceProviderId, int tenantId)
throws DataAccessException {

NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
namedJdbcTemplate.executeUpdate(SAMLSSOServiceProviderConstants.SQLQueries.UPDATE_SAML2_SSO_CONFIG,
namedPreparedStatement -> {
namedPreparedStatement.setInt(ID, serviceProviderId);
setUpdateServiceProviderParameters(namedPreparedStatement, serviceProviderDO);
setUpdateServiceProviderParameters(namedPreparedStatement, serviceProviderDO, tenantId);
});
}

Expand All @@ -537,7 +535,7 @@ private void processUpdateSPProperties(SAMLSSOServiceProviderDO serviceProviderD
}), serviceProviderDO);
}

private SAMLSSOServiceProviderDO processGetServiceProvider(String issuer) throws DataAccessException {
private SAMLSSOServiceProviderDO processGetServiceProvider(String issuer, int tenantId) throws DataAccessException {

NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
SAMLSSOServiceProviderDO serviceProviderDO = namedJdbcTemplate.fetchSingleRecord(
Expand All @@ -548,12 +546,12 @@ private SAMLSSOServiceProviderDO processGetServiceProvider(String issuer) throws
});

if (serviceProviderDO != null) {
addProperties(processGetServiceProviderId(issuer), serviceProviderDO);
addProperties(processGetServiceProviderId(issuer, tenantId), serviceProviderDO);
}
return serviceProviderDO;
}

private List<SAMLSSOServiceProviderDO> processGetServiceProviders() throws DataAccessException {
private List<SAMLSSOServiceProviderDO> processGetServiceProviders(int tenantId) throws DataAccessException {

List<SAMLSSOServiceProviderDO> serviceProvidersList;
NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();
Expand All @@ -563,12 +561,12 @@ private List<SAMLSSOServiceProviderDO> processGetServiceProviders() throws DataA
namedPreparedStatement -> namedPreparedStatement.setInt(TENANT_ID, tenantId));

for (SAMLSSOServiceProviderDO serviceProviderDO : serviceProvidersList) {
addProperties(processGetServiceProviderId(serviceProviderDO.getIssuer()), serviceProviderDO);
addProperties(processGetServiceProviderId(serviceProviderDO.getIssuer(), tenantId), serviceProviderDO);
}
return serviceProvidersList;
}

private void processDeleteServiceProvider(String issuer) throws DataAccessException {
private void processDeleteServiceProvider(String issuer, int tenantId) throws DataAccessException {

NamedJdbcTemplate namedJdbcTemplate = JdbcUtils.getNewNamedJdbcTemplate();

Expand Down Expand Up @@ -637,4 +635,4 @@ private int getApplicationCertificateId(String issuer, int tenantId) throws Data
return certificateId != null ? certificateId : -1;
}

}
}
Loading

0 comments on commit 8313f3f

Please sign in to comment.