Skip to content

Commit

Permalink
Added support for google clientId and improved error handling
Browse files Browse the repository at this point in the history
Added support for google clientId and improved error handling
  • Loading branch information
AlexF4Dev committed Jul 12, 2024
1 parent 33b7d8a commit 995dc91
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 145 deletions.
115 changes: 0 additions & 115 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,6 @@
"webpack-cli": "^5.0.1"
},
"dependencies": {
"@azure/msal-node": "^2.10.0",
"@opentelemetry/tracing": "^0.24.0",
"adal-node": "^0.2.4",
"applicationinsights": "^1.0.5",
"aws4": "^1.9.1",
Expand Down
93 changes: 66 additions & 27 deletions src/utils/auth/oidcClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { ILoopbackClient, ServerAuthorizationCodeResponse } from "@azure/msal-node";
import * as crypto from 'crypto';
import * as http from "http";
import * as jws from 'jws';
Expand All @@ -7,7 +6,29 @@ import { v4 as uuid } from 'uuid';
import { env, Uri, window } from "vscode";
import { MemoryCache } from '../memoryCache';

export class CodeLoopbackClient implements ILoopbackClient {
type ServerAuthorizationCodeResponse = {
// Success case
code?: string;
client_info?: string;
state?: string;
cloud_instance_name?: string;
cloud_instance_host_name?: string;
cloud_graph_host_name?: string;
msgraph_host?: string;
// Error case
error?: string;
error_uri?: string;
error_description?: string;
suberror?: string;
timestamp?: string;
trace_id?: string;
correlation_id?: string;
claims?: string;
// Native Account ID
accountId?: string;
};

export class CodeLoopbackClient {
port: number = 0; // default port, which will be set to a random available port
private server!: http.Server;

Expand Down Expand Up @@ -64,6 +85,9 @@ export class CodeLoopbackClient implements ILoopbackClient {
const redirectUri = await this.getRedirectUri();
res.writeHead(302, { location: redirectUri }); // Prevent auth code from being saved in the browser history
res.end();
} else {
res.end(`Authorization Server Error:${JSON.stringify(authCodeResponse)}`);

Check failure

Code scanning / CodeQL

Reflected cross-site scripting High

Cross-site scripting vulnerability due to a
user-provided value
.
reject(new Error(`Authorization Server Error:${JSON.stringify(authCodeResponse)}`));
}
resolve({ url, ...authCodeResponse });
});
Expand Down Expand Up @@ -174,7 +198,6 @@ export class CodeLoopbackClient implements ILoopbackClient {
}
return "";
}

