Skip to content

Commit 5d181f6

Browse files
authored
Move load_npy to C++ (#849)
So that the GIL is released entirely Performance (based on [data_formats.py](https://github.com/facebookresearch/spdl/blob/c7db5be1512f8c5f17b07b047196cfb22e01624c/examples/data_formats.py)) * QPS of loading NPY files with multi-threading | Concurrency | Baseline | C++ | Improvement | |-------------|----------|------|-------------| | 32 | 1577 | 1891 | 17% | | 16 | 1561 | 1717 | 9% | | 8 | 1592 | 1859 | 14% | | 4 | 1590 | 1812 | 12% | | 2 | 1769 | 1820 | 3% | | 1 | 1690 | 1948 | 13% | <img width="668" height="410" alt="Screenshot 2025-07-25 at 4 59 08 PM" src="https://github.com/user-attachments/assets/d8cbcccf-7607-4b84-8b41-612e6ea8a300" /> * QPS of loading NPZ (no compression) files with multi-threading | Concurrency | Baseline | C++ | Improvement | |-------------|----------|------|-------------| | 32 | 1577 | 1726 | 9% | | 16 | 1495 | 1755 | 15% | | 8 | 1607 | 1754 | 8% | | 4 | 1591 | 1677 | 5% | | 2 | 1637 | 1781 | 8% | | 1 | 1658 | 1829 | 9% | <img width="686" height="420" alt="Screenshot 2025-07-25 at 5 02 11 PM" src="https://github.com/user-attachments/assets/26a99e09-38c4-4f5f-b88b-8f9034b914d9" /> * QPS of loading NPZ (with compression) files with multi-threading | Concurrency | Baseline | C++ | Improvement | |-------------|----------|------|-------------| | 32 | 1192 | 1473 | 19% | | 16 | 1241 | 1633 | 24% | | 8 | 1230 | 1640 | 25% | | 4 | 1305 | 1548 | 16% | | 2 | 1379 | 1547 | 11% | | 1 | 1277 | 1571 | 19% | <img width="682" height="418" alt="Screenshot 2025-07-25 at 5 05 32 PM" src="https://github.com/user-attachments/assets/41f33bb6-76f7-46de-b5e3-8bc035b5fe93" /> The performance of multiprocessing stays roughly same.
1 parent 4b3a006 commit 5d181f6

File tree

6 files changed

+291
-108
lines changed

6 files changed

+291
-108
lines changed

src/spdl/io/_array.py

Lines changed: 14 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,9 @@
99
"load_npz",
1010
"NpzFile",
1111
]
12-
import ast
13-
import struct
1412
from collections.abc import Iterator, Mapping
15-
from dataclasses import dataclass
16-
from typing import Any
1713

1814
import numpy as np
19-
from numpy.lib.format import MAGIC_LEN, MAGIC_PREFIX
2015
from numpy.typing import NDArray
2116

2217
# Importing `spdl.io.lib` instead of `spdl.io.lilb._archive`
@@ -26,35 +21,8 @@
2621
# pyre-strict
2722

2823

29-
@dataclass
30-
class _ArrayInterface:
31-
shape: tuple[int, ...] # pyre-ignore: [35]
32-
typestr: str # pyre-ignore: [35]
33-
data: memoryview # pyre-ignore: [35]
34-
offset: int = 0 # pyre-ignore: [35]
35-
version: int = 3 # pyre-ignore: [35]
36-
37-
@property
38-
def __array_interface__(self) -> dict[str, Any]:
39-
return {
40-
"shape": self.shape,
41-
"typestr": self.typestr,
42-
"data": self.data,
43-
"offset": self.offset,
44-
"version": self.version,
45-
}
46-
47-
48-
def _get_header_size_info(version: tuple[int, int]) -> tuple[str, str]:
49-
match version:
50-
case (1, 0):
51-
return ("<H", "latin1")
52-
case (2, 0):
53-
return ("<I", "latin1")
54-
case (3, 0):
55-
return ("<I", "utf8")
56-
case _:
57-
raise ValueError(f"Unexpected version {version}.")
24+
def _get_pointer(data: bytes) -> int:
25+
return np.frombuffer(data, dtype=np.byte).ctypes.data
5826

5927

6028
def load_npy(
@@ -111,50 +79,8 @@ def load_npy(
11179
`creates a new array <https://github.com/numpy/numpy/blob/v2.2.0/numpy/_core/records.py#L935-L939>`_.
11280
11381
"""
114-
if len(data) < MAGIC_LEN:
115-
raise ValueError("The input data is too short.")
116-
117-
view = memoryview(data)
118-
magic_str = view[:MAGIC_LEN].tobytes()
119-
if not magic_str.startswith(MAGIC_PREFIX):
120-
raise ValueError(rf"Expected the data to start with {MAGIC_PREFIX}.")
121-
122-
major, minor = magic_str[-2:]
123-
hlength_type, encoding = _get_header_size_info((major, minor))
124-
125-
info_length_size = struct.calcsize(hlength_type)
126-
info_start = MAGIC_LEN + info_length_size
127-
128-
if len(data) < info_start:
129-
raise ValueError("Failed to parse info. The input data is invalid.")
130-
info_length_str = data[MAGIC_LEN:info_start]
131-
info_length = struct.unpack(hlength_type, info_length_str)[0]
132-
133-
data_start = info_start + info_length
134-
if len(data) < data_start:
135-
raise ValueError(
136-
"Failed to parse data. The recorded data size exceeds the provided data size."
137-
)
138-
info_str = view[info_start:data_start].tobytes()
139-
140-
info = ast.literal_eval(info_str.decode(encoding))
141-
142-
if info.get("fortran_order"):
143-
raise ValueError(
144-
"Array saved with `format_order=True is not supported. Please use `numpy.load`."
145-
)
146-
147-
# TODO: Try `numpy.frombuffer``
148-
# https://github.com/numpy/numpy/blob/e20317a43d3714f9085ad959f68c1ba6bc998fcd/numpy/_core/src/multiarray/ctors.c#L3711
149-
aif = _ArrayInterface(
150-
shape=info["shape"],
151-
typestr=info["descr"],
152-
data=view,
153-
offset=data_start,
154-
version=2,
155-
)
156-
157-
return np.array(aif, copy=copy)
82+
buffer = _libspdl._archive.load_npy(_get_pointer(data), len(data))
83+
return np.array(buffer, copy=copy)
15884

15985

16086
class NpzFile(Mapping):
@@ -168,7 +94,8 @@ class NpzFile(Mapping):
16894
"""
16995

17096
def __init__(self, data: bytes, meta: dict[str, tuple[int, int, int, int]]) -> None:
171-
self._data = memoryview(data) # pyre-ignore
97+
self._data: int = _get_pointer(data)
98+
self._len: int = len(data)
17299
self._meta = meta
173100
self.files: list[str] = [f.removesuffix(".npy") for f in meta]
174101

@@ -192,16 +119,18 @@ def __getitem__(self, key: str) -> NDArray:
192119
else:
193120
raise KeyError(f"{key} is not a file in the archive")
194121

195-
start, compressed_size, uncompressed_size, compression_method = self._meta[key]
122+
offset, compressed_size, uncompressed_size, compression_method = self._meta[key]
196123
match compression_method:
197124
case 0:
198-
return load_npy(self._data[start : start + compressed_size])
125+
buffer = _libspdl._archive.load_npy(
126+
self._data, size=compressed_size, offset=offset
127+
)
128+
return np.array(buffer, copy=False)
199129
case 8:
200-
return load_npy(
201-
_libspdl._archive.inflate(
202-
self._data.obj, start, compressed_size, uncompressed_size
203-
)
130+
buffer = _libspdl._archive.load_npy_compressed(
131+
self._data, offset, compressed_size, uncompressed_size
204132
)
133+
return np.array(buffer, copy=False)
205134
case _:
206135
raise ValueError(
207136
"Compression method other than DEFLATE is not supported."

src/spdl/io/lib/archive/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ message(STATUS "########################################")
1111
set(name _archive)
1212
message(STATUS "Building ${name}")
1313

14-
set(srcs register.cpp zip_impl.cpp)
14+
set(srcs register.cpp zip_impl.cpp numpy_support.cpp)
1515
set(deps ZLIB::ZLIB fmt::fmt glog::glog)
1616
set(nb_options
1717
STABLE_ABI
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "numpy_support.h"
10+
#include "zip_impl.h"
11+
12+
#include <algorithm>
13+
#include <cstring>
14+
#include <sstream>
15+
#include <stdexcept>
16+
#include <string>
17+
18+
namespace spdl::archive {
19+
20+
//////////////////////////////////////////////////////////////////////////////
21+
// load_npy
22+
//////////////////////////////////////////////////////////////////////////////
23+
24+
namespace {
25+
26+
void check_magic(const char** data, size_t* size) {
27+
const static char* prefix = "\x93NUMPY";
28+
const static size_t len = std::strlen(prefix);
29+
if (*size < len) {
30+
throw std::runtime_error(
31+
"Failed to parse the magic prefix. (data too short)");
32+
}
33+
if (std::strncmp(*data, prefix, len) != 0) {
34+
throw std::runtime_error(
35+
"The data must start with the prefix '\\x93NUMPY'");
36+
}
37+
*data = (*data) + len;
38+
*size = (*size) - len;
39+
}
40+
41+
std::string_view extract_header(const char** data, size_t* size) {
42+
auto s = (*size);
43+
auto* d = (*data);
44+
if (s < 2) {
45+
throw std::runtime_error("Failed to parse version number.");
46+
}
47+
int major = static_cast<int>(d[0]);
48+
// int minor = static_cast<int>(data[1]);
49+
s -= 2;
50+
d += 2;
51+
switch (major) {
52+
case 1: {
53+
// The next two bytes are header length in little endien.
54+
if (s < 2) {
55+
throw std::runtime_error("Failed to parse header length.");
56+
}
57+
unsigned short len = (*d);
58+
len += (unsigned short)(*(d + 1)) << 8;
59+
s -= 2;
60+
d += 2;
61+
if (s < len) {
62+
throw std::runtime_error("Failed to parse header");
63+
}
64+
std::string_view header{d, len};
65+
*data = d + len;
66+
*size = s - len;
67+
return header;
68+
}
69+
case 2:
70+
[[fallthrough]];
71+
case 3: {
72+
// The next four bytes are header length.
73+
if (s < 4) {
74+
throw std::runtime_error("Failed to parse header length.");
75+
}
76+
size_t len;
77+
{
78+
int l = (int)*d;
79+
l += (int)(*(d + 1) << 8);
80+
l += (int)(*(d + 2) << 16);
81+
l += (int)(*(d + 3) << 24);
82+
if (l <= 0) {
83+
throw std::runtime_error(
84+
"Invalid data. The header length must be greater than 0.");
85+
}
86+
len = l;
87+
}
88+
s -= 4;
89+
d += 4;
90+
if (s < len) {
91+
throw std::runtime_error("Failed to parse header");
92+
}
93+
std::string_view header{d, len};
94+
*data = d + len;
95+
*size = s - len;
96+
return header;
97+
}
98+
default:
99+
throw std::runtime_error(
100+
"Unexpected format version. Only 1, 2 and 3 are supported.");
101+
}
102+
}
103+
104+
NPYArray parse_header(const std::string_view header) {
105+
// NPY header is a string expression of Python dictionary with the following
106+
// keys See:
107+
// https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html
108+
//
109+
// - "descr": Format description. e.g. "'<i8'", "'<f4'"
110+
// - "fortran_order": "True" or "False".
111+
// - "shape": Tuple of int. e.g. "()", "(3, 4, 5)"
112+
NPYArray ret;
113+
{
114+
size_t pos = header.find("'descr':");
115+
if (pos == std::string::npos) {
116+
throw std::runtime_error("Failed to parse header `'descr'`.");
117+
}
118+
pos = header.find('\'', pos + 7);
119+
if (pos == std::string::npos) {
120+
throw std::runtime_error("Failed to parse header `'descr'`.");
121+
}
122+
size_t end_pos = header.find('\'', pos + 1);
123+
if (end_pos == std::string::npos) {
124+
throw std::runtime_error("Failed to parse header `'descr'`.");
125+
}
126+
ret.descr = header.substr(pos + 1, end_pos - pos - 1);
127+
}
128+
{
129+
size_t pos = header.find("'shape':");
130+
if (pos == std::string::npos) {
131+
throw std::runtime_error("Failed to parse header `'shape'`.");
132+
}
133+
pos = header.find('(', pos);
134+
if (pos == std::string::npos) {
135+
throw std::runtime_error("Failed to parse header `'shape'`.");
136+
}
137+
size_t end_pos = header.find(')', pos);
138+
if (end_pos == std::string::npos) {
139+
throw std::runtime_error("Failed to parse header `'shape'`.");
140+
}
141+
std::string shape_str(header.substr(pos + 1, end_pos - pos - 1));
142+
std::istringstream shape_stream(shape_str);
143+
std::string number;
144+
while (std::getline(shape_stream, number, ',')) {
145+
number.erase(
146+
std::remove_if(number.begin(), number.end(), ::isspace),
147+
number.end());
148+
if (!number.empty()) {
149+
ret.shape.push_back(std::stoi(number));
150+
}
151+
}
152+
}
153+
{
154+
const std::string key = "'fortran_order':";
155+
size_t pos = header.find(key);
156+
if (pos != std::string::npos) {
157+
pos += key.length();
158+
while (pos < header.size() && std::isspace(header[pos])) {
159+
++pos;
160+
}
161+
if (pos < header.size() && header[pos] == 'T') {
162+
ret.fortran_order = true;
163+
} else if (pos < header.size() && header[pos] == 'F') {
164+
ret.fortran_order = false;
165+
}
166+
}
167+
}
168+
return ret;
169+
}
170+
} // namespace
171+
172+
NPYArray load_npy(const char* data, size_t size) {
173+
check_magic(&data, &size);
174+
auto header = extract_header(&data, &size);
175+
auto array = parse_header(header);
176+
array.data = (void*)data;
177+
return array;
178+
}
179+
180+
NPYArray load_npy_compressed(
181+
const char* data,
182+
uint32_t compressed_size,
183+
uint32_t uncompressed_size) {
184+
auto buffer = std::make_unique<char[]>(uncompressed_size);
185+
zip::inflate(data, compressed_size, buffer.get(), uncompressed_size);
186+
auto ret = load_npy(buffer.get(), uncompressed_size);
187+
ret.buffer = std::move(buffer);
188+
return ret;
189+
}
190+
191+
} // namespace spdl::archive
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <memory>
12+
#include <string>
13+
#include <vector>
14+
15+
namespace spdl::archive {
16+
17+
struct NPYArray {
18+
std::string descr{};
19+
bool fortran_order = false;
20+
std::vector<size_t> shape{};
21+
22+
// Pointer to the array data (not owned)
23+
void* data = nullptr;
24+
25+
// Owned data (optional)
26+
std::unique_ptr<char[]> buffer{};
27+
};
28+
29+
NPYArray load_npy(const char*, size_t);
30+
NPYArray load_npy_compressed(const char*, uint32_t, uint32_t);
31+
32+
} // namespace spdl::archive

0 commit comments

Comments
 (0)