diff --git a/library/agent/hooks.test.ts b/library/agent/hooks.test.ts new file mode 100644 index 000000000..e0e9dc6d3 --- /dev/null +++ b/library/agent/hooks.test.ts @@ -0,0 +1,65 @@ +import * as t from "tap"; +import { addHook, removeHook, executeHooks } from "./hooks"; + +t.test("it works", async (t) => { + let hookOneCalls = 0; + let hookTwoCalls = 0; + + function hook1(sql: string) { + t.equal(sql, "SELECT 1"); + hookOneCalls++; + } + function hook2(sql: string) { + t.equal(sql, "SELECT 1"); + hookTwoCalls++; + } + + function hook3() { + throw new Error("hook3 should not be called"); + } + + t.same(hookOneCalls, 0, "hookOneCalls starts at 0"); + t.same(hookTwoCalls, 0, "hookTwoCalls starts at 0"); + + executeHooks("beforeSQLExecute", "SELECT 1"); + + t.same(hookOneCalls, 0, "hookOneCalls still at 0"); + t.same(hookTwoCalls, 0, "hookTwoCalls still at 0"); + + addHook("beforeSQLExecute", hook1); + // @ts-expect-error some other hook is not defined in the types + addHook("someOtherHook", hook3); + executeHooks("beforeSQLExecute", "SELECT 1"); + + t.equal(hookOneCalls, 1, "hook1 called once"); + t.equal(hookTwoCalls, 0, "hook2 not called"); + + addHook("beforeSQLExecute", hook2); + t.same(executeHooks("beforeSQLExecute", "SELECT 1"), [], "no value returned"); + + t.equal(hookOneCalls, 2, "hook1 called twice"); + t.equal(hookTwoCalls, 1, "hook2 called once"); + + removeHook("beforeSQLExecute", hook1); + executeHooks("beforeSQLExecute", "SELECT 1"); + + t.equal(hookOneCalls, 2, "hook1 still called twice"); + t.equal(hookTwoCalls, 2, "hook2 called twice"); + + removeHook("beforeSQLExecute", hook2); + t.same(executeHooks("beforeSQLExecute", "SELECT 1"), [], "no hooks executed"); + + t.equal(hookOneCalls, 2, "hook1 still called twice"); + t.equal(hookTwoCalls, 2, "hook2 still called twice"); + + // @ts-expect-error returnTest is not defined in the types + addHook("returnTest", () => { + return 1; + }); + // @ts-expect-error returnTest is not defined in the types + addHook("returnTest", () => { + return 2; + }); + // @ts-expect-error returnTest is not defined in the types + t.same(executeHooks("returnTest"), [1, 2], "returns values from hooks"); +}); diff --git a/library/agent/hooks.ts b/library/agent/hooks.ts new file mode 100644 index 000000000..15fc3793e --- /dev/null +++ b/library/agent/hooks.ts @@ -0,0 +1,56 @@ +type HookName = "beforeSQLExecute"; + +// Map hook names to argument and return types +interface HookTypes { + beforeSQLExecute: { + args: [sql: string]; + return: void; + }; +} + +const hooks = new Map< + HookName, + Array<(...args: HookTypes[HookName]["args"]) => HookTypes[HookName]["return"]> +>(); + +export function addHook( + name: N, + fn: (...args: HookTypes[N]["args"]) => HookTypes[N]["return"] +) { + if (!hooks.has(name)) { + hooks.set(name, [fn]); + } else { + hooks.get(name)!.push(fn); + } +} + +export function removeHook( + name: N, + fn: (...args: HookTypes[N]["args"]) => HookTypes[N]["return"] +) { + if (hooks.has(name)) { + const fns = hooks.get(name)!; + const index = fns.indexOf(fn); + if (index !== -1) { + fns.splice(index, 1); + } + } +} + +export function executeHooks( + name: N, + ...args: [...HookTypes[N]["args"]] +): Array { + const results: Array = []; + const hookList = hooks.get(name); + + for (const fn of hookList ?? []) { + const result = ( + fn as (...args: HookTypes[N]["args"]) => HookTypes[N]["return"] + )(...args); + if (result !== undefined) { + results.push(result); + } + } + return results; +} diff --git a/library/index.ts b/library/index.ts index 35489dca1..139f386eb 100644 --- a/library/index.ts +++ b/library/index.ts @@ -14,6 +14,7 @@ import { isESM } from "./helpers/isESM"; import { checkIndexImportGuard } from "./helpers/indexImportGuard"; import { setRateLimitGroup } from "./ratelimiting/group"; import { isLibBundled } from "./helpers/isLibBundled"; +import { addHook, removeHook } from "./agent/hooks"; const supported = isFirewallSupported(); const shouldEnable = shouldEnableFirewall(); @@ -47,6 +48,8 @@ export { addKoaMiddleware, addRestifyMiddleware, setRateLimitGroup, + addHook, + removeHook, }; // Required for ESM / TypeScript default export support @@ -63,4 +66,6 @@ export default { addKoaMiddleware, addRestifyMiddleware, setRateLimitGroup, + addHook, + removeHook, }; diff --git a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.test.ts b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.test.ts index 2e0f90333..08bd6b9cf 100644 --- a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.test.ts +++ b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.test.ts @@ -1,6 +1,7 @@ import * as t from "tap"; import { checkContextForSqlInjection } from "./checkContextForSqlInjection"; import { SQLDialectMySQL } from "./dialects/SQLDialectMySQL"; +import { addHook } from "../../agent/hooks"; t.test("it returns correct path", async () => { t.same( @@ -36,3 +37,53 @@ t.test("it returns correct path", async () => { } ); }); + +t.test("it executes hooks", async () => { + let hookCalled = 0; + + function hook(sql: string) { + t.equal( + sql, + "SELECT * FROM users WHERE id = '1' OR 1=1; -- '", + "hook called with correct sql" + ); + hookCalled++; + } + + addHook("beforeSQLExecute", hook); + + t.same( + checkContextForSqlInjection({ + sql: "SELECT * FROM users WHERE id = '1' OR 1=1; -- '", + operation: "mysql.query", + dialect: new SQLDialectMySQL(), + context: { + cookies: {}, + headers: {}, + remoteAddress: "ip", + method: "POST", + url: "url", + query: {}, + body: { + id: "1' OR 1=1; --", + }, + source: "express", + route: "/", + routeParams: {}, + }, + }), + { + operation: "mysql.query", + kind: "sql_injection", + source: "body", + pathsToPayload: [".id"], + metadata: { + sql: "SELECT * FROM users WHERE id = '1' OR 1=1; -- '", + dialect: "MySQL", + }, + payload: "1' OR 1=1; --", + } + ); + + t.equal(hookCalled, 1, "hook called once"); +}); diff --git a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts index 9f0b013e5..7f60ee668 100644 --- a/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts +++ b/library/vulnerabilities/sql-injection/checkContextForSqlInjection.ts @@ -1,5 +1,6 @@ import { getInstance } from "../../agent/AgentSingleton"; import { Context } from "../../agent/Context"; +import { executeHooks } from "../../agent/hooks"; import { InterceptorResult } from "../../agent/hooks/InterceptorResult"; import { SOURCES } from "../../agent/Source"; import { getPathsToPayload } from "../../helpers/attackPath"; @@ -25,6 +26,8 @@ export function checkContextForSqlInjection({ context: Context; dialect: SQLDialect; }): InterceptorResult { + executeHooks("beforeSQLExecute", sql); + for (const source of SOURCES) { const userInput = extractStringsFromUserInputCached(context, source); if (!userInput) {