Skip to content

Commit 7ee2fa9

Browse files
feat: Implement pre-initialized Docker container pool to improve /eval request performance
- Added a pool of pre-initialized Docker containers, each assigned a unique session ID and ready for immediate use. - Pre-configured each container with the necessary startup scripts to ensure environments are ready for /eval requests. - On receiving a new /eval request, the system now allocates a container from the pool, reducing the need to create and initialize a container on demand. - Improved request latency by significantly reducing the time taken to start and initialize Docker containers during each session.
1 parent b359b3c commit 7ee2fa9

File tree

5 files changed

+228
-91
lines changed

5 files changed

+228
-91
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package org.togetherjava.jshellapi.dto;
2+
3+
import java.io.BufferedReader;
4+
import java.io.BufferedWriter;
5+
import java.io.InputStream;
6+
import java.io.OutputStream;
7+
8+
public record ContainerState(boolean isCached, String containerId, BufferedReader containerOutput, BufferedWriter containerInput) {
9+
}

JShellAPI/src/main/java/org/togetherjava/jshellapi/service/DockerService.java

Lines changed: 189 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.github.dockerjava.api.DockerClient;
44
import com.github.dockerjava.api.async.ResultCallback;
5+
import com.github.dockerjava.api.command.InspectContainerResponse;
56
import com.github.dockerjava.api.command.PullImageResultCallback;
67
import com.github.dockerjava.api.model.*;
78
import com.github.dockerjava.core.DefaultDockerClientConfig;
@@ -10,26 +11,33 @@
1011
import org.slf4j.Logger;
1112
import org.slf4j.LoggerFactory;
1213
import org.springframework.beans.factory.DisposableBean;
13-
import org.springframework.lang.Nullable;
1414
import org.springframework.stereotype.Service;
1515

1616
import org.togetherjava.jshellapi.Config;
17+
import org.togetherjava.jshellapi.dto.ContainerState;
1718

1819
import java.io.*;
1920
import java.nio.charset.StandardCharsets;
2021
import java.time.Duration;
2122
import java.util.*;
22-
import java.util.concurrent.TimeUnit;
23+
import java.util.concurrent.*;
2324

2425
@Service
2526
public class DockerService implements DisposableBean {
2627
private static final Logger LOGGER = LoggerFactory.getLogger(DockerService.class);
2728
private static final String WORKER_LABEL = "jshell-api-worker";
2829
private static final UUID WORKER_UNIQUE_ID = UUID.randomUUID();
30+
private static final String IMAGE_NAME = "togetherjava.org:5001/togetherjava/jshellwrapper";
31+
private static final String IMAGE_TAG = "master";
2932

3033
private final DockerClient client;
34+
private final Config config;
35+
private final ExecutorService executor = Executors.newSingleThreadExecutor();
36+
private final ConcurrentHashMap<StartupScriptId, String> cachedContainers = new ConcurrentHashMap<>();
37+
private final StartupScriptsService startupScriptsService;
3138

32-
public DockerService(Config config) {
39+
public DockerService(Config config, StartupScriptsService startupScriptsService) throws InterruptedException, IOException {
40+
this.startupScriptsService = startupScriptsService;
3341
DefaultDockerClientConfig clientConfig =
3442
DefaultDockerClientConfig.createDefaultConfigBuilder().build();
3543
ApacheDockerHttpClient httpClient =
@@ -39,8 +47,13 @@ public DockerService(Config config) {
3947
.connectionTimeout(Duration.ofSeconds(config.dockerConnectionTimeout()))
4048
.build();
4149
this.client = DockerClientImpl.getInstance(clientConfig, httpClient);
50+
this.config = config;
4251

52+
if (!isImagePresentLocally()) {
53+
pullImage();
54+
}
4355
cleanupLeftovers(WORKER_UNIQUE_ID);
56+
executor.submit(() -> initializeCachedContainer(StartupScriptId.EMPTY));
4457
}
4558

4659
private void cleanupLeftovers(UUID currentId) {
@@ -57,79 +70,198 @@ private void cleanupLeftovers(UUID currentId) {
5770
}
5871
}
5972

60-
public String spawnContainer(long maxMemoryMegs, long cpus, @Nullable String cpuSetCpus,
61-
String name, Duration evalTimeout, long sysoutLimit) throws InterruptedException {
62-
String imageName = "togetherjava.org:5001/togetherjava/jshellwrapper";
63-
boolean presentLocally = client.listImagesCmd()
64-
.withFilter("reference", List.of(imageName))
65-
.exec()
66-
.stream()
67-
.flatMap(it -> Arrays.stream(it.getRepoTags()))
68-
.anyMatch(it -> it.endsWith(":master"));
69-
70-
if (!presentLocally) {
71-
client.pullImageCmd(imageName)
72-
.withTag("master")
73-
.exec(new PullImageResultCallback())
74-
.awaitCompletion(5, TimeUnit.MINUTES);
73+
/**
74+
* Checks if the Docker image with the given name and tag is present locally.
75+
*
76+
* @return true if the image is present, false otherwise.
77+
*/
78+
private boolean isImagePresentLocally() {
79+
return client.listImagesCmd()
80+
.withFilter("reference", List.of(IMAGE_NAME))
81+
.exec()
82+
.stream()
83+
.flatMap(it -> Arrays.stream(it.getRepoTags()))
84+
.anyMatch(it -> it.endsWith(":" + IMAGE_TAG));
85+
}
86+
87+
/**
88+
* Pulls the Docker image.
89+
*/
90+
private void pullImage() throws InterruptedException {
91+
if (!isImagePresentLocally()) {
92+
client.pullImageCmd(IMAGE_NAME)
93+
.withTag(IMAGE_TAG)
94+
.exec(new PullImageResultCallback())
95+
.awaitCompletion(5, TimeUnit.MINUTES);
7596
}
97+
}
7698

77-
return client.createContainerCmd(imageName + ":master")
78-
.withHostConfig(HostConfig.newHostConfig()
99+
/**
100+
* Creates a Docker container with the given name.
101+
*
102+
* @param name The name of the container to create.
103+
* @return The ID of the created container.
104+
*/
105+
public String createContainer(String name) {
106+
HostConfig hostConfig = HostConfig.newHostConfig()
79107
.withAutoRemove(true)
80108
.withInit(true)
81109
.withCapDrop(Capability.ALL)
82110
.withNetworkMode("none")
83111
.withPidsLimit(2000L)
84112
.withReadonlyRootfs(true)
85-
.withMemory(maxMemoryMegs * 1024 * 1024)
86-
.withCpuCount(cpus)
87-
.withCpusetCpus(cpuSetCpus))
88-
.withStdinOpen(true)
89-
.withAttachStdin(true)
90-
.withAttachStderr(true)
91-
.withAttachStdout(true)
92-
.withEnv("evalTimeoutSeconds=" + evalTimeout.toSeconds(),
93-
"sysOutCharLimit=" + sysoutLimit)
94-
.withLabels(Map.of(WORKER_LABEL, WORKER_UNIQUE_ID.toString()))
95-
.withName(name)
96-
.exec()
97-
.getId();
113+
.withMemory((long) config.dockerMaxRamMegaBytes() * 1024 * 1024)
114+
.withCpuCount((long) Math.ceil(config.dockerCPUsUsage()))
115+
.withCpusetCpus(config.dockerCPUSetCPUs());
116+
117+
return client.createContainerCmd(IMAGE_NAME + ":" + IMAGE_TAG)
118+
.withHostConfig(hostConfig)
119+
.withStdinOpen(true)
120+
.withAttachStdin(true)
121+
.withAttachStderr(true)
122+
.withAttachStdout(true)
123+
.withEnv("evalTimeoutSeconds=" + config.evalTimeoutSeconds(),
124+
"sysOutCharLimit=" + config.sysOutCharLimit())
125+
.withLabels(Map.of(WORKER_LABEL, WORKER_UNIQUE_ID.toString()))
126+
.withName(name)
127+
.exec()
128+
.getId();
129+
}
130+
131+
/**
132+
* Spawns a new Docker container with specified configurations.
133+
*
134+
* @param name Name of the container.
135+
* @param startupScriptId Script to initialize the container with.
136+
* @return The ContainerState of the newly created container.
137+
*/
138+
public ContainerState initializeContainer(String name, StartupScriptId startupScriptId) throws IOException {
139+
if (cachedContainers.isEmpty() || !cachedContainers.containsKey(startupScriptId)) {
140+
String containerId = createContainer(name);
141+
return setupContainerWithScript(containerId, true, startupScriptId);
142+
}
143+
String containerId = cachedContainers.get(startupScriptId);
144+
executor.submit(() -> initializeCachedContainer(startupScriptId));
145+
// Rename container with new name.
146+
client.renameContainerCmd(containerId).withName(name).exec();
147+
return setupContainerWithScript(containerId, false, startupScriptId);
148+
}
149+
150+
/**
151+
* Initializes a new cached docker container with specified configurations.
152+
*
153+
* @param startupScriptId Script to initialize the container with.
154+
*/
155+
private void initializeCachedContainer(StartupScriptId startupScriptId) {
156+
String containerName = cachedContainerName();
157+
String id = createContainer(containerName);
158+
startContainer(id);
159+
160+
try (PipedInputStream containerInput = new PipedInputStream();
161+
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new PipedOutputStream(containerInput)))) {
162+
attachToContainer(id, containerInput);
163+
164+
writer.write(Utils.sanitizeStartupScript(startupScriptsService.get(startupScriptId)));
165+
writer.newLine();
166+
writer.flush();
167+
168+
cachedContainers.put(startupScriptId, id);
169+
} catch (IOException e) {
170+
killContainerByName(containerName);
171+
throw new RuntimeException(e);
172+
}
98173
}
99174

100-
public InputStream startAndAttachToContainer(String containerId, InputStream stdin)
101-
throws IOException {
175+
/**
176+
*
177+
* @param containerId The id of the container
178+
* @param isCached Indicator if the container is cached or new
179+
* @param startupScriptId The startup script id of the session
180+
* @return ContainerState of the spawned container.
181+
* @throws IOException if an I/O error occurs
182+
*/
183+
private ContainerState setupContainerWithScript(String containerId, boolean isCached, StartupScriptId startupScriptId) throws IOException {
184+
if (!isCached) {
185+
startContainer(containerId);
186+
}
187+
PipedInputStream containerInput = new PipedInputStream();
188+
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new PipedOutputStream(containerInput)));
189+
190+
InputStream containerOutput = attachToContainer(containerId, containerInput);
191+
BufferedReader reader = new BufferedReader(new InputStreamReader(containerOutput));
192+
193+
if (!isCached) {
194+
writer.write(Utils.sanitizeStartupScript(startupScriptsService.get(startupScriptId)));
195+
writer.newLine();
196+
writer.flush();
197+
}
198+
199+
return new ContainerState(isCached, containerId, reader, writer);
200+
}
201+
202+
/**
203+
* Creates a new container
204+
* @param containerId the ID of the container to start
205+
*/
206+
public void startContainer(String containerId) {
207+
if (!isContainerRunning(containerId)) {
208+
client.startContainerCmd(containerId).exec();
209+
}
210+
}
211+
212+
/**
213+
* Attaches to a running Docker container's input (stdin) and output streams (stdout, stderr).
214+
* Logs any output from stderr and returns an InputStream to read stdout.
215+
*
216+
* @param containerId the ID of the running container to attach to
217+
* @param containerInput the input stream (containerInput) to send to the container
218+
* @return InputStream to read the container's stdout
219+
* @throws IOException if an I/O error occurs
220+
*/
221+
public InputStream attachToContainer(String containerId, InputStream containerInput) throws IOException {
102222
PipedInputStream pipeIn = new PipedInputStream();
103223
PipedOutputStream pipeOut = new PipedOutputStream(pipeIn);
104224

105225
client.attachContainerCmd(containerId)
106-
.withLogs(true)
107-
.withFollowStream(true)
108-
.withStdOut(true)
109-
.withStdErr(true)
110-
.withStdIn(stdin)
111-
.exec(new ResultCallback.Adapter<>() {
112-
@Override
113-
public void onNext(Frame object) {
114-
try {
115-
String payloadString =
116-
new String(object.getPayload(), StandardCharsets.UTF_8);
117-
if (object.getStreamType() == StreamType.STDOUT) {
118-
pipeOut.write(object.getPayload());
119-
} else {
120-
LOGGER.warn("Received STDERR from container {}: {}", containerId,
121-
payloadString);
226+
.withLogs(true)
227+
.withFollowStream(true)
228+
.withStdOut(true)
229+
.withStdErr(true)
230+
.withStdIn(containerInput)
231+
.exec(new ResultCallback.Adapter<>() {
232+
@Override
233+
public void onNext(Frame object) {
234+
try {
235+
String payloadString = new String(object.getPayload(), StandardCharsets.UTF_8);
236+
if (object.getStreamType() == StreamType.STDOUT) {
237+
pipeOut.write(object.getPayload()); // Write stdout data to pipeOut
238+
} else {
239+
LOGGER.warn("Received STDERR from container {}: {}", containerId, payloadString);
240+
}
241+
} catch (IOException e) {
242+
throw new UncheckedIOException(e);
122243
}
123-
} catch (IOException e) {
124-
throw new UncheckedIOException(e);
125244
}
126-
}
127-
});
245+
});
128246

129-
client.startContainerCmd(containerId).exec();
130247
return pipeIn;
131248
}
132249

250+
/**
251+
* Checks if the Docker container with the given ID is currently running.
252+
*
253+
* @param containerId the ID of the container to check
254+
* @return true if the container is running, false otherwise
255+
*/
256+
public boolean isContainerRunning(String containerId) {
257+
InspectContainerResponse containerResponse = client.inspectContainerCmd(containerId).exec();
258+
return Boolean.TRUE.equals(containerResponse.getState().getRunning());
259+
}
260+
261+
private String cachedContainerName() {
262+
return "cached_session_" + UUID.randomUUID();
263+
}
264+
133265
public void killContainerByName(String name) {
134266
LOGGER.debug("Fetching container to kill {}.", name);
135267
List<Container> containers = client.listContainersCmd().withNameFilter(Set.of(name)).exec();
@@ -150,6 +282,7 @@ public boolean isDead(String containerName) {
150282
@Override
151283
public void destroy() throws Exception {
152284
LOGGER.info("destroy() called. Destroying all containers...");
285+
executor.shutdown();
153286
cleanupLeftovers(UUID.randomUUID());
154287
client.close();
155288
}

0 commit comments

Comments
 (0)