/**
* Parses string into an object.
*
Expand All @@ -198,7 +221,7 @@ export class CodeLoopbackClient implements ILoopbackClient {

export const CALLBACK_PORT = 7777;

export const remoteOutput = window.createOutputChannel("oidc");
export const remoteOutput = window.createOutputChannel('REST-OIDC');

interface TokenInformation {
access_token: string;
Expand All @@ -224,7 +247,7 @@ export class OidcClient {
authorizeEndpoint: string,
tokenEndpoint: string,
scopes: string,
audience: string,): Promise<string | undefined> {
audience: string): Promise<string | undefined> {
const key = `${clientId}-${callbackPort}-${authorizeEndpoint}-${tokenEndpoint}-${scopes}-${audience}`;
const cache = MemoryCache.createOrGet<OidcClient>('oidc');

Expand All @@ -246,10 +269,17 @@ export class OidcClient {
}

public async getAccessToken(): Promise<string | undefined> {
const tryDecode = (token: string): any => {
try {
const { payload } = jws.decode(token) ?? {};
return JSON.parse(payload);
} catch (ex) {
return null;
}
};
if (this._tokenInformation?.access_token) {
const { payload } = jws.decode(this._tokenInformation.access_token) ?? {};
const payloadJson = JSON.parse(payload);
if (payloadJson.exp && payloadJson.exp > Date.now() / 1000) {
const payloadJson = tryDecode(this._tokenInformation.access_token);
if (payloadJson === null || payloadJson.exp && payloadJson.exp > Date.now() / 1000) {
return this._tokenInformation.access_token;
} else {
return this.getAccessTokenByRefreshToken(this._tokenInformation.refresh_token, this.clientId).then((resp) => {
Expand All @@ -262,7 +292,7 @@ export class OidcClient {
const nonceId = uuid();

// Retrieve all required scopes
const scopes = this.getScopes((this.scopes ?? "").split(' '));
const scopes = this.getScopes((this.scopes ?? "").split(','));

const codeVerifier = toBase64UrlEncoding(crypto.randomBytes(32));
const codeChallenge = toBase64UrlEncoding(sha256(codeVerifier));
Expand Down Expand Up @@ -310,13 +340,13 @@ export class OidcClient {

const loopbackClient = await CodeLoopbackClient.initialize(this.callbackPort);


try {
await env.openExternal(uri);
const callBackResp = await loopbackClient.listenForAuthCode();
const codeExchangePromise = this._handleCallback(Uri.parse(callBackResp.url));

const resp = await Promise.race([
const resp = await Promise.race([
codeExchangePromise,
new Promise<null>((_, reject) => setTimeout(() => reject('Cancelled'), 60000))
]);
Expand Down Expand Up @@ -358,7 +388,7 @@ export class OidcClient {
}


private async _handleCallback(uri: Uri): Promise<TokenInformation> {
private async _handleCallback(uri: Uri): Promise<TokenInformation | undefined> {
const query = new URLSearchParams(uri.query);
const code = query.get('code');
const stateId = query.get('state');
Expand Down Expand Up @@ -389,18 +419,26 @@ export class OidcClient {
code_verifier: codeVerifier,
redirect_uri: this.redirectUri,
}).toString();
try {
const response = await fetch(`${this.tokenEndpoint}`, {
method: 'POST',
headers: {
"Content-Type": "application/x-www-form-urlencoded",
'Content-Length': postData.length.toString()
},
body: postData
});
const json = await response.json();
const { access_token, refresh_token } = json;
if (!access_token) {
remoteOutput.appendLine(`Failed to retrieve access token: ${response.status} ${JSON.stringify(json)}`);
}

const response = await fetch(`${this.tokenEndpoint}`, {
method: 'POST',
headers: {
"Content-Type": "application/x-www-form-urlencoded",
'Content-Length': postData.length.toString()
},
body: postData
});

const { access_token, refresh_token } = await response.json();
return { access_token, refresh_token };
return { access_token, refresh_token };
} catch (ex) {
remoteOutput.appendLine(`Failed to retrieve access token: ${ex}`);
return undefined;
}
}

/**
Expand All @@ -410,9 +448,10 @@ export class OidcClient {
private getScopes(scopes: string[] = []): string[] {
const modifiedScopes = [...scopes];

if (!modifiedScopes.includes('offline_access')) {
modifiedScopes.push('offline_access');
}
// if (!modifiedScopes.includes('offline_access')) {
// modifiedScopes.push('offline_access');
// }

if (!modifiedScopes.includes('openid')) {
modifiedScopes.push('openid');
}
Expand All @@ -436,4 +475,4 @@ export function toBase64UrlEncoding(buffer: Buffer) {

export function sha256(buffer: string | Uint8Array): Buffer {
return crypto.createHash('sha256').update(buffer).digest();
}
}
2 changes: 1 addition & 1 deletion src/utils/httpVariableProviders/systemVariableProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export class SystemVariableProvider implements HttpVariableProvider {
private readonly requestUrlRegex: RegExp = /^(?:[^\s]+\s+)([^:]*:\/\/\/?[^/\s]*\/?)/;

private readonly aadRegex: RegExp = new RegExp(`\\s*\\${Constants.AzureActiveDirectoryVariableName}(\\s+(${Constants.AzureActiveDirectoryForceNewOption}))?(\\s+(ppe|public|cn|de|us))?(\\s+([^\\.]+\\.[^\\}\\s]+|[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}))?(\\s+aud:([^\\.]+\\.[^\\}\\s]+|[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}))?\\s*`);
private readonly oidcRegex: RegExp = new RegExp(`\\s*(\\${Constants.OidcVariableName})(?:\\s+(${Constants.OIdcForceNewOption}))?(?:\\s*clientId:([\\w|-]+))?(?:\\s*issuer:([\\w|\.|:|/]+))?(?:\\s*callbackPort:([\\w|_]+))?(?:\\s*authorizeEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*tokenEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*scopes:([\\w|,]+))?(?:\\s*audience:(\\w+))?`);
private readonly oidcRegex: RegExp = new RegExp(`\\s*(\\${Constants.OidcVariableName})(?:\\s+(${Constants.OIdcForceNewOption}))?(?:\\s*clientId:([\\w|\.|:|/|_|-]+))?(?:\\s*issuer:([\\w|\.|:|/]+))?(?:\\s*callbackPort:([\\w|_]+))?(?:\\s*authorizeEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*tokenEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*scopes:([\\w|\.|:|/|_|-]+))?(?:\\s*audience:([\\w|\.|:|/|_|-]+))?`);

Check failure

Code scanning / CodeQL

Useless regular-expression character escape High

The escape sequence '.' is equivalent to just '.', so the sequence may still represent a meta-character when it is used in a
regular expression
.

Check failure

Code scanning / CodeQL

Useless regular-expression character escape High

The escape sequence '.' is equivalent to just '.', so the sequence may still represent a meta-character when it is used in a
regular expression
.

Check failure

Code scanning / CodeQL

Useless regular-expression character escape High

The escape sequence '.' is equivalent to just '.', so the sequence may still represent a meta-character when it is used in a
regular expression
.

Check failure

Code scanning / CodeQL

Useless regular-expression character escape High

The escape sequence '.' is equivalent to just '.', so the sequence may still represent a meta-character when it is used in a
regular expression
.

Check failure

Code scanning / CodeQL

Useless regular-expression character escape High

The escape sequence '.' is equivalent to just '.', so the sequence may still represent a meta-character when it is used in a
regular expression
.

Check failure

Code scanning / CodeQL

Useless regular-expression character escape High

The escape sequence '.' is equivalent to just '.', so the sequence may still represent a meta-character when it is used in a
regular expression
.

private readonly innerSettingsEnvironmentVariableProvider: EnvironmentVariableProvider = EnvironmentVariableProvider.Instance;
private static _instance: SystemVariableProvider;
Expand Down

0 comments on commit 995dc91

Please sign in to comment.