Skip to content

Commit ab65392

Browse files
Support for BATCHSIZE, MINBATCHSIZE, INPUTS and OUTPUTS on AI.MODELGET (#9)
* [add] Support for BATCHSIZE, MINBATCHSIZE, INPUTS and OUTPUTS on AI.MODELGET
1 parent e4b0a6a commit ab65392

File tree

4 files changed

+176
-45
lines changed

4 files changed

+176
-45
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ AI._SCRIPTSCAN | N/A
150150
AI.DAGRUN | N/A
151151
AI.DAGRUN_RO | N/A
152152
AI.INFO | info and infoResetStat (for resetting stats)
153-
AI.CONFIG * | N/A
153+
AI.CONFIG * | configLoadBackend and configBackendsPath
154154

155155

156156
### Running tests

src/client.ts

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,8 @@ export class Client {
4949
});
5050
}
5151

52-
public modelset(keName: string, m: Model): Promise<any> {
53-
const args: any[] = [keName, m.backend.toString(), m.device];
54-
if (m.tag !== undefined) {
55-
args.push('TAG');
56-
args.push(m.tag.toString());
57-
}
58-
if (m.inputs.length > 0) {
59-
args.push('INPUTS');
60-
m.inputs.forEach((value) => args.push(value));
61-
}
62-
if (m.outputs.length > 0) {
63-
args.push('OUTPUTS');
64-
m.outputs.forEach((value) => args.push(value));
65-
}
66-
args.push('BLOB');
67-
args.push(m.blob);
52+
public modelset(keyName: string, m: Model): Promise<any> {
53+
const args: any[] = m.modelSetFlatArgs(keyName);
6854
return this._sendCommand('ai.modelset', args);
6955
}
7056

src/model.ts

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,32 @@ export class Model {
1111
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow models)
1212
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow models)
1313
* @param blob - the Protobuf-serialized model
14+
* @param batchsize - when provided with an batchsize that is greater than 0, the engine will batch incoming requests from multiple clients that use the model with input tensors of the same shape.
15+
* @param minbatchsize - when provided with an minbatchsize that is greater than 0, the engine will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize
1416
*/
15-
constructor(backend: Backend, device: string, inputs: string[], outputs: string[], blob: Buffer | undefined) {
17+
constructor(
18+
backend: Backend,
19+
device: string,
20+
inputs: string[],
21+
outputs: string[],
22+
blob: Buffer | undefined,
23+
batchsize?: number,
24+
minbatchsize?: number,
25+
) {
1626
this._backend = backend;
1727
this._device = device;
1828
this._inputs = inputs;
1929
this._outputs = outputs;
2030
this._blob = blob;
2131
this._tag = undefined;
32+
this._batchsize = batchsize || 0;
33+
if (this._batchsize < 0) {
34+
this._batchsize = 0;
35+
}
36+
this._minbatchsize = minbatchsize || 0;
37+
if (this._minbatchsize < 0) {
38+
this._minbatchsize = 0;
39+
}
2240
}
2341

2442
// tag is an optional string for tagging the model such as a version number or any arbitrary identifier
@@ -86,14 +104,39 @@ export class Model {
86104
this._blob = value;
87105
}
88106

107+
private _batchsize: number;
108+
109+
get batchsize(): number {
110+
return this._batchsize;
111+
}
112+
113+
set batchsize(value: number) {
114+
this._batchsize = value;
115+
}
116+
117+
private _minbatchsize: number;
118+
119+
get minbatchsize(): number {
120+
return this._minbatchsize;
121+
}
122+
123+
set minbatchsize(value: number) {
124+
this._minbatchsize = value;
125+
}
126+
89127
static NewModelFromModelGetReply(reply: any[]) {
90128
let backend = null;
91129
let device = null;
92130
let tag = null;
93131
let blob = null;
132+
let batchsize: number = 0;
133+
let minbatchsize: number = 0;
134+
const inputs: string[] = [];
135+
const outputs: string[] = [];
94136
for (let i = 0; i < reply.length; i += 2) {
95137
const key = reply[i];
96138
const obj = reply[i + 1];
139+
97140
switch (key.toString()) {
98141
case 'backend':
99142
backend = BackendMap[obj.toString()];
@@ -106,9 +149,20 @@ export class Model {
106149
tag = obj.toString();
107150
break;
108151
case 'blob':
109-
// blob = obj;
110152
blob = Buffer.from(obj);
111153
break;
154+
case 'batchsize':
155+
batchsize = parseInt(obj.toString(), 10);
156+
break;
157+
case 'minbatchsize':
158+
minbatchsize = parseInt(obj.toString(), 10);
159+
break;
160+
case 'inputs':
161+
obj.forEach((input) => inputs.push(input));
162+
break;
163+
case 'outputs':
164+
obj.forEach((output) => outputs.push(output));
165+
break;
112166
}
113167
}
114168
if (backend == null || device == null || blob == null) {
@@ -126,10 +180,37 @@ export class Model {
126180
'AI.MODELGET reply did not had the full elements to build the Model. Missing ' + missingArr.join(',') + '.',
127181
);
128182
}
129-
const model = new Model(backend, device, [], [], blob);
183+
const model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize);
130184
if (tag !== null) {
131185
model.tag = tag;
132186
}
133187
return model;
134188
}
189+
190+
modelSetFlatArgs(keyName: string) {
191+
const args: any[] = [keyName, this.backend.toString(), this.device];
192+
if (this.tag !== undefined) {
193+
args.push('TAG');
194+
args.push(this.tag.toString());
195+
}
196+
if (this.batchsize > 0) {
197+
args.push('BATCHSIZE');
198+
args.push(this.batchsize);
199+
if (this.minbatchsize > 0) {
200+
args.push('MINBATCHSIZE');
201+
args.push(this.minbatchsize);
202+
}
203+
}
204+
if (this.inputs.length > 0) {
205+
args.push('INPUTS');
206+
this.inputs.forEach((value) => args.push(value));
207+
}
208+
if (this.outputs.length > 0) {
209+
args.push('OUTPUTS');
210+
this.outputs.forEach((value) => args.push(value));
211+
}
212+
args.push('BLOB');
213+
args.push(this.blob);
214+
return args;
215+
}
135216
}

tests/test_client.ts

Lines changed: 89 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,77 @@ it(
331331
const aiclient = new Client(nativeClient);
332332

333333
const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
334-
const model = new Model(Backend.TF, 'CPU', ['a', 'b'], ['c'], modelBlob);
334+
const inputs: string[] = ['a', 'b'];
335+
const outputs: string[] = ['c'];
336+
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
335337
model.tag = 'test_tag';
336338
const resultModelSet = await aiclient.modelset('mymodel', model);
337339
expect(resultModelSet).to.equal('OK');
338340

339-
const modelOut = await aiclient.modelget('mymodel');
341+
const modelOut: Model = await aiclient.modelget('mymodel');
340342
expect(modelOut.blob.toString()).to.equal(modelBlob.toString());
343+
for (let index = 0; index < modelOut.outputs.length; index++) {
344+
expect(modelOut.outputs[index]).to.equal(outputs[index]);
345+
expect(modelOut.outputs[index]).to.equal(model.outputs[index]);
346+
}
347+
for (let index = 0; index < modelOut.inputs.length; index++) {
348+
expect(modelOut.inputs[index]).to.equal(inputs[index]);
349+
expect(modelOut.inputs[index]).to.equal(model.inputs[index]);
350+
}
351+
expect(modelOut.batchsize).to.equal(model.batchsize);
352+
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
353+
aiclient.end(true);
354+
}),
355+
);
356+
357+
it(
358+
'ai.modelget batching positive testing',
359+
mochaAsync(async () => {
360+
const nativeClient = createClient();
361+
const aiclient = new Client(nativeClient);
362+
363+
const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
364+
const inputs: string[] = ['a', 'b'];
365+
const outputs: string[] = ['c'];
366+
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
367+
model.tag = 'test_tag';
368+
model.batchsize = 100;
369+
model.minbatchsize = 5;
370+
const resultModelSet = await aiclient.modelset('mymodel-batching', model);
371+
expect(resultModelSet).to.equal('OK');
372+
const modelOut: Model = await aiclient.modelget('mymodel-batching');
373+
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop', modelOut);
374+
expect(resultModelSet2).to.equal('OK');
375+
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
376+
expect(modelOut.batchsize).to.equal(model.batchsize);
377+
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
378+
aiclient.end(true);
379+
}),
380+
);
381+
382+
it(
383+
'ai.modelget batching via constructor positive testing',
384+
mochaAsync(async () => {
385+
const nativeClient = createClient();
386+
const aiclient = new Client(nativeClient);
387+
388+
const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
389+
const inputs: string[] = ['a', 'b'];
390+
const outputs: string[] = ['c'];
391+
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 100, 5);
392+
model.tag = 'test_tag';
393+
const resultModelSet = await aiclient.modelset('mymodel-batching-t2', model);
394+
expect(resultModelSet).to.equal('OK');
395+
const modelOut: Model = await aiclient.modelget('mymodel-batching-t2');
396+
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop-t2', modelOut);
397+
expect(resultModelSet2).to.equal('OK');
398+
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
399+
expect(modelOut.batchsize).to.equal(model.batchsize);
400+
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
401+
402+
const model2 = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 1000);
403+
expect(model2.batchsize).to.equal(1000);
404+
expect(model2.minbatchsize).to.equal(0);
341405
aiclient.end(true);
342406
}),
343407
);
@@ -624,26 +688,26 @@ it(
624688
);
625689

626690
it(
627-
'ai.config positive and negative testing',
628-
mochaAsync(async () => {
629-
const nativeClient = createClient();
630-
const aiclient = new Client(nativeClient);
631-
const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/');
632-
expect(result).to.equal('OK');
633-
// negative test
634-
try {
635-
const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so');
636-
} catch (e) {
637-
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
638-
}
639-
640-
try {
641-
// may throw error if backend already loaded
642-
const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so');
643-
expect(loadResult).to.equal('OK');
644-
} catch (e) {
645-
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
646-
}
647-
aiclient.end(true);
648-
}),
649-
);
691+
'ai.config positive and negative testing',
692+
mochaAsync(async () => {
693+
const nativeClient = createClient();
694+
const aiclient = new Client(nativeClient);
695+
const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/');
696+
expect(result).to.equal('OK');
697+
// negative test
698+
try {
699+
const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so');
700+
} catch (e) {
701+
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
702+
}
703+
704+
try {
705+
// may throw error if backend already loaded
706+
const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so');
707+
expect(loadResult).to.equal('OK');
708+
} catch (e) {
709+
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
710+
}
711+
aiclient.end(true);
712+
}),
713+
);

0 commit comments

Comments
 (0)