Skip to content

Commit

Permalink
Merge pull request #29 from alanenriqueo/fix-public-ip-firewall
Browse files Browse the repository at this point in the history
Add support for flexible server and obtain runner's public IP address from ipify
  • Loading branch information
DaeunYim authored Aug 4, 2022
2 parents 66a9747 + 3928dca commit f82d2b2
Show file tree
Hide file tree
Showing 11 changed files with 330 additions and 90 deletions.
6 changes: 1 addition & 5 deletions lib/Constants.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
"use strict";
Object.defineProperty(exports, "__esModule", { value: true });
exports.PsqlConstants = exports.FirewallConstants = exports.FileConstants = void 0;
exports.PsqlConstants = exports.FileConstants = void 0;
class FileConstants {
}
exports.FileConstants = FileConstants;
// regex checks that string should end with .sql and if folderPath is present, * should not be included in folderPath
FileConstants.singleParentDirRegex = /^((?!\*\/).)*(\.sql)$/g;
class FirewallConstants {
}
exports.FirewallConstants = FirewallConstants;
FirewallConstants.ipv4MatchPattern = /\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/;
class PsqlConstants {
}
exports.PsqlConstants = PsqlConstants;
Expand Down
42 changes: 24 additions & 18 deletions lib/Utils/FirewallUtils/ResourceManager.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,27 @@ class AzurePSQLResourceManager {
getPSQLServer() {
return this._resource;
}
_populatePSQLServerData(serverName) {
_getPSQLServer(serverType, apiVersion, serverName) {
return __awaiter(this, void 0, void 0, function* () {
// trim the cloud hostname suffix from servername
serverName = serverName.split('.')[0];
const httpRequest = {
method: 'GET',
uri: this._restClient.getRequestUri('//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/servers', {}, [], '2017-12-01')
uri: this._restClient.getRequestUri(`//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/${serverType}`, {}, [], apiVersion)
};
core.debug(`Get PSQL server '${serverName}' details`);
core.debug(`Get '${serverName}' for PSQL ${serverType} details`);
try {
const httpResponse = yield this._restClient.beginRequest(httpRequest);
if (httpResponse.statusCode !== 200) {
throw AzureRestClient_1.ToError(httpResponse);
}
const sqlServers = httpResponse.body && httpResponse.body.value;
if (sqlServers && sqlServers.length > 0) {
this._resource = sqlServers.filter((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase())[0];
if (!this._resource) {
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
}
core.debug(JSON.stringify(this._resource));
}
else {
throw new Error(`Unable to get details of PSQL server ${serverName}. No PSQL servers were found in the subscription.`);
const sqlServers = ((httpResponse.body && httpResponse.body.value) || []);
const sqlServer = sqlServers.find((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase());
if (sqlServer) {
this._serverType = serverType;
this._apiVersion = apiVersion;
this._resource = sqlServer;
return true;
}
return false;
}
catch (error) {
if (error instanceof AzureRestClient_1.AzureError) {
Expand All @@ -80,12 +76,22 @@ class AzurePSQLResourceManager {
}
});
}
_populatePSQLServerData(serverName) {
return __awaiter(this, void 0, void 0, function* () {
// trim the cloud hostname suffix from servername
serverName = serverName.split('.')[0];
(yield this._getPSQLServer('servers', '2017-12-01', serverName)) || (yield this._getPSQLServer('flexibleServers', '2021-06-01', serverName));
if (!this._resource) {
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
}
});
}
addFirewallRule(startIpAddress, endIpAddress) {
return __awaiter(this, void 0, void 0, function* () {
const firewallRuleName = `ClientIPAddress_${Date.now()}`;
const httpRequest = {
method: 'PUT',
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRuleName}`, {}, [], '2017-12-01'),
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRuleName}`, {}, [], this._apiVersion),
body: JSON.stringify({
'properties': {
'startIpAddress': startIpAddress,
Expand Down Expand Up @@ -122,7 +128,7 @@ class AzurePSQLResourceManager {
return __awaiter(this, void 0, void 0, function* () {
const httpRequest = {
method: 'GET',
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${ruleName}`, {}, [], '2017-12-01')
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${ruleName}`, {}, [], this._apiVersion)
};
try {
const httpResponse = yield this._restClient.beginRequest(httpRequest);
Expand All @@ -143,7 +149,7 @@ class AzurePSQLResourceManager {
return __awaiter(this, void 0, void 0, function* () {
const httpRequest = {
method: 'DELETE',
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRule.name}`, {}, [], '2017-12-01')
uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRule.name}`, {}, [], this._apiVersion)
};
try {
const httpResponse = yield this._restClient.beginRequest(httpRequest);
Expand Down
18 changes: 10 additions & 8 deletions lib/Utils/PsqlUtils/PsqlUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ var __importDefault = (this && this.__importDefault) || function (mod) {
};
Object.defineProperty(exports, "__esModule", { value: true });
const Constants_1 = require("../../Constants");
const Constants_2 = require("../../Constants");
const PsqlToolRunner_1 = __importDefault(require("./PsqlToolRunner"));
const http_client_1 = require("@actions/http-client");
class PsqlUtils {
static detectIPAddress(connectionString) {
var _a;
return __awaiter(this, void 0, void 0, function* () {
let psqlError = '';
let ipAddress = '';
Expand All @@ -31,16 +32,17 @@ class PsqlUtils {
// "SELECT 1" psql command is run to check if psql client is able to connect to DB using the connectionString
try {
yield PsqlToolRunner_1.default.init();
yield PsqlToolRunner_1.default.executePsqlCommand(connectionString, Constants_1.PsqlConstants.SELECT_1, options);
yield PsqlToolRunner_1.default.executePsqlCommand(`${connectionString} connect_timeout=10`, Constants_1.PsqlConstants.SELECT_1, options);
}
catch (err) {
catch (_b) {
if (psqlError) {
const ipAddresses = psqlError.match(Constants_2.FirewallConstants.ipv4MatchPattern);
if (ipAddresses) {
ipAddress = ipAddresses[0];
const http = new http_client_1.HttpClient();
try {
const ipv4 = yield http.getJson('https://api.ipify.org?format=json');
ipAddress = ((_a = ipv4.result) === null || _a === void 0 ? void 0 : _a.ip) || '';
}
else {
throw new Error(`Unable to detect client IP Address: ${psqlError}`);
catch (err) {
throw new Error(`Unable to detect client IP Address: ${err.message}`);
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"dependencies": {
"@actions/core": "^1.2.6",
"@actions/exec": "^1.0.4",
"@actions/http-client": "^2.0.1",
"@actions/io": "^1.0.2",
"azure-actions-webclient": "^1.0.11",
"crypto": "^1.0.1",
Expand Down
4 changes: 0 additions & 4 deletions src/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ export class FileConstants {
static readonly singleParentDirRegex = /^((?!\*\/).)*(\.sql)$/g;
}

export class FirewallConstants {
static readonly ipv4MatchPattern = /\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/;
}

export class PsqlConstants {
static readonly SELECT_1 = "SELECT 1";
// host, port, dbname, user, password must be present in connection string in any order.
Expand Down
46 changes: 27 additions & 19 deletions src/Utils/FirewallUtils/ResourceManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,33 +61,29 @@ export default class AzurePSQLResourceManager {
return this._resource;
}

private async _populatePSQLServerData(serverName: string) {
// trim the cloud hostname suffix from servername
serverName = serverName.split('.')[0];
private async _getPSQLServer(serverType: string, apiVersion: string, serverName: string) {
const httpRequest: WebRequest = {
method: 'GET',
uri: this._restClient.getRequestUri('//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/servers', {}, [], '2017-12-01')
uri: this._restClient.getRequestUri(`//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/${serverType}`, {}, [], apiVersion)
}

core.debug(`Get PSQL server '${serverName}' details`);
core.debug(`Get '${serverName}' for PSQL ${serverType} details`);
try {
const httpResponse = await this._restClient.beginRequest(httpRequest);
if (httpResponse.statusCode !== 200) {
throw ToError(httpResponse);
}

const sqlServers = httpResponse.body && httpResponse.body.value as AzurePSQLServer[];
if (sqlServers && sqlServers.length > 0) {
this._resource = sqlServers.filter((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase())[0];
if (!this._resource) {
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
}

core.debug(JSON.stringify(this._resource));
}
else {
throw new Error(`Unable to get details of PSQL server ${serverName}. No PSQL servers were found in the subscription.`);
const sqlServers = ((httpResponse.body && httpResponse.body.value) || []) as AzurePSQLServer[];
const sqlServer = sqlServers.find((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase());
if (sqlServer) {
this._serverType = serverType;
this._apiVersion = apiVersion;
this._resource = sqlServer;
return true;
}

return false;
}
catch(error) {
if (error instanceof AzureError) {
Expand All @@ -98,11 +94,21 @@ export default class AzurePSQLResourceManager {
}
}

private async _populatePSQLServerData(serverName: string) {
// trim the cloud hostname suffix from servername
serverName = serverName.split('.')[0];

(await this._getPSQLServer('servers', '2017-12-01', serverName)) || (await this._getPSQLServer('flexibleServers', '2021-06-01', serverName));
if (!this._resource) {
throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`);
}
}

public async addFirewallRule(startIpAddress: string, endIpAddress: string): Promise<FirewallRule> {
const firewallRuleName = `ClientIPAddress_${Date.now()}`;
const httpRequest: WebRequest = {
method: 'PUT',
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRuleName}`, {}, [], '2017-12-01'),
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRuleName}`, {}, [], this._apiVersion),
body: JSON.stringify({
'properties': {
'startIpAddress': startIpAddress,
Expand Down Expand Up @@ -141,7 +147,7 @@ export default class AzurePSQLResourceManager {
public async getFirewallRule(ruleName: string): Promise<FirewallRule> {
const httpRequest: WebRequest = {
method: 'GET',
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${ruleName}`, {}, [], '2017-12-01')
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${ruleName}`, {}, [], this._apiVersion)
};

try {
Expand All @@ -164,7 +170,7 @@ export default class AzurePSQLResourceManager {
public async removeFirewallRule(firewallRule: FirewallRule): Promise<void> {
const httpRequest: WebRequest = {
method: 'DELETE',
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRule.name}`, {}, [], '2017-12-01')
uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRule.name}`, {}, [], this._apiVersion)
};

try {
Expand Down Expand Up @@ -228,6 +234,8 @@ export default class AzurePSQLResourceManager {
});
}

private _serverType?: string;
private _apiVersion?: string;
private _resource?: AzurePSQLServer;
private _restClient: AzureRestClient;
}
21 changes: 13 additions & 8 deletions src/Utils/PsqlUtils/PsqlUtils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { PsqlConstants } from "../../Constants";
import { FirewallConstants } from "../../Constants";
import PsqlToolRunner from "./PsqlToolRunner";
import { HttpClient } from '@actions/http-client';

export default class PsqlUtils {
static async detectIPAddress(connectionString: string): Promise<string> {
Expand All @@ -14,21 +14,26 @@ export default class PsqlUtils {
},
silent: true
};

// "SELECT 1" psql command is run to check if psql client is able to connect to DB using the connectionString
try {
await PsqlToolRunner.init();
await PsqlToolRunner.executePsqlCommand(connectionString, PsqlConstants.SELECT_1, options);
} catch(err) {
await PsqlToolRunner.executePsqlCommand(`${connectionString} connect_timeout=10`, PsqlConstants.SELECT_1, options);
} catch {
if (psqlError) {
const ipAddresses = psqlError.match(FirewallConstants.ipv4MatchPattern);
if (ipAddresses) {
ipAddress = ipAddresses[0];
} else {
throw new Error(`Unable to detect client IP Address: ${psqlError}`);
const http = new HttpClient();
try {
const ipv4 = await http.getJson<IPResponse>('https://api.ipify.org?format=json');
ipAddress = ipv4.result?.ip || '';
} catch(err) {
throw new Error(`Unable to detect client IP Address: ${err.message}`);
}
}
}
return ipAddress;
}
}

export interface IPResponse {
ip: string;
}
47 changes: 20 additions & 27 deletions src/__tests__/Utils/PsqlUtils.test.ts
Original file line number Diff line number Diff line change
@@ -1,42 +1,35 @@
import PsqlUtils from "../../Utils/PsqlUtils/PsqlUtils";
import { FirewallConstants } from "../../Constants";
import { HttpClient } from '@actions/http-client';
import PsqlToolRunner from "../../Utils/PsqlUtils/PsqlToolRunner";
import PsqlUtils, { IPResponse } from "../../Utils/PsqlUtils/PsqlUtils";

jest.mock('../../Utils/PsqlUtils/PsqlToolRunner');
const CONFIGURED = "configured";
jest.mock('@actions/http-client');

describe('Testing PsqlUtils', () => {
afterEach(() => {
jest.clearAllMocks()
jest.resetAllMocks();
});

let detectIPAddressSpy: any;
beforeAll(() => {
detectIPAddressSpy = PsqlUtils.detectIPAddress = jest.fn().mockImplementation( (connString: string) => {
let psqlError;
if (connString != CONFIGURED) {
psqlError = `psql: error: could not connect to server: FATAL: no pg_hba.conf entry for host "1.2.3.4", user "<user>", database "<db>"`;
}
let ipAddress = '';
if (psqlError) {
const ipAddresses = psqlError.match(FirewallConstants.ipv4MatchPattern);
if (ipAddresses) {
ipAddress = ipAddresses[0];
} else {
throw new Error(`Unable to detect client IP Address: ${psqlError}`);
}
}
return ipAddress;
test('detectIPAddress should return ip address', async () => {
const psqlError: string = `psql: error: could not connect to server: FATAL: no pg_hba.conf entry for host "1.2.3.4", user "<user>", database "<db>"`;

jest.spyOn(PsqlToolRunner, 'executePsqlCommand').mockImplementation(async (_connectionString: string, _command: string, options: any = {}) => {
options.listeners.stderr(Buffer.from(psqlError));
throw new Error(psqlError);
});
jest.spyOn(HttpClient.prototype, 'getJson').mockResolvedValue({
statusCode: 200,
result: {
ip: '1.2.3.4',
},
headers: {},
});
});

test('detectIPAddress should return ip address', async () => {
await PsqlUtils.detectIPAddress("");
expect(detectIPAddressSpy).toReturnWith("1.2.3.4");
return PsqlUtils.detectIPAddress("").then(ipAddress => expect(ipAddress).toEqual("1.2.3.4"));
});

test('detectIPAddress should return empty string', async () => {
await PsqlUtils.detectIPAddress(CONFIGURED);
expect(detectIPAddressSpy).toReturnWith("");
return PsqlUtils.detectIPAddress("").then(ipAddress => expect(ipAddress).toEqual(""));
})

});
Loading

0 comments on commit f82d2b2

Please sign in to comment.