Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions integration/websockets/e2e/gateway-ack.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,38 @@ describe('WebSocketGateway (ack)', () => {
);
});

it('should handle manual ack for async operations when @Ack() is used (success case)', async () => {
app = await createNestApp(AckGateway);
await app.listen(3000);

ws = io('http://localhost:8080');
const payload = { shouldSucceed: true };

await new Promise<void>(resolve =>
ws.emit('manual-ack', payload, response => {
expect(response).to.eql({ status: 'success', data: payload });
resolve();
}),
);
});

it('should handle manual ack for async operations when @Ack() is used (error case)', async () => {
app = await createNestApp(AckGateway);
await app.listen(3000);

ws = io('http://localhost:8080');
const payload = { shouldSucceed: false };

await new Promise<void>(resolve =>
ws.emit('manual-ack', payload, response => {
expect(response).to.eql({
status: 'error',
message: 'Operation failed',
});
resolve();
}),
);
});

afterEach(() => app.close());
});
22 changes: 21 additions & 1 deletion integration/websockets/src/ack.gateway.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets';
import {
Ack,
MessageBody,
SubscribeMessage,
WebSocketGateway,
} from '@nestjs/websockets';

@WebSocketGateway(8080)
export class AckGateway {
@SubscribeMessage('push')
onPush() {
return 'pong';
}

@SubscribeMessage('manual-ack')
async handleManualAck(
@MessageBody() data: any,
@Ack() ack: (response: any) => void,
) {
await new Promise(resolve => setTimeout(resolve, 20));

if (data.shouldSucceed) {
ack({ status: 'success', data });
} else {
ack({ status: 'error', message: 'Operation failed' });
}
return { status: 'ignored' };
}
}
1 change: 1 addition & 0 deletions packages/common/enums/route-paramtypes.enum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ export enum RouteParamtypes {
HOST = 10,
IP = 11,
RAW_BODY = 12,
ACK = 13,
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { Observable } from 'rxjs';
export interface WsMessageHandler<T = string> {
message: T;
callback: (...args: any[]) => Observable<any> | Promise<any>;
isAckHandledManually: boolean;
}

/**
Expand Down
10 changes: 6 additions & 4 deletions packages/platform-socket.io/adapters/io-adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,24 @@ export class IoAdapter extends AbstractWsAdapter {
first(),
);

handlers.forEach(({ message, callback }) => {
handlers.forEach(({ message, callback, isAckHandledManually }) => {
const source$ = fromEvent(socket, message).pipe(
mergeMap((payload: any) => {
const { data, ack } = this.mapPayload(payload);
return transform(callback(data, ack)).pipe(
filter((response: any) => !isNil(response)),
map((response: any) => [response, ack]),
map((response: any) => [response, ack, isAckHandledManually]),
);
}),
takeUntil(disconnect$),
);
source$.subscribe(([response, ack]) => {
source$.subscribe(([response, ack, isAckHandledManually]) => {
if (response.event) {
return socket.emit(response.event, response.data);
}
isFunction(ack) && ack(response);
if (!isAckHandledManually && isFunction(ack)) {
ack(response);
}
});
});
}
Expand Down
1 change: 1 addition & 0 deletions packages/websockets/context/ws-metadata-constants.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { WsParamtype } from '../enums/ws-paramtype.enum';

export const DEFAULT_CALLBACK_METADATA = {
[`${WsParamtype.ACK}:2`]: { index: 2, data: undefined, pipes: [] },
[`${WsParamtype.PAYLOAD}:1`]: { index: 1, data: undefined, pipes: [] },
[`${WsParamtype.SOCKET}:0`]: { index: 0, data: undefined, pipes: [] },
};
28 changes: 28 additions & 0 deletions packages/websockets/decorators/ack.decorator.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { WsParamtype } from '../enums/ws-paramtype.enum';
import { createPipesWsParamDecorator } from '../utils/param.utils';

/**
* WebSockets `ack` parameter decorator.
* Extracts the `ack` callback function from the arguments of a ws event.
*
* This decorator signals to the framework that the `ack` callback will be
* handled manually within the method, preventing the framework from
* automatically sending an acknowledgement based on the return value.
*
* @example
* ```typescript
* @SubscribeMessage('events')
* onEvent(
* @MessageBody() data: string,
* @Ack() ack: (response: any) => void
* ) {
* // Manually call the ack callback
* ack({ status: 'ok' });
* }
* ```
*
* @publicApi
*/
export function Ack(): ParameterDecorator {
return createPipesWsParamDecorator(WsParamtype.ACK)();
}
1 change: 1 addition & 0 deletions packages/websockets/decorators/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ export * from './gateway-server.decorator';
export * from './message-body.decorator';
export * from './socket-gateway.decorator';
export * from './subscribe-message.decorator';
export * from './ack.decorator';
1 change: 1 addition & 0 deletions packages/websockets/enums/ws-paramtype.enum.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ import { RouteParamtypes } from '@nestjs/common/enums/route-paramtypes.enum';
export enum WsParamtype {
SOCKET = RouteParamtypes.REQUEST,
PAYLOAD = RouteParamtypes.BODY,
ACK = RouteParamtypes.ACK,
}
4 changes: 4 additions & 0 deletions packages/websockets/factories/ws-params-factory.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { isFunction } from '@nestjs/common/utils/shared.utils';
import { WsParamtype } from '../enums/ws-paramtype.enum';

export class WsParamsFactory {
Expand All @@ -14,6 +15,9 @@ export class WsParamsFactory {
return args[0];
case WsParamtype.PAYLOAD:
return data ? args[1]?.[data] : args[1];
case WsParamtype.ACK: {
return args.find(arg => isFunction(arg));
}
default:
return null;
}
Expand Down
33 changes: 33 additions & 0 deletions packages/websockets/gateway-metadata-explorer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@ import {
GATEWAY_SERVER_METADATA,
MESSAGE_MAPPING_METADATA,
MESSAGE_METADATA,
PARAM_ARGS_METADATA,
} from './constants';
import { NestGateway } from './interfaces/nest-gateway.interface';
import { ParamsMetadata } from '@nestjs/core/helpers/interfaces';
import { WsParamtype } from './enums/ws-paramtype.enum';
import { ContextUtils } from '@nestjs/core/helpers/context-utils';

export interface MessageMappingProperties {
message: any;
methodName: string;
callback: (...args: any[]) => Observable<any> | Promise<any>;
isAckHandledManually: boolean;
}

export class GatewayMetadataExplorer {
private readonly contextUtils = new ContextUtils();
constructor(private readonly metadataScanner: MetadataScanner) {}

public explore(instance: NestGateway): MessageMappingProperties[] {
Expand All @@ -38,13 +44,40 @@ export class GatewayMetadataExplorer {
return null;
}
const message = Reflect.getMetadata(MESSAGE_METADATA, callback);
const isAckHandledManually = this.hasAckDecorator(
instancePrototype,
methodName,
);

return {
callback,
message,
methodName,
isAckHandledManually,
};
}

private hasAckDecorator(
instancePrototype: object,
methodName: string,
): boolean {
const paramsMetadata: ParamsMetadata = Reflect.getMetadata(
PARAM_ARGS_METADATA,
instancePrototype.constructor,
methodName,
);

if (!paramsMetadata) {
return false;
}
const metadataKeys = Object.keys(paramsMetadata);
return metadataKeys.some(key => {
const type = this.contextUtils.mapParamType(key);

return (Number(type) as WsParamtype) === WsParamtype.ACK;
});
}

public *scanForServerHooks(instance: NestGateway): IterableIterator<string> {
for (const propertyKey in instance) {
if (isFunction(propertyKey)) {
Expand Down
28 changes: 28 additions & 0 deletions packages/websockets/test/decorators/ack.decorator.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import 'reflect-metadata';
import { expect } from 'chai';
import { PARAM_ARGS_METADATA } from '../../constants';
import { Ack } from '../../decorators/ack.decorator';
import { WsParamtype } from '../../enums/ws-paramtype.enum';

class AckTest {
public test(@Ack() ack: Function) {}
}

describe('@Ack', () => {
it('should enhance class with expected request metadata', () => {
const argsMetadata = Reflect.getMetadata(
PARAM_ARGS_METADATA,
AckTest,
'test',
);

const expectedMetadata = {
[`${WsParamtype.ACK}:0`]: {
index: 0,
data: undefined,
pipes: [],
},
};
expect(argsMetadata).to.be.eql(expectedMetadata);
});
});
20 changes: 19 additions & 1 deletion packages/websockets/test/gateway-metadata-explorer.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import { MetadataScanner } from '../../core/metadata-scanner';
import { WebSocketServer } from '../decorators/gateway-server.decorator';
import { WebSocketGateway } from '../decorators/socket-gateway.decorator';
import { SubscribeMessage } from '../decorators/subscribe-message.decorator';
import { Ack } from '../decorators/ack.decorator';
import { GatewayMetadataExplorer } from '../gateway-metadata-explorer';

describe('GatewayMetadataExplorer', () => {
const message = 'test';
const secMessage = 'test2';
const ackMessage = 'ack-test';

@WebSocketGateway()
class Test {
Expand All @@ -28,6 +30,9 @@ describe('GatewayMetadataExplorer', () => {
@SubscribeMessage(secMessage)
public testSec() {}

@SubscribeMessage(ackMessage)
public testWithAck(@Ack() ack: Function) {}

public noMessage() {}
}
let instance: GatewayMetadataExplorer;
Expand Down Expand Up @@ -61,9 +66,22 @@ describe('GatewayMetadataExplorer', () => {
});
it(`should return message mapping properties when "isMessageMapping" metadata is not undefined`, () => {
const metadata = instance.exploreMethodMetadata(test, 'test')!;
expect(metadata).to.have.keys(['callback', 'message', 'methodName']);
expect(metadata).to.have.keys([
'callback',
'message',
'methodName',
'isAckHandledManually',
]);
expect(metadata.message).to.eql(message);
});
it('should set "isAckHandledManually" property to true when @Ack decorator is used', () => {
const metadata = instance.exploreMethodMetadata(test, 'testWithAck')!;
expect(metadata.isAckHandledManually).to.be.true;
});
it('should set "isAckHandledManually" property to false when @Ack decorator is not used', () => {
const metadata = instance.exploreMethodMetadata(test, 'test')!;
expect(metadata.isAckHandledManually).to.be.false;
});
});
describe('scanForServerHooks', () => {
it(`should return properties with @Client decorator`, () => {
Expand Down
34 changes: 32 additions & 2 deletions packages/websockets/test/web-sockets-controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ describe('WebSocketsController', () => {
message: 'message',
methodName: 'methodName',
callback: handlerCallback,
isAckHandledManually: false,
},
];
server = { server: 'test' };
Expand Down Expand Up @@ -173,6 +174,7 @@ describe('WebSocketsController', () => {
message: 'message',
methodName: 'methodName',
callback: messageHandlerCallback,
isAckHandledManually: false,
},
]);
});
Expand All @@ -188,11 +190,13 @@ describe('WebSocketsController', () => {
methodName: 'findOne',
message: 'find',
callback: null!,
isAckHandledManually: false,
},
{
methodName: 'create',
message: 'insert',
callback: null!,
isAckHandledManually: false,
},
];
const insertEntrypointDefinitionSpy = sinon.spy(
Expand Down Expand Up @@ -423,14 +427,40 @@ describe('WebSocketsController', () => {
client = { on: onSpy, off: onSpy };

handlers = [
{ message: 'test', callback: { bind: () => 'testCallback' } },
{ message: 'test2', callback: { bind: () => 'testCallback2' } },
{
message: 'test',
callback: { bind: () => 'testCallback' },
isAckHandledManually: true,
},
{
message: 'test2',
callback: { bind: () => 'testCallback2' },
isAckHandledManually: false,
},
];
});
it('should bind each handler to client', () => {
instance.subscribeMessages(handlers, client, gateway);
expect(onSpy.calledTwice).to.be.true;
});
it('should pass "isAckHandledManually" flag to the adapter', () => {
const adapter = config.getIoAdapter();
const bindMessageHandlersSpy = sinon.spy(adapter, 'bindMessageHandlers');

instance.subscribeMessages(handlers, client, gateway);

const handlersPassedToAdapter = bindMessageHandlersSpy.firstCall.args[1];

expect(handlersPassedToAdapter[0].message).to.equal(handlers[0].message);
expect(handlersPassedToAdapter[0].isAckHandledManually).to.equal(
handlers[0].isAckHandledManually,
);

expect(handlersPassedToAdapter[1].message).to.equal(handlers[1].message);
expect(handlersPassedToAdapter[1].isAckHandledManually).to.equal(
handlers[1].isAckHandledManually,
);
});
});
describe('pickResult', () => {
describe('when deferredResult contains value which', () => {
Expand Down
Loading
Loading