Skip to content

Commit f854d4a

Browse files
committed
Introduce backing file
1 parent bd9bea6 commit f854d4a

File tree

3 files changed

+85
-29
lines changed

3 files changed

+85
-29
lines changed

src/python/library/tritonclient/utils/shared_memory/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def from_param(cls, value):
4848

4949
class ShmFile(Structure):
5050
if sys.platform == "win32":
51-
_fields_ = [("shm_handle_", c_void_p)]
51+
_fields_ = [
52+
("backing_file_handle_", c_void_p),
53+
("shm_mapping_handle_", c_void_p),
54+
]
5255
else:
5356
_fields_ = [("shm_fd_", c_int)]
5457

@@ -334,7 +337,9 @@ def __init__(self, err):
334337
-4: "unable to read/mmap the shared memory region",
335338
-5: "unable to unlink the shared memory region",
336339
-6: "unable to munmap the shared memory region",
337-
-7: "unable to create file mapping",
340+
-7: "unable to create shm directory or backing file",
341+
-8: "unable to create file mapping",
342+
-9: "unable to delete backing file",
338343
}
339344
self._msg = None
340345
if type(err) == str:

src/python/library/tritonclient/utils/shared_memory/shared_memory.cc

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,22 @@
3939
#include "shared_memory.h"
4040
#include "shared_memory_handle.h"
4141

42+
#define TRITON_SHM_FILE_ROOT "C:\\triton_shm\\"
43+
4244
//==============================================================================
4345
// SharedMemoryControlContext
4446
namespace {
4547

4648
void*
4749
SharedMemoryHandleCreate(
4850
std::string triton_shm_name, void* shm_addr, std::string shm_key,
49-
ShmFile* shm_file, size_t offset, size_t byte_size)
51+
std::unique_ptr<ShmFile>&& shm_file, size_t offset, size_t byte_size)
5052
{
5153
SharedMemoryHandle* handle = new SharedMemoryHandle();
5254
handle->triton_shm_name_ = triton_shm_name;
5355
handle->base_addr_ = shm_addr;
5456
handle->shm_key_ = shm_key;
55-
handle->platform_handle_.reset(shm_file);
57+
handle->platform_handle_ = std::move(shm_file);
5658
handle->offset_ = offset;
5759
handle->byte_size_ = byte_size;
5860
return static_cast<void*>(handle);
@@ -73,14 +75,14 @@ SharedMemoryRegionMap(
7375
DWORD low_order_offset = upperbound_offset & 0xFFFFFFFF;
7476
// map shared memory to process address space
7577
*shm_addr = MapViewOfFile(
76-
shm_file->shm_handle_, // handle to map object
77-
FILE_MAP_ALL_ACCESS, // read/write permission
78-
high_order_offset, // offset (high-order DWORD)
79-
low_order_offset, // offset (low-order DWORD)
78+
shm_file->shm_mapping_handle_, // handle to map object
79+
FILE_MAP_ALL_ACCESS, // read/write permission
80+
high_order_offset, // offset (high-order DWORD)
81+
low_order_offset, // offset (low-order DWORD)
8082
byte_size);
8183

8284
if (*shm_addr == NULL) {
83-
CloseHandle(shm_file->shm_handle_);
85+
CloseHandle(shm_file->shm_mapping_handle_);
8486
return -1;
8587
}
8688
// For Windows, we cannot close the shared memory handle here. When all
@@ -100,6 +102,38 @@ SharedMemoryRegionMap(
100102
#endif
101103
}
102104

105+
#ifdef _WIN32
106+
int
107+
SharedMemoryCreateBackingFile(const char* shm_key, HANDLE* backing_file_handle)
108+
{
109+
LPCSTR backing_file_directory(TRITON_SHM_FILE_ROOT);
110+
bool success = CreateDirectory(backing_file_directory, NULL);
111+
if (!success && GetLastError() != ERROR_ALREADY_EXISTS) {
112+
return -1;
113+
}
114+
LPCSTR backing_file_path =
115+
std::string(TRITON_SHM_FILE_ROOT + std::string(shm_key)).c_str();
116+
*backing_file_handle = CreateFile(
117+
backing_file_path, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ, NULL,
118+
OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
119+
if (*backing_file_handle == INVALID_HANDLE_VALUE) {
120+
return -1;
121+
}
122+
return 0;
123+
}
124+
125+
int
126+
SharedMemoryDeleteBackingFile(const char* key, HANDLE backing_file_handle)
127+
{
128+
CloseHandle(backing_file_handle);
129+
LPCSTR backing_file_path =
130+
std::string(TRITON_SHM_FILE_ROOT + std::string(key)).c_str();
131+
if (!DeleteFile(backing_file_path)) {
132+
return -1;
133+
}
134+
}
135+
#endif
136+
103137
} // namespace
104138

105139
TRITONCLIENT_DECLSPEC int
@@ -108,6 +142,11 @@ SharedMemoryRegionCreate(
108142
void** shm_handle)
109143
{
110144
#ifdef _WIN32
145+
HANDLE backing_file_handle;
146+
int err = SharedMemoryCreateBackingFile(shm_key, &backing_file_handle);
147+
if (err == -1) {
148+
return -7;
149+
}
111150
// The CreateFileMapping function takes a high-order and low-order DWORD (4
112151
// bytes each) for size. 'size_t' can either be 4 or 8 bytes depending on the
113152
// operating system. To handle both cases agnostically, we cast 'byte_size' to
@@ -118,22 +157,28 @@ SharedMemoryRegionCreate(
118157
DWORD low_order_size = upperbound_size & 0xFFFFFFFF;
119158

120159
HANDLE win_handle = CreateFileMapping(
121-
INVALID_HANDLE_VALUE, // use paging file
122-
NULL, // default security
123-
PAGE_READWRITE, // read/write access
124-
high_order_size, // maximum object size (high-order DWORD)
125-
low_order_size, // maximum object size (low-order DWORD)
126-
shm_key); // name of mapping object
160+
backing_file_handle, // use backing file
161+
NULL, // default security
162+
PAGE_READWRITE, // read/write access
163+
high_order_size, // maximum object size (high-order DWORD)
164+
low_order_size, // maximum object size (low-order DWORD)
165+
shm_key); // name of mapping object
127166

128167
if (win_handle == NULL) {
129-
return -7;
168+
LPCSTR backing_file_path =
169+
std::string(TRITON_SHM_FILE_ROOT + std::string(shm_key)).c_str();
170+
// Cleanup backing file on failure
171+
SharedMemoryDeleteBackingFile(shm_key, backing_file_handle);
172+
return -8;
130173
}
131174

132-
ShmFile* shm_file = new ShmFile(win_handle);
175+
std::unique_ptr<ShmFile> shm_file =
176+
std::make_unique<ShmFile>(backing_file_handle, win_handle);
133177
// get base address of shared memory region
134178
void* shm_addr = nullptr;
135-
int err = SharedMemoryRegionMap(shm_file, 0, byte_size, &shm_addr);
179+
err = SharedMemoryRegionMap(shm_file.get(), 0, byte_size, &shm_addr);
136180
if (err == -1) {
181+
SharedMemoryDeleteBackingFile(shm_key, backing_file_handle);
137182
return -4;
138183
}
139184
#else
@@ -149,18 +194,18 @@ SharedMemoryRegionCreate(
149194
return -3;
150195
}
151196

152-
ShmFile* shm_file = new ShmFile(shm_fd);
197+
std::unique_ptr<ShmFile> shm_file = std::make_unique<ShmFile>(shm_fd);
153198
// get base address of shared memory region
154199
void* shm_addr = nullptr;
155-
int err = SharedMemoryRegionMap(shm_file, 0, byte_size, &shm_addr);
200+
int err = SharedMemoryRegionMap(shm_file.get(), 0, byte_size, &shm_addr);
156201
if (err == -1) {
157202
return -4;
158203
}
159204
#endif
160205
// create a handle for the shared memory region
161206
*shm_handle = SharedMemoryHandleCreate(
162-
std::string(triton_shm_name), shm_addr, std::string(shm_key), shm_file, 0,
163-
byte_size);
207+
std::string(triton_shm_name), shm_addr, std::string(shm_key),
208+
std::move(shm_file), 0, byte_size);
164209
return 0;
165210
}
166211

@@ -186,7 +231,8 @@ GetSharedMemoryHandleInfo(
186231
*offset = handle->offset_;
187232
*byte_size = handle->byte_size_;
188233
#ifdef _WIN32
189-
file->shm_handle_ = handle->platform_handle_->shm_handle_;
234+
file->backing_file_handle_ = handle->platform_handle_->shm_mapping_handle_;
235+
file->shm_mapping_handle_ = handle->platform_handle_->shm_mapping_handle_;
190236
#else
191237
file->shm_fd_ = handle->platform_handle_->shm_fd_;
192238
#endif
@@ -204,10 +250,12 @@ SharedMemoryRegionDestroy(void* shm_handle)
204250
if (!success) {
205251
return -6;
206252
}
207-
// We keep Windows shared memory handles open until we are done
208-
// using them. When all handles are closed, the system will free
209-
// the section of the paging file that the object uses.
210-
CloseHandle(handle->platform_handle_->shm_handle_);
253+
CloseHandle(handle->platform_handle_->shm_mapping_handle_);
254+
int err = SharedMemoryDeleteBackingFile(
255+
handle->shm_key_.c_str(), handle->platform_handle_->backing_file_handle_);
256+
if (err == -1) {
257+
return -9;
258+
}
211259
#else
212260
int status = munmap(shm_addr, handle->byte_size_);
213261
if (status == -1) {

src/python/library/tritonclient/utils/shared_memory/shared_memory_handle.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@
3737

3838
struct ShmFile {
3939
#ifdef _WIN32
40-
HANDLE shm_handle_;
41-
ShmFile(HANDLE shm_handle) : shm_handle_(shm_handle){};
40+
HANDLE backing_file_handle_;
41+
HANDLE shm_mapping_handle_;
42+
ShmFile(HANDLE backing_file_handle, HANDLE shm_mapping_handle)
43+
: backing_file_handle_(backing_file_handle),
44+
shm_mapping_handle_(shm_mapping_handle){};
4245
#else
4346
int shm_fd_;
4447
ShmFile(int shm_fd) : shm_fd_(shm_fd){};

0 commit comments

Comments
 (0)