Skip to content
148 changes: 147 additions & 1 deletion mockito/inorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,157 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import annotations

from .mockito import verify as verify_main
from collections import deque
from functools import partial
from typing import Deque

from .verification import VerificationError
from .invocation import (
RealInvocation,
VerifiableInvocation,
verification_has_lower_bound_of_zero,
)
from .mockito import ArgumentError, verify as verify_main
from .mock_registry import mock_registry


def verify(object, *args, **kwargs):
kwargs['inorder'] = True
return verify_main(object, *args, **kwargs)


class InOrder:

def __init__(self, *objects: object):
objects_ = []
for obj in objects:
if obj in objects_:
raise ValueError(f"{obj} is provided more than once")
objects_.append(obj)
self._objects = objects_
self._attach_all()
self.ordered_invocations: Deque[RealInvocation] = deque()

def _attach_all(self):
for obj in self._objects:
if m := mock_registry.mock_for(obj):
m.attach(self)

def update(self, invocation: RealInvocation) -> None:
self.ordered_invocations.append(invocation)

def verify(
self,
obj: object,
times=None,
atleast=None,
atmost=None,
between=None,
):
"""
Central method of InOrder class.
Use this method to verify the calling order of observed mocks.
:param obj: obj to verify the ordered invocation

"""
expected_mock = mock_registry.mock_for(obj)
if expected_mock is None:
raise ArgumentError(
f"\n{obj} is not setup with any stubbings or expectations."
)

if obj not in self._objects:
raise ArgumentError(
f"\n{obj} is not part of that InOrder."
)

return verify_main(
obj=obj,
times=times,
atleast=atleast,
atmost=atmost,
between=between,
_factory=partial(InOrderVerifiableInvocation, inorder=self),
)

def __enter__(self):
self._attach_all()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
for obj in self._objects:
if m := mock_registry.mock_for(obj):
m.detach(self)


class InOrderVerifiableInvocation(VerifiableInvocation):
def __init__(self, mock, method_name, verification, inorder: InOrder):
super().__init__(mock, method_name, verification)
self._inorder = inorder

def __call__(self, *params, **named_params): # noqa: C901
self._remember_params(params, named_params)

ordered = self._inorder.ordered_invocations

if not ordered:
raise VerificationError(
"\nThere are no recorded invocations."
)

# Find first invocation in global order that hasn't been used
# for "in-order" verification yet.
try:
start_idx, next_invocation = next(
(i, inv)
for i, inv in enumerate(ordered)
if not inv.verified_inorder
)
except StopIteration:
raise VerificationError(
"\nThere are no more recorded invocations."
)

called_mock = next_invocation.mock
if called_mock is not self.mock:
called_obj = mock_registry.obj_for(called_mock)
if called_obj is None:
raise RuntimeError(
f"{called_mock} is not in the registry (anymore)."
)
expected_obj = mock_registry.obj_for(self.mock)
raise VerificationError(
f"\nWanted a call from {expected_obj}, but "
f"got {called_obj}.{next_invocation} instead!"
)

matched_invocations = []

# Walk the contiguous block of this mock in the global queue.
for inv in list(ordered)[start_idx:]:
if inv.verified_inorder:
continue
if inv.mock is not self.mock:
break

if not self.matches(inv):
raise VerificationError(
"\nWanted %s to be invoked,\n"
"got %s instead." % (self, inv)
)

self.capture_arguments(inv)
matched_invocations.append(inv)

self.verification.verify(self, len(matched_invocations))

for inv in matched_invocations:
inv.verified = True
inv.verified_inorder = True

if verification_has_lower_bound_of_zero(self.verification):
for stub in self.mock.stubbed_invocations:
if stub.matches(self) or self.matches(stub):
stub.allow_zero_invocations = True
9 changes: 9 additions & 0 deletions mockito/mock_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def register(self, obj: object, mock: Mock) -> None:
def mock_for(self, obj: object) -> Mock | None:
return self.mocks.get(obj, None)

def obj_for(self, mock: Mock) -> object | None:
return self.mocks.lookup(mock)

def unstub(self, obj: object) -> None:
try:
mock = self.mocks.pop(obj)
Expand Down Expand Up @@ -84,6 +87,12 @@ def get(self, key, default=None):
return value
return default

def lookup(self, value, default=None):
for key, v in self._store:
if v is value:
return key
return default

def values(self):
return [v for k, v in self._store]

Expand Down
16 changes: 15 additions & 1 deletion mockito/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def remembered_invocation_builder(
return invoc(*args, **kwargs)


class Mock(object):
class Mock:
def __init__(
self,
mocked_obj: object,
Expand All @@ -69,8 +69,22 @@ def __init__(
self._methods_to_unstub: dict[str, Callable | None] = {}
self._signatures_store: dict[str, signature.Signature | None] = {}

self._observers: list = []

def attach(self, observer) -> None:
if observer not in self._observers:
self._observers.append(observer)

def detach(self, observer) -> None:
try:
self._observers.remove(observer)
except ValueError:
pass

def remember(self, invocation: invocation.RealInvocation) -> None:
self.invocations.append(invocation)
for observer in self._observers:
observer.update(invocation)

def finish_stubbing(
self, stubbed_invocation: invocation.StubbedInvocation
Expand Down
16 changes: 12 additions & 4 deletions mockito/mockito.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,15 @@ def _get_mock_or_raise(obj: object) -> Mock:
raise ArgumentError("obj '%s' is not registered" % obj)
return theMock

def verify(obj, times=None, atleast=None, atmost=None, between=None,
inorder=False):
def verify(
obj,
times=None,
atleast=None,
atmost=None,
between=None,
inorder=False,
_factory=None,
):
"""Central interface to verify interactions.

`verify` uses a fluent interface::
Expand Down Expand Up @@ -145,10 +152,11 @@ def verify(obj, times=None, atleast=None, atmost=None, between=None,

theMock = _get_mock_or_raise(obj)

factory = _factory or invocation.VerifiableInvocation

class Verify(object):
def __getattr__(self, method_name):
return invocation.VerifiableInvocation(
theMock, method_name, verification_fn)
return factory(theMock, method_name, verification_fn)

return Verify()

Expand Down
Loading
Loading