From 72a87929e678d11b9027483ed2f08cc799a9bac3 Mon Sep 17 00:00:00 2001 From: Geoffrey Eisenbarth Date: Wed, 13 Aug 2025 10:40:59 -0500 Subject: [PATCH 1/5] Add tests --- tests/models.py | 54 +++++++++++++++++++++++++------- tests/test_queries.py | 71 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 111 insertions(+), 14 deletions(-) diff --git a/tests/models.py b/tests/models.py index 61e073e..c913ef3 100644 --- a/tests/models.py +++ b/tests/models.py @@ -6,6 +6,17 @@ from .fields import MyIntegerField +class MockObjectMeta: + app_label = "tests" + unique_together = ("name", "number") + + +if django.get_version() <= "5.1": + MockObjectMeta.index_together = ("name", "number") +else: + MockObjectMeta.indexes = [models.Index(fields=["name", "number"])] + + class MockObject(models.Model): name = models.CharField(max_length=500) number = MyIntegerField(null=True, db_column="num") @@ -15,16 +26,8 @@ class MockObject(models.Model): ) objects = CopyManager() - class Meta: - app_label = "tests" - unique_together = ("name", "number") - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if django.get_version() <= "5.1": - self._meta.index_together = ("name", "number") - else: - self._meta.indexes = [models.Index(fields=["name", "number"])] + class Meta(MockObjectMeta): + pass def copy_name_template(self): return 'upper("%(name)s")' @@ -128,6 +131,35 @@ class SecondaryMockObject(models.Model): objects = CopyManager() -class UniqueMockObject(models.Model): +class UniqueFieldConstraintMockObject(models.Model): name = models.CharField(max_length=500, unique=True) objects = CopyManager() + + +class UniqueModelConstraintMockObject(models.Model): + name = models.CharField(max_length=500) + number = MyIntegerField(null=True, db_column="num") + objects = CopyManager() + + class Meta: + constraints = [ + models.UniqueConstraint( + name="constraint", + fields=["name"], + ), + ] + + +class UniqueModelConstraintAsIndexMockObject(models.Model): + name = models.CharField(max_length=500) + number = MyIntegerField(null=True, db_column="num") + objects = CopyManager() + + class Meta: + constraints = [ + models.UniqueConstraint( + name="constraint_as_index", + fields=["name"], + include=["number"], # Converts Constraint to Index + ), + ] diff --git a/tests/test_queries.py b/tests/test_queries.py index 8aa6b9e..b941de7 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -22,7 +22,9 @@ MockObject, OverloadMockObject, SecondaryMockObject, - UniqueMockObject, + UniqueFieldConstraintMockObject, + UniqueModelConstraintMockObject, + UniqueModelConstraintAsIndexMockObject, ) try: @@ -536,13 +538,76 @@ def test_encoding_save(self, _): @mock.patch("django.db.connection.validate_no_atomic_block") def test_ignore_conflicts(self, _): - UniqueMockObject.objects.from_csv( + UniqueFieldConstraintMockObject.objects.from_csv( self.name_path, dict(name="NAME"), ignore_conflicts=True ) - UniqueMockObject.objects.from_csv( + UniqueFieldConstraintMockObject.objects.from_csv( self.name_path, dict(name="NAME"), ignore_conflicts=True ) + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_update_conflicts_target_field_update(self, _): + UniqueFieldConstraintMockObject.objects.from_csv( + self.name_path, + dict(name="NAME"), + drop_constraints=False, + drop_indexes=False, + update_conflicts=True, + update_fields=["name"], + unique_fields=["name"], + ) + UniqueFieldConstraintMockObject.objects.from_csv( + self.name_path, + dict(name="NAME"), + drop_constraints=False, + drop_indexes=False, + update_conflicts=True, + update_fields=["name"], + unique_fields=["name"], + ) + + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_update_conflicts_target_constraint_update(self, _): + UniqueModelConstraintMockObject.objects.from_csv( + self.name_path, + dict(name="NAME", number="NUMBER"), + drop_constraints=False, + drop_indexes=False, + update_conflicts=True, + update_fields=["name", "number"], + unique_fields=["name"], + ) + UniqueModelConstraintMockObject.objects.from_csv( + self.name_path, + dict(name="NAME", number="NUMBER"), + drop_constraints=False, + drop_indexes=False, + update_conflicts=True, + update_fields=["name", "number"], + unique_fields=["name"], + ) + + @mock.patch("django.db.connection.validate_no_atomic_block") + def test_update_conflicts_target_constraint_as_index_update(self, _): + UniqueModelConstraintAsIndexMockObject.objects.from_csv( + self.name_path, + dict(name="NAME", number="NUMBER"), + drop_constraints=False, + drop_indexes=False, + update_conflicts=True, + update_fields=["name", "number"], + unique_fields=["name"], + ) + UniqueModelConstraintAsIndexMockObject.objects.from_csv( + self.name_path, + dict(name="NAME", number="NUMBER"), + drop_constraints=False, + drop_indexes=False, + update_conflicts=True, + update_fields=["name", "number"], + unique_fields=["name"], + ) + @mock.patch("django.db.connection.validate_no_atomic_block") def test_static_values(self, _): ExtendedMockObject.objects.from_csv( From b78e377005501de3d57cd99e36c002af78fc7ac4 Mon Sep 17 00:00:00 2001 From: Geoffrey Eisenbarth Date: Wed, 13 Aug 2025 10:41:08 -0500 Subject: [PATCH 2/5] Add docs --- docs/index.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 6d3b87d..53dedd4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -229,6 +229,18 @@ Keyword Argument Description ``ignore_conflicts`` Specify True to ignore unique constraint or exclusion constraint violation errors. The default is False. +``update_conflicts`` Specify True to update rows that fail insertion due + to conflicts. Requires both ``update_fields`` and + ``unique_fields`` to be specified. The default is False. + +``update_fields`` When ``update_conflicts`` is ``True``, this specifies + which model fields should be updated. If passed, this + must be a list of field names. The default is None. + +``unique_fields`` When ``update_conflicts`` is ``True``, this specifies + which fields might be in conflict. If passed, this must + be a list of field names. The default is None. + ``using`` Sets the database to use when importing data. Default is None, which will use the ``'default'`` database. From c6883b1904fffb1d2066eb495ab64807dd725c79 Mon Sep 17 00:00:00 2001 From: Geoffrey Eisenbarth Date: Wed, 13 Aug 2025 10:43:21 -0500 Subject: [PATCH 3/5] Support model indexes/contraints --- postgres_copy/managers.py | 45 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/postgres_copy/managers.py b/postgres_copy/managers.py index fb7a356..e3875a2 100644 --- a/postgres_copy/managers.py +++ b/postgres_copy/managers.py @@ -4,6 +4,7 @@ import typing from django.db import connection, models +from django.db.models.constraints import BaseConstraint from django.db.models.fields import Field from django.db.transaction import TransactionManagementError from django.db.backends.base.schema import BaseDatabaseSchemaEditor @@ -30,6 +31,13 @@ def constrained_fields(self) -> typing.List[Field]: if hasattr(f, "db_constraint") and f.db_constraint ] + @property + def model_constraints(self) -> typing.List[BaseConstraint]: + """ + Returns list of model-level constraints. + """ + return getattr(self.model._meta, "constraints", []) + @property def indexed_fields(self) -> typing.List[Field]: """ @@ -37,6 +45,13 @@ def indexed_fields(self) -> typing.List[Field]: """ return [f for f in self.model._meta.fields if f.db_index] + @property + def model_indexes(self) -> typing.List[models.Index]: + """ + Returns list of model-level indexes. + """ + return getattr(self.model._meta, "indexes", []) + def edit_schema( self, schema_editor: BaseDatabaseSchemaEditor, @@ -79,6 +94,14 @@ def drop_constraints(self) -> None: args = (self.model, field, field_copy) self.edit_schema(schema_editor, "alter_field", args) + # Remove any model constraints + for constraint in self.model_constraints: + logger.debug( + f"Dropping constraint '{constraint.name}' from {self.model.__name__}" + ) + args = (self.model, constraint) + self.edit_schema(schema_editor, "remove_constraint", args) + def drop_indexes(self) -> None: """ Drop indexes on the model and its fields. @@ -102,6 +125,14 @@ def drop_indexes(self) -> None: args = (self.model, field, field_copy) self.edit_schema(schema_editor, "alter_field", args) + # Remove any model indexes + for index in self.model_indexes: + logger.debug( + f"Dropping index '{index.name}' from {self.model.__name__}" + ) + args = (self.model, index) + self.edit_schema(schema_editor, "remove_index", args) + def restore_constraints(self) -> None: """ Restore constraints on the model and its fields. @@ -127,6 +158,14 @@ def restore_constraints(self) -> None: args = (self.model, field_copy, field) self.edit_schema(schema_editor, "alter_field", args) + # Add any constraints to the model + for constraint in self.model_constraints: + logger.debug( + f"Adding constraint '{constraint.name}' to {self.model.__name__}" + ) + args = (self.model, constraint) + self.edit_schema(schema_editor, "add_constraint", args) + def restore_indexes(self) -> None: """ Restore indexes on the model and its fields. @@ -152,6 +191,12 @@ def restore_indexes(self) -> None: args = (self.model, field_copy, field) self.edit_schema(schema_editor, "alter_field", args) + # Add any indexes to the model + for index in self.model_indexes: + logger.debug(f"Adding index '{index.name}' to {self.model.__name__}") + args = (self.model, index) + self.edit_schema(schema_editor, "add_index", args) + class CopyQuerySet(ConstraintQuerySet): """ From 30312c6c5e0b39c4da81839248a057457114b699 Mon Sep 17 00:00:00 2001 From: Geoffrey Eisenbarth Date: Wed, 13 Aug 2025 10:43:39 -0500 Subject: [PATCH 4/5] Add upsert --- postgres_copy/copy_from.py | 57 ++++++++++++++++++++++++++++++-------- postgres_copy/managers.py | 6 ++++ 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/postgres_copy/copy_from.py b/postgres_copy/copy_from.py index e9fd19e..af025c3 100644 --- a/postgres_copy/copy_from.py +++ b/postgres_copy/copy_from.py @@ -13,8 +13,9 @@ from django.contrib.humanize.templatetags.humanize import intcomma from django.core.exceptions import FieldDoesNotExist -from django.db import NotSupportedError, connections, router +from django.db import connections, router from django.db.models import Field, Model +from django.db.models.constants import OnConflict from django.db.backends.utils import CursorWrapper from .psycopg_compat import copy_from @@ -40,6 +41,9 @@ def __init__( force_null: typing.Optional[typing.List[str]] = None, encoding: typing.Optional[str] = None, ignore_conflicts: bool = False, + update_conflicts: bool = False, + update_fields: typing.Optional[typing.Collection[str]] = None, + unique_fields: typing.Optional[typing.Collection[str]] = None, static_mapping: typing.Optional[typing.Dict[str, str]] = None, temp_table_name: typing.Optional[str] = None, ) -> None: @@ -66,13 +70,28 @@ def __init__( self.force_not_null = force_not_null self.force_null = force_null self.encoding = encoding - self.supports_ignore_conflicts = True self.ignore_conflicts = ignore_conflicts + self.update_conflicts = update_conflicts if static_mapping is not None: self.static_mapping = OrderedDict(static_mapping) else: self.static_mapping = OrderedDict() + # Convert field names to fields + opts = self.model._meta + if update_fields: + self.update_fields = [opts.get_field(name) for name in update_fields] + else: + self.update_fields = update_fields + if unique_fields: + # Primary key is allowed in unique_fields + self.unique_fields = [ + opts.get_field(opts.pk.name if name == "pk" else name) + for name in unique_fields + ] + else: + self.unique_fields = unique_fields + # Line up the database connection if using is not None: self.using = using @@ -85,12 +104,13 @@ def __init__( if self.conn.vendor != "postgresql": raise TypeError("Only PostgreSQL backends supported") - # Check if it is PSQL 9.5 or greater, which determines if ignore_conflicts is supported - self.supports_ignore_conflicts = self.is_postgresql_9_5() - if self.ignore_conflicts and not self.supports_ignore_conflicts: - raise NotSupportedError( - "This database backend does not support ignoring conflicts." - ) + # Use Django to validate ON CONFLICT related kwargs + self.on_conflict = self.model.objects.none()._check_bulk_create_options( + ignore_conflicts=self.ignore_conflicts, + update_conflicts=self.update_conflicts, + update_fields=self.update_fields, + unique_fields=self.unique_fields, + ) # Pull the CSV headers self.headers = self.get_headers() @@ -351,12 +371,27 @@ def insert_suffix(self) -> str: """ Preps the suffix to the insert query. """ - if self.ignore_conflicts: - return """ + if self.on_conflict == OnConflict.IGNORE: + suffix = """ ON CONFLICT DO NOTHING; """ + elif self.on_conflict == OnConflict.UPDATE: + update_columns = [field.column for field in self.update_fields] + model_table = self.model._meta.db_table + + suffix = """ + ON CONFLICT ({target}) DO UPDATE + SET {values} + WHERE ({new}) IS DISTINCT FROM ({old}); + """.format( + target=", ".join(f.column for f in self.unique_fields), + values=", ".join(f"{c}=EXCLUDED.{c}" for c in update_columns), + new=", ".join(f"{model_table}.{c}" for c in update_columns), + old=", ".join(f"EXCLUDED.{c}" for c in update_columns), + ) else: - return ";" + suffix = ";" + return suffix def prep_insert(self) -> str: """ diff --git a/postgres_copy/managers.py b/postgres_copy/managers.py index e3875a2..0aa944e 100644 --- a/postgres_copy/managers.py +++ b/postgres_copy/managers.py @@ -231,6 +231,12 @@ def from_csv( "drop_constraints=False and drop_indexes=False." ) + if kwargs.get("update_conflicts"): + raise ValueError( + "update_conflicts is mutually exclusive with " + "drop_constraints or drop_indexes." + ) + # Create a mapping dictionary if none was provided mapping_dict = mapping if mapping is not None else {} From 7949b99d9eb62025d19fa2a799a999567fdaef27 Mon Sep 17 00:00:00 2001 From: Geoffrey Eisenbarth Date: Wed, 13 Aug 2025 10:59:43 -0500 Subject: [PATCH 5/5] Close #168 --- docs/index.rst | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 53dedd4..a9f4d92 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -476,6 +476,48 @@ Now you can run that subclass directly rather than via a manager. The only diffe c.save() +For example, if you wish to return a QuerySet of the models imported into the database, you could do the following: + +.. code-block:: python + :emphasize-lines: 9,14,25 + + + from django.db import models + from postgres_copy import CopyMapping, CopyQuerySet + + + class ResultsCopyMapping(CopyMapping): + def insert_suffix(self) -> str: + """Add `RETURNING` sql clause to get newly created/updated ids.""" + suffix = super().insert_suffix() + suffix = suffix.split(';')[0] + ' RETURNING id;' + return suffix + + def post_insert(self, cursor) -> None: + """Extend to store results from `RETURNING` clause.""" + self.obj_ids = [r[0] for r in cursor.fetchall()] + + class ResultsCopyQuerySet(CopyQuerySet): + def from_csv(self, csv_path_or_obj, mapping=None, **kwargs): + mapping = ResultsCopyMapping( + self.model, + csv_path_or_obj, + mapping=None, + **kwargs + ) + count = mapping.save(silent=True) + objs = self.model.objects.filter(id__in=mapping.obj_ids) + return objs, count + + + class Person(models.Model): + name = models.CharField(max_length=500) + number = models.IntegerField() + source_csv = models.CharField(max_length=500) + + objects = ResultsCopyQuerySet.as_manager() + + Export options ==============