Skip to content

Add type anotation to methods and functions #779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions push_notifications/api/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from rest_framework.response import Response
from rest_framework.serializers import ModelSerializer, Serializer, ValidationError
from rest_framework.viewsets import ModelViewSet

from ..fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re
from ..models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
from ..settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS
from typing import Any
from push_notifications.fields import UNSIGNED_64BIT_INT_MAX_VALUE, hex_re
from push_notifications.models import APNSDevice, GCMDevice, WebPushDevice, WNSDevice
from push_notifications.settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS
from django.db.models import QuerySet


# Fields
Expand All @@ -15,16 +16,16 @@ class HexIntegerField(IntegerField):
Store an integer represented as a hex string of form "0x01".
"""

def to_internal_value(self, data):
def to_internal_value(self, data: str | int) -> int:
# validate hex string and convert it to the unsigned
# integer representation for internal use
try:
data = int(data, 16) if type(data) != int else data
data = int(data, 16) if not isinstance(data, int) else data
except ValueError:
raise ValidationError("Device ID is not a valid hex number")
return super().to_internal_value(data)

def to_representation(self, value):
def to_representation(self, value: int) -> int:
return value


Expand All @@ -45,7 +46,7 @@ class APNSDeviceSerializer(ModelSerializer):
class Meta(DeviceSerializerMixin.Meta):
model = APNSDevice

def validate_registration_id(self, value):
def validate_registration_id(self, value: str) -> str:

# https://developer.apple.com/documentation/uikit/uiapplicationdelegate/1622958-application
# As of 02/2023 APNS tokens (registration_id) "are of variable length. Do not hard-code their size."
Expand All @@ -56,7 +57,7 @@ def validate_registration_id(self, value):


class UniqueRegistrationSerializerMixin(Serializer):
def validate(self, attrs):
def validate(self, attrs: dict[str, Any]) -> dict[str, Any]:
devices = None
primary_key = None
request_method = None
Expand Down Expand Up @@ -103,7 +104,7 @@ class Meta(DeviceSerializerMixin.Meta):
)
extra_kwargs = {"id": {"read_only": False, "required": False}}

def validate_device_id(self, value):
def validate_device_id(self, value: int) -> int:
# device ids are 64 bit unsigned values
if value > UNSIGNED_64BIT_INT_MAX_VALUE:
raise ValidationError("Device ID is out of range")
Expand All @@ -126,7 +127,7 @@ class Meta(DeviceSerializerMixin.Meta):

# Permissions
class IsOwner(permissions.BasePermission):
def has_object_permission(self, request, view, obj):
def has_object_permission(self, request: Any, view: Any, obj: Any) -> bool:
# must be the owner to view the object
return obj.user == request.user

Expand All @@ -135,7 +136,7 @@ def has_object_permission(self, request, view, obj):
class DeviceViewSetMixin:
lookup_field = "registration_id"

def create(self, request, *args, **kwargs):
def create(self, request: Any, *args: Any, **kwargs: Any) -> Response:
serializer = None
is_update = False
if SETTINGS.get("UPDATE_ON_DUPLICATE_REG_ID") and self.lookup_field in request.data:
Expand All @@ -157,12 +158,12 @@ def create(self, request, *args, **kwargs):
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

def perform_create(self, serializer):
def perform_create(self, serializer: ModelSerializer) -> None:
if self.request.user.is_authenticated:
serializer.save(user=self.request.user)
return super().perform_create(serializer)

def perform_update(self, serializer):
def perform_update(self, serializer: ModelSerializer) -> None:
if self.request.user.is_authenticated:
serializer.save(user=self.request.user)
return super().perform_update(serializer)
Expand All @@ -171,7 +172,7 @@ def perform_update(self, serializer):
class AuthorizedMixin:
permission_classes = (permissions.IsAuthenticated, IsOwner)

def get_queryset(self):
def get_queryset(self) -> QuerySet:
# filter all devices to only those belonging to the current user
return self.queryset.filter(user=self.request.user)

Expand Down
46 changes: 37 additions & 9 deletions push_notifications/apns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from . import models
from .conf import get_manager
from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError
from typing import Any, Callable


def _apns_create_socket(creds=None, application_id=None):
def _apns_create_socket(creds=None, application_id=None) -> apns2_client.APNsClient:
if creds is None:
if not get_manager().has_auth_token_creds(application_id):
cert = get_manager().get_apns_certificate(application_id)
Expand All @@ -39,9 +40,20 @@ def _apns_create_socket(creds=None, application_id=None):


def _apns_prepare(
token, alert, application_id=None, badge=None, sound=None, category=None,
content_available=False, action_loc_key=None, loc_key=None, loc_args=[],
extra={}, mutable_content=False, thread_id=None, url_args=None):
token: str,
alert: str | dict[str, Any] | None,
application_id: str | None = None,
badge: int | Callable | None = None,
sound: str | None = None,
category: str | None = None,
content_available: bool = False,
action_loc_key: str | None = None,
loc_key: str | None = None,
loc_args: str | None = [],
extra: dict[str, Any] = {},
mutable_content: bool = False,
thread_id: str | None = None,
url_args: list[str] | None = None) -> apns2_payload.Payload:
if action_loc_key or loc_key or loc_args:
apns2_alert = apns2_payload.PayloadAlert(
body=alert if alert else {}, body_localized_key=loc_key,
Expand All @@ -59,8 +71,13 @@ def _apns_prepare(


def _apns_send(
registration_id, alert, batch=False, application_id=None, creds=None, **kwargs
):
registration_id: str | list[str],
alert: str | dict[str, Any] | None,
batch: bool = False,
application_id: str | None = None,
creds: apns2_credentials.Credentials | None = None,
**kwargs: Any
) -> dict[str, str] | None:
client = _apns_create_socket(creds=creds, application_id=application_id)

notification_kwargs = {}
Expand Down Expand Up @@ -95,9 +112,16 @@ def _apns_send(
get_manager().get_apns_topic(application_id=application_id),
**notification_kwargs
)
return None


def apns_send_message(registration_id, alert, application_id=None, creds=None, **kwargs):
def apns_send_message(
registration_id: str,
alert: str | dict[str, Any] | None,
application_id: str | None = None,
creds: apns2_credentials.Credentials | None = None,
**kwargs: Any
) -> None:
"""
Sends an APNS notification to a single registration_id.
This will send the notification as form data.
Expand All @@ -122,8 +146,12 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, *


def apns_send_bulk_message(
registration_ids, alert, application_id=None, creds=None, **kwargs
):
registration_ids: list[str],
alert: str | dict[str, Any] | None,
application_id: str | None = None,
creds: apns2_credentials.Credentials | None = None,
**kwargs: Any
) -> dict[str, str]:
"""
Sends an APNS notification to one or more registration_ids.
The registration_ids argument needs to be a list.
Expand Down
Loading