Skip to content

Commit

Permalink
refactor(api): Ensure all raw sql queries are parameterized
Browse files Browse the repository at this point in the history
  • Loading branch information
alepefe committed Nov 8, 2024
1 parent 274377a commit c2b8feb
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 49 deletions.
87 changes: 54 additions & 33 deletions api/src/infrastructure/postgres-survey-answers.repository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import { Inject, Logger } from '@nestjs/common';
import { InjectDataSource } from '@nestjs/typeorm';
import { WidgetDataFilter } from '@shared/dto/widgets/widget-data-filter';
import { SectionWithDataWidget } from '@shared/dto/sections/section.entity';
import { SQLAdapter } from '@api/infrastructure/sql-adapter';
import {
FilterClauseWithParams,
SQLAdapter,
} from '@api/infrastructure/sql-adapter';
import {
BaseWidgetWithData,
WidgetChartData,
Expand Down Expand Up @@ -48,18 +51,18 @@ export class PostgresSurveyAnswerRepository
sections: SectionWithDataWidget[],
filters?: WidgetDataFilter[],
): Promise<SectionWithDataWidget[]> {
let filterClause: string = '';
if (filters !== undefined) {
filterClause = this.sqlAdapter.generateSqlFromWidgetDataFilters(filters);
}
const filterClauseWithParams: FilterClauseWithParams =
this.sqlAdapter.generateFilterClauseFromWidgetDataFilters(filters);

const widgetDataPromises = [];
for (let sectionIdx = 0; sectionIdx < sections.length; sectionIdx++) {
const section = sections[sectionIdx];
const baseWidgets = section.baseWidgets;
for (let widgetIdx = 0; widgetIdx < baseWidgets.length; widgetIdx++) {
const widget = baseWidgets[widgetIdx];
widgetDataPromises.push(this.addDataToWidget(widget, filterClause));
widgetDataPromises.push(
this.addDataToWidget(widget, filterClauseWithParams),
);
}
}

Expand All @@ -73,41 +76,43 @@ export class PostgresSurveyAnswerRepository
): Promise<BaseWidgetWithData> {
const { filters, breakdownIndicator } = params;
if (breakdownIndicator === undefined) {
const filterClause =
this.sqlAdapter.generateSqlFromWidgetDataFilters(filters);
await this.addDataToWidget(widget, filterClause);
const filterClauseWithParams =
this.sqlAdapter.generateFilterClauseFromWidgetDataFilters(filters);
await this.addDataToWidget(widget, filterClauseWithParams);
} else {
const filterClause = this.sqlAdapter.generateSqlFromWidgetDataFilters(
filters,
'main',
);
const filterClauseWithParams =
this.sqlAdapter.generateFilterClauseFromWidgetDataFilters(filters, {
alias: 'main',
});
await this.addBreakdownDataToWidget(
widget,
breakdownIndicator,
filterClause,
filterClauseWithParams,
);
}
return widget;
}

private async addDataToWidget(
widget: BaseWidgetWithData,
filterClause: string,
filterClauseWithParams: FilterClauseWithParams,
): Promise<void> {
widget.data = {};

// Check if the indicator is an edge case
const methodName = this.edgeCasesMethodNameMap[widget.indicator];
if (methodName !== undefined) {
return this[methodName](widget, filterClause);
return this[methodName](widget, filterClauseWithParams);
}

const [supportsChart, supportsMap] =
WidgetUtils.getSupportedVisualizations(widget);

const dataPromises = [];
if (supportsChart === true) {
dataPromises.push(this.addChartDataToWidget(widget, filterClause));
dataPromises.push(
this.addChartDataToWidget(widget, filterClauseWithParams),
);
}

if (supportsMap === true) {
Expand All @@ -119,12 +124,19 @@ export class PostgresSurveyAnswerRepository

private async addChartDataToWidget(
widget: BaseWidgetWithData,
filterClause: string,
filterClauseWithParams: FilterClauseWithParams,
): Promise<void> {
const [whereClauseSql, whereClauseParams] =
this.sqlAdapter.addExpressionToFilterClause(filterClauseWithParams, [
'question_indicator',
'=',
widget.indicator,
]);

const totalsSql = `SELECT answer as "key", count(answer)::integer as "count", SUM(COUNT(answer)) OVER ()::integer AS total
FROM ${this.answersTable} ${this.sqlAdapter.appendExpressionToFilterClause(filterClause, 'question_indicator = $1')} GROUP BY answer ORDER BY answer`;
FROM ${this.answersTable} ${whereClauseSql} GROUP BY answer ORDER BY answer`;
const totalsResult: { key: string; count: number; total: number }[] =
await this.dataSource.query(totalsSql, [widget.indicator]);
await this.dataSource.query(totalsSql, whereClauseParams);

const arr: WidgetChartData = [];
for (let rowIdx = 0; rowIdx < totalsResult.length; rowIdx++) {
Expand All @@ -147,38 +159,42 @@ export class PostgresSurveyAnswerRepository

private async addTotalSurveysDataToWidget(
widget: BaseWidgetWithData,
filterClause: string,
filterClauseWithParams: FilterClauseWithParams,
): Promise<void> {
const [filterClause, queryParams] = filterClauseWithParams;

const filteredCount = `SELECT COUNT(count)::integer as count FROM (SELECT COUNT(DISTINCT survey_id) FROM ${this.answersTable} ${filterClause} GROUP BY survey_id) AS survey_count`;
const totalCount = `SELECT COUNT(count)::integer as count FROM (SELECT COUNT(DISTINCT survey_id) FROM ${this.answersTable} GROUP BY survey_id) AS survey_count`;
const [[{ count: value }], [{ count: total }]] = await Promise.all([
this.dataSource.query(filteredCount),
this.dataSource.query(filteredCount, queryParams),
this.dataSource.query(totalCount),
]);
widget.data.counter = { value, total };
}

private async addTotalCountriesDataToWidget(
widget: BaseWidgetWithData,
filterClause: string,
filterClauseWithParams: FilterClauseWithParams,
): Promise<void> {
const [filterClause, queryParams] = filterClauseWithParams;

const filteredCount = `SELECT COUNT(DISTINCT country_code)::integer as "count" FROM ${this.answersTable} ${filterClause}`;
const totalCount = `SELECT COUNT(DISTINCT country_code)::integer as "count" FROM ${this.answersTable};`;
const [[{ count: value }], [{ count: total }]] = await Promise.all([
this.dataSource.query(filteredCount),
this.dataSource.query(filteredCount, queryParams),
this.dataSource.query(totalCount),
]);
widget.data.counter = { value, total };
}

private async addAdoptionOfTechnologyByCountryDataToWidget(
widget: BaseWidgetWithData,
filterClause: string,
filterClauseWithParams: FilterClauseWithParams,
): Promise<void> {
// Best workaround to reference correct question without changing the frontend title ('Adoption of technology by country' once transformed)
widget.indicator = 'digital-technologies-integrated';
await Promise.all([
this.addChartDataToWidget(widget, filterClause),
this.addChartDataToWidget(widget, filterClauseWithParams),
this.addMapDataToWidget(widget),
]);
widget.indicator = 'adoption-of-technology-by-country';
Expand All @@ -187,8 +203,15 @@ export class PostgresSurveyAnswerRepository
private async addBreakdownDataToWidget(
widget: BaseWidgetWithData,
breakdownIndicator: string,
filterClause: string,
filterClauseWithParams: FilterClauseWithParams,
): Promise<void> {
const [filterClause, queryParams] =
this.sqlAdapter.addExpressionToFilterClause(
filterClauseWithParams,
['question_indicator', '=', widget.indicator],
'main',
);

const sqlCode = `WITH breakdown_data AS (
SELECT
main_answer,
Expand All @@ -204,8 +227,8 @@ export class PostgresSurveyAnswerRepository
JOIN
survey_answers AS secondary
ON
main.survey_id = secondary.survey_id AND secondary.question_indicator = $1
${this.sqlAdapter.appendExpressionToFilterClause(filterClause, `question_indicator = $2`, 'main')}
main.survey_id = secondary.survey_id AND secondary.question_indicator = $${queryParams.length + 1}
${filterClause}
) AS s
GROUP BY
main_answer, secondary_answer
Expand All @@ -225,10 +248,8 @@ FROM breakdown_data
GROUP BY main_answer
ORDER BY main_answer`;

const breakdown = await this.dataSource.query(sqlCode, [
breakdownIndicator,
widget.indicator,
]);
queryParams.push(breakdownIndicator);
const breakdown = await this.dataSource.query(sqlCode, queryParams);
widget.data = { breakdown };
}
}
46 changes: 30 additions & 16 deletions api/src/infrastructure/sql-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,49 +2,63 @@ import { WidgetDataFilter } from '@shared/dto/widgets/widget-data-filter';
import { Injectable, Logger } from '@nestjs/common';
import { CountryISO3Map } from '@shared/constants/country-iso3.map';

export type FilterClauseWithParams = [sqlCode: string, queryParams: unknown[]];

@Injectable()
export class SQLAdapter {
public constructor(private readonly logger: Logger) {}

public generateSqlFromWidgetDataFilters(
public generateFilterClauseFromWidgetDataFilters(
filters?: WidgetDataFilter[],
alias?: string,
): string {
if (Array.isArray(filters) === false) return '';
opts: { alias?: string; queryParams?: unknown[] } = {},
): FilterClauseWithParams {
opts.queryParams ??= [];
if (Array.isArray(filters) === false) return ['', opts.queryParams];

const { alias: rawAlias, queryParams } = opts;
const alias = rawAlias === undefined ? '' : `${rawAlias}.`;
let currentParamIdx = queryParams.length;

let filterClause: string = 'WHERE ';
for (const filter of filters) {
// Countries
// Countries edge case
if (filter.name == 'location-country-region') {
filterClause += '(';
for (const filterValue of filter.values) {
filterClause += `${alias === undefined ? '' : `${alias}.`}country_code ${filter.operator} '${CountryISO3Map.getISO3ByCountryName(filterValue)}' OR `;
filterClause += `${alias}country_code ${filter.operator} $${++currentParamIdx} OR `;
queryParams.push(CountryISO3Map.getISO3ByCountryName(filterValue));
}
filterClause = filterClause.slice(0, -4);
filterClause += ') AND ';
continue;
}

filterClause += `(${alias === undefined ? '' : `${alias}.`}question_indicator = '${filter.name}' AND (`;
filterClause += `(${alias}question_indicator = '${filter.name}' AND (`;

for (const filterValue of filter.values) {
filterClause += `${alias === undefined ? '' : `${alias}.`}answer ${filter.operator} '${filterValue}' OR `;
filterClause += `${alias}answer ${filter.operator} $${++currentParamIdx} OR `;
queryParams.push(filterValue);
}
filterClause = filterClause.slice(0, -4);
filterClause += ')) AND ';
}
filterClause = filterClause.slice(0, -4);
return filterClause;
return [filterClause, queryParams];
}

public appendExpressionToFilterClause(
filterClause: string,
newExpression: string,
public addExpressionToFilterClause(
filterClauseWithParams: FilterClauseWithParams = ['', []],
newExpression: [column: string, operator: string, value: unknown],
alias?: string,
): string {
if (filterClause !== '') {
return `${filterClause} AND ${alias === undefined ? newExpression : `${alias}.${newExpression}`}`;
): FilterClauseWithParams {
alias = alias === undefined ? '' : `${alias}.`;

if (filterClauseWithParams[0] !== '') {
const sqlCode = `${filterClauseWithParams[0]} AND ${alias}${newExpression[0]} ${newExpression[1]} $${filterClauseWithParams[1].length + 1}`;
return [sqlCode, [...filterClauseWithParams[1], newExpression[2]]];
}
return `WHERE ${alias === undefined ? newExpression : `${alias}.${newExpression}`}`;

const sqlCode = `WHERE ${alias}${newExpression[0]} ${newExpression[1]} $${filterClauseWithParams[1].length + 1}`;
return [sqlCode, [...filterClauseWithParams[1], newExpression[2]]];
}
}

0 comments on commit c2b8feb

Please sign in to comment.