improve logic around auth server

This commit is contained in:
Tyler Leonhardt 2022-02-06 14:07:13 -08:00
parent a0dd9ed39f
commit e485dc292f
No known key found for this signature in database
GPG key ID: 1BC2B6244363E77E
3 changed files with 189 additions and 178 deletions

View file

@ -1,10 +1,9 @@
<!-- Copyright (C) Microsoft Corporation. All rights reserved. -->
<!DOCTYPE html>
<html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<title>Azure Account - Sign In</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" type="text/css" media="screen" href="auth.css" />

View file

@ -10,12 +10,13 @@ import * as vscode from 'vscode';
import * as nls from 'vscode-nls';
import { v4 as uuid } from 'uuid';
import fetch, { Response } from 'node-fetch';
import { createServer, startServer } from './authServer';
import { Keychain } from './keychain';
import Logger from './logger';
import { toBase64UrlEncoding } from './utils';
import { sha256 } from './env/node/sha256';
import { BetterTokenStorage, IDidChangeInOtherWindowEvent } from './betterSecretStorage';
import { LoopbackAuthServer } from './authServer';
import path = require('path');
const localize = nls.loadMessageBundle();
@ -238,63 +239,42 @@ export class AzureActiveDirectoryService {
}
private async createSessionWithLocalServer(scopeData: IScopeData) {
const nonce = randomBytes(16).toString('base64');
const { server, redirectPromise, codePromise } = createServer(nonce);
const codeVerifier = toBase64UrlEncoding(randomBytes(32).toString('base64'));
const codeChallenge = toBase64UrlEncoding(await sha256(codeVerifier));
const qs = querystring.stringify({
response_type: 'code',
response_mode: 'query',
client_id: scopeData.clientId,
redirect_uri: redirectUrl,
scope: scopeData.scopesToSend,
prompt: 'select_account',
code_challenge_method: 'S256',
code_challenge: codeChallenge,
});
const loginUrl = `${loginEndpointUrl}${scopeData.tenant}/oauth2/v2.0/authorize?${qs}`;
const server = new LoopbackAuthServer(path.join(__dirname, '../media'), loginUrl);
await server.start();
server.state = `${server.port},${encodeURIComponent(server.nonce)}`;
let token: IToken | undefined;
let codeToExchange;
try {
const port = await startServer(server);
vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${port}/signin?nonce=${encodeURIComponent(nonce)}`));
const redirectReq = await redirectPromise;
if ('err' in redirectReq) {
const { err, res } = redirectReq;
res.writeHead(302, { Location: `/?error=${encodeURIComponent(err && err.message || 'Unknown error')}` });
res.end();
throw err;
}
const host = redirectReq.req.headers.host || '';
const updatedPortStr = (/^[^:]+:(\d+)$/.exec(Array.isArray(host) ? host[0] : host) || [])[1];
const updatedPort = updatedPortStr ? parseInt(updatedPortStr, 10) : port;
const state = `${updatedPort},${encodeURIComponent(nonce)}`;
const codeVerifier = toBase64UrlEncoding(randomBytes(32).toString('base64'));
const codeChallenge = toBase64UrlEncoding(await sha256(codeVerifier));
const loginUrl = `${loginEndpointUrl}${scopeData.tenant}/oauth2/v2.0/authorize?response_type=code&response_mode=query&client_id=${encodeURIComponent(scopeData.clientId)}&redirect_uri=${encodeURIComponent(redirectUrl)}&state=${state}&scope=${encodeURIComponent(scopeData.scopesToSend)}&prompt=select_account&code_challenge_method=S256&code_challenge=${codeChallenge}`;
redirectReq.res.writeHead(302, { Location: loginUrl });
redirectReq.res.end();
const codeRes = await codePromise;
const res = codeRes.res;
try {
if ('err' in codeRes) {
throw codeRes.err;
}
token = await this.exchangeCodeForToken(codeRes.code, codeVerifier, scopeData);
if (token.expiresIn) {
this.setSessionTimeout(token.sessionId, token.refreshToken, scopeData, token.expiresIn * AzureActiveDirectoryService.REFRESH_TIMEOUT_MODIFIER);
}
await this.setToken(token, scopeData);
Logger.info(`Login successful for scopes: ${scopeData.scopeStr}`);
res.writeHead(302, { Location: '/' });
const session = await this.convertToSession(token);
return session;
} catch (err) {
res.writeHead(302, { Location: `/?error=${encodeURIComponent(err && err.message || 'Unknown error')}` });
throw err;
} finally {
res.end();
}
vscode.env.openExternal(vscode.Uri.parse(`http://localhost:${server.port}/signin?nonce=${encodeURIComponent(server.nonce)}`));
const { code } = await server.waitForOAuthResponse();
codeToExchange = code;
} finally {
setTimeout(() => {
server.close();
void server.stop();
}, 5000);
}
const token = await this.exchangeCodeForToken(codeToExchange, codeVerifier, scopeData);
if (token.expiresIn) {
this.setSessionTimeout(token.sessionId, token.refreshToken, scopeData, token.expiresIn * AzureActiveDirectoryService.REFRESH_TIMEOUT_MODIFIER);
}
await this.setToken(token, scopeData);
Logger.info(`Login successful for scopes: ${scopeData.scopeStr}`);
const session = await this.convertToSession(token);
return session;
}
private async createSessionWithoutLocalServer(scopeData: IScopeData): Promise<vscode.AuthenticationSession> {

View file

@ -2,65 +2,13 @@
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import * as http from 'http';
import * as url from 'url';
import { URL } from 'url';
import * as fs from 'fs';
import * as path from 'path';
import { randomBytes } from 'crypto';
interface Deferred<T> {
resolve: (result: T | Promise<T>) => void;
reject: (reason: any) => void;
}
/**
* Asserts that the argument passed in is neither undefined nor null.
*/
function assertIsDefined<T>(arg: T | null | undefined): T {
if (typeof (arg) === 'undefined' || arg === null) {
throw new Error('Assertion Failed: argument is undefined or null');
}
return arg;
}
export async function startServer(server: http.Server): Promise<string> {
let portTimer: NodeJS.Timer;
function cancelPortTimer() {
clearTimeout(portTimer);
}
const port = new Promise<string>((resolve, reject) => {
portTimer = setTimeout(() => {
reject(new Error('Timeout waiting for port'));
}, 5000);
server.on('listening', () => {
const address = server.address();
if (typeof address === 'string') {
resolve(address);
} else {
resolve(assertIsDefined(address).port.toString());
}
});
server.on('error', _ => {
reject(new Error('Error listening to server'));
});
server.on('close', () => {
reject(new Error('Closed'));
});
server.listen(0, '127.0.0.1');
});
port.then(cancelPortTimer, cancelPortTimer);
return port;
}
function sendFile(res: http.ServerResponse, filepath: string, contentType: string) {
function sendFile(res: http.ServerResponse, filepath: string) {
fs.readFile(filepath, (err, body) => {
if (err) {
console.error(err);
@ -68,89 +16,173 @@ function sendFile(res: http.ServerResponse, filepath: string, contentType: strin
res.end();
} else {
res.writeHead(200, {
'Content-Length': body.length,
'Content-Type': contentType
'content-length': body.length,
});
res.end(body);
}
});
}
async function callback(nonce: string, reqUrl: url.Url): Promise<string> {
const query = reqUrl.query;
if (!query || typeof query === 'string') {
throw new Error('No query received.');
}
let error = query.error_description || query.error;
if (!error) {
const state = (query.state as string) || '';
const receivedNonce = (state.split(',')[1] || '').replace(/ /g, '+');
if (receivedNonce !== nonce) {
error = 'Nonce does not match.';
}
}
const code = query.code as string;
if (!error && code) {
return code;
}
throw new Error((error as string) || 'No code received.');
interface IOAuthResult {
code: string;
state: string;
}
export function createServer(nonce: string) {
type RedirectResult = { req: http.IncomingMessage; res: http.ServerResponse } | { err: any; res: http.ServerResponse };
let deferredRedirect: Deferred<RedirectResult>;
const redirectPromise = new Promise<RedirectResult>((resolve, reject) => deferredRedirect = { resolve, reject });
interface ILoopbackServer {
/**
* If undefined, the server is not started yet.
*/
port: number | undefined;
type CodeResult = { code: string; res: http.ServerResponse } | { err: any; res: http.ServerResponse };
let deferredCode: Deferred<CodeResult>;
const codePromise = new Promise<CodeResult>((resolve, reject) => deferredCode = { resolve, reject });
/**
* The nonce used
*/
nonce: string;
const codeTimer = setTimeout(() => {
deferredCode.reject(new Error('Timeout waiting for code'));
}, 5 * 60 * 1000);
/**
* The state parameter used in the OAuth flow.
*/
state: string | undefined;
function cancelCodeTimer() {
clearTimeout(codeTimer);
/**
* Starts the server.
* @returns The port to listen on.
* @throws If the server fails to start.
* @throws If the server is already started.
*/
start(): Promise<number>;
/**
* Stops the server.
* @throws If the server is not started.
* @throws If the server fails to stop.
*/
stop(): Promise<void>;
/**
* Returns a promise that resolves to the result of the OAuth flow.
*/
waitForOAuthResponse(): Promise<IOAuthResult>;
}
export class LoopbackAuthServer implements ILoopbackServer {
private readonly _server: http.Server;
private readonly _resultPromise: Promise<IOAuthResult>;
private _startingRedirect: URL;
public nonce = randomBytes(16).toString('base64');
public port: number | undefined;
public set state(state: string | undefined) {
if (state) {
this._startingRedirect.searchParams.set('state', state);
} else {
this._startingRedirect.searchParams.delete('state');
}
}
public get state(): string | undefined {
return this._startingRedirect.searchParams.get('state') ?? undefined;
}
const server = http.createServer(function (req, res) {
const reqUrl = url.parse(req.url!, /* parseQueryString */ true);
switch (reqUrl.pathname) {
case '/signin': {
const receivedNonce = ((reqUrl.query.nonce as string) || '').replace(/ /g, '+');
if (receivedNonce === nonce) {
deferredRedirect.resolve({ req, res });
} else {
const err = new Error('Nonce does not match.');
deferredRedirect.resolve({ err, res });
constructor(serveRoot: string, startingRedirect: string) {
if (!serveRoot) {
throw new Error('serveRoot must be defined');
}
if (!startingRedirect) {
throw new Error('startingRedirect must be defined');
}
this._startingRedirect = new URL(startingRedirect);
let deferred: { resolve: (result: IOAuthResult) => void; reject: (reason: any) => void };
this._resultPromise = new Promise<IOAuthResult>((resolve, reject) => deferred = { resolve, reject });
this._server = http.createServer((req, res) => {
const reqUrl = new URL(req.url!, `http://${req.headers.host}`);
switch (reqUrl.pathname) {
case '/signin': {
const receivedNonce = (reqUrl.searchParams.get('nonce') ?? '').replace(/ /g, '+');
if (receivedNonce !== this.nonce) {
res.writeHead(302, { location: `/?error=${encodeURIComponent('Nonce does not match.')}` });
res.end();
}
res.writeHead(302, { location: this._startingRedirect.toString() });
res.end();
break;
}
break;
case '/callback': {
const code = reqUrl.searchParams.get('code') ?? undefined;
const state = reqUrl.searchParams.get('state') ?? undefined;
if (!code || !state) {
res.writeHead(400);
res.end();
return;
}
if (this.state !== state) {
res.writeHead(302, { location: `/?error=${encodeURIComponent('State does not match.')}` });
res.end();
throw new Error('State does not match.');
}
deferred.resolve({ code, state });
res.writeHead(302, { location: '/' });
res.end();
break;
}
// Serve the static files
case '/':
sendFile(res, path.join(serveRoot, 'index.html'));
break;
default:
// substring to get rid of leading '/'
sendFile(res, path.join(serveRoot, reqUrl.pathname.substring(1)));
break;
}
case '/':
sendFile(res, path.join(__dirname, '../media/auth.html'), 'text/html; charset=utf-8');
break;
case '/auth.css':
sendFile(res, path.join(__dirname, '../media/auth.css'), 'text/css; charset=utf-8');
break;
case '/callback':
deferredCode.resolve(callback(nonce, reqUrl)
.then(code => ({ code, res }), err => ({ err, res })));
break;
default:
res.writeHead(404);
res.end();
break;
}
});
});
}
codePromise.then(cancelCodeTimer, cancelCodeTimer);
return {
server,
redirectPromise,
codePromise
};
public start(): Promise<number> {
return new Promise<number>((resolve, reject) => {
if (this._server.listening) {
throw new Error('Server is already started');
}
const portTimeout = setTimeout(() => {
reject(new Error('Timeout waiting for port'));
}, 5000);
this._server.on('listening', () => {
const address = this._server.address();
if (typeof address === 'string') {
this.port = parseInt(address);
} else if (address instanceof Object) {
this.port = address.port;
} else {
throw new Error('Unable to determine port');
}
clearTimeout(portTimeout);
resolve(this.port);
});
this._server.on('error', err => {
reject(new Error(`Error listening to server: ${err}`));
});
this._server.on('close', () => {
reject(new Error('Closed'));
});
this._server.listen(0, '127.0.0.1');
});
}
public stop(): Promise<void> {
return new Promise<void>((resolve, reject) => {
if (!this._server.listening) {
throw new Error('Server is not started');
}
this._server.close((err) => {
if (err) {
reject(err);
} else {
resolve();
}
});
});
}
public waitForOAuthResponse(): Promise<IOAuthResult> {
return this._resultPromise;
}
}