Skip to content

Commit

Permalink
Add concept of document safety to disable parsing large files
Browse files Browse the repository at this point in the history
Signed-off-by: worksofliam <[email protected]>
  • Loading branch information
worksofliam committed Jan 9, 2025
1 parent e24853b commit 03c8680
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 95 deletions.
3 changes: 3 additions & 0 deletions src/language/providers/completionProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@ export const completionProvider = languages.registerCompletionItemProvider(
const offset = document.offsetAt(position);

const sqlDoc = getSqlDocument(document);

if (!sqlDoc) return;

const currentStatement = sqlDoc.getStatementByOffset(offset);

const allItems: CompletionItem[] = [];
Expand Down
4 changes: 4 additions & 0 deletions src/language/providers/hoverProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export const openProvider = workspace.onDidOpenTextDocument(async (document) =>
const sqlDoc = getSqlDocument(document);
const defaultSchema = getDefaultSchema();

if (!sqlDoc) return;

for (const statement of sqlDoc.statements) {
const refs = statement.getObjectReferences();
if (refs.length) {
Expand Down Expand Up @@ -64,6 +66,8 @@ export const hoverProvider = languages.registerHoverProvider({ language: `sql` }
const sqlDoc = getSqlDocument(document);
const offset = document.offsetAt(position);

if (!sqlDoc) return;

const tokAt = sqlDoc.getTokenByOffset(offset);
const statementAt = sqlDoc.getStatementByOffset(offset);

Expand Down
9 changes: 8 additions & 1 deletion src/language/providers/logic/parse.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import { TextDocument } from "vscode";
import Document from "../../sql/document";
import { VALID_STATEMENT_LENGTH } from "../../../connection/syntaxChecker/checker";

let cached: Map<string, {ast, version}> = new Map();

export function getSqlDocument(document: TextDocument): Document {
export function getSqlDocument(document: TextDocument): Document|undefined {
if (!isSafeDocument(document)) return undefined;

const uri = document.uri.toString();
const likelyNew = document.uri.scheme === `untitled` && document.version === 1;

Expand All @@ -19,4 +22,8 @@ export function getSqlDocument(document: TextDocument): Document {
cached.set(uri, { ast: newAsp, version: document.version });

return newAsp;
}

export function isSafeDocument(doc: TextDocument): boolean {
return doc.languageId === `sql` && doc.lineCount < VALID_STATEMENT_LENGTH;
}
3 changes: 3 additions & 0 deletions src/language/providers/parameterProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ export const signatureProvider = languages.registerSignatureHelpProvider({ langu
if (remoteAssistIsEnabled()) {

const sqlDoc = getSqlDocument(document);

if (!sqlDoc) return;

const currentStatement = sqlDoc.getStatementByOffset(offset);

if (currentStatement) {
Expand Down
208 changes: 114 additions & 94 deletions src/language/providers/problemProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { commands, CompletionItemKind, Diagnostic, DiagnosticSeverity, languages
import {
SQLType,
} from "../../database/schemas";
import Statement from "../../database/statement";
import Document from "../sql/document";
import { remoteAssistIsEnabled } from "./logic/available";
import Configuration from "../../configuration";
import { SQLStatementChecker, SqlSyntaxError } from "../../connection/syntaxChecker";
import { StatementGroup, StatementType } from "../sql/types";
import { VALID_STATEMENT_LENGTH } from "../../connection/syntaxChecker/checker";
import { getSqlDocument, isSafeDocument } from "./logic/parse";
import path from "path";

export interface CompletionType {
order: string;
Expand Down Expand Up @@ -51,8 +51,10 @@ function shouldShowWarnings() {
const CHECKER_AVAILABLE_CONTEXT = `vscode-db2i.syntax.checkerAvailable`;
const CHECKER_RUNNING_CONTEXT = `vscode-db2i.syntax.checkerRunning`;

export function setCheckerAvailableContext() {
const available = SQLStatementChecker.get() !== undefined;
const checkerAvailable = () => (SQLStatementChecker.get() !== undefined);

export function setCheckerAvailableContext(additionalState = true) {
const available = checkerAvailable() && additionalState;
commands.executeCommand(`setContext`, CHECKER_AVAILABLE_CONTEXT, available);
}

Expand All @@ -73,13 +75,23 @@ export const checkDocumentDefintion = commands.registerCommand(CHECK_DOCUMENT_CO

export const problemProvider = [
workspace.onDidCloseTextDocument(e => {
sqlDiagnosticCollection.delete(e.uri);
// Only clear errors from unsaved files.
if (e.isUntitled) {
sqlDiagnosticCollection.delete(e.uri);
}
}),

workspace.onDidOpenTextDocument(e => {
const isSql = e.languageId === `sql`;
if (isSql && checkOnOpen()) {
validateSqlDocument(e);
if (isSql) {
if (checkerAvailable() && !isSafeDocument(e)) {
const basename = e.fileName ? path.basename(e.fileName) : `Untitled`;
window.showWarningMessage(`The SQL syntax checker is disabled for this document (${basename}) because it is too large.`);
}

if (checkOnOpen()) {
validateSqlDocument(e);
}
}
}),

Expand All @@ -94,8 +106,14 @@ export const problemProvider = [
validateSqlDocument(e.document, e.document.offsetAt(e.contentChanges[0].range.start));
}, getCheckerTimeout());
}
})
}),

window.onDidChangeActiveTextEditor(e => {
const canRun = e && e.document.languageId === `sql` && isSafeDocument(e.document);
setCheckerAvailableContext(canRun);
}),
];

interface SqlDiagnostic extends Diagnostic {
groupId: number;
}
Expand All @@ -104,111 +122,113 @@ async function validateSqlDocument(document: TextDocument, specificStatement?: n
const checker = SQLStatementChecker.get();
if (remoteAssistIsEnabled() && checker && !checkerRunning) {
setCheckerRunningContext(true);
const content = document.getText();
const sqlDocument = new Document(content);

const allGroups = sqlDocument.getStatementGroups();
let statementRanges: StatementRange[] = [];

for (let i = 0; i < allGroups.length; i++) {
const group = allGroups[i];
if (specificStatement) {
// If specificStatement is outside this group, continue
if (
specificStatement < group.range.start ||
(specificStatement > (allGroups[i + 1] ? allGroups[i + 1].range.start : group.range.end))
) {
continue;
const sqlDocument = getSqlDocument(document);

if (sqlDocument) {

const allGroups = sqlDocument.getStatementGroups();
let statementRanges: StatementRange[] = [];

for (let i = 0; i < allGroups.length; i++) {
const group = allGroups[i];
if (specificStatement) {
// If specificStatement is outside this group, continue
if (
specificStatement < group.range.start ||
(specificStatement > (allGroups[i + 1] ? allGroups[i + 1].range.start : group.range.end))
) {
continue;
}
}
}

const range = getStatementRangeFromGroup(group, i);
const range = getStatementRangeFromGroup(group, i);

if (range) {
statementRanges.push(range);
}
if (range) {
statementRanges.push(range);
}

// We also add the surrounding ranges, as we need to check the end of the statement
if (specificStatement) {
for (let j = i - 1; j <= i + 1; j++) {
if (allGroups[j]) {
const nextRange = getStatementRangeFromGroup(allGroups[j], j);
if (nextRange && !statementRanges.some(r => r.groupId === nextRange.groupId)) {
statementRanges.push(nextRange);
// We also add the surrounding ranges, as we need to check the end of the statement
if (specificStatement) {
for (let j = i - 1; j <= i + 1; j++) {
if (allGroups[j]) {
const nextRange = getStatementRangeFromGroup(allGroups[j], j);
if (nextRange && !statementRanges.some(r => r.groupId === nextRange.groupId)) {
statementRanges.push(nextRange);
}
}
}
}

break;
break;
}
}
}


if (statementRanges.length > 0) {
const validStatements = statementRanges.filter(r => r.validate);
const invalidStatements = statementRanges.filter(r => !r.validate);
const sqlStatementContents = validStatements.map(range => content.substring(range.start, range.end));

if (validStatements.length > 0) {
const se = performance.now();
const syntaxChecked = await window.withProgress({ location: ProgressLocation.Window, title: `$(sync-spin) Checking SQL Syntax` }, () => { return checker.checkMultipleStatements(sqlStatementContents) });
const ee = performance.now();
if (statementRanges.length > 0) {
const validStatements = statementRanges.filter(r => r.validate);
const invalidStatements = statementRanges.filter(r => !r.validate);
const sqlStatementContents = validStatements.map(range => sqlDocument.content.substring(range.start, range.end));

if (syntaxChecked) {
if (syntaxChecked.length > 0) {
let currentErrors: SqlDiagnostic[] = specificStatement ? languages.getDiagnostics(document.uri) as SqlDiagnostic[] : [];

// Remove old CL errors.
for (const invalidStatement of invalidStatements) {
const existingError = currentErrors.findIndex(e => e.groupId === invalidStatement.groupId);
if (existingError >= 0) {
currentErrors.splice(existingError, 1);
}
}
if (validStatements.length > 0) {
const se = performance.now();
const syntaxChecked = await window.withProgress({ location: ProgressLocation.Window, title: `$(sync-spin) Checking SQL Syntax` }, () => { return checker.checkMultipleStatements(sqlStatementContents) });
const ee = performance.now();

for (let i = 0; i < validStatements.length; i++) {
const currentRange = validStatements[i];
const groupError = syntaxChecked[i];
let existingError: number = currentErrors.findIndex(e => e.groupId === currentRange.groupId);
if (syntaxChecked) {
if (syntaxChecked.length > 0) {
let currentErrors: SqlDiagnostic[] = specificStatement ? languages.getDiagnostics(document.uri) as SqlDiagnostic[] : [];

if (groupError.type === `none`) {
if (existingError !== -1) {
// Remove old CL errors.
for (const invalidStatement of invalidStatements) {
const existingError = currentErrors.findIndex(e => e.groupId === invalidStatement.groupId);
if (existingError >= 0) {
currentErrors.splice(existingError, 1);
}
}

} else if (shouldShowError(groupError)) {
let baseIndex = () => { return currentRange.start + groupError.offset };

if (baseIndex() > currentRange.end) {
// This is a syntax error that is outside the range of the statement.
groupError.offset = (currentRange.end-currentRange.start);
}

const selectedWord = document.getWordRangeAtPosition(document.positionAt(baseIndex()))
|| new Range(
document.positionAt(baseIndex() - 1),
document.positionAt(baseIndex())
);


const newDiag: SqlDiagnostic = {
message: `${groupError.text} - ${groupError.sqlstate}`,
code: groupError.sqlid,
range: selectedWord,
severity: diagnosticTypeMap[groupError.type],
groupId: currentRange.groupId
};
for (let i = 0; i < validStatements.length; i++) {
const currentRange = validStatements[i];
const groupError = syntaxChecked[i];
let existingError: number = currentErrors.findIndex(e => e.groupId === currentRange.groupId);

if (groupError.type === `none`) {
if (existingError !== -1) {
currentErrors.splice(existingError, 1);
}

} else if (shouldShowError(groupError)) {
let baseIndex = () => { return currentRange.start + groupError.offset };

if (baseIndex() > currentRange.end) {
// This is a syntax error that is outside the range of the statement.
groupError.offset = (currentRange.end - currentRange.start);
}

const selectedWord = document.getWordRangeAtPosition(document.positionAt(baseIndex()))
|| new Range(
document.positionAt(baseIndex() - 1),
document.positionAt(baseIndex())
);


const newDiag: SqlDiagnostic = {
message: `${groupError.text} - ${groupError.sqlstate}`,
code: groupError.sqlid,
range: selectedWord,
severity: diagnosticTypeMap[groupError.type],
groupId: currentRange.groupId
};

if (existingError >= 0) {
currentErrors[existingError] = newDiag;
} else {
currentErrors.push(newDiag);
}

if (existingError >= 0) {
currentErrors[existingError] = newDiag;
} else {
currentErrors.push(newDiag);
}

}
}

sqlDiagnosticCollection.set(document.uri, currentErrors);
sqlDiagnosticCollection.set(document.uri, currentErrors);
}
}
}
}
Expand Down Expand Up @@ -245,7 +265,7 @@ function getStatementRangeFromGroup(currentGroup: StatementGroup, groupId: numbe
}

const stmtLength = currentGroup.range.end - statementRange.start;
if (stmtLength >= VALID_STATEMENT_LENGTH) {
if (stmtLength >= VALID_STATEMENT_LENGTH) {
// Just too long for our API.
statementRange.validate = false;
}
Expand Down

0 comments on commit 03c8680

Please sign in to comment.