Skip to content

Commit

Permalink
Added support for google clientId and improved error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexF4Dev committed Jul 11, 2024
1 parent 33b7d8a commit 81cb1d6
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 144 deletions.
115 changes: 0 additions & 115 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,6 @@
"webpack-cli": "^5.0.1"
},
"dependencies": {
"@azure/msal-node": "^2.10.0",
"@opentelemetry/tracing": "^0.24.0",
"adal-node": "^0.2.4",
"applicationinsights": "^1.0.5",
"aws4": "^1.9.1",
Expand Down
93 changes: 67 additions & 26 deletions src/utils/auth/oidcClient.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { ILoopbackClient, ServerAuthorizationCodeResponse } from "@azure/msal-node";
import * as crypto from 'crypto';
import * as http from "http";
import * as jws from 'jws';
Expand All @@ -7,7 +6,29 @@ import { v4 as uuid } from 'uuid';
import { env, Uri, window } from "vscode";
import { MemoryCache } from '../memoryCache';

export class CodeLoopbackClient implements ILoopbackClient {
type ServerAuthorizationCodeResponse = {
// Success case
code?: string;
client_info?: string;
state?: string;
cloud_instance_name?: string;
cloud_instance_host_name?: string;
cloud_graph_host_name?: string;
msgraph_host?: string;
// Error case
error?: string;
error_uri?: string;
error_description?: string;
suberror?: string;
timestamp?: string;
trace_id?: string;
correlation_id?: string;
claims?: string;
// Native Account ID
accountId?: string;
};

export class CodeLoopbackClient {
port: number = 0; // default port, which will be set to a random available port
private server!: http.Server;

Expand Down Expand Up @@ -64,6 +85,9 @@ export class CodeLoopbackClient implements ILoopbackClient {
const redirectUri = await this.getRedirectUri();
res.writeHead(302, { location: redirectUri }); // Prevent auth code from being saved in the browser history
res.end();
} else {
res.end(`Authorization Server Error:${JSON.stringify(authCodeResponse)}`);
reject(new Error(`Authorization Server Error:${JSON.stringify(authCodeResponse)}`));
}
resolve({ url, ...authCodeResponse });
});
Expand Down Expand Up @@ -198,7 +222,7 @@ export class CodeLoopbackClient implements ILoopbackClient {

export const CALLBACK_PORT = 7777;

export const remoteOutput = window.createOutputChannel("oidc");
export const remoteOutput = window.createOutputChannel('REST-OIDC');

interface TokenInformation {
access_token: string;
Expand All @@ -224,7 +248,7 @@ export class OidcClient {
authorizeEndpoint: string,
tokenEndpoint: string,
scopes: string,
audience: string,): Promise<string | undefined> {
audience: string): Promise<string | undefined> {
const key = `${clientId}-${callbackPort}-${authorizeEndpoint}-${tokenEndpoint}-${scopes}-${audience}`;
const cache = MemoryCache.createOrGet<OidcClient>('oidc');

Expand All @@ -246,10 +270,18 @@ export class OidcClient {
}

public async getAccessToken(): Promise<string | undefined> {
const tryDecode = (token: string): any => {
try {
const { payload } = jws.decode(token) ?? {};
return JSON.parse(payload);
} catch (ex) {
return null;
}
}

if (this._tokenInformation?.access_token) {
const { payload } = jws.decode(this._tokenInformation.access_token) ?? {};
const payloadJson = JSON.parse(payload);
if (payloadJson.exp && payloadJson.exp > Date.now() / 1000) {
const payloadJson = tryDecode(this._tokenInformation.access_token);
if (payloadJson === null || payloadJson.exp && payloadJson.exp > Date.now() / 1000) {
return this._tokenInformation.access_token;
} else {
return this.getAccessTokenByRefreshToken(this._tokenInformation.refresh_token, this.clientId).then((resp) => {
Expand All @@ -262,7 +294,7 @@ export class OidcClient {
const nonceId = uuid();

// Retrieve all required scopes
const scopes = this.getScopes((this.scopes ?? "").split(' '));
const scopes = this.getScopes((this.scopes ?? "").split(','));

const codeVerifier = toBase64UrlEncoding(crypto.randomBytes(32));
const codeChallenge = toBase64UrlEncoding(sha256(codeVerifier));
Expand Down Expand Up @@ -310,13 +342,13 @@ export class OidcClient {

const loopbackClient = await CodeLoopbackClient.initialize(this.callbackPort);


try {
await env.openExternal(uri);
const callBackResp = await loopbackClient.listenForAuthCode();
const codeExchangePromise = this._handleCallback(Uri.parse(callBackResp.url));

const resp = await Promise.race([
const resp = await Promise.race([
codeExchangePromise,
new Promise<null>((_, reject) => setTimeout(() => reject('Cancelled'), 60000))
]);
Expand Down Expand Up @@ -358,7 +390,7 @@ export class OidcClient {
}


private async _handleCallback(uri: Uri): Promise<TokenInformation> {
private async _handleCallback(uri: Uri): Promise<TokenInformation | undefined> {
const query = new URLSearchParams(uri.query);
const code = query.get('code');
const stateId = query.get('state');
Expand Down Expand Up @@ -389,18 +421,26 @@ export class OidcClient {
code_verifier: codeVerifier,
redirect_uri: this.redirectUri,
}).toString();
try {
const response = await fetch(`${this.tokenEndpoint}`, {
method: 'POST',
headers: {
"Content-Type": "application/x-www-form-urlencoded",
'Content-Length': postData.length.toString()
},
body: postData
});
const json = await response.json();
const { access_token, refresh_token } = json;
if (!access_token) {
remoteOutput.appendLine(`Failed to retrieve access token: ${response.status} ${JSON.stringify(json)}`);
}

const response = await fetch(`${this.tokenEndpoint}`, {
method: 'POST',
headers: {
"Content-Type": "application/x-www-form-urlencoded",
'Content-Length': postData.length.toString()
},
body: postData
});

const { access_token, refresh_token } = await response.json();
return { access_token, refresh_token };
return { access_token, refresh_token };
} catch (ex) {
remoteOutput.appendLine(`Failed to retrieve access token: ${ex}`);
return undefined;
}
}

/**
Expand All @@ -410,9 +450,10 @@ export class OidcClient {
private getScopes(scopes: string[] = []): string[] {
const modifiedScopes = [...scopes];

if (!modifiedScopes.includes('offline_access')) {
modifiedScopes.push('offline_access');
}
// if (!modifiedScopes.includes('offline_access')) {
// modifiedScopes.push('offline_access');
// }

if (!modifiedScopes.includes('openid')) {
modifiedScopes.push('openid');
}
Expand All @@ -436,4 +477,4 @@ export function toBase64UrlEncoding(buffer: Buffer) {

export function sha256(buffer: string | Uint8Array): Buffer {
return crypto.createHash('sha256').update(buffer).digest();
}
}
2 changes: 1 addition & 1 deletion src/utils/httpVariableProviders/systemVariableProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export class SystemVariableProvider implements HttpVariableProvider {
private readonly requestUrlRegex: RegExp = /^(?:[^\s]+\s+)([^:]*:\/\/\/?[^/\s]*\/?)/;

private readonly aadRegex: RegExp = new RegExp(`\\s*\\${Constants.AzureActiveDirectoryVariableName}(\\s+(${Constants.AzureActiveDirectoryForceNewOption}))?(\\s+(ppe|public|cn|de|us))?(\\s+([^\\.]+\\.[^\\}\\s]+|[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}))?(\\s+aud:([^\\.]+\\.[^\\}\\s]+|[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}))?\\s*`);
private readonly oidcRegex: RegExp = new RegExp(`\\s*(\\${Constants.OidcVariableName})(?:\\s+(${Constants.OIdcForceNewOption}))?(?:\\s*clientId:([\\w|-]+))?(?:\\s*issuer:([\\w|\.|:|/]+))?(?:\\s*callbackPort:([\\w|_]+))?(?:\\s*authorizeEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*tokenEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*scopes:([\\w|,]+))?(?:\\s*audience:(\\w+))?`);
private readonly oidcRegex: RegExp = new RegExp(`\\s*(\\${Constants.OidcVariableName})(?:\\s+(${Constants.OIdcForceNewOption}))?(?:\\s*clientId:([\\w|\.|:|/|_|-]+))?(?:\\s*issuer:([\\w|\.|:|/]+))?(?:\\s*callbackPort:([\\w|_]+))?(?:\\s*authorizeEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*tokenEndpoint:([\\w|\.|:|/|_|-]+))?(?:\\s*scopes:([\\w|\.|:|/|_|-]+))?(?:\\s*audience:([\\w|\.|:|/|_|-]+))?`);

private readonly innerSettingsEnvironmentVariableProvider: EnvironmentVariableProvider = EnvironmentVariableProvider.Instance;
private static _instance: SystemVariableProvider;
Expand Down

0 comments on commit 81cb1d6

Please sign in to comment.