diff --git a/lib/Constants.js b/lib/Constants.js index 56705545..4823ddfb 100644 --- a/lib/Constants.js +++ b/lib/Constants.js @@ -1,15 +1,11 @@ "use strict"; Object.defineProperty(exports, "__esModule", { value: true }); -exports.PsqlConstants = exports.FirewallConstants = exports.FileConstants = void 0; +exports.PsqlConstants = exports.FileConstants = void 0; class FileConstants { } exports.FileConstants = FileConstants; // regex checks that string should end with .sql and if folderPath is present, * should not be included in folderPath FileConstants.singleParentDirRegex = /^((?!\*\/).)*(\.sql)$/g; -class FirewallConstants { -} -exports.FirewallConstants = FirewallConstants; -FirewallConstants.ipv4MatchPattern = /\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/; class PsqlConstants { } exports.PsqlConstants = PsqlConstants; diff --git a/lib/Utils/FirewallUtils/ResourceManager.js b/lib/Utils/FirewallUtils/ResourceManager.js index 1020c810..080ef5e4 100644 --- a/lib/Utils/FirewallUtils/ResourceManager.js +++ b/lib/Utils/FirewallUtils/ResourceManager.js @@ -46,31 +46,27 @@ class AzurePSQLResourceManager { getPSQLServer() { return this._resource; } - _populatePSQLServerData(serverName) { + _getPSQLServer(serverType, apiVersion, serverName) { return __awaiter(this, void 0, void 0, function* () { - // trim the cloud hostname suffix from servername - serverName = serverName.split('.')[0]; const httpRequest = { method: 'GET', - uri: this._restClient.getRequestUri('//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/servers', {}, [], '2017-12-01') + uri: this._restClient.getRequestUri(`//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/${serverType}`, {}, [], apiVersion) }; - core.debug(`Get PSQL server '${serverName}' details`); + core.debug(`Get '${serverName}' for PSQL ${serverType} details`); try { const httpResponse = yield this._restClient.beginRequest(httpRequest); if (httpResponse.statusCode !== 200) { throw AzureRestClient_1.ToError(httpResponse); } - const sqlServers = httpResponse.body && httpResponse.body.value; - if (sqlServers && sqlServers.length > 0) { - this._resource = sqlServers.filter((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase())[0]; - if (!this._resource) { - throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`); - } - core.debug(JSON.stringify(this._resource)); - } - else { - throw new Error(`Unable to get details of PSQL server ${serverName}. No PSQL servers were found in the subscription.`); + const sqlServers = ((httpResponse.body && httpResponse.body.value) || []); + const sqlServer = sqlServers.find((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase()); + if (sqlServer) { + this._serverType = serverType; + this._apiVersion = apiVersion; + this._resource = sqlServer; + return true; } + return false; } catch (error) { if (error instanceof AzureRestClient_1.AzureError) { @@ -80,12 +76,22 @@ class AzurePSQLResourceManager { } }); } + _populatePSQLServerData(serverName) { + return __awaiter(this, void 0, void 0, function* () { + // trim the cloud hostname suffix from servername + serverName = serverName.split('.')[0]; + (yield this._getPSQLServer('servers', '2017-12-01', serverName)) || (yield this._getPSQLServer('flexibleServers', '2021-06-01', serverName)); + if (!this._resource) { + throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`); + } + }); + } addFirewallRule(startIpAddress, endIpAddress) { return __awaiter(this, void 0, void 0, function* () { const firewallRuleName = `ClientIPAddress_${Date.now()}`; const httpRequest = { method: 'PUT', - uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRuleName}`, {}, [], '2017-12-01'), + uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRuleName}`, {}, [], this._apiVersion), body: JSON.stringify({ 'properties': { 'startIpAddress': startIpAddress, @@ -122,7 +128,7 @@ class AzurePSQLResourceManager { return __awaiter(this, void 0, void 0, function* () { const httpRequest = { method: 'GET', - uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${ruleName}`, {}, [], '2017-12-01') + uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${ruleName}`, {}, [], this._apiVersion) }; try { const httpResponse = yield this._restClient.beginRequest(httpRequest); @@ -143,7 +149,7 @@ class AzurePSQLResourceManager { return __awaiter(this, void 0, void 0, function* () { const httpRequest = { method: 'DELETE', - uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRule.name}`, {}, [], '2017-12-01') + uri: this._restClient.getRequestUri(`/${this._resource.id}/firewallRules/${firewallRule.name}`, {}, [], this._apiVersion) }; try { const httpResponse = yield this._restClient.beginRequest(httpRequest); diff --git a/lib/Utils/PsqlUtils/PsqlUtils.js b/lib/Utils/PsqlUtils/PsqlUtils.js index 843032c3..1657c90e 100644 --- a/lib/Utils/PsqlUtils/PsqlUtils.js +++ b/lib/Utils/PsqlUtils/PsqlUtils.js @@ -13,10 +13,11 @@ var __importDefault = (this && this.__importDefault) || function (mod) { }; Object.defineProperty(exports, "__esModule", { value: true }); const Constants_1 = require("../../Constants"); -const Constants_2 = require("../../Constants"); const PsqlToolRunner_1 = __importDefault(require("./PsqlToolRunner")); +const http_client_1 = require("@actions/http-client"); class PsqlUtils { static detectIPAddress(connectionString) { + var _a; return __awaiter(this, void 0, void 0, function* () { let psqlError = ''; let ipAddress = ''; @@ -31,16 +32,17 @@ class PsqlUtils { // "SELECT 1" psql command is run to check if psql client is able to connect to DB using the connectionString try { yield PsqlToolRunner_1.default.init(); - yield PsqlToolRunner_1.default.executePsqlCommand(connectionString, Constants_1.PsqlConstants.SELECT_1, options); + yield PsqlToolRunner_1.default.executePsqlCommand(`${connectionString} connect_timeout=10`, Constants_1.PsqlConstants.SELECT_1, options); } - catch (err) { + catch (_b) { if (psqlError) { - const ipAddresses = psqlError.match(Constants_2.FirewallConstants.ipv4MatchPattern); - if (ipAddresses) { - ipAddress = ipAddresses[0]; + const http = new http_client_1.HttpClient(); + try { + const ipv4 = yield http.getJson('https://api.ipify.org?format=json'); + ipAddress = ((_a = ipv4.result) === null || _a === void 0 ? void 0 : _a.ip) || ''; } - else { - throw new Error(`Unable to detect client IP Address: ${psqlError}`); + catch (err) { + throw new Error(`Unable to detect client IP Address: ${err.message}`); } } } diff --git a/package-lock.json b/package-lock.json index 02d1cf11..31f04150 100644 --- a/package-lock.json +++ b/package-lock.json @@ -17,6 +17,14 @@ "@actions/io": "^1.0.1" } }, + "@actions/http-client": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@actions/http-client/-/http-client-2.0.1.tgz", + "integrity": "sha512-PIXiMVtz6VvyaRsGY268qvj57hXQEpsYogYOu2nrQhlf+XCGmZstmuZBbAybUl1nQGnvS1k1eEsQ69ZoD7xlSw==", + "requires": { + "tunnel": "^0.0.6" + } + }, "@actions/io": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/@actions/io/-/io-1.0.2.tgz", diff --git a/package.json b/package.json index 584f9ac4..499f4455 100644 --- a/package.json +++ b/package.json @@ -11,6 +11,7 @@ "dependencies": { "@actions/core": "^1.2.6", "@actions/exec": "^1.0.4", + "@actions/http-client": "^2.0.1", "@actions/io": "^1.0.2", "azure-actions-webclient": "^1.0.11", "crypto": "^1.0.1", diff --git a/src/Constants.ts b/src/Constants.ts index 08aa6067..83f60305 100644 --- a/src/Constants.ts +++ b/src/Constants.ts @@ -3,10 +3,6 @@ export class FileConstants { static readonly singleParentDirRegex = /^((?!\*\/).)*(\.sql)$/g; } -export class FirewallConstants { - static readonly ipv4MatchPattern = /\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b/; -} - export class PsqlConstants { static readonly SELECT_1 = "SELECT 1"; // host, port, dbname, user, password must be present in connection string in any order. diff --git a/src/Utils/FirewallUtils/ResourceManager.ts b/src/Utils/FirewallUtils/ResourceManager.ts index 60b05a03..23318343 100644 --- a/src/Utils/FirewallUtils/ResourceManager.ts +++ b/src/Utils/FirewallUtils/ResourceManager.ts @@ -61,33 +61,29 @@ export default class AzurePSQLResourceManager { return this._resource; } - private async _populatePSQLServerData(serverName: string) { - // trim the cloud hostname suffix from servername - serverName = serverName.split('.')[0]; + private async _getPSQLServer(serverType: string, apiVersion: string, serverName: string) { const httpRequest: WebRequest = { method: 'GET', - uri: this._restClient.getRequestUri('//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/servers', {}, [], '2017-12-01') + uri: this._restClient.getRequestUri(`//subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/${serverType}`, {}, [], apiVersion) } - core.debug(`Get PSQL server '${serverName}' details`); + core.debug(`Get '${serverName}' for PSQL ${serverType} details`); try { const httpResponse = await this._restClient.beginRequest(httpRequest); if (httpResponse.statusCode !== 200) { throw ToError(httpResponse); } - const sqlServers = httpResponse.body && httpResponse.body.value as AzurePSQLServer[]; - if (sqlServers && sqlServers.length > 0) { - this._resource = sqlServers.filter((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase())[0]; - if (!this._resource) { - throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`); - } - - core.debug(JSON.stringify(this._resource)); - } - else { - throw new Error(`Unable to get details of PSQL server ${serverName}. No PSQL servers were found in the subscription.`); + const sqlServers = ((httpResponse.body && httpResponse.body.value) || []) as AzurePSQLServer[]; + const sqlServer = sqlServers.find((sqlResource) => sqlResource.name.toLowerCase() === serverName.toLowerCase()); + if (sqlServer) { + this._serverType = serverType; + this._apiVersion = apiVersion; + this._resource = sqlServer; + return true; } + + return false; } catch(error) { if (error instanceof AzureError) { @@ -98,11 +94,21 @@ export default class AzurePSQLResourceManager { } } + private async _populatePSQLServerData(serverName: string) { + // trim the cloud hostname suffix from servername + serverName = serverName.split('.')[0]; + + (await this._getPSQLServer('servers', '2017-12-01', serverName)) || (await this._getPSQLServer('flexibleServers', '2021-06-01', serverName)); + if (!this._resource) { + throw new Error(`Unable to get details of PSQL server ${serverName}. PSQL server '${serverName}' was not found in the subscription.`); + } + } + public async addFirewallRule(startIpAddress: string, endIpAddress: string): Promise { const firewallRuleName = `ClientIPAddress_${Date.now()}`; const httpRequest: WebRequest = { method: 'PUT', - uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRuleName}`, {}, [], '2017-12-01'), + uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRuleName}`, {}, [], this._apiVersion), body: JSON.stringify({ 'properties': { 'startIpAddress': startIpAddress, @@ -141,7 +147,7 @@ export default class AzurePSQLResourceManager { public async getFirewallRule(ruleName: string): Promise { const httpRequest: WebRequest = { method: 'GET', - uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${ruleName}`, {}, [], '2017-12-01') + uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${ruleName}`, {}, [], this._apiVersion) }; try { @@ -164,7 +170,7 @@ export default class AzurePSQLResourceManager { public async removeFirewallRule(firewallRule: FirewallRule): Promise { const httpRequest: WebRequest = { method: 'DELETE', - uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRule.name}`, {}, [], '2017-12-01') + uri: this._restClient.getRequestUri(`/${this._resource!.id}/firewallRules/${firewallRule.name}`, {}, [], this._apiVersion) }; try { @@ -228,6 +234,8 @@ export default class AzurePSQLResourceManager { }); } + private _serverType?: string; + private _apiVersion?: string; private _resource?: AzurePSQLServer; private _restClient: AzureRestClient; } \ No newline at end of file diff --git a/src/Utils/PsqlUtils/PsqlUtils.ts b/src/Utils/PsqlUtils/PsqlUtils.ts index 999fef18..3e43fafd 100644 --- a/src/Utils/PsqlUtils/PsqlUtils.ts +++ b/src/Utils/PsqlUtils/PsqlUtils.ts @@ -1,6 +1,6 @@ import { PsqlConstants } from "../../Constants"; -import { FirewallConstants } from "../../Constants"; import PsqlToolRunner from "./PsqlToolRunner"; +import { HttpClient } from '@actions/http-client'; export default class PsqlUtils { static async detectIPAddress(connectionString: string): Promise { @@ -14,21 +14,26 @@ export default class PsqlUtils { }, silent: true }; + // "SELECT 1" psql command is run to check if psql client is able to connect to DB using the connectionString try { await PsqlToolRunner.init(); - await PsqlToolRunner.executePsqlCommand(connectionString, PsqlConstants.SELECT_1, options); - } catch(err) { + await PsqlToolRunner.executePsqlCommand(`${connectionString} connect_timeout=10`, PsqlConstants.SELECT_1, options); + } catch { if (psqlError) { - const ipAddresses = psqlError.match(FirewallConstants.ipv4MatchPattern); - if (ipAddresses) { - ipAddress = ipAddresses[0]; - } else { - throw new Error(`Unable to detect client IP Address: ${psqlError}`); + const http = new HttpClient(); + try { + const ipv4 = await http.getJson('https://api.ipify.org?format=json'); + ipAddress = ipv4.result?.ip || ''; + } catch(err) { + throw new Error(`Unable to detect client IP Address: ${err.message}`); } } } return ipAddress; } +} +export interface IPResponse { + ip: string; } \ No newline at end of file diff --git a/src/__tests__/Utils/PsqlUtils.test.ts b/src/__tests__/Utils/PsqlUtils.test.ts index 968c8746..946e172f 100644 --- a/src/__tests__/Utils/PsqlUtils.test.ts +++ b/src/__tests__/Utils/PsqlUtils.test.ts @@ -1,42 +1,35 @@ -import PsqlUtils from "../../Utils/PsqlUtils/PsqlUtils"; -import { FirewallConstants } from "../../Constants"; +import { HttpClient } from '@actions/http-client'; +import PsqlToolRunner from "../../Utils/PsqlUtils/PsqlToolRunner"; +import PsqlUtils, { IPResponse } from "../../Utils/PsqlUtils/PsqlUtils"; jest.mock('../../Utils/PsqlUtils/PsqlToolRunner'); -const CONFIGURED = "configured"; +jest.mock('@actions/http-client'); describe('Testing PsqlUtils', () => { afterEach(() => { - jest.clearAllMocks() + jest.resetAllMocks(); }); - let detectIPAddressSpy: any; - beforeAll(() => { - detectIPAddressSpy = PsqlUtils.detectIPAddress = jest.fn().mockImplementation( (connString: string) => { - let psqlError; - if (connString != CONFIGURED) { - psqlError = `psql: error: could not connect to server: FATAL: no pg_hba.conf entry for host "1.2.3.4", user "", database ""`; - } - let ipAddress = ''; - if (psqlError) { - const ipAddresses = psqlError.match(FirewallConstants.ipv4MatchPattern); - if (ipAddresses) { - ipAddress = ipAddresses[0]; - } else { - throw new Error(`Unable to detect client IP Address: ${psqlError}`); - } - } - return ipAddress; + test('detectIPAddress should return ip address', async () => { + const psqlError: string = `psql: error: could not connect to server: FATAL: no pg_hba.conf entry for host "1.2.3.4", user "", database ""`; + + jest.spyOn(PsqlToolRunner, 'executePsqlCommand').mockImplementation(async (_connectionString: string, _command: string, options: any = {}) => { + options.listeners.stderr(Buffer.from(psqlError)); + throw new Error(psqlError); + }); + jest.spyOn(HttpClient.prototype, 'getJson').mockResolvedValue({ + statusCode: 200, + result: { + ip: '1.2.3.4', + }, + headers: {}, }); - }); - test('detectIPAddress should return ip address', async () => { - await PsqlUtils.detectIPAddress(""); - expect(detectIPAddressSpy).toReturnWith("1.2.3.4"); + return PsqlUtils.detectIPAddress("").then(ipAddress => expect(ipAddress).toEqual("1.2.3.4")); }); test('detectIPAddress should return empty string', async () => { - await PsqlUtils.detectIPAddress(CONFIGURED); - expect(detectIPAddressSpy).toReturnWith(""); + return PsqlUtils.detectIPAddress("").then(ipAddress => expect(ipAddress).toEqual("")); }) }); \ No newline at end of file diff --git a/src/__tests__/Utils/ResourceManager.test.ts b/src/__tests__/Utils/ResourceManager.test.ts new file mode 100644 index 00000000..f446d8be --- /dev/null +++ b/src/__tests__/Utils/ResourceManager.test.ts @@ -0,0 +1,225 @@ +import { IAuthorizer } from 'azure-actions-webclient/Authorizer/IAuthorizer'; +import { ServiceClient as AzureRestClient } from 'azure-actions-webclient/AzureRestClient'; +import AzurePSQLResourceManager, { FirewallRule } from '../../Utils/FirewallUtils/ResourceManager'; + +jest.mock('azure-actions-webclient/AzureRestClient'); + +describe('Testing ResourceManager', () => { + afterEach(() => { + jest.resetAllMocks(); + }); + + it('Initializes resource manager correctly for single server', async () => { + let getRequestUrlSpy = jest.spyOn(AzureRestClient.prototype, 'getRequestUri').mockReturnValue('https://randomUrl/'); + let beginRequestSpy = jest.spyOn(AzureRestClient.prototype, 'beginRequest').mockResolvedValue({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer1', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer1' + }, + { + name: 'testServer2', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer2' + } + ] + }, + statusMessage: 'OK', + headers: [] + }); + + let resourceManager = await AzurePSQLResourceManager.getResourceManager('testServer1', {} as IAuthorizer); + let server = resourceManager.getPSQLServer(); + + expect(server!.name).toEqual('testServer1'); + expect(server!.id).toEqual('/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer1'); + expect(getRequestUrlSpy).toHaveBeenCalledTimes(1); + expect(beginRequestSpy).toHaveBeenCalledTimes(1); + }); + + it('Initializes resource manager correctly for flexible server', async () => { + let getRequestUrlSpy = jest.spyOn(AzureRestClient.prototype, 'getRequestUri').mockReturnValue('https://randomUrl/'); + let beginRequestSpy = jest.spyOn(AzureRestClient.prototype, 'beginRequest').mockResolvedValueOnce({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer1', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer1' + }, + { + name: 'testServer2', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer2' + } + ] + }, + statusMessage: 'OK', + headers: [] + }).mockResolvedValueOnce({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer3', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/flexibleServers/testServer3' + }, + { + name: 'testServer4', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/flexibleServers/testServer4' + } + ] + }, + statusMessage: 'OK', + headers: [] + }); + + let resourceManager = await AzurePSQLResourceManager.getResourceManager('testServer4', {} as IAuthorizer); + let server = resourceManager.getPSQLServer(); + + expect(server!.name).toEqual('testServer4'); + expect(server!.id).toEqual('/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/flexibleServers/testServer4'); + expect(getRequestUrlSpy).toHaveBeenCalledTimes(2); + expect(beginRequestSpy).toHaveBeenCalledTimes(2); + }); + + it('Throws if the server does not exist', async () => { + let getRequestUrlSpy = jest.spyOn(AzureRestClient.prototype, 'getRequestUri').mockReturnValue('https://randomUrl/'); + let beginRequestSpy = jest.spyOn(AzureRestClient.prototype, 'beginRequest').mockResolvedValueOnce({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer1', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer1' + }, + { + name: 'testServer2', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer2' + } + ] + }, + statusMessage: 'OK', + headers: [] + }).mockResolvedValueOnce({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer3', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/flexibleServers/testServer3' + }, + { + name: 'testServer4', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/flexibleServers/testServer4' + } + ] + }, + statusMessage: 'OK', + headers: [] + }); + + let expectedError = `Unable to get details of PSQL server testServer5. PSQL server 'testServer5' was not found in the subscription.`; + try { + await AzurePSQLResourceManager.getResourceManager('testServer5', {} as IAuthorizer); + } catch(error) { + expect(error.message).toEqual(expectedError); + } + expect(getRequestUrlSpy).toHaveBeenCalledTimes(2); + expect(beginRequestSpy).toHaveBeenCalledTimes(2); + }); + + it('Adds firewall rule successfully', async () => { + let getRequestUrlSpy = jest.spyOn(AzureRestClient.prototype, 'getRequestUri').mockReturnValue('https://randomUrl/'); + let beginRequestSpy = jest.spyOn(AzureRestClient.prototype, 'beginRequest').mockResolvedValueOnce({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer1', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer1' + } + ] + }, + statusMessage: 'OK', + headers: [] + }).mockResolvedValueOnce({ + statusCode: 202, + body: {}, + statusMessage: 'OK', + headers: { + 'azure-asyncoperation': 'http://asyncRedirectionURI' + } + }).mockResolvedValueOnce({ + statusCode: 200, + body: { + 'status': 'Succeeded' + }, + statusMessage: 'OK', + headers: {} + }).mockResolvedValueOnce({ + statusCode: 200, + body: { + name: 'FirewallRule' + }, + statusMessage: 'OK', + headers: {} + }); + + let resourceManager = await AzurePSQLResourceManager.getResourceManager('testServer1', {} as IAuthorizer); + let firewallRule = await resourceManager.addFirewallRule('0.0.0.0', '1.1.1.1'); + + expect(firewallRule.name).toEqual('FirewallRule'); + expect(getRequestUrlSpy).toHaveBeenCalledTimes(3); + expect(beginRequestSpy).toHaveBeenCalledTimes(4); + }); + + it('Removes firewall rule successfully', async () => { + let getRequestUrlSpy = jest.spyOn(AzureRestClient.prototype, 'getRequestUri').mockReturnValue('https://randomUrl/'); + let beginRequestSpy = jest.spyOn(AzureRestClient.prototype, 'beginRequest').mockResolvedValueOnce({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer1', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/servers/testServer1' + } + ] + }, + statusMessage: 'OK', + headers: [] + }).mockResolvedValueOnce({ + statusCode: 200, + body: { + value: [ + { + name: 'testServer3', + id: '/subscriptions/SubscriptionId/resourceGroups/testrg/providers/Microsoft.DBforPostgreSQL/flexibleServers/testServer3' + } + ] + }, + statusMessage: 'OK', + headers: [] + }).mockResolvedValueOnce({ + statusCode: 202, + body: {}, + statusMessage: 'OK', + headers: { + 'azure-asyncoperation': 'http://asyncRedirectionURI' + } + }).mockResolvedValueOnce({ + statusCode: 200, + body: { + 'status': 'Succeeded' + }, + statusMessage: 'OK', + headers: {} + }); + + let resourceManager = await AzurePSQLResourceManager.getResourceManager('testServer3', {} as IAuthorizer); + await resourceManager.removeFirewallRule({ name: 'FirewallRule' } as FirewallRule); + + expect(getRequestUrlSpy).toHaveBeenCalledTimes(3); + expect(beginRequestSpy).toHaveBeenCalledTimes(4); + }) +}); \ No newline at end of file diff --git a/tsconfig.json b/tsconfig.json index 6d1f6927..e779dac0 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -4,7 +4,7 @@ // "incremental": true, /* Enable incremental compilation */ "target": "es6", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019', 'ES2020', or 'ESNEXT'. */ "module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', 'es2020', or 'ESNext'. */ - "lib": ["es2020.string"], /* Specify library files to be included in the compilation. */ + "lib": ["es2020.string", "dom"], /* Specify library files to be included in the compilation. */ // "allowJs": true, /* Allow javascript files to be compiled. */ // "checkJs": true, /* Report errors in .js files. */ // "jsx": "preserve", /* Specify JSX code generation: 'preserve', 'react-native', or 'react'. */