Skip to content

Commit 38e1965

Browse files
committed
add tests
1 parent 4d87545 commit 38e1965

File tree

1 file changed

+313
-0
lines changed

1 file changed

+313
-0
lines changed
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
from unittest.mock import patch
2+
3+
from django.conf import settings
4+
from django.db import connection
5+
from django.urls import reverse
6+
from djstripe.enums import BillingScheme
7+
from djstripe.models import Customer
8+
from model_bakery import baker
9+
from rest_framework import status
10+
11+
from kobo.apps.kobo_auth.shortcuts import User
12+
from kobo.apps.organizations.constants import UsageType
13+
from kpi.models.user_reports import BillingAndUsageSnapshot
14+
from kpi.tests.base_test_case import BaseTestCase
15+
16+
17+
class UserReportsViewSetAPITestCase(BaseTestCase):
18+
fixtures = ['test_data']
19+
20+
def setUp(self):
21+
self.client.login(username='adminuser', password='pass')
22+
self.url = reverse(self._get_endpoint('api_v2:user-reports-list'))
23+
24+
# Creat and add a subscription to someuser
25+
self.someuser = User.objects.get(username='someuser')
26+
organization = self.someuser.organization
27+
self.customer = baker.make(Customer, subscriber=organization)
28+
self.subscription = baker.make(
29+
'djstripe.Subscription',
30+
customer=self.customer,
31+
items__price__livemode=False,
32+
items__price__billing_scheme=BillingScheme.per_unit,
33+
livemode=False,
34+
metadata={'organization_id': str(organization.id)},
35+
)
36+
37+
baker.make('kpi.BillingAndUsageSnapshot', organization_id=organization.id)
38+
39+
# Manually refresh the materialized view
40+
with connection.cursor() as cursor:
41+
cursor.execute('REFRESH MATERIALIZED VIEW user_reports_mv;')
42+
43+
def test_list_view_requires_authentication(self):
44+
self.client.logout()
45+
response = self.client.get(self.url)
46+
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
47+
48+
def test_list_view_requires_superuser_permission(self):
49+
self.client.logout()
50+
self.client.force_login(user=self.someuser)
51+
response = self.client.get(self.url)
52+
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
53+
54+
def test_list_view_succeeds_for_superuser(self):
55+
response = self.client.get(self.url)
56+
self.assertEqual(response.status_code, status.HTTP_200_OK)
57+
# Make sure that all 3 users from the 'test_data' are included
58+
self.assertEqual(len(response.data['results']), 3)
59+
60+
def test_endpoint_returns_error_when_stripe_is_disabled(self):
61+
try:
62+
settings.STRIPE_ENABLED = False
63+
response = self.client.get(self.url)
64+
65+
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
66+
self.assertEqual(
67+
response.json(),
68+
{'details': 'Stripe must be enabled to access this endpoint.'},
69+
)
70+
finally:
71+
# Restore the original setting
72+
settings.STRIPE_ENABLED = True
73+
74+
def test_endpoint_returns_error_when_mv_is_missing(self):
75+
# Drop the materialized view before the test
76+
with connection.cursor() as cursor:
77+
cursor.execute('DROP MATERIALIZED VIEW IF EXISTS user_reports_mv CASCADE;')
78+
79+
response = self.client.get(self.url)
80+
81+
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
82+
self.assertEqual(
83+
response.json(),
84+
{
85+
'details': 'The data source for user reports is missing. '
86+
'Please run migration 0070 to create the materialized '
87+
'view: user_reports_mv.',
88+
},
89+
)
90+
91+
def test_subscription_data_is_correctly_returned(self):
92+
user_with_sub = self._get_someuser_data()
93+
self.assertEqual(len(user_with_sub['subscriptions']), 1)
94+
self.assertEqual(user_with_sub['subscriptions'][0]['id'], self.subscription.id)
95+
96+
subscription_item = user_with_sub['subscriptions'][0]['items'][0]
97+
98+
self.assertEqual(subscription_item['id'], self.subscription.items.first().id)
99+
self.assertEqual(
100+
subscription_item['price']['id'], self.subscription.items.first().price.id
101+
)
102+
self.assertEqual(
103+
subscription_item['price']['product']['id'],
104+
self.subscription.items.first().price.product.id,
105+
)
106+
self.assertEqual(
107+
user_with_sub['subscriptions'][0]['customer'], self.customer.id
108+
)
109+
self.assertEqual(
110+
user_with_sub['subscriptions'][0]['metadata']['organization_id'],
111+
self.subscription.metadata['organization_id'],
112+
)
113+
114+
@patch('kpi.serializers.v2.user_reports.get_organizations_effective_limits')
115+
def test_current_service_usage_data_is_correctly_returned(self, mock_get_limits):
116+
# Update a BillingAndUsageSnapshot with specific usage data
117+
billing_and_usage_snapshot = BillingAndUsageSnapshot.objects.get(
118+
organization_id=self.someuser.organization.id
119+
)
120+
billing_and_usage_snapshot.current_period_submissions = 15
121+
billing_and_usage_snapshot.submission_counts_all_time = 150
122+
billing_and_usage_snapshot.current_period_asr = 120
123+
billing_and_usage_snapshot.nlp_usage_asr_seconds_total = 240
124+
billing_and_usage_snapshot.storage_bytes_total = 200000000
125+
billing_and_usage_snapshot.save()
126+
127+
# Mock `get_organizations_effective_limits` to return test limits.
128+
mock_limits = {
129+
self.someuser.organization.id: {
130+
f'{UsageType.SUBMISSION}_limit': 10, # Exceeded
131+
f'{UsageType.STORAGE_BYTES}_limit': 500000000, # Not exceeded
132+
f'{UsageType.ASR_SECONDS}_limit': 120, # At the limit
133+
f'{UsageType.MT_CHARACTERS}_limit': 5, # Not used
134+
}
135+
}
136+
mock_get_limits.return_value = mock_limits
137+
138+
# Refresh the materialized view to sync with the snapshot
139+
with connection.cursor() as cursor:
140+
cursor.execute('REFRESH MATERIALIZED VIEW user_reports_mv;')
141+
142+
someuser_data = self._get_someuser_data()
143+
144+
service_usage = someuser_data['current_service_usage']
145+
# Assert total usage counts from the snapshot
146+
self.assertEqual(service_usage['total_submission_count']['current_period'], 15)
147+
self.assertEqual(service_usage['total_submission_count']['all_time'], 150)
148+
self.assertEqual(service_usage['total_storage_bytes'], 200000000)
149+
self.assertEqual(
150+
service_usage['total_nlp_usage']['asr_seconds_current_period'], 0
151+
)
152+
self.assertEqual(service_usage['total_nlp_usage']['asr_seconds_all_time'], 0)
153+
self.assertEqual(
154+
service_usage['total_nlp_usage']['mt_characters_current_period'], 0
155+
)
156+
self.assertEqual(service_usage['total_nlp_usage']['mt_characters_all_time'], 0)
157+
158+
# Assert calculated balances based on mock limits and real results
159+
balances = service_usage['balances']
160+
161+
# Submission balance: 15 / 10 = 1.5, so 150% and exceeded.
162+
self.assertIsNotNone(balances['submission'])
163+
self.assertTrue(balances['submission']['exceeded'])
164+
self.assertEqual(balances['submission']['effective_limit'], 10)
165+
self.assertEqual(balances['submission']['balance_value'], -5)
166+
self.assertEqual(balances['submission']['balance_percent'], 150)
167+
168+
# Storage balance: 200,000,000 / 500,000,000 = 0.4, so 40% and not exceeded.
169+
self.assertIsNotNone(balances['storage_bytes'])
170+
self.assertFalse(balances['storage_bytes']['exceeded'])
171+
self.assertEqual(balances['storage_bytes']['effective_limit'], 500000000)
172+
self.assertEqual(balances['storage_bytes']['balance_value'], 300000000)
173+
self.assertEqual(balances['storage_bytes']['balance_percent'], 40)
174+
175+
# ASR Seconds balance: 0 / 120 = 0, so 0% and not exceeded.
176+
self.assertIsNotNone(balances['asr_seconds'])
177+
self.assertFalse(balances['asr_seconds']['exceeded'])
178+
self.assertEqual(balances['asr_seconds']['effective_limit'], 120)
179+
self.assertEqual(balances['asr_seconds']['balance_value'], 120)
180+
self.assertEqual(balances['asr_seconds']['balance_percent'], 0)
181+
182+
# MT Characters balance: 0 / 5 = 0, so 0% and not exceeded.
183+
self.assertIsNotNone(balances['mt_characters'])
184+
self.assertFalse(balances['mt_characters']['exceeded'])
185+
self.assertEqual(balances['mt_characters']['effective_limit'], 5)
186+
self.assertEqual(balances['mt_characters']['balance_value'], 5)
187+
self.assertEqual(balances['mt_characters']['balance_percent'], 0)
188+
189+
def test_organization_data_is_correctly_returned(self):
190+
someuser_data = self._get_someuser_data()
191+
192+
organization_data = someuser_data['organizations']
193+
194+
self.assertEqual(
195+
organization_data['organization_name'], self.someuser.organization.name
196+
)
197+
self.assertEqual(
198+
organization_data['organization_uid'], str(self.someuser.organization.id)
199+
)
200+
self.assertEqual(organization_data['role'], 'owner')
201+
202+
def test_account_restricted_field(self):
203+
# Verify `account_restricted` is initially false
204+
response = self.client.get(self.url)
205+
self.assertEqual(response.status_code, status.HTTP_200_OK)
206+
207+
results = response.data['results']
208+
someuser_data = next(
209+
(user for user in results if user['username'] == 'someuser'),
210+
None,
211+
)
212+
213+
self.assertIsNotNone(someuser_data)
214+
self.assertFalse(someuser_data['account_restricted'])
215+
216+
# Update the BillingAndUsageSnapshot to exceed the desired limit
217+
billing_and_usage_snapshot = BillingAndUsageSnapshot.objects.get(
218+
organization_id=self.someuser.organization.id
219+
)
220+
billing_and_usage_snapshot.current_period_submissions = 10
221+
billing_and_usage_snapshot.save()
222+
223+
# Mock the `get_organizations_effective_limits` function
224+
# to return a predefined limit
225+
mock_limits = {
226+
self.someuser.organization.id: {f'{UsageType.SUBMISSION}_limit': 1}
227+
}
228+
with patch(
229+
'kpi.serializers.v2.user_reports.get_organizations_effective_limits',
230+
return_value=mock_limits,
231+
):
232+
with connection.cursor() as cursor:
233+
cursor.execute('REFRESH MATERIALIZED VIEW user_reports_mv;')
234+
235+
someuser_data = self._get_someuser_data()
236+
self.assertTrue(someuser_data['account_restricted'])
237+
238+
def test_accepted_tos_field(self):
239+
# Verify `accepted_tos` is initially false
240+
response = self.client.get(self.url)
241+
self.assertEqual(response.status_code, status.HTTP_200_OK)
242+
243+
results = response.data['results']
244+
self.assertEqual(results[0]['accepted_tos'], False)
245+
246+
# POST to the tos endpoint to accept the terms of service
247+
tos_url = reverse(self._get_endpoint('tos'))
248+
response = self.client.post(tos_url)
249+
assert response.status_code == status.HTTP_204_NO_CONTENT
250+
251+
with connection.cursor() as cursor:
252+
cursor.execute('REFRESH MATERIALIZED VIEW user_reports_mv;')
253+
254+
# Verify `accepted_tos` has been set to True
255+
response = self.client.get(self.url)
256+
self.assertEqual(response.status_code, status.HTTP_200_OK)
257+
258+
results = response.data['results']
259+
self.assertTrue(results[0]['accepted_tos'])
260+
261+
def test_filter_by_email_icontains(self):
262+
response = self.client.get(self.url, {'email': 'some@user'})
263+
self.assertEqual(response.status_code, status.HTTP_200_OK)
264+
self.assertEqual(response.data['count'], 1)
265+
self.assertEqual(response.data['results'][0]['email'], '[email protected]')
266+
267+
def test_filter_by_username_icontains(self):
268+
response = self.client.get(self.url, {'username': 'some'})
269+
self.assertEqual(response.status_code, status.HTTP_200_OK)
270+
self.assertEqual(response.data['count'], 1)
271+
self.assertEqual(response.data['results'][0]['username'], 'someuser')
272+
273+
def test_filter_by_submission_counts_range(self):
274+
billing_and_usage_snapshot = BillingAndUsageSnapshot.objects.get(
275+
organization_id=self.someuser.organization.id
276+
)
277+
billing_and_usage_snapshot.submission_counts_all_time = 50
278+
billing_and_usage_snapshot.save()
279+
280+
with connection.cursor() as cursor:
281+
cursor.execute('REFRESH MATERIALIZED VIEW user_reports_mv;')
282+
283+
# Test filter for submissions > 40
284+
response = self.client.get(self.url, {'submission_counts_all_time_min': 40})
285+
self.assertEqual(response.status_code, status.HTTP_200_OK)
286+
self.assertEqual(response.data['count'], 1)
287+
self.assertEqual(response.data['results'][0]['username'], 'someuser')
288+
289+
# Test filter for submissions < 40
290+
response = self.client.get(self.url, {'submission_counts_all_time_max': 40})
291+
self.assertEqual(response.status_code, status.HTTP_200_OK)
292+
self.assertGreater(response.data['count'], 0)
293+
294+
def test_ordering_by_date_joined(self):
295+
response = self.client.get(self.url, {'ordering': 'date_joined'})
296+
self.assertEqual(response.status_code, status.HTTP_200_OK)
297+
298+
results = response.data['results']
299+
self.assertEqual(results[0]['username'], 'adminuser')
300+
self.assertEqual(results[1]['username'], 'someuser')
301+
self.assertEqual(results[2]['username'], 'anotheruser')
302+
303+
def _get_someuser_data(self):
304+
305+
response = self.client.get(self.url)
306+
self.assertEqual(response.status_code, status.HTTP_200_OK)
307+
308+
results = response.data['results']
309+
someuser_data = next(
310+
(user for user in results if user['username'] == 'someuser'), None
311+
)
312+
self.assertIsNotNone(someuser_data)
313+
return someuser_data

0 commit comments

Comments
 (0)