Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions library/agent/hooks.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import * as t from "tap";
import { addHook, removeHook, executeHooks } from "./hooks";

t.test("it works", async (t) => {
let hookOneCalls = 0;
let hookTwoCalls = 0;

function hook1(sql: string) {
t.equal(sql, "SELECT 1");
hookOneCalls++;
}
function hook2(sql: string) {
t.equal(sql, "SELECT 1");
hookTwoCalls++;
}

function hook3() {
throw new Error("hook3 should not be called");
}

t.same(hookOneCalls, 0, "hookOneCalls starts at 0");
t.same(hookTwoCalls, 0, "hookTwoCalls starts at 0");

executeHooks("beforeSQLExecute", "SELECT 1");

t.same(hookOneCalls, 0, "hookOneCalls still at 0");
t.same(hookTwoCalls, 0, "hookTwoCalls still at 0");

addHook("beforeSQLExecute", hook1);
// @ts-expect-error some other hook is not defined in the types
addHook("someOtherHook", hook3);
executeHooks("beforeSQLExecute", "SELECT 1");

t.equal(hookOneCalls, 1, "hook1 called once");
t.equal(hookTwoCalls, 0, "hook2 not called");

addHook("beforeSQLExecute", hook2);
t.same(executeHooks("beforeSQLExecute", "SELECT 1"), [], "no value returned");

t.equal(hookOneCalls, 2, "hook1 called twice");
t.equal(hookTwoCalls, 1, "hook2 called once");

removeHook("beforeSQLExecute", hook1);
executeHooks("beforeSQLExecute", "SELECT 1");

t.equal(hookOneCalls, 2, "hook1 still called twice");
t.equal(hookTwoCalls, 2, "hook2 called twice");

removeHook("beforeSQLExecute", hook2);
t.same(executeHooks("beforeSQLExecute", "SELECT 1"), [], "no hooks executed");

t.equal(hookOneCalls, 2, "hook1 still called twice");
t.equal(hookTwoCalls, 2, "hook2 still called twice");

// @ts-expect-error returnTest is not defined in the types
addHook("returnTest", () => {
return 1;
});
// @ts-expect-error returnTest is not defined in the types
addHook("returnTest", () => {
return 2;
});
// @ts-expect-error returnTest is not defined in the types
t.same(executeHooks("returnTest"), [1, 2], "returns values from hooks");
});
56 changes: 56 additions & 0 deletions library/agent/hooks.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
type HookName = "beforeSQLExecute";

// Map hook names to argument and return types
interface HookTypes {
beforeSQLExecute: {
args: [sql: string];
return: void;
};
}

const hooks = new Map<
HookName,
Array<(...args: HookTypes[HookName]["args"]) => HookTypes[HookName]["return"]>
>();

export function addHook<N extends HookName>(
name: N,
fn: (...args: HookTypes[N]["args"]) => HookTypes[N]["return"]
) {
if (!hooks.has(name)) {
hooks.set(name, [fn]);
} else {
hooks.get(name)!.push(fn);
}
}

export function removeHook<N extends HookName>(
name: N,
fn: (...args: HookTypes[N]["args"]) => HookTypes[N]["return"]
) {
if (hooks.has(name)) {
const fns = hooks.get(name)!;
const index = fns.indexOf(fn);
if (index !== -1) {
fns.splice(index, 1);
}
}
}

export function executeHooks<N extends HookName>(
name: N,
...args: [...HookTypes[N]["args"]]
): Array<HookTypes[N]["return"]> {
const results: Array<HookTypes[N]["return"]> = [];
const hookList = hooks.get(name);

for (const fn of hookList ?? []) {
const result = (
fn as (...args: HookTypes[N]["args"]) => HookTypes[N]["return"]
)(...args);
if (result !== undefined) {
results.push(result);
}
}
return results;
}
5 changes: 5 additions & 0 deletions library/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { isESM } from "./helpers/isESM";
import { checkIndexImportGuard } from "./helpers/indexImportGuard";
import { setRateLimitGroup } from "./ratelimiting/group";
import { isLibBundled } from "./helpers/isLibBundled";
import { addHook, removeHook } from "./agent/hooks";

const supported = isFirewallSupported();
const shouldEnable = shouldEnableFirewall();
Expand Down Expand Up @@ -47,6 +48,8 @@ export {
addKoaMiddleware,
addRestifyMiddleware,
setRateLimitGroup,
addHook,
removeHook,
};

// Required for ESM / TypeScript default export support
Expand All @@ -63,4 +66,6 @@ export default {
addKoaMiddleware,
addRestifyMiddleware,
setRateLimitGroup,
addHook,
removeHook,
};
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as t from "tap";
import { checkContextForSqlInjection } from "./checkContextForSqlInjection";
import { SQLDialectMySQL } from "./dialects/SQLDialectMySQL";
import { addHook } from "../../agent/hooks";

t.test("it returns correct path", async () => {
t.same(
Expand Down Expand Up @@ -36,3 +37,53 @@ t.test("it returns correct path", async () => {
}
);
});

t.test("it executes hooks", async () => {
let hookCalled = 0;

function hook(sql: string) {
t.equal(
sql,
"SELECT * FROM users WHERE id = '1' OR 1=1; -- '",
"hook called with correct sql"
);
hookCalled++;
}

addHook("beforeSQLExecute", hook);

t.same(
checkContextForSqlInjection({
sql: "SELECT * FROM users WHERE id = '1' OR 1=1; -- '",
operation: "mysql.query",
dialect: new SQLDialectMySQL(),
context: {
cookies: {},
headers: {},
remoteAddress: "ip",
method: "POST",
url: "url",
query: {},
body: {
id: "1' OR 1=1; --",
},
source: "express",
route: "/",
routeParams: {},
},
}),
{
operation: "mysql.query",
kind: "sql_injection",
source: "body",
pathsToPayload: [".id"],
metadata: {
sql: "SELECT * FROM users WHERE id = '1' OR 1=1; -- '",
dialect: "MySQL",
},
payload: "1' OR 1=1; --",
}
);

t.equal(hookCalled, 1, "hook called once");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need a test that throws an error in the hook? Just to see what the behaviour is? Like is it caught by our own instrumentation catcher etc?

});
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { getInstance } from "../../agent/AgentSingleton";
import { Context } from "../../agent/Context";
import { executeHooks } from "../../agent/hooks";
import { InterceptorResult } from "../../agent/hooks/InterceptorResult";
import { SOURCES } from "../../agent/Source";
import { getPathsToPayload } from "../../helpers/attackPath";
Expand All @@ -25,6 +26,8 @@ export function checkContextForSqlInjection({
context: Context;
dialect: SQLDialect;
}): InterceptorResult {
executeHooks("beforeSQLExecute", sql);

for (const source of SOURCES) {
const userInput = extractStringsFromUserInputCached(context, source);
if (!userInput) {
Expand Down