Skip to content

Commit 6c903e5

Browse files
committed
Added PgVector integration
Signed-off-by: chandr-andr (Kiselev Aleksandr) <[email protected]>
1 parent 335e591 commit 6c903e5

File tree

6 files changed

+60
-0
lines changed

6 files changed

+60
-0
lines changed

Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,6 @@ itertools = "0.12.1"
5656
openssl-src = "300.2.2"
5757
openssl-sys = "0.9.102"
5858
pg_interval = { git = "https://github.com/chandr-andr/rust-postgres-interval.git", branch = "psqlpy" }
59+
pgvector = { git = "https://github.com/chandr-andr/pgvector-rust.git", branch = "psqlpy", features = [
60+
"postgres",
61+
] }

python/psqlpy/_internal/extra_types.pyi

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,3 +774,16 @@ class IntervalArray:
774774
### Parameters:
775775
- `inner`: inner value, sequence of timedelta values.
776776
"""
777+
778+
class PgVector:
779+
"""Represent VECTOR in PostgreSQL."""
780+
781+
def __init__(
782+
self: Self,
783+
inner: typing.Sequence[float | int],
784+
) -> None:
785+
"""Create new instance of PgVector.
786+
787+
### Parameters:
788+
- `inner`: inner value, sequence of float or int values.
789+
"""

python/psqlpy/extra_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
MoneyArray,
2727
NumericArray,
2828
PathArray,
29+
PgVector,
2930
PointArray,
3031
PyBox,
3132
PyCircle,
@@ -98,4 +99,5 @@
9899
"LsegArray",
99100
"CircleArray",
100101
"IntervalArray",
102+
"PgVector",
101103
]

src/extra_types.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,25 @@ use crate::{
1818
},
1919
};
2020

21+
#[pyclass]
22+
#[derive(Clone)]
23+
pub struct PgVector(Vec<f32>);
24+
25+
#[pymethods]
26+
impl PgVector {
27+
#[new]
28+
fn new(vector: Vec<f32>) -> Self {
29+
Self(vector)
30+
}
31+
}
32+
33+
impl PgVector {
34+
#[must_use]
35+
pub fn inner_value(self) -> Vec<f32> {
36+
self.0
37+
}
38+
}
39+
2140
macro_rules! build_python_type {
2241
($st_name:ident, $rust_type:ty) => {
2342
#[pyclass]
@@ -412,5 +431,6 @@ pub fn extra_types_module(_py: Python<'_>, pymod: &Bound<'_, PyModule>) -> PyRes
412431
pymod.add_class::<LsegArray>()?;
413432
pymod.add_class::<CircleArray>()?;
414433
pymod.add_class::<IntervalArray>()?;
434+
pymod.add_class::<PgVector>()?;
415435
Ok(())
416436
}

src/value_converter.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use crate::{
3333
exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult},
3434
extra_types,
3535
};
36+
use pgvector::Vector as PgVector;
3637
use postgres_array::{array::Array, Dimension};
3738

3839
static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
@@ -268,6 +269,8 @@ pub enum PythonDTO {
268269
PyLsegArray(Array<PythonDTO>),
269270
PyCircleArray(Array<PythonDTO>),
270271
PyIntervalArray(Array<PythonDTO>),
272+
// PgVector
273+
PyPgVector(Vec<f32>),
271274
}
272275

273276
impl ToPyObject for PythonDTO {
@@ -594,6 +597,9 @@ impl ToSql for PythonDTO {
594597
PythonDTO::PyIntervalArray(array) => {
595598
array.to_sql(&Type::INTERVAL_ARRAY, out)?;
596599
}
600+
PythonDTO::PyPgVector(vector) => {
601+
<PgVector as ToSql>::to_sql(&PgVector::from(vector.clone()), ty, out)?;
602+
}
597603
}
598604

599605
if return_is_null_true {
@@ -1139,6 +1145,12 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult<
11391145
._convert_to_python_dto();
11401146
}
11411147

1148+
if parameter.is_instance_of::<extra_types::PgVector>() {
1149+
return Ok(PythonDTO::PyPgVector(
1150+
parameter.extract::<extra_types::PgVector>()?.inner_value(),
1151+
));
1152+
}
1153+
11421154
if let Ok(id_address) = parameter.extract::<IpAddr>() {
11431155
return Ok(PythonDTO::PyIpAddress(id_address));
11441156
}

0 commit comments

Comments
 (0)