Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,18 @@ Keyword Argument Description
``ignore_conflicts`` Specify True to ignore unique constraint or exclusion
constraint violation errors. The default is False.

``update_conflicts`` Specify True to update rows that fail insertion due
to conflicts. Requires both ``update_fields`` and
``unique_fields`` to be specified. The default is False.

``update_fields`` When ``update_conflicts`` is ``True``, this specifies
which model fields should be updated. If passed, this
must be a list of field names. The default is None.

``unique_fields`` When ``update_conflicts`` is ``True``, this specifies
which fields might be in conflict. If passed, this must
be a list of field names. The default is None.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
database.
Expand Down Expand Up @@ -464,6 +476,48 @@ Now you can run that subclass directly rather than via a manager. The only diffe
c.save()


For example, if you wish to return a QuerySet of the models imported into the database, you could do the following:

.. code-block:: python
:emphasize-lines: 9,14,25


from django.db import models
from postgres_copy import CopyMapping, CopyQuerySet


class ResultsCopyMapping(CopyMapping):
def insert_suffix(self) -> str:
"""Add `RETURNING` sql clause to get newly created/updated ids."""
suffix = super().insert_suffix()
suffix = suffix.split(';')[0] + ' RETURNING id;'
return suffix

def post_insert(self, cursor) -> None:
"""Extend to store results from `RETURNING` clause."""
self.obj_ids = [r[0] for r in cursor.fetchall()]

class ResultsCopyQuerySet(CopyQuerySet):
def from_csv(self, csv_path_or_obj, mapping=None, **kwargs):
mapping = ResultsCopyMapping(
self.model,
csv_path_or_obj,
mapping=None,
**kwargs
)
count = mapping.save(silent=True)
objs = self.model.objects.filter(id__in=mapping.obj_ids)
return objs, count


class Person(models.Model):
name = models.CharField(max_length=500)
number = models.IntegerField()
source_csv = models.CharField(max_length=500)

objects = ResultsCopyQuerySet.as_manager()


Export options
==============

Expand Down
57 changes: 46 additions & 11 deletions postgres_copy/copy_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

from django.contrib.humanize.templatetags.humanize import intcomma
from django.core.exceptions import FieldDoesNotExist
from django.db import NotSupportedError, connections, router
from django.db import connections, router
from django.db.models import Field, Model
from django.db.models.constants import OnConflict
from django.db.backends.utils import CursorWrapper

from .psycopg_compat import copy_from
Expand All @@ -40,6 +41,9 @@ def __init__(
force_null: typing.Optional[typing.List[str]] = None,
encoding: typing.Optional[str] = None,
ignore_conflicts: bool = False,
update_conflicts: bool = False,
update_fields: typing.Optional[typing.Collection[str]] = None,
unique_fields: typing.Optional[typing.Collection[str]] = None,
static_mapping: typing.Optional[typing.Dict[str, str]] = None,
temp_table_name: typing.Optional[str] = None,
) -> None:
Expand All @@ -66,13 +70,28 @@ def __init__(
self.force_not_null = force_not_null
self.force_null = force_null
self.encoding = encoding
self.supports_ignore_conflicts = True
self.ignore_conflicts = ignore_conflicts
self.update_conflicts = update_conflicts
if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
self.static_mapping = OrderedDict()

# Convert field names to fields
opts = self.model._meta
if update_fields:
self.update_fields = [opts.get_field(name) for name in update_fields]
else:
self.update_fields = update_fields
if unique_fields:
# Primary key is allowed in unique_fields
self.unique_fields = [
opts.get_field(opts.pk.name if name == "pk" else name)
for name in unique_fields
]
else:
self.unique_fields = unique_fields

# Line up the database connection
if using is not None:
self.using = using
Expand All @@ -85,12 +104,13 @@ def __init__(
if self.conn.vendor != "postgresql":
raise TypeError("Only PostgreSQL backends supported")

# Check if it is PSQL 9.5 or greater, which determines if ignore_conflicts is supported
self.supports_ignore_conflicts = self.is_postgresql_9_5()
if self.ignore_conflicts and not self.supports_ignore_conflicts:
raise NotSupportedError(
"This database backend does not support ignoring conflicts."
)
# Use Django to validate ON CONFLICT related kwargs
self.on_conflict = self.model.objects.none()._check_bulk_create_options(
ignore_conflicts=self.ignore_conflicts,
update_conflicts=self.update_conflicts,
update_fields=self.update_fields,
unique_fields=self.unique_fields,
)

# Pull the CSV headers
self.headers = self.get_headers()
Expand Down Expand Up @@ -351,12 +371,27 @@ def insert_suffix(self) -> str:
"""
Preps the suffix to the insert query.
"""
if self.ignore_conflicts:
return """
if self.on_conflict == OnConflict.IGNORE:
suffix = """
ON CONFLICT DO NOTHING;
"""
elif self.on_conflict == OnConflict.UPDATE:
update_columns = [field.column for field in self.update_fields]
model_table = self.model._meta.db_table

suffix = """
ON CONFLICT ({target}) DO UPDATE
SET {values}
WHERE ({new}) IS DISTINCT FROM ({old});
""".format(
target=", ".join(f.column for f in self.unique_fields),
values=", ".join(f"{c}=EXCLUDED.{c}" for c in update_columns),
new=", ".join(f"{model_table}.{c}" for c in update_columns),
old=", ".join(f"EXCLUDED.{c}" for c in update_columns),
)
else:
return ";"
suffix = ";"
return suffix

def prep_insert(self) -> str:
"""
Expand Down
51 changes: 51 additions & 0 deletions postgres_copy/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

from django.db import connection, models
from django.db.models.constraints import BaseConstraint
from django.db.models.fields import Field
from django.db.transaction import TransactionManagementError
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
Expand All @@ -30,13 +31,27 @@ def constrained_fields(self) -> typing.List[Field]:
if hasattr(f, "db_constraint") and f.db_constraint
]

@property
def model_constraints(self) -> typing.List[BaseConstraint]:
"""
Returns list of model-level constraints.
"""
return getattr(self.model._meta, "constraints", [])

@property
def indexed_fields(self) -> typing.List[Field]:
"""
Returns list of model's fields with db_index set to True.
"""
return [f for f in self.model._meta.fields if f.db_index]

@property
def model_indexes(self) -> typing.List[models.Index]:
"""
Returns list of model-level indexes.
"""
return getattr(self.model._meta, "indexes", [])

def edit_schema(
self,
schema_editor: BaseDatabaseSchemaEditor,
Expand Down Expand Up @@ -79,6 +94,14 @@ def drop_constraints(self) -> None:
args = (self.model, field, field_copy)
self.edit_schema(schema_editor, "alter_field", args)

# Remove any model constraints
for constraint in self.model_constraints:
logger.debug(
f"Dropping constraint '{constraint.name}' from {self.model.__name__}"
)
args = (self.model, constraint)
self.edit_schema(schema_editor, "remove_constraint", args)

def drop_indexes(self) -> None:
"""
Drop indexes on the model and its fields.
Expand All @@ -102,6 +125,14 @@ def drop_indexes(self) -> None:
args = (self.model, field, field_copy)
self.edit_schema(schema_editor, "alter_field", args)

# Remove any model indexes
for index in self.model_indexes:
logger.debug(
f"Dropping index '{index.name}' from {self.model.__name__}"
)
args = (self.model, index)
self.edit_schema(schema_editor, "remove_index", args)

def restore_constraints(self) -> None:
"""
Restore constraints on the model and its fields.
Expand All @@ -127,6 +158,14 @@ def restore_constraints(self) -> None:
args = (self.model, field_copy, field)
self.edit_schema(schema_editor, "alter_field", args)

# Add any constraints to the model
for constraint in self.model_constraints:
logger.debug(
f"Adding constraint '{constraint.name}' to {self.model.__name__}"
)
args = (self.model, constraint)
self.edit_schema(schema_editor, "add_constraint", args)

def restore_indexes(self) -> None:
"""
Restore indexes on the model and its fields.
Expand All @@ -152,6 +191,12 @@ def restore_indexes(self) -> None:
args = (self.model, field_copy, field)
self.edit_schema(schema_editor, "alter_field", args)

# Add any indexes to the model
for index in self.model_indexes:
logger.debug(f"Adding index '{index.name}' to {self.model.__name__}")
args = (self.model, index)
self.edit_schema(schema_editor, "add_index", args)


class CopyQuerySet(ConstraintQuerySet):
"""
Expand Down Expand Up @@ -186,6 +231,12 @@ def from_csv(
"drop_constraints=False and drop_indexes=False."
)

if kwargs.get("update_conflicts"):
raise ValueError(
"update_conflicts is mutually exclusive with "
"drop_constraints or drop_indexes."
)

# Create a mapping dictionary if none was provided
mapping_dict = mapping if mapping is not None else {}

Expand Down
54 changes: 43 additions & 11 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from .fields import MyIntegerField


class MockObjectMeta:
app_label = "tests"
unique_together = ("name", "number")


if django.get_version() <= "5.1":
MockObjectMeta.index_together = ("name", "number")
else:
MockObjectMeta.indexes = [models.Index(fields=["name", "number"])]


class MockObject(models.Model):
name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column="num")
Expand All @@ -15,16 +26,8 @@ class MockObject(models.Model):
)
objects = CopyManager()

class Meta:
app_label = "tests"
unique_together = ("name", "number")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if django.get_version() <= "5.1":
self._meta.index_together = ("name", "number")
else:
self._meta.indexes = [models.Index(fields=["name", "number"])]
class Meta(MockObjectMeta):
pass

def copy_name_template(self):
return 'upper("%(name)s")'
Expand Down Expand Up @@ -128,6 +131,35 @@ class SecondaryMockObject(models.Model):
objects = CopyManager()


class UniqueMockObject(models.Model):
class UniqueFieldConstraintMockObject(models.Model):
name = models.CharField(max_length=500, unique=True)
objects = CopyManager()


class UniqueModelConstraintMockObject(models.Model):
name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column="num")
objects = CopyManager()

class Meta:
constraints = [
models.UniqueConstraint(
name="constraint",
fields=["name"],
),
]


class UniqueModelConstraintAsIndexMockObject(models.Model):
name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column="num")
objects = CopyManager()

class Meta:
constraints = [
models.UniqueConstraint(
name="constraint_as_index",
fields=["name"],
include=["number"], # Converts Constraint to Index
),
]
Loading
Loading