diff --git a/AGENTS.md b/AGENTS.md index 19b3caa8..90ccd30b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -13,7 +13,7 @@ - Install (dev stack): `uv sync` (always use uv) - Lint & format: `uv run ruff check . --fix && uv run ruff format .` -- Tests (quiet + coverage): `export MESA_FRAMES_RUNTIME_TYPECHECKING = 1 && uv run pytest -q --cov=mesa_frames --cov-report=term-missing` +- Tests (quiet + coverage): `export MESA_FRAMES_RUNTIME_TYPECHECKING=1 && uv run pytest -q --cov=mesa_frames --cov-report=term-missing` - Pre-commit (all files): `uv run pre-commit run -a` - Docs preview: `uv run mkdocs serve` @@ -36,7 +36,7 @@ Always run tools via uv: `uv run `. ## Commit & Pull Request Guidelines - Commits: Imperative mood, concise subject, meaningful body when needed. - Example: `Fix AgentsDF.sets copy binding and tests`. + Example: `Fix AgentSetRegistry.sets copy binding and tests`. - PRs: Link issues, summarize changes, note API impacts, add/adjust tests and docs. - CI hygiene: Run `ruff`, `pytest`, and `pre-commit` locally before pushing. diff --git a/README.md b/README.md index 986b9b22..6a16baad 100644 --- a/README.md +++ b/README.md @@ -88,13 +88,13 @@ pip install -e . ### Creation of an Agent -The agent implementation differs from base mesa. Agents are only defined at the AgentSet level. You can import `AgentSetPolars`. As in mesa, you subclass and make sure to call `super().__init__(model)`. You can use the `add` method or the `+=` operator to add agents to the AgentSet. Most methods mirror the functionality of `mesa.AgentSet`. Additionally, `mesa-frames.AgentSet` implements many dunder methods such as `AgentSet[mask, attr]` to get and set items intuitively. All operations are by default inplace, but if you'd like to use functional programming, mesa-frames implements a fast copy method which aims to reduce memory usage, relying on reference-only and native copy methods. +The agent implementation differs from base mesa. Agents are only defined at the AgentSet level. You can import `AgentSet`. As in mesa, you subclass and make sure to call `super().__init__(model)`. You can use the `add` method or the `+=` operator to add agents to the AgentSet. Most methods mirror the functionality of `mesa.AgentSet`. Additionally, `mesa-frames.AgentSet` implements many dunder methods such as `AgentSet[mask, attr]` to get and set items intuitively. All operations are by default inplace, but if you'd like to use functional programming, mesa-frames implements a fast copy method which aims to reduce memory usage, relying on reference-only and native copy methods. ```python -from mesa-frames import AgentSetPolars +from mesa-frames import AgentSet -class MoneyAgentPolars(AgentSetPolars): - def __init__(self, n: int, model: ModelDF): +class MoneyAgents(AgentSet): + def __init__(self, n: int, model: Model): super().__init__(model) # Adding the agents to the agent set self += pl.DataFrame( @@ -126,20 +126,20 @@ class MoneyAgentPolars(AgentSetPolars): ### Creation of the Model -Creation of the model is fairly similar to the process in mesa. You subclass `ModelDF` and call `super().__init__()`. The `model.agents` attribute has the same interface as `mesa-frames.AgentSet`. You can use `+=` or `self.agents.add` with a `mesa-frames.AgentSet` (or a list of `AgentSet`) to add agents to the model. +Creation of the model is fairly similar to the process in mesa. You subclass `Model` and call `super().__init__()`. The `model.sets` attribute has the same interface as `mesa-frames.AgentSet`. You can use `+=` or `self.sets.add` with a `mesa-frames.AgentSet` (or a list of `AgentSet`) to add agents to the model. ```python -from mesa-frames import ModelDF +from mesa-frames import Model -class MoneyModelDF(ModelDF): +class MoneyModelDF(Model): def __init__(self, N: int, agents_cls): super().__init__() self.n_agents = N - self.agents += MoneyAgentPolars(N, self) + self.sets += MoneyAgents(N, self) def step(self): - # Executes the step method for every agentset in self.agents - self.agents.do("step") + # Executes the step method for every agentset in self.sets + self.sets.do("step") def run_model(self, n): for _ in range(n): diff --git a/ROADMAP.md b/ROADMAP.md index b42b9901..7dd953f5 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -49,7 +49,7 @@ The Sugarscape example demonstrates the need for this abstraction, as multiple a #### Progress and Next Steps -- Create utility functions in `DiscreteSpaceDF` and `AgentContainer` to move agents optimally based on specified attributes +- Create utility functions in `DiscreteSpace` and `AgentSetRegistry` to move agents optimally based on specified attributes - Provide built-in resolution strategies for common concurrency scenarios - Ensure the implementation works efficiently with the vectorized approach of mesa-frames diff --git a/docs/api/conf.py b/docs/api/conf.py index 0dcdded8..43098ec2 100644 --- a/docs/api/conf.py +++ b/docs/api/conf.py @@ -4,6 +4,7 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. import sys +from datetime import datetime from pathlib import Path sys.path.insert(0, str(Path("..").resolve())) @@ -11,7 +12,7 @@ # -- Project information ----------------------------------------------------- project = "mesa-frames" author = "Project Mesa, Adam Amer" -copyright = f"2023, {author}" +copyright = f"{datetime.now().year}, {author}" # -- General configuration --------------------------------------------------- extensions = [ diff --git a/docs/api/reference/agents/index.rst b/docs/api/reference/agents/index.rst index 5d725f02..a1c03126 100644 --- a/docs/api/reference/agents/index.rst +++ b/docs/api/reference/agents/index.rst @@ -4,14 +4,14 @@ Agents .. currentmodule:: mesa_frames -.. autoclass:: AgentSetPolars +.. autoclass:: AgentSet :members: :inherited-members: :autosummary: :autosummary-nosignatures: -.. autoclass:: AgentsDF +.. autoclass:: AgentSetRegistry :members: :inherited-members: :autosummary: - :autosummary-nosignatures: \ No newline at end of file + :autosummary-nosignatures: diff --git a/docs/api/reference/model.rst b/docs/api/reference/model.rst index 0e05d8d7..099e601b 100644 --- a/docs/api/reference/model.rst +++ b/docs/api/reference/model.rst @@ -3,7 +3,7 @@ Model .. currentmodule:: mesa_frames -.. autoclass:: ModelDF +.. autoclass:: Model :members: :inherited-members: :autosummary: diff --git a/docs/api/reference/space/index.rst b/docs/api/reference/space/index.rst index e2afa319..8741b6b6 100644 --- a/docs/api/reference/space/index.rst +++ b/docs/api/reference/space/index.rst @@ -4,7 +4,7 @@ This page provides a high-level overview of possible space objects for mesa-fram .. currentmodule:: mesa_frames -.. autoclass:: GridPolars +.. autoclass:: Grid :members: :inherited-members: :autosummary: diff --git a/docs/general/index.md b/docs/general/index.md index ea3a52d7..9859d2ee 100644 --- a/docs/general/index.md +++ b/docs/general/index.md @@ -41,11 +41,11 @@ pip install -e . Here's a quick example of how to create a model using mesa-frames: ```python -from mesa_frames import AgentSetPolars, ModelDF +from mesa_frames import AgentSet, Model import polars as pl -class MoneyAgentPolars(AgentSetPolars): - def __init__(self, n: int, model: ModelDF): +class MoneyAgents(AgentSet): + def __init__(self, n: int, model: Model): super().__init__(model) self += pl.DataFrame( {"wealth": pl.ones(n, eager=True)} @@ -57,13 +57,13 @@ class MoneyAgentPolars(AgentSetPolars): def give_money(self): # ... (implementation details) -class MoneyModelDF(ModelDF): +class MoneyModel(Model): def __init__(self, N: int): super().__init__() - self.agents += MoneyAgentPolars(N, self) + self.sets += MoneyAgents(N, self) def step(self): - self.agents.do("step") + self.sets.do("step") def run_model(self, n): for _ in range(n): diff --git a/docs/general/user-guide/0_getting-started.md b/docs/general/user-guide/0_getting-started.md index b2917576..1edc1587 100644 --- a/docs/general/user-guide/0_getting-started.md +++ b/docs/general/user-guide/0_getting-started.md @@ -35,14 +35,14 @@ Here's a comparison between mesa-frames and mesa: === "mesa-frames" ```python - class MoneyAgentPolarsConcise(AgentSetPolars): + class MoneyAgents(AgentSet): # initialization... def give_money(self): # Active agents are changed to wealthy agents self.select(self.wealth > 0) # Receiving agents are sampled (only native expressions currently supported) - other_agents = self.agents.sample( + other_agents = self.sets.sample( n=len(self.active_agents), with_replacement=True ) @@ -64,7 +64,7 @@ Here's a comparison between mesa-frames and mesa: def give_money(self): # Verify agent has some wealth if self.wealth > 0: - other_agent = self.random.choice(self.model.agents) + other_agent = self.random.choice(self.model.sets) if other_agent is not None: other_agent.wealth += 1 self.wealth -= 1 @@ -84,7 +84,7 @@ If you're familiar with mesa, this guide will help you understand the key differ === "mesa-frames" ```python - class MoneyAgentSet(AgentSetPolars): + class MoneyAgents(AgentSet): def __init__(self, n, model): super().__init__(model) self += pl.DataFrame({ @@ -92,7 +92,7 @@ If you're familiar with mesa, this guide will help you understand the key differ }) def step(self): givers = self.wealth > 0 - receivers = self.agents.sample(n=len(self.active_agents)) + receivers = self.sets.sample(n=len(self.active_agents)) self[givers, "wealth"] -= 1 new_wealth = receivers.groupby("unique_id").count() self[new_wealth["unique_id"], "wealth"] += new_wealth["count"] @@ -121,13 +121,13 @@ If you're familiar with mesa, this guide will help you understand the key differ === "mesa-frames" ```python - class MoneyModel(ModelDF): + class MoneyModel(Model): def __init__(self, N): super().__init__() - self.agents += MoneyAgentSet(N, self) + self.sets += MoneyAgents(N, self) def step(self): - self.agents.do("step") + self.sets.do("step") ``` diff --git a/docs/general/user-guide/1_classes.md b/docs/general/user-guide/1_classes.md index ac696731..f85c062d 100644 --- a/docs/general/user-guide/1_classes.md +++ b/docs/general/user-guide/1_classes.md @@ -1,18 +1,18 @@ # Classes 📚 -## AgentSetDF 👥 +## AgentSet 👥 -To create your own AgentSetDF class, you need to subclass the AgentSetPolars class and make sure to call `super().__init__(model)`. +To create your own AgentSet class, you need to subclass the AgentSet class and make sure to call `super().__init__(model)`. -Typically, the next step would be to populate the class with your agents. To do that, you need to add a DataFrame to the AgentSetDF. You can do `self += agents` or `self.add(agents)`, where `agents` is a DataFrame or something that could be passed to a DataFrame constructor, like a dictionary or lists of lists. You need to make sure your DataFrame doesn't have a 'unique_id' column because IDs are generated automatically, otherwise you will get an error raised. In the DataFrame, you should also put any attribute of the agent you are using. +Typically, the next step would be to populate the class with your agents. To do that, you need to add a DataFrame to the AgentSet. You can do `self += agents` or `self.add(agents)`, where `agents` is a DataFrame or something that could be passed to a DataFrame constructor, like a dictionary or lists of lists. You need to make sure your DataFrame doesn't have a 'unique_id' column because IDs are generated automatically, otherwise you will get an error raised. In the DataFrame, you should also put any attribute of the agent you are using. How can you choose which agents should be in the same AgentSet? The idea is that you should minimize the missing values in the DataFrame (so they should have similar/same attributes) and mostly everybody should do the same actions. Example: ```python -class MoneyAgent(AgentSetPolars): - def __init__(self, n: int, model: ModelDF): +class MoneyAgents(AgentSet): + def __init__(self, n: int, model: Model): super().__init__(model) self.initial_wealth = pl.ones(n) self += pl.DataFrame({ @@ -25,28 +25,28 @@ class MoneyAgent(AgentSetPolars): You can access the underlying DataFrame where agents are stored with `self.df`. This allows you to use DataFrame methods like `self.df.sample` or `self.df.group_by("wealth")` and more. -## ModelDF 🏗️ +## Model 🏗️ -To add your AgentSetDF to your ModelDF, you should also add it to the agents with `+=` or `add`. +To add your AgentSet to your Model, you should also add it to the sets with `+=` or `add`. -NOTE: ModelDF.agents are stored in a class which is entirely similar to AgentSetDF called AgentsDF. The API of the two are the same. If you try accessing AgentsDF.df, you will get a dictionary of `[AgentSetDF, DataFrame]`. +NOTE: Model.sets are stored in a class which is entirely similar to AgentSet called AgentSetRegistry. The API of the two are the same. If you try accessing AgentSetRegistry.df, you will get a dictionary of `[AgentSet, DataFrame]`. Example: ```python -class EcosystemModel(ModelDF): +class EcosystemModel(Model): def __init__(self, n_prey, n_predators): super().__init__() - self.agents += Preys(n_prey, self) - self.agents += Predators(n_predators, self) + self.sets += Preys(n_prey, self) + self.sets += Predators(n_predators, self) def step(self): - self.agents.do("move") - self.agents.do("hunt") + self.sets.do("move") + self.sets.do("hunt") self.prey.do("reproduce") ``` -## Space: GridDF 🌐 +## Space: Grid 🌐 mesa-frames provides efficient implementations of spatial environments: @@ -55,12 +55,12 @@ mesa-frames provides efficient implementations of spatial environments: Example: ```python -class GridWorld(ModelDF): +class GridWorld(Model): def __init__(self, width, height): super().__init__() - self.space = GridPolars(self, (width, height)) - self.agents += AgentSet(100, self) - self.space.place_to_empty(self.agents) + self.space = Grid(self, (width, height)) + self.sets += AgentSet(100, self) + self.space.place_to_empty(self.sets) ``` A continuous GeoSpace, NetworkSpace, and a collection to have multiple spaces in the models are in the works! 🚧 @@ -73,13 +73,13 @@ You configure what to collect, how to store it, and when to trigger collection. Example: ```python -class ExampleModel(ModelDF): +class ExampleModel(Model): def __init__(self): super().__init__() - self.agents = MoneyAgent(self) + self.sets = MoneyAgent(self) self.datacollector = DataCollector( model=self, - model_reporters={"total_wealth": lambda m: m.agents["wealth"].sum()}, + model_reporters={"total_wealth": lambda m: lambda m: list(m.sets.df.values())[0]["wealth"].sum()}, agent_reporters={"wealth": "wealth"}, storage="csv", storage_uri="./data", @@ -87,7 +87,7 @@ class ExampleModel(ModelDF): ) def step(self): - self.agents.step() + self.sets.step() self.datacollector.conditional_collect() self.datacollector.flush() ``` diff --git a/docs/general/user-guide/2_introductory-tutorial.ipynb b/docs/general/user-guide/2_introductory-tutorial.ipynb index 24742f80..ec1165da 100644 --- a/docs/general/user-guide/2_introductory-tutorial.ipynb +++ b/docs/general/user-guide/2_introductory-tutorial.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 1, "id": "df4d8623", "metadata": {}, "outputs": [], @@ -44,19 +44,34 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "fc0ee981", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ImportError", + "evalue": "cannot import name 'Model' from partially initialized module 'mesa_frames' (most likely due to a circular import) (/home/adam/projects/mesa-frames/mesa_frames/__init__.py)", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mImportError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Model, AgentSet, DataCollector\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mMoneyModelDF\u001b[39;00m(Model):\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, N: \u001b[38;5;28mint\u001b[39m, agents_cls):\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/projects/mesa-frames/mesa_frames/__init__.py:65\u001b[39m\n\u001b[32m 63\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconcrete\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01magentset\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AgentSet\n\u001b[32m 64\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconcrete\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01magentsetregistry\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AgentSetRegistry\n\u001b[32m---> \u001b[39m\u001b[32m65\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconcrete\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatacollector\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m DataCollector\n\u001b[32m 66\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconcrete\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmodel\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Model\n\u001b[32m 67\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconcrete\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mspace\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Grid\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/projects/mesa-frames/mesa_frames/concrete/datacollector.py:62\u001b[39m\n\u001b[32m 60\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtempfile\u001b[39;00m\n\u001b[32m 61\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpsycopg2\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m62\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mabstract\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mdatacollector\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AbstractDataCollector\n\u001b[32m 63\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Any, Literal\n\u001b[32m 64\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mcollections\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mabc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Callable\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/projects/mesa-frames/mesa_frames/abstract/datacollector.py:50\u001b[39m\n\u001b[32m 48\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Any, Literal\n\u001b[32m 49\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mcollections\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mabc\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Callable\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mmesa_frames\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Model\n\u001b[32m 51\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpolars\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpl\u001b[39;00m\n\u001b[32m 52\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mthreading\u001b[39;00m\n", + "\u001b[31mImportError\u001b[39m: cannot import name 'Model' from partially initialized module 'mesa_frames' (most likely due to a circular import) (/home/adam/projects/mesa-frames/mesa_frames/__init__.py)" + ] + } + ], "source": [ - "from mesa_frames import ModelDF, AgentSetPolars, DataCollector\n", + "from mesa_frames import Model, AgentSet, DataCollector\n", "\n", "\n", - "class MoneyModelDF(ModelDF):\n", + "class MoneyModel(Model):\n", " def __init__(self, N: int, agents_cls):\n", " super().__init__()\n", " self.n_agents = N\n", - " self.agents += agents_cls(N, self)\n", + " self.sets += agents_cls(N, self)\n", " self.datacollector = DataCollector(\n", " model=self,\n", " model_reporters={\"total_wealth\": lambda m: m.agents[\"wealth\"].sum()},\n", @@ -67,8 +82,8 @@ " )\n", "\n", " def step(self):\n", - " # Executes the step method for every agentset in self.agents\n", - " self.agents.do(\"step\")\n", + " # Executes the step method for every agentset in self.sets\n", + " self.sets.do(\"step\")\n", "\n", " def run_model(self, n):\n", " for _ in range(n):\n", @@ -84,12 +99,12 @@ "source": [ "## Implementing the AgentSet 👥\n", "\n", - "Now, let's implement our `MoneyAgentSet` using polars backends." + "Now, let's implement our `MoneyAgents` using polars backends." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "2bac0126", "metadata": {}, "outputs": [], @@ -97,8 +112,8 @@ "import polars as pl\n", "\n", "\n", - "class MoneyAgentPolars(AgentSetPolars):\n", - " def __init__(self, n: int, model: ModelDF):\n", + "class MoneyAgents(AgentSet):\n", + " def __init__(self, n: int, model: Model):\n", " super().__init__(model)\n", " self += pl.DataFrame({\"wealth\": pl.ones(n, eager=True)})\n", "\n", @@ -126,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "id": "65da4e6f", "metadata": {}, "outputs": [ @@ -154,14 +169,11 @@ } ], "source": [ - "# Choose either MoneyAgentPandas or MoneyAgentPolars\n", - "agent_class = MoneyAgentPolars\n", - "\n", "# Create and run the model\n", - "model = MoneyModelDF(1000, agent_class)\n", + "model = MoneyModel(1000, MoneyAgents)\n", "model.run_model(100)\n", "\n", - "wealth_dist = list(model.agents.df.values())[0]\n", + "wealth_dist = list(model.sets.df.values())[0]\n", "\n", "# Print the final wealth distribution\n", "print(wealth_dist.select(pl.col(\"wealth\")).describe())" @@ -182,13 +194,13 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "id": "fbdb540810924de8", "metadata": {}, "outputs": [], "source": [ - "class MoneyAgentPolarsConcise(AgentSetPolars):\n", - " def __init__(self, n: int, model: ModelDF):\n", + "class MoneyAgentsConcise(AgentSet):\n", + " def __init__(self, n: int, model: Model):\n", " super().__init__(model)\n", " ## Adding the agents to the agent set\n", " # 1. Changing the df attribute directly (not recommended, if other agents were added before, they will be lost)\n", @@ -242,8 +254,8 @@ " self[new_wealth, \"wealth\"] += new_wealth[\"len\"]\n", "\n", "\n", - "class MoneyAgentPolarsNative(AgentSetPolars):\n", - " def __init__(self, n: int, model: ModelDF):\n", + "class MoneyAgentsNative(AgentSet):\n", + " def __init__(self, n: int, model: Model):\n", " super().__init__(model)\n", " self += pl.DataFrame({\"wealth\": pl.ones(n, eager=True)})\n", "\n", @@ -286,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "id": "9dbe761af964af5b", "metadata": {}, "outputs": [], @@ -294,7 +306,7 @@ "import mesa\n", "\n", "\n", - "class MoneyAgent(mesa.Agent):\n", + "class MesaMoneyAgent(mesa.Agent):\n", " \"\"\"An agent with fixed initial wealth.\"\"\"\n", "\n", " def __init__(self, model):\n", @@ -307,20 +319,20 @@ " def step(self):\n", " # Verify agent has some wealth\n", " if self.wealth > 0:\n", - " other_agent: MoneyAgent = self.model.random.choice(self.model.agents)\n", + " other_agent: MesaMoneyAgent = self.model.random.choice(self.model.agents)\n", " if other_agent is not None:\n", " other_agent.wealth += 1\n", " self.wealth -= 1\n", "\n", "\n", - "class MoneyModel(mesa.Model):\n", + "class MesaMoneyModel(mesa.Model):\n", " \"\"\"A model with some number of agents.\"\"\"\n", "\n", " def __init__(self, N: int):\n", " super().__init__()\n", " self.num_agents = N\n", " for _ in range(N):\n", - " self.agents.add(MoneyAgent(self))\n", + " self.agents.add(MesaMoneyAgent(self))\n", "\n", " def step(self):\n", " \"\"\"Advance the model by one step.\"\"\"\n", @@ -333,7 +345,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "id": "2d864cd3", "metadata": {}, "outputs": [ @@ -367,7 +379,7 @@ "import time\n", "\n", "\n", - "def run_simulation(model: MoneyModel | MoneyModelDF, n_steps: int):\n", + "def run_simulation(model: MesaMoneyModel | MoneyModel, n_steps: int):\n", " start_time = time.time()\n", " model.run_model(n_steps)\n", " end_time = time.time()\n", @@ -386,15 +398,11 @@ " print(f\"---------------\\n{implementation}:\")\n", " for n_agents in n_agents_list:\n", " if implementation == \"mesa\":\n", - " ntime = run_simulation(MoneyModel(n_agents), n_steps)\n", + " ntime = run_simulation(MesaMoneyModel(n_agents), n_steps)\n", " elif implementation == \"mesa-frames (pl concise)\":\n", - " ntime = run_simulation(\n", - " MoneyModelDF(n_agents, MoneyAgentPolarsConcise), n_steps\n", - " )\n", + " ntime = run_simulation(MoneyModel(n_agents, MoneyAgentsConcise), n_steps)\n", " elif implementation == \"mesa-frames (pl native)\":\n", - " ntime = run_simulation(\n", - " MoneyModelDF(n_agents, MoneyAgentPolarsNative), n_steps\n", - " )\n", + " ntime = run_simulation(MoneyModel(n_agents, MoneyAgentsNative), n_steps)\n", "\n", " print(f\" Number of agents: {n_agents}, Time: {ntime:.2f} seconds\")\n", " print(\"---------------\")" @@ -417,7 +425,7 @@ ], "metadata": { "kernelspec": { - "display_name": "mesa-frames", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/docs/general/user-guide/4_datacollector.ipynb b/docs/general/user-guide/4_datacollector.ipynb index 247dbf70..3fa16b49 100644 --- a/docs/general/user-guide/4_datacollector.ipynb +++ b/docs/general/user-guide/4_datacollector.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "id": "9a63283cbaf04dbcab1f6479b197f3a8", "metadata": { "editable": true @@ -43,7 +43,7 @@ "source": [ "## Minimal Example Model\n", "\n", - "We create a tiny model using the `ModelDF` and an `AgentSetPolars`-style agent container. This is just to demonstrate collection APIs.\n" + "We create a tiny model using the `Model` and an `AgentSet`-style agent container. This is just to demonstrate collection APIs.\n" ] }, { @@ -53,14 +53,54 @@ "metadata": { "editable": true }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'model': shape: (5, 5)\n", + " ┌──────┬─────────────────────────────────┬───────┬──────────────┬──────────┐\n", + " │ step ┆ seed ┆ batch ┆ total_wealth ┆ n_agents │\n", + " │ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + " │ i64 ┆ str ┆ i64 ┆ f64 ┆ i64 │\n", + " ╞══════╪═════════════════════════════════╪═══════╪══════════════╪══════════╡\n", + " │ 2 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 4 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 6 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 8 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " │ 10 ┆ 162681765859364298619846106603… ┆ 0 ┆ 1000.0 ┆ 1000 │\n", + " └──────┴─────────────────────────────────┴───────┴──────────────┴──────────┘,\n", + " 'agent': shape: (5_000, 4)\n", + " ┌────────────────────┬──────┬─────────────────────────────────┬───────┐\n", + " │ wealth_MoneyAgents ┆ step ┆ seed ┆ batch │\n", + " │ --- ┆ --- ┆ --- ┆ --- │\n", + " │ f64 ┆ i32 ┆ str ┆ i32 │\n", + " ╞════════════════════╪══════╪═════════════════════════════════╪═══════╡\n", + " │ 0.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 3.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 1.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 3.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 6.0 ┆ 2 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ … ┆ … ┆ … ┆ … │\n", + " │ 4.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 1.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " │ 0.0 ┆ 10 ┆ 162681765859364298619846106603… ┆ 0 │\n", + " └────────────────────┴──────┴─────────────────────────────────┴───────┘}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "from mesa_frames import ModelDF, AgentSetPolars, DataCollector\n", + "from mesa_frames import Model, AgentSet, DataCollector\n", "import polars as pl\n", "\n", "\n", - "class MoneyAgents(AgentSetPolars):\n", - " def __init__(self, n: int, model: ModelDF):\n", + "class MoneyAgents(AgentSet):\n", + " def __init__(self, n: int, model: Model):\n", " super().__init__(model)\n", " # one column, one unit of wealth each\n", " self += pl.DataFrame({\"wealth\": pl.ones(n, eager=True)})\n", @@ -73,28 +113,28 @@ " self[income[\"unique_id\"], \"wealth\"] += income[\"len\"]\n", "\n", "\n", - "class MoneyModel(ModelDF):\n", + "class MoneyModel(Model):\n", " def __init__(self, n: int):\n", " super().__init__()\n", - " self.agents = MoneyAgents(n, self)\n", + " self.sets.add(MoneyAgents(n, self))\n", " self.dc = DataCollector(\n", " model=self,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: m.agents[\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(m.agents),\n", + " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\", # pull existing column\n", " },\n", " storage=\"memory\", # we'll switch this per example\n", " storage_uri=None,\n", - " trigger=lambda m: m._steps % 2\n", + " trigger=lambda m: m.steps % 2\n", " == 0, # collect every 2 steps via conditional_collect\n", " reset_memory=True,\n", " )\n", "\n", " def step(self):\n", - " self.agents.do(\"step\")\n", + " self.sets.do(\"step\")\n", "\n", " def run(self, steps: int, conditional: bool = True):\n", " for _ in range(steps):\n", @@ -135,10 +175,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "id": "5f14f38c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import os\n", "\n", @@ -147,8 +198,8 @@ "model_csv.dc = DataCollector(\n", " model=model_csv,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: m.agents[\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(m.agents),\n", + " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\",\n", @@ -175,20 +226,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "8763a12b2bbd4a93a75aff182afb95dc", "metadata": { "editable": true }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "os.makedirs(\"./data_parquet\", exist_ok=True)\n", "model_parq = MoneyModel(1000)\n", "model_parq.dc = DataCollector(\n", " model=model_parq,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: m.agents[\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(m.agents),\n", + " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\",\n", @@ -217,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "7cdc8c89c7104fffa095e18ddfef8986", "metadata": { "editable": true @@ -228,8 +290,8 @@ "model_s3.dc = DataCollector(\n", " model=model_s3,\n", " model_reporters={\n", - " \"total_wealth\": lambda m: m.agents[\"wealth\"].sum(),\n", - " \"n_agents\": lambda m: len(m.agents),\n", + " \"total_wealth\": lambda m: list(m.sets.df.values())[0][\"wealth\"].sum(),\n", + " \"n_agents\": lambda m: len(list(m.sets.df.values())[0]),\n", " },\n", " agent_reporters={\n", " \"wealth\": \"wealth\",\n", @@ -257,12 +319,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "id": "938c804e27f84196a10c8828c723f798", "metadata": { "editable": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "CREATE SCHEMA IF NOT EXISTS public;\n", + "CREATE TABLE IF NOT EXISTS public.model_data (\n", + " step INTEGER,\n", + " seed VARCHAR,\n", + " total_wealth BIGINT,\n", + " n_agents INTEGER\n", + ");\n", + "\n", + "\n", + "CREATE TABLE IF NOT EXISTS public.agent_data (\n", + " step INTEGER,\n", + " seed VARCHAR,\n", + " unique_id BIGINT,\n", + " wealth BIGINT\n", + ");\n", + "\n" + ] + } + ], "source": [ "DDL_MODEL = r\"\"\"\n", "CREATE SCHEMA IF NOT EXISTS public;\n", @@ -295,7 +381,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "id": "59bbdb311c014d738909a11f9e486628", "metadata": { "editable": true @@ -324,12 +410,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "id": "8a65eabff63a45729fe45fb5ade58bdc", "metadata": { "editable": true }, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 5)
stepseedbatchtotal_wealthn_agents
i64stri64f64i64
2"732054881101029867447298951813…0100.0100
4"732054881101029867447298951813…0100.0100
6"732054881101029867447298951813…0100.0100
8"732054881101029867447298951813…0100.0100
10"732054881101029867447298951813…0100.0100
" + ], + "text/plain": [ + "shape: (5, 5)\n", + "┌──────┬─────────────────────────────────┬───────┬──────────────┬──────────┐\n", + "│ step ┆ seed ┆ batch ┆ total_wealth ┆ n_agents │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ str ┆ i64 ┆ f64 ┆ i64 │\n", + "╞══════╪═════════════════════════════════╪═══════╪══════════════╪══════════╡\n", + "│ 2 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 4 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 6 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 8 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "│ 10 ┆ 732054881101029867447298951813… ┆ 0 ┆ 100.0 ┆ 100 │\n", + "└──────┴─────────────────────────────────┴───────┴──────────────┴──────────┘" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "m = MoneyModel(100)\n", "m.dc.trigger = lambda model: model._steps % 3 == 0 # every 3rd step\n", @@ -361,13 +479,21 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "mesa-frames (3.12.3)", "language": "python", "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.x" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" } }, "nbformat": 4, diff --git a/examples/boltzmann_wealth/performance_plot.py b/examples/boltzmann_wealth/performance_plot.py index 625c6c56..e565bda3 100644 --- a/examples/boltzmann_wealth/performance_plot.py +++ b/examples/boltzmann_wealth/performance_plot.py @@ -8,16 +8,16 @@ import seaborn as sns from packaging import version -from mesa_frames import AgentSetPolars, ModelDF +from mesa_frames import AgentSet, Model ### ---------- Mesa implementation ---------- ### def mesa_implementation(n_agents: int) -> None: - model = MoneyModel(n_agents) + model = MesaMoneyModel(n_agents) model.run_model(100) -class MoneyAgent(mesa.Agent): +class MesaMoneyAgent(mesa.Agent): """An agent with fixed initial wealth.""" def __init__(self, model): @@ -36,14 +36,14 @@ def step(self): self.wealth -= 1 -class MoneyModel(mesa.Model): +class MesaMoneyModel(mesa.Model): """A model with some number of agents.""" def __init__(self, N): super().__init__() self.num_agents = N for _ in range(self.num_agents): - self.agents.add(MoneyAgent(self)) + self.agents.add(MesaMoneyAgent(self)) def step(self): """Advance the model by one step.""" @@ -55,7 +55,7 @@ def run_model(self, n_steps) -> None: """def compute_gini(model): - agent_wealths = model.agents.get("wealth") + agent_wealths = model.sets.get("wealth") x = sorted(agent_wealths) N = model.num_agents B = sum(xi * (N - i) for i, xi in enumerate(x)) / (N * sum(x)) @@ -65,12 +65,12 @@ def run_model(self, n_steps) -> None: ### ---------- Mesa-frames implementation ---------- ### -class MoneyAgentPolarsConcise(AgentSetPolars): - def __init__(self, n: int, model: ModelDF): +class MoneyAgentsConcise(AgentSet): + def __init__(self, n: int, model: Model): super().__init__(model) ## Adding the agents to the agent set # 1. Changing the agents attribute directly (not recommended, if other agents were added before, they will be lost) - """self.agents = pl.DataFrame( + """self.sets = pl.DataFrame( "wealth": pl.ones(n, eager=True)} )""" # 2. Adding the dataframe with add @@ -120,8 +120,8 @@ def give_money(self): self[new_wealth, "wealth"] += new_wealth["len"] -class MoneyAgentPolarsNative(AgentSetPolars): - def __init__(self, n: int, model: ModelDF): +class MoneyAgentsNative(AgentSet): + def __init__(self, n: int, model: Model): super().__init__(model) self += pl.DataFrame({"wealth": pl.ones(n, eager=True)}) @@ -154,15 +154,15 @@ def give_money(self): ) -class MoneyModelDF(ModelDF): +class MoneyModel(Model): def __init__(self, N: int, agents_cls): super().__init__() self.n_agents = N - self.agents += agents_cls(N, self) + self.sets += agents_cls(N, self) def step(self): - # Executes the step method for every agentset in self.agents - self.agents.do("step") + # Executes the step method for every agentset in self.sets + self.sets.do("step") def run_model(self, n): for _ in range(n): @@ -170,12 +170,12 @@ def run_model(self, n): def mesa_frames_polars_concise(n_agents: int) -> None: - model = MoneyModelDF(n_agents, MoneyAgentPolarsConcise) + model = MoneyModel(n_agents, MoneyAgentsConcise) model.run_model(100) def mesa_frames_polars_native(n_agents: int) -> None: - model = MoneyModelDF(n_agents, MoneyAgentPolarsNative) + model = MoneyModel(n_agents, MoneyAgentsNative) model.run_model(100) diff --git a/examples/sugarscape_ig/ss_polars/agents.py b/examples/sugarscape_ig/ss_polars/agents.py index 2d921761..b0ecbe90 100644 --- a/examples/sugarscape_ig/ss_polars/agents.py +++ b/examples/sugarscape_ig/ss_polars/agents.py @@ -4,13 +4,13 @@ import polars as pl from numba import b1, guvectorize, int32 -from mesa_frames import AgentSetPolars, ModelDF +from mesa_frames import AgentSet, Model -class AntPolarsBase(AgentSetPolars): +class AntDFBase(AgentSet): def __init__( self, - model: ModelDF, + model: Model, n_agents: int, initial_sugar: np.ndarray | None = None, metabolism: np.ndarray | None = None, @@ -169,7 +169,7 @@ def get_best_moves(self, neighborhood: pl.DataFrame) -> pl.DataFrame: raise NotImplementedError("Subclasses must implement this method") -class AntPolarsLoopDF(AntPolarsBase): +class AntPolarsLoopDF(AntDFBase): def get_best_moves(self, neighborhood: pl.DataFrame): best_moves = pl.DataFrame() @@ -224,7 +224,7 @@ def get_best_moves(self, neighborhood: pl.DataFrame): return best_moves.sort("agent_order").select(["dim_0", "dim_1"]) -class AntPolarsLoop(AntPolarsBase): +class AntPolarsLoop(AntDFBase): numba_target = None def get_best_moves(self, neighborhood: pl.DataFrame): diff --git a/examples/sugarscape_ig/ss_polars/model.py b/examples/sugarscape_ig/ss_polars/model.py index be9768c1..56a3a83b 100644 --- a/examples/sugarscape_ig/ss_polars/model.py +++ b/examples/sugarscape_ig/ss_polars/model.py @@ -1,15 +1,15 @@ import numpy as np import polars as pl -from mesa_frames import GridPolars, ModelDF +from mesa_frames import Grid, Model -from .agents import AntPolarsBase +from .agents import AntDFBase -class SugarscapePolars(ModelDF): +class SugarscapePolars(Model): def __init__( self, - agent_type: type[AntPolarsBase], + agent_type: type[AntDFBase], n_agents: int, sugar_grid: np.ndarray | None = None, initial_sugar: np.ndarray | None = None, @@ -24,7 +24,7 @@ def __init__( if sugar_grid is None: sugar_grid = self.random.integers(0, 4, (width, height)) grid_dimensions = sugar_grid.shape - self.space = GridPolars( + self.space = Grid( self, grid_dimensions, neighborhood_type="von_neumann", capacity=1 ) dim_0 = pl.Series("dim_0", pl.arange(grid_dimensions[0], eager=True)).to_frame() @@ -33,15 +33,15 @@ def __init__( sugar=sugar_grid.flatten(), max_sugar=sugar_grid.flatten() ) self.space.set_cells(sugar_grid) - self.agents += agent_type(self, n_agents, initial_sugar, metabolism, vision) + self.sets += agent_type(self, n_agents, initial_sugar, metabolism, vision) if initial_positions is not None: - self.space.place_agents(self.agents, initial_positions) + self.space.place_agents(self.sets, initial_positions) else: - self.space.place_to_empty(self.agents) + self.space.place_to_empty(self.sets) def run_model(self, steps: int) -> list[int]: for _ in range(steps): - if len(self.agents) == 0: + if len(list(self.sets.df.values())[0]) == 0: return empty_cells = self.space.empty_cells full_cells = self.space.full_cells diff --git a/mesa_frames/__init__.py b/mesa_frames/__init__.py index d47087d1..20fcbeef 100644 --- a/mesa_frames/__init__.py +++ b/mesa_frames/__init__.py @@ -11,25 +11,25 @@ - Provides similar syntax to Mesa for ease of transition - Allows for vectorized functions when simultaneous activation of agents is possible - Implements SIMD processing for optimized simultaneous operations -- Includes GridDF for efficient grid-based spatial modeling +- Includes Grid for efficient grid-based spatial modeling Main Components: -- AgentSetPolars: Agent set implementation using Polars backend -- ModelDF: Base model class for mesa-frames -- GridDF: Grid space implementation for spatial modeling +- AgentSet: Agent set implementation using Polars backend +- Model: Base model class for mesa-frames +- Grid: Grid space implementation for spatial modeling Usage: To use mesa-frames, import the necessary components and subclass them as needed: - from mesa_frames import AgentSetPolars, ModelDF, GridDF + from mesa_frames import AgentSet, Model, Grid - class MyAgent(AgentSetPolars): + class MyAgent(AgentSet): # Define your agent logic here - class MyModel(ModelDF): + class MyModel(Model): def __init__(self, width, height): super().__init__() - self.grid = GridDF(width, height, self) + self.grid = Grid(self, [width, height]) # Define your model logic here Note: mesa-frames is in early development. API and usage patterns may change. @@ -60,12 +60,14 @@ def __init__(self, width, height): stacklevel=2, ) -from mesa_frames.concrete.agents import AgentsDF -from mesa_frames.concrete.agentset import AgentSetPolars -from mesa_frames.concrete.model import ModelDF -from mesa_frames.concrete.space import GridPolars +from mesa_frames.concrete.agentset import AgentSet +from mesa_frames.concrete.agentsetregistry import AgentSetRegistry +from mesa_frames.concrete.model import Model + +# DataCollector has to be imported after Model or a circular import error will occur from mesa_frames.concrete.datacollector import DataCollector +from mesa_frames.concrete.space import Grid -__all__ = ["AgentsDF", "AgentSetPolars", "ModelDF", "GridPolars", "DataCollector"] +__all__ = ["AgentSetRegistry", "AgentSet", "Model", "Grid", "DataCollector"] __version__ = "0.1.1.dev0" diff --git a/mesa_frames/abstract/__init__.py b/mesa_frames/abstract/__init__.py index b61914db..bfa358d0 100644 --- a/mesa_frames/abstract/__init__.py +++ b/mesa_frames/abstract/__init__.py @@ -6,17 +6,17 @@ Classes: agents.py: - - AgentContainer: Abstract base class for agent containers. - - AgentSetDF: Abstract base class for agent sets using DataFrames. + - AbstractAgentSetRegistry: Abstract base class for agent containers. + - AbstractAgentSet: Abstract base class for agent sets using DataFrames. mixin.py: - CopyMixin: Mixin class providing fast copy functionality. - DataFrameMixin: Mixin class defining the interface for DataFrame operations. space.py: - - SpaceDF: Abstract base class for all space classes. - - DiscreteSpaceDF: Abstract base class for discrete space classes (Grids and Networks). - - GridDF: Abstract base class for grid classes. + - AbstractSpace: Abstract base class for all space classes. + - AbstractDiscreteSpace: Abstract base class for discrete space classes (Grids and Networks). + - AbstractGrid: Abstract base class for grid classes. These abstract classes and mixins provide the foundation for the concrete implementations in mesa-frames, ensuring consistent interfaces and shared @@ -28,9 +28,9 @@ For example: - from mesa_frames.abstract import AgentSetDF, DataFrameMixin + from mesa_frames.abstract import AbstractAgentSet, DataFrameMixin - class ConcreteAgentSet(AgentSetDF): + class ConcreteAgentSet(AbstractAgentSet): # Implement abstract methods here ... diff --git a/mesa_frames/abstract/agents.py b/mesa_frames/abstract/agents.py deleted file mode 100644 index f4243558..00000000 --- a/mesa_frames/abstract/agents.py +++ /dev/null @@ -1,1139 +0,0 @@ -""" -Abstract base classes for agent containers in mesa-frames. - -This module defines the core abstractions for agent containers in the mesa-frames -extension. It provides the foundation for implementing agent storage and -manipulation using DataFrame-based approaches. - -Classes: - AgentContainer(CopyMixin): - An abstract base class that defines the common interface for all agent - containers in mesa-frames. It inherits from CopyMixin to provide fast - copying functionality. - - AgentSetDF(AgentContainer, DataFrameMixin): - An abstract base class for agent sets that use DataFrames as the underlying - storage mechanism. It inherits from both AgentContainer and DataFrameMixin - to combine agent container functionality with DataFrame operations. - -These abstract classes are designed to be subclassed by concrete implementations -that use Polars library as their backend. - -Usage: - These classes should not be instantiated directly. Instead, they should be - subclassed to create concrete implementations: - - from mesa_frames.abstract.agents import AgentSetDF - - class AgentSetPolars(AgentSetDF): - def __init__(self, model): - super().__init__(model) - # Implementation using polars DataFrame - ... - - # Implement other abstract methods - -Note: - The abstract methods in these classes use Python's @abstractmethod decorator, - ensuring that concrete subclasses must implement these methods. - -Attributes and methods of each class are documented in their respective docstrings. -""" - -from __future__ import annotations # PEP 563: postponed evaluation of type annotations - -from abc import abstractmethod -from collections.abc import Callable, Collection, Iterable, Iterator, Sequence -from contextlib import suppress -from typing import Any, Literal, Self, overload - -from numpy.random import Generator - -from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin -from mesa_frames.types_ import ( - AgentMask, - BoolSeries, - DataFrame, - DataFrameInput, - IdsLike, - Index, - Series, -) - - -class AgentContainer(CopyMixin): - """An abstract class for containing agents. Defines the common interface for AgentSetDF and AgentsDF.""" - - _copy_only_reference: list[str] = [ - "_model", - ] - _model: mesa_frames.concrete.model.ModelDF - - @abstractmethod - def __init__(self) -> None: ... - - def discard( - self, - agents: IdsLike - | AgentMask - | mesa_frames.concrete.agents.AgentSetDF - | Collection[mesa_frames.concrete.agents.AgentSetDF], - inplace: bool = True, - ) -> Self: - """Remove agents from the AgentContainer. Does not raise an error if the agent is not found. - - Parameters - ---------- - agents : IdsLike | AgentMask | mesa_frames.concrete.agents.AgentSetDF | Collection[mesa_frames.concrete.agents.AgentSetDF] - The agents to remove - inplace : bool - Whether to remove the agent in place. Defaults to True. - - Returns - ------- - Self - The updated AgentContainer. - """ - with suppress(KeyError, ValueError): - return self.remove(agents, inplace=inplace) - return self._get_obj(inplace) - - @abstractmethod - def add( - self, - agents: DataFrame - | DataFrameInput - | mesa_frames.concrete.agents.AgentSetDF - | Collection[mesa_frames.concrete.agents.AgentSetDF], - inplace: bool = True, - ) -> Self: - """Add agents to the AgentContainer. - - Parameters - ---------- - agents : DataFrame | DataFrameInput | mesa_frames.concrete.agents.AgentSetDF | Collection[mesa_frames.concrete.agents.AgentSetDF] - The agents to add. - inplace : bool - Whether to add the agents in place. Defaults to True. - - Returns - ------- - Self - The updated AgentContainer. - """ - ... - - @overload - @abstractmethod - def contains(self, agents: int) -> bool: ... - - @overload - @abstractmethod - def contains( - self, agents: mesa_frames.concrete.agents.AgentSetDF | IdsLike - ) -> BoolSeries: ... - - @abstractmethod - def contains( - self, agents: mesa_frames.concrete.agents.AgentSetDF | IdsLike - ) -> bool | BoolSeries: - """Check if agents with the specified IDs are in the AgentContainer. - - Parameters - ---------- - agents : mesa_frames.concrete.agents.AgentSetDF | IdsLike - The ID(s) to check for. - - Returns - ------- - bool | BoolSeries - True if the agent is in the AgentContainer, False otherwise. - """ - - @overload - @abstractmethod - def do( - self, - method_name: str, - *args: Any, - mask: AgentMask | None = None, - return_results: Literal[False] = False, - inplace: bool = True, - **kwargs: Any, - ) -> Self: ... - - @overload - @abstractmethod - def do( - self, - method_name: str, - *args: Any, - mask: AgentMask | None = None, - return_results: Literal[True], - inplace: bool = True, - **kwargs: Any, - ) -> Any | dict[mesa_frames.concrete.agents.AgentSetDF, Any]: ... - - @abstractmethod - def do( - self, - method_name: str, - *args: Any, - mask: AgentMask | None = None, - return_results: bool = False, - inplace: bool = True, - **kwargs: Any, - ) -> Self | Any | dict[mesa_frames.concrete.agents.AgentSetDF, Any]: - """Invoke a method on the AgentContainer. - - Parameters - ---------- - method_name : str - The name of the method to invoke. - *args : Any - Positional arguments to pass to the method - mask : AgentMask | None, optional - The subset of agents on which to apply the method - return_results : bool, optional - Whether to return the result of the method, by default False - inplace : bool, optional - Whether the operation should be done inplace, by default False - **kwargs : Any - Keyword arguments to pass to the method - - Returns - ------- - Self | Any | dict[mesa_frames.concrete.agents.AgentSetDF, Any] - The updated AgentContainer or the result of the method. - """ - ... - - @abstractmethod - @overload - def get(self, attr_names: str) -> Series | dict[str, Series]: ... - - @abstractmethod - @overload - def get( - self, attr_names: Collection[str] | None = None - ) -> DataFrame | dict[str, DataFrame]: ... - - @abstractmethod - def get( - self, - attr_names: str | Collection[str] | None = None, - mask: AgentMask | None = None, - ) -> Series | dict[str, Series] | DataFrame | dict[str, DataFrame]: - """Retrieve the value of a specified attribute for each agent in the AgentContainer. - - Parameters - ---------- - attr_names : str | Collection[str] | None, optional - The attributes to retrieve. If None, all attributes are retrieved. Defaults to None. - mask : AgentMask | None, optional - The AgentMask of agents to retrieve the attribute for. If None, attributes of all agents are returned. Defaults to None. - - Returns - ------- - Series | dict[str, Series] | DataFrame | dict[str, DataFrame] - The attribute values. - """ - ... - - @abstractmethod - def remove( - self, - agents: ( - IdsLike - | AgentMask - | mesa_frames.concrete.agents.AgentSetDF - | Collection[mesa_frames.concrete.agents.AgentSetDF] - ), - inplace: bool = True, - ) -> Self: - """Remove the agents from the AgentContainer. - - Parameters - ---------- - agents : IdsLike | AgentMask | mesa_frames.concrete.agents.AgentSetDF | Collection[mesa_frames.concrete.agents.AgentSetDF] - The agents to remove. - inplace : bool, optional - Whether to remove the agent in place. - - Returns - ------- - Self - The updated AgentContainer. - """ - ... - - @abstractmethod - def select( - self, - mask: AgentMask | None = None, - filter_func: Callable[[Self], AgentMask] | None = None, - n: int | None = None, - negate: bool = False, - inplace: bool = True, - ) -> Self: - """Select agents in the AgentContainer based on the given criteria. - - Parameters - ---------- - mask : AgentMask | None, optional - The AgentMask of agents to be selected, by default None - filter_func : Callable[[Self], AgentMask] | None, optional - A function which takes as input the AgentContainer and returns a AgentMask, by default None - n : int | None, optional - The maximum number of agents to be selected, by default None - negate : bool, optional - If the selection should be negated, by default False - inplace : bool, optional - If the operation should be performed on the same object, by default True - - Returns - ------- - Self - A new or updated AgentContainer. - """ - ... - - @abstractmethod - @overload - def set( - self, - attr_names: dict[str, Any], - values: None, - mask: AgentMask | None = None, - inplace: bool = True, - ) -> Self: ... - - @abstractmethod - @overload - def set( - self, - attr_names: str | Collection[str], - values: Any, - mask: AgentMask | None = None, - inplace: bool = True, - ) -> Self: ... - - @abstractmethod - def set( - self, - attr_names: DataFrameInput | str | Collection[str], - values: Any | None = None, - mask: AgentMask | None = None, - inplace: bool = True, - ) -> Self: - """Set the value of a specified attribute or attributes for each agent in the mask in AgentContainer. - - Parameters - ---------- - attr_names : DataFrameInput | str | Collection[str] - The key can be: - - A string: sets the specified column of the agents in the AgentContainer. - - A collection of strings: sets the specified columns of the agents in the AgentContainer. - - A dictionary: keys should be attributes and values should be the values to set. Value should be None. - values : Any | None - The value to set the attribute to. If None, attr_names must be a dictionary. - mask : AgentMask | None - The AgentMask of agents to set the attribute for. - inplace : bool - Whether to set the attribute in place. - - Returns - ------- - Self - The updated agent set. - """ - ... - - @abstractmethod - def shuffle(self, inplace: bool = False) -> Self: - """Shuffles the order of agents in the AgentContainer. - - Parameters - ---------- - inplace : bool - Whether to shuffle the agents in place. - - Returns - ------- - Self - A new or updated AgentContainer. - """ - - @abstractmethod - def sort( - self, - by: str | Sequence[str], - ascending: bool | Sequence[bool] = True, - inplace: bool = True, - **kwargs, - ) -> Self: - """ - Sorts the agents in the agent set based on the given criteria. - - Parameters - ---------- - by : str | Sequence[str] - The attribute(s) to sort by. - ascending : bool | Sequence[bool] - Whether to sort in ascending order. - inplace : bool - Whether to sort the agents in place. - **kwargs - Keyword arguments to pass to the sort - - Returns - ------- - Self - A new or updated AgentContainer. - """ - - def __add__( - self, - other: DataFrame - | DataFrameInput - | mesa_frames.concrete.agents.AgentSetDF - | Collection[mesa_frames.concrete.agents.AgentSetDF], - ) -> Self: - """Add agents to a new AgentContainer through the + operator. - - Parameters - ---------- - other : DataFrame | DataFrameInput | mesa_frames.concrete.agents.AgentSetDF | Collection[mesa_frames.concrete.agents.AgentSetDF] - The agents to add. - - Returns - ------- - Self - A new AgentContainer with the added agents. - """ - return self.add(agents=other, inplace=False) - - def __contains__(self, agents: int | AgentSetDF) -> bool: - """Check if an agent is in the AgentContainer. - - Parameters - ---------- - agents : int | AgentSetDF - The ID(s) or AgentSetDF to check for. - - Returns - ------- - bool - True if the agent is in the AgentContainer, False otherwise. - """ - return self.contains(agents=agents) - - @overload - def __getitem__( - self, key: str | tuple[AgentMask, str] - ) -> Series | dict[AgentSetDF, Series]: ... - - @overload - def __getitem__( - self, - key: AgentMask | Collection[str] | tuple[AgentMask, Collection[str]], - ) -> DataFrame | dict[AgentSetDF, DataFrame]: ... - - def __getitem__( - self, - key: ( - str - | Collection[str] - | AgentMask - | tuple[AgentMask, str] - | tuple[AgentMask, Collection[str]] - | tuple[dict[AgentSetDF, AgentMask], str] - | tuple[dict[AgentSetDF, AgentMask], Collection[str]] - ), - ) -> Series | DataFrame | dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]: - """Implement the [] operator for the AgentContainer. - - The key can be: - - An attribute or collection of attributes (eg. AgentContainer["str"], AgentContainer[["str1", "str2"]]): returns the specified column(s) of the agents in the AgentContainer. - - An AgentMask (eg. AgentContainer[AgentMask]): returns the agents in the AgentContainer that satisfy the AgentMask. - - A tuple (eg. AgentContainer[AgentMask, "str"]): returns the specified column of the agents in the AgentContainer that satisfy the AgentMask. - - A tuple with a dictionary (eg. AgentContainer[{AgentSetDF: AgentMask}, "str"]): returns the specified column of the agents in the AgentContainer that satisfy the AgentMask from the dictionary. - - A tuple with a dictionary (eg. AgentContainer[{AgentSetDF: AgentMask}, Collection[str]]): returns the specified columns of the agents in the AgentContainer that satisfy the AgentMask from the dictionary. - - Parameters - ---------- - key : str | Collection[str] | AgentMask | tuple[AgentMask, str] | tuple[AgentMask, Collection[str]] | tuple[dict[AgentSetDF, AgentMask], str] | tuple[dict[AgentSetDF, AgentMask], Collection[str]] - The key to retrieve. - - Returns - ------- - Series | DataFrame | dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame] - The attribute values. - """ - # TODO: fix types - if isinstance(key, tuple): - return self.get(mask=key[0], attr_names=key[1]) - else: - if isinstance(key, str) or ( - isinstance(key, Collection) and all(isinstance(k, str) for k in key) - ): - return self.get(attr_names=key) - else: - return self.get(mask=key) - - def __iadd__( - self, - other: ( - DataFrame - | DataFrameInput - | mesa_frames.concrete.agents.AgentSetDF - | Collection[mesa_frames.concrete.agents.AgentSetDF] - ), - ) -> Self: - """Add agents to the AgentContainer through the += operator. - - Parameters - ---------- - other : DataFrame | DataFrameInput | mesa_frames.concrete.agents.AgentSetDF | Collection[mesa_frames.concrete.agents.AgentSetDF] - The agents to add. - - Returns - ------- - Self - The updated AgentContainer. - """ - return self.add(agents=other, inplace=True) - - def __isub__( - self, - other: ( - IdsLike - | AgentMask - | mesa_frames.concrete.agents.AgentSetDF - | Collection[mesa_frames.concrete.agents.AgentSetDF] - ), - ) -> Self: - """Remove agents from the AgentContainer through the -= operator. - - Parameters - ---------- - other : IdsLike | AgentMask | mesa_frames.concrete.agents.AgentSetDF | Collection[mesa_frames.concrete.agents.AgentSetDF] - The agents to remove. - - Returns - ------- - Self - The updated AgentContainer. - """ - return self.discard(other, inplace=True) - - def __sub__( - self, - other: ( - IdsLike - | AgentMask - | mesa_frames.concrete.agents.AgentSetDF - | Collection[mesa_frames.concrete.agents.AgentSetDF] - ), - ) -> Self: - """Remove agents from a new AgentContainer through the - operator. - - Parameters - ---------- - other : IdsLike | AgentMask | mesa_frames.concrete.agents.AgentSetDF | Collection[mesa_frames.concrete.agents.AgentSetDF] - The agents to remove. - - Returns - ------- - Self - A new AgentContainer with the removed agents. - """ - return self.discard(other, inplace=False) - - def __setitem__( - self, - key: ( - str - | Collection[str] - | AgentMask - | tuple[AgentMask, str | Collection[str]] - | tuple[dict[AgentSetDF, AgentMask], str] - | tuple[dict[AgentSetDF, AgentMask], Collection[str]] - ), - values: Any, - ) -> None: - """Implement the [] operator for setting values in the AgentContainer. - - The key can be: - - A string (eg. AgentContainer["str"]): sets the specified column of the agents in the AgentContainer. - - A list of strings(eg. AgentContainer[["str1", "str2"]]): sets the specified columns of the agents in the AgentContainer. - - A tuple (eg. AgentContainer[AgentMask, "str"]): sets the specified column of the agents in the AgentContainer that satisfy the AgentMask. - - A AgentMask (eg. AgentContainer[AgentMask]): sets the attributes of the agents in the AgentContainer that satisfy the AgentMask. - - A tuple with a dictionary (eg. AgentContainer[{AgentSetDF: AgentMask}, "str"]): sets the specified column of the agents in the AgentContainer that satisfy the AgentMask from the dictionary. - - A tuple with a dictionary (eg. AgentContainer[{AgentSetDF: AgentMask}, Collection[str]]): sets the specified columns of the agents in the AgentContainer that satisfy the AgentMask from the dictionary. - - Parameters - ---------- - key : str | Collection[str] | AgentMask | tuple[AgentMask, str | Collection[str]] | tuple[dict[AgentSetDF, AgentMask], str] | tuple[dict[AgentSetDF, AgentMask], Collection[str]] - The key to set. - values : Any - The values to set for the specified key. - """ - # TODO: fix types as in __getitem__ - if isinstance(key, tuple): - self.set(mask=key[0], attr_names=key[1], values=values) - else: - if isinstance(key, str) or ( - isinstance(key, Collection) and all(isinstance(k, str) for k in key) - ): - try: - self.set(attr_names=key, values=values) - except KeyError: # key=AgentMask - self.set(attr_names=None, mask=key, values=values) - else: - self.set(attr_names=None, mask=key, values=values) - - @abstractmethod - def __getattr__(self, name: str) -> Any | dict[str, Any]: - """Fallback for retrieving attributes of the AgentContainer. Retrieve an attribute of the underlying DataFrame(s). - - Parameters - ---------- - name : str - The name of the attribute to retrieve. - - Returns - ------- - Any | dict[str, Any] - The attribute value - """ - - @abstractmethod - def __iter__(self) -> Iterator[dict[str, Any]]: - """Iterate over the agents in the AgentContainer. - - Returns - ------- - Iterator[dict[str, Any]] - An iterator over the agents. - """ - ... - - @abstractmethod - def __len__(self) -> int: - """Get the number of agents in the AgentContainer. - - Returns - ------- - int - The number of agents in the AgentContainer. - """ - ... - - @abstractmethod - def __repr__(self) -> str: - """Get a string representation of the DataFrame in the AgentContainer. - - Returns - ------- - str - A string representation of the DataFrame in the AgentContainer. - """ - pass - - @abstractmethod - def __reversed__(self) -> Iterator: - """Iterate over the agents in the AgentContainer in reverse order. - - Returns - ------- - Iterator - An iterator over the agents in reverse order. - """ - ... - - @abstractmethod - def __str__(self) -> str: - """Get a string representation of the agents in the AgentContainer. - - Returns - ------- - str - A string representation of the agents in the AgentContainer. - """ - ... - - @property - def model(self) -> mesa_frames.concrete.model.ModelDF: - """The model that the AgentContainer belongs to. - - Returns - ------- - mesa_frames.concrete.model.ModelDF - """ - return self._model - - @property - def random(self) -> Generator: - """The random number generator of the model. - - Returns - ------- - Generator - """ - return self.model.random - - @property - def space(self) -> mesa_frames.abstract.space.SpaceDF | None: - """The space of the model. - - Returns - ------- - mesa_frames.abstract.space.SpaceDF | None - """ - return self.model.space - - @property - @abstractmethod - def df(self) -> DataFrame | dict[str, DataFrame]: - """The agents in the AgentContainer. - - Returns - ------- - DataFrame | dict[str, DataFrame] - """ - - @df.setter - @abstractmethod - def df( - self, agents: DataFrame | list[mesa_frames.concrete.agents.AgentSetDF] - ) -> None: - """Set the agents in the AgentContainer. - - Parameters - ---------- - agents : DataFrame | list[mesa_frames.concrete.agents.AgentSetDF] - """ - - @property - @abstractmethod - def active_agents(self) -> DataFrame | dict[str, DataFrame]: - """The active agents in the AgentContainer. - - Returns - ------- - DataFrame | dict[str, DataFrame] - """ - - @active_agents.setter - @abstractmethod - def active_agents( - self, - mask: AgentMask, - ) -> None: - """Set the active agents in the AgentContainer. - - Parameters - ---------- - mask : AgentMask - The mask to apply. - """ - self.select(mask=mask, inplace=True) - - @property - @abstractmethod - def inactive_agents( - self, - ) -> DataFrame | dict[mesa_frames.concrete.agents.AgentSetDF, DataFrame]: - """The inactive agents in the AgentContainer. - - Returns - ------- - DataFrame | dict[mesa_frames.concrete.agents.AgentSetDF, DataFrame] - """ - - @property - @abstractmethod - def index( - self, - ) -> Index | dict[mesa_frames.concrete.agents.AgentSetDF, Index]: - """The ids in the AgentContainer. - - Returns - ------- - Index | dict[mesa_frames.concrete.agents.AgentSetDF, Index] - """ - ... - - @property - @abstractmethod - def pos( - self, - ) -> DataFrame | dict[mesa_frames.concrete.agents.AgentSetDF, DataFrame]: - """The position of the agents in the AgentContainer. - - Returns - ------- - DataFrame | dict[mesa_frames.concrete.agents.AgentSetDF, DataFrame] - """ - ... - - -class AgentSetDF(AgentContainer, DataFrameMixin): - """The AgentSetDF class is a container for agents of the same type. - - Parameters - ---------- - model : mesa_frames.concrete.model.ModelDF - The model that the agent set belongs to. - """ - - _df: DataFrame # The agents in the AgentSetDF - _mask: ( - AgentMask # The underlying mask used for the active agents in the AgentSetDF. - ) - _model: ( - mesa_frames.concrete.model.ModelDF - ) # The model that the AgentSetDF belongs to. - - @abstractmethod - def __init__(self, model: mesa_frames.concrete.model.ModelDF) -> None: ... - - @abstractmethod - def add( - self, - agents: DataFrame | DataFrameInput, - inplace: bool = True, - ) -> Self: - """Add agents to the AgentSetDF. - - Agents can be the input to the DataFrame constructor. So, the input can be: - - A DataFrame: adds the agents from the DataFrame. - - A DataFrameInput: passes the input to the DataFrame constructor. - - Parameters - ---------- - agents : DataFrame | DataFrameInput - The agents to add. - inplace : bool, optional - If True, perform the operation in place, by default True - - Returns - ------- - Self - A new AgentContainer with the added agents. - """ - ... - - def discard(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self: - """Remove an agent from the AgentSetDF. Does not raise an error if the agent is not found. - - Parameters - ---------- - agents : IdsLike | AgentMask - The ids to remove - inplace : bool, optional - Whether to remove the agent in place, by default True - - Returns - ------- - Self - The updated AgentSetDF. - """ - return super().discard(agents, inplace) - - @overload - def do( - self, - method_name: str, - *args, - mask: AgentMask | None = None, - return_results: Literal[False] = False, - inplace: bool = True, - **kwargs, - ) -> Self: ... - - @overload - def do( - self, - method_name: str, - *args, - mask: AgentMask | None = None, - return_results: Literal[True], - inplace: bool = True, - **kwargs, - ) -> Any: ... - - def do( - self, - method_name: str, - *args, - mask: AgentMask | None = None, - return_results: bool = False, - inplace: bool = True, - **kwargs, - ) -> Self | Any: - masked_df = self._get_masked_df(mask) - # If the mask is empty, we can use the object as is - if len(masked_df) == len(self._df): - obj = self._get_obj(inplace) - method = getattr(obj, method_name) - result = method(*args, **kwargs) - else: # If the mask is not empty, we need to create a new masked AgentSetDF and concatenate the AgentSetDFs at the end - obj = self._get_obj(inplace=False) - obj._df = masked_df - original_masked_index = obj._get_obj_copy(obj.index) - method = getattr(obj, method_name) - result = method(*args, **kwargs) - obj._concatenate_agentsets( - [self], - duplicates_allowed=True, - keep_first_only=True, - original_masked_index=original_masked_index, - ) - if inplace: - for key, value in obj.__dict__.items(): - setattr(self, key, value) - obj = self - if return_results: - return result - else: - return obj - - @abstractmethod - @overload - def get( - self, - attr_names: str, - mask: AgentMask | None = None, - ) -> Series: ... - - @abstractmethod - @overload - def get( - self, - attr_names: Collection[str] | None = None, - mask: AgentMask | None = None, - ) -> DataFrame: ... - - @abstractmethod - def get( - self, - attr_names: str | Collection[str] | None = None, - mask: AgentMask | None = None, - ) -> Series | DataFrame: ... - - @abstractmethod - def step(self) -> None: - """Run a single step of the AgentSetDF. This method should be overridden by subclasses.""" - ... - - def remove(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self: - if isinstance(agents, str) and agents == "active": - agents = self.active_agents - if agents is None or (isinstance(agents, Iterable) and len(agents) == 0): - return self._get_obj(inplace) - agents = self._df_index(self._get_masked_df(agents), "unique_id") - agentsdf = self.model.agents.remove(agents, inplace=inplace) - # TODO: Refactor AgentsDF to return dict[str, AgentSetDF] instead of dict[AgentSetDF, DataFrame] - # And assign a name to AgentSetDF? This has to be replaced by a nicer API of AgentsDF - for agentset in agentsdf.df.keys(): - if isinstance(agentset, self.__class__): - return agentset - return self - - @abstractmethod - def _concatenate_agentsets( - self, - objs: Iterable[Self], - duplicates_allowed: bool = True, - keep_first_only: bool = True, - original_masked_index: Index | None = None, - ) -> Self: ... - - @abstractmethod - def _get_bool_mask(self, mask: AgentMask) -> BoolSeries: - """Get the equivalent boolean mask based on the input mask. - - Parameters - ---------- - mask : AgentMask - - Returns - ------- - BoolSeries - """ - ... - - @abstractmethod - def _get_masked_df(self, mask: AgentMask) -> DataFrame: - """Get the df filtered by the input mask. - - Parameters - ---------- - mask : AgentMask - - Returns - ------- - DataFrame - """ - - @overload - @abstractmethod - def _get_obj_copy(self, obj: DataFrame) -> DataFrame: ... - - @overload - @abstractmethod - def _get_obj_copy(self, obj: Series) -> Series: ... - - @overload - @abstractmethod - def _get_obj_copy(self, obj: Index) -> Index: ... - - @abstractmethod - def _get_obj_copy( - self, obj: DataFrame | Series | Index - ) -> DataFrame | Series | Index: ... - - @abstractmethod - def _discard(self, ids: IdsLike) -> Self: - """Remove an agent from the DataFrame of the AgentSetDF. Gets called by self.model.agents.remove and self.model.agents.discard. - - Parameters - ---------- - ids : IdsLike - - The ids to remove - - Returns - ------- - Self - """ - ... - - @abstractmethod - def _update_mask( - self, original_active_indices: Index, new_active_indices: Index | None = None - ) -> None: ... - - def __add__(self, other: DataFrame | DataFrameInput) -> Self: - """Add agents to a new AgentSetDF through the + operator. - - Other can be: - - A DataFrame: adds the agents from the DataFrame. - - A DataFrameInput: passes the input to the DataFrame constructor. - - Parameters - ---------- - other : DataFrame | DataFrameInput - The agents to add. - - Returns - ------- - Self - A new AgentContainer with the added agents. - """ - return super().__add__(other) - - def __iadd__(self, other: DataFrame | DataFrameInput) -> Self: - """ - Add agents to the AgentSetDF through the += operator. - - Other can be: - - A DataFrame: adds the agents from the DataFrame. - - A DataFrameInput: passes the input to the DataFrame constructor. - - Parameters - ---------- - other : DataFrame | DataFrameInput - The agents to add. - - Returns - ------- - Self - The updated AgentContainer. - """ - return super().__iadd__(other) - - @abstractmethod - def __getattr__(self, name: str) -> Any: - if __debug__: # Only execute in non-optimized mode - if name == "_df": - raise AttributeError( - "The _df attribute is not set. You probably forgot to call super().__init__ in the __init__ method." - ) - - @overload - def __getitem__(self, key: str | tuple[AgentMask, str]) -> Series | DataFrame: ... - - @overload - def __getitem__( - self, - key: AgentMask | Collection[str] | tuple[AgentMask, Collection[str]], - ) -> DataFrame: ... - - def __getitem__( - self, - key: ( - str - | Collection[str] - | AgentMask - | tuple[AgentMask, str] - | tuple[AgentMask, Collection[str]] - ), - ) -> Series | DataFrame: - attr = super().__getitem__(key) - assert isinstance(attr, (Series, DataFrame, Index)) - return attr - - def __len__(self) -> int: - return len(self._df) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}\n {str(self._df)}" - - def __str__(self) -> str: - return f"{self.__class__.__name__}\n {str(self._df)}" - - def __reversed__(self) -> Iterator: - return reversed(self._df) - - @property - def df(self) -> DataFrame: - return self._df - - @df.setter - def df(self, agents: DataFrame) -> None: - """Set the agents in the AgentSetDF. - - Parameters - ---------- - agents : DataFrame - The agents to set. - """ - self._df = agents - - @property - @abstractmethod - def active_agents(self) -> DataFrame: ... - - @property - @abstractmethod - def inactive_agents(self) -> DataFrame: ... - - @property - def index(self) -> Index: ... - - @property - def pos(self) -> DataFrame: - if self.space is None: - raise AttributeError( - "Attempted to access `pos`, but the model has no space attached." - ) - pos = self._df_get_masked_df( - df=self.space.agents, index_cols="agent_id", mask=self.index - ) - pos = self._df_reindex( - pos, self.index, new_index_cols="unique_id", original_index_cols="agent_id" - ) - return pos diff --git a/mesa_frames/abstract/agentset.py b/mesa_frames/abstract/agentset.py new file mode 100644 index 00000000..a7da9097 --- /dev/null +++ b/mesa_frames/abstract/agentset.py @@ -0,0 +1,393 @@ +""" +Abstract base classes for agent sets in mesa-frames. + +This module defines the core abstractions for agent sets in the mesa-frames +extension. It provides the foundation for implementing agent set storage and +manipulation. + +Classes: + AbstractAgentSet: + An abstract base class for agent sets that combines agent container + functionality with DataFrame operations. It inherits from both + AbstractAgentSetRegistry and DataFrameMixin to provide comprehensive + agent management capabilities. + +This abstract class is designed to be subclassed to create concrete +implementations that use specific DataFrame backends. +""" + +from __future__ import annotations + +from abc import abstractmethod +from collections.abc import Collection, Iterable, Iterator +from typing import Any, Literal, Self, overload + +from mesa_frames.abstract.agentsetregistry import AbstractAgentSetRegistry +from mesa_frames.abstract.mixin import DataFrameMixin +from mesa_frames.types_ import ( + AgentMask, + BoolSeries, + DataFrame, + DataFrameInput, + IdsLike, + Index, + Series, +) + + +class AbstractAgentSet(AbstractAgentSetRegistry, DataFrameMixin): + """The AbstractAgentSet class is a container for agents of the same type. + + Parameters + ---------- + model : mesa_frames.concrete.model.Model + The model that the agent set belongs to. + """ + + _df: DataFrame # The agents in the AbstractAgentSet + _mask: AgentMask # The underlying mask used for the active agents in the AbstractAgentSet. + _model: ( + mesa_frames.concrete.model.Model + ) # The model that the AbstractAgentSet belongs to. + + @abstractmethod + def __init__(self, model: mesa_frames.concrete.model.Model) -> None: ... + + @abstractmethod + def add( + self, + agents: DataFrame | DataFrameInput, + inplace: bool = True, + ) -> Self: + """Add agents to the AbstractAgentSet. + + Agents can be the input to the DataFrame constructor. So, the input can be: + - A DataFrame: adds the agents from the DataFrame. + - A DataFrameInput: passes the input to the DataFrame constructor. + + Parameters + ---------- + agents : DataFrame | DataFrameInput + The agents to add. + inplace : bool, optional + If True, perform the operation in place, by default True + + Returns + ------- + Self + A new AbstractAgentSetRegistry with the added agents. + """ + ... + + def discard(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self: + """Remove an agent from the AbstractAgentSet. Does not raise an error if the agent is not found. + + Parameters + ---------- + agents : IdsLike | AgentMask + The ids to remove + inplace : bool, optional + Whether to remove the agent in place, by default True + + Returns + ------- + Self + The updated AbstractAgentSet. + """ + return super().discard(agents, inplace) + + @overload + def do( + self, + method_name: str, + *args, + mask: AgentMask | None = None, + return_results: Literal[False] = False, + inplace: bool = True, + **kwargs, + ) -> Self: ... + + @overload + def do( + self, + method_name: str, + *args, + mask: AgentMask | None = None, + return_results: Literal[True], + inplace: bool = True, + **kwargs, + ) -> Any: ... + + def do( + self, + method_name: str, + *args, + mask: AgentMask | None = None, + return_results: bool = False, + inplace: bool = True, + **kwargs, + ) -> Self | Any: + masked_df = self._get_masked_df(mask) + # If the mask is empty, we can use the object as is + if len(masked_df) == len(self._df): + obj = self._get_obj(inplace) + method = getattr(obj, method_name) + result = method(*args, **kwargs) + else: # If the mask is not empty, we need to create a new masked AbstractAgentSet and concatenate the AbstractAgentSets at the end + obj = self._get_obj(inplace=False) + obj._df = masked_df + original_masked_index = obj._get_obj_copy(obj.index) + method = getattr(obj, method_name) + result = method(*args, **kwargs) + obj._concatenate_agentsets( + [self], + duplicates_allowed=True, + keep_first_only=True, + original_masked_index=original_masked_index, + ) + if inplace: + for key, value in obj.__dict__.items(): + setattr(self, key, value) + obj = self + if return_results: + return result + else: + return obj + + @abstractmethod + @overload + def get( + self, + attr_names: str, + mask: AgentMask | None = None, + ) -> Series: ... + + @abstractmethod + @overload + def get( + self, + attr_names: Collection[str] | None = None, + mask: AgentMask | None = None, + ) -> DataFrame: ... + + @abstractmethod + def get( + self, + attr_names: str | Collection[str] | None = None, + mask: AgentMask | None = None, + ) -> Series | DataFrame: ... + + @abstractmethod + def step(self) -> None: + """Run a single step of the AbstractAgentSet. This method should be overridden by subclasses.""" + ... + + def remove(self, agents: IdsLike | AgentMask, inplace: bool = True) -> Self: + if isinstance(agents, str) and agents == "active": + agents = self.active_agents + if agents is None or (isinstance(agents, Iterable) and len(agents) == 0): + return self._get_obj(inplace) + agents = self._df_index(self._get_masked_df(agents), "unique_id") + sets = self.model.sets.remove(agents, inplace=inplace) + # TODO: Refactor AgentSetRegistry to return dict[str, AbstractAgentSet] instead of dict[AbstractAgentSet, DataFrame] + # And assign a name to AbstractAgentSet? This has to be replaced by a nicer API of AgentSetRegistry + for agentset in sets.df.keys(): + if isinstance(agentset, self.__class__): + return agentset + return self + + @abstractmethod + def _concatenate_agentsets( + self, + objs: Iterable[Self], + duplicates_allowed: bool = True, + keep_first_only: bool = True, + original_masked_index: Index | None = None, + ) -> Self: ... + + @abstractmethod + def _get_bool_mask(self, mask: AgentMask) -> BoolSeries: + """Get the equivalent boolean mask based on the input mask. + + Parameters + ---------- + mask : AgentMask + + Returns + ------- + BoolSeries + """ + ... + + @abstractmethod + def _get_masked_df(self, mask: AgentMask) -> DataFrame: + """Get the df filtered by the input mask. + + Parameters + ---------- + mask : AgentMask + + Returns + ------- + DataFrame + """ + + @overload + @abstractmethod + def _get_obj_copy(self, obj: DataFrame) -> DataFrame: ... + + @overload + @abstractmethod + def _get_obj_copy(self, obj: Series) -> Series: ... + + @overload + @abstractmethod + def _get_obj_copy(self, obj: Index) -> Index: ... + + @abstractmethod + def _get_obj_copy( + self, obj: DataFrame | Series | Index + ) -> DataFrame | Series | Index: ... + + @abstractmethod + def _discard(self, ids: IdsLike) -> Self: + """Remove an agent from the DataFrame of the AbstractAgentSet. Gets called by self.model.sets.remove and self.model.sets.discard. + + Parameters + ---------- + ids : IdsLike + + The ids to remove + + Returns + ------- + Self + """ + ... + + @abstractmethod + def _update_mask( + self, original_active_indices: Index, new_active_indices: Index | None = None + ) -> None: ... + + def __add__(self, other: DataFrame | DataFrameInput) -> Self: + """Add agents to a new AbstractAgentSet through the + operator. + + Other can be: + - A DataFrame: adds the agents from the DataFrame. + - A DataFrameInput: passes the input to the DataFrame constructor. + + Parameters + ---------- + other : DataFrame | DataFrameInput + The agents to add. + + Returns + ------- + Self + A new AbstractAgentSetRegistry with the added agents. + """ + return super().__add__(other) + + def __iadd__(self, other: DataFrame | DataFrameInput) -> Self: + """ + Add agents to the AbstractAgentSet through the += operator. + + Other can be: + - A DataFrame: adds the agents from the DataFrame. + - A DataFrameInput: passes the input to the DataFrame constructor. + + Parameters + ---------- + other : DataFrame | DataFrameInput + The agents to add. + + Returns + ------- + Self + The updated AbstractAgentSetRegistry. + """ + return super().__iadd__(other) + + @abstractmethod + def __getattr__(self, name: str) -> Any: + if __debug__: # Only execute in non-optimized mode + if name == "_df": + raise AttributeError( + "The _df attribute is not set. You probably forgot to call super().__init__ in the __init__ method." + ) + + @overload + def __getitem__(self, key: str | tuple[AgentMask, str]) -> Series | DataFrame: ... + + @overload + def __getitem__( + self, + key: AgentMask | Collection[str] | tuple[AgentMask, Collection[str]], + ) -> DataFrame: ... + + def __getitem__( + self, + key: ( + str + | Collection[str] + | AgentMask + | tuple[AgentMask, str] + | tuple[AgentMask, Collection[str]] + ), + ) -> Series | DataFrame: + attr = super().__getitem__(key) + assert isinstance(attr, (Series, DataFrame, Index)) + return attr + + def __len__(self) -> int: + return len(self._df) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}\n {str(self._df)}" + + def __str__(self) -> str: + return f"{self.__class__.__name__}\n {str(self._df)}" + + def __reversed__(self) -> Iterator: + return reversed(self._df) + + @property + def df(self) -> DataFrame: + return self._df + + @df.setter + def df(self, agents: DataFrame) -> None: + """Set the agents in the AbstractAgentSet. + + Parameters + ---------- + agents : DataFrame + The agents to set. + """ + self._df = agents + + @property + @abstractmethod + def active_agents(self) -> DataFrame: ... + + @property + @abstractmethod + def inactive_agents(self) -> DataFrame: ... + + @property + def index(self) -> Index: ... + + @property + def pos(self) -> DataFrame: + if self.space is None: + raise AttributeError( + "Attempted to access `pos`, but the model has no space attached." + ) + pos = self._df_get_masked_df( + df=self.space.agents, index_cols="agent_id", mask=self.index + ) + pos = self._df_reindex( + pos, self.index, new_index_cols="unique_id", original_index_cols="agent_id" + ) + return pos diff --git a/mesa_frames/abstract/agentsetregistry.py b/mesa_frames/abstract/agentsetregistry.py new file mode 100644 index 00000000..c8fa6c60 --- /dev/null +++ b/mesa_frames/abstract/agentsetregistry.py @@ -0,0 +1,798 @@ +""" +Abstract base classes for agent containers in mesa-frames. + +This module defines the core abstractions for agent containers in the mesa-frames +extension. It provides the foundation for implementing agent storage and +manipulation using DataFrame-based approaches. + +Classes: + AbstractAgentSetRegistry(CopyMixin): + An abstract base class that defines the common interface for all agent + containers in mesa-frames. It inherits from CopyMixin to provide fast + copying functionality. + + AbstractAgentSet(AbstractAgentSetRegistry, DataFrameMixin): + An abstract base class for agent sets that use DataFrames as the underlying + storage mechanism. It inherits from both AbstractAgentSetRegistry and DataFrameMixin + to combine agent container functionality with DataFrame operations. + +These abstract classes are designed to be subclassed by concrete implementations +that use Polars library as their backend. + +Usage: + These classes should not be instantiated directly. Instead, they should be + subclassed to create concrete implementations: + + from mesa_frames.abstract.agents import AbstractAgentSet + + class AgentSet(AbstractAgentSet): + def __init__(self, model): + super().__init__(model) + # Implementation using a DataFrame backend + ... + + # Implement other abstract methods + +Note: + The abstract methods in these classes use Python's @abstractmethod decorator, + ensuring that concrete subclasses must implement these methods. + +Attributes and methods of each class are documented in their respective docstrings. +""" + +from __future__ import annotations # PEP 563: postponed evaluation of type annotations + +from abc import abstractmethod +from collections.abc import Callable, Collection, Iterator, Sequence +from contextlib import suppress +from typing import Any, Literal, Self, overload + +from numpy.random import Generator + +from mesa_frames.abstract.mixin import CopyMixin +from mesa_frames.types_ import ( + AgentMask, + BoolSeries, + DataFrame, + DataFrameInput, + IdsLike, + Index, + Series, +) + + +class AbstractAgentSetRegistry(CopyMixin): + """An abstract class for containing agents. Defines the common interface for AbstractAgentSet and AgentSetRegistry.""" + + _copy_only_reference: list[str] = [ + "_model", + ] + _model: mesa_frames.concrete.model.Model + + @abstractmethod + def __init__(self) -> None: ... + + def discard( + self, + agents: IdsLike + | AgentMask + | mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet], + inplace: bool = True, + ) -> Self: + """Remove agents from the AbstractAgentSetRegistry. Does not raise an error if the agent is not found. + + Parameters + ---------- + agents : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The agents to remove + inplace : bool + Whether to remove the agent in place. Defaults to True. + + Returns + ------- + Self + The updated AbstractAgentSetRegistry. + """ + with suppress(KeyError, ValueError): + return self.remove(agents, inplace=inplace) + return self._get_obj(inplace) + + @abstractmethod + def add( + self, + agents: DataFrame + | DataFrameInput + | mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet], + inplace: bool = True, + ) -> Self: + """Add agents to the AbstractAgentSetRegistry. + + Parameters + ---------- + agents : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The agents to add. + inplace : bool + Whether to add the agents in place. Defaults to True. + + Returns + ------- + Self + The updated AbstractAgentSetRegistry. + """ + ... + + @overload + @abstractmethod + def contains(self, agents: int) -> bool: ... + + @overload + @abstractmethod + def contains( + self, agents: mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike + ) -> BoolSeries: ... + + @abstractmethod + def contains( + self, agents: mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike + ) -> bool | BoolSeries: + """Check if agents with the specified IDs are in the AbstractAgentSetRegistry. + + Parameters + ---------- + agents : mesa_frames.abstract.agentset.AbstractAgentSet | IdsLike + The ID(s) to check for. + + Returns + ------- + bool | BoolSeries + True if the agent is in the AbstractAgentSetRegistry, False otherwise. + """ + + @overload + @abstractmethod + def do( + self, + method_name: str, + *args: Any, + mask: AgentMask | None = None, + return_results: Literal[False] = False, + inplace: bool = True, + **kwargs: Any, + ) -> Self: ... + + @overload + @abstractmethod + def do( + self, + method_name: str, + *args: Any, + mask: AgentMask | None = None, + return_results: Literal[True], + inplace: bool = True, + **kwargs: Any, + ) -> Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any]: ... + + @abstractmethod + def do( + self, + method_name: str, + *args: Any, + mask: AgentMask | None = None, + return_results: bool = False, + inplace: bool = True, + **kwargs: Any, + ) -> Self | Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any]: + """Invoke a method on the AbstractAgentSetRegistry. + + Parameters + ---------- + method_name : str + The name of the method to invoke. + *args : Any + Positional arguments to pass to the method + mask : AgentMask | None, optional + The subset of agents on which to apply the method + return_results : bool, optional + Whether to return the result of the method, by default False + inplace : bool, optional + Whether the operation should be done inplace, by default False + **kwargs : Any + Keyword arguments to pass to the method + + Returns + ------- + Self | Any | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Any] + The updated AbstractAgentSetRegistry or the result of the method. + """ + ... + + @abstractmethod + @overload + def get(self, attr_names: str) -> Series | dict[str, Series]: ... + + @abstractmethod + @overload + def get( + self, attr_names: Collection[str] | None = None + ) -> DataFrame | dict[str, DataFrame]: ... + + @abstractmethod + def get( + self, + attr_names: str | Collection[str] | None = None, + mask: AgentMask | None = None, + ) -> Series | dict[str, Series] | DataFrame | dict[str, DataFrame]: + """Retrieve the value of a specified attribute for each agent in the AbstractAgentSetRegistry. + + Parameters + ---------- + attr_names : str | Collection[str] | None, optional + The attributes to retrieve. If None, all attributes are retrieved. Defaults to None. + mask : AgentMask | None, optional + The AgentMask of agents to retrieve the attribute for. If None, attributes of all agents are returned. Defaults to None. + + Returns + ------- + Series | dict[str, Series] | DataFrame | dict[str, DataFrame] + The attribute values. + """ + ... + + @abstractmethod + def remove( + self, + agents: ( + IdsLike + | AgentMask + | mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + ), + inplace: bool = True, + ) -> Self: + """Remove the agents from the AbstractAgentSetRegistry. + + Parameters + ---------- + agents : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The agents to remove. + inplace : bool, optional + Whether to remove the agent in place. + + Returns + ------- + Self + The updated AbstractAgentSetRegistry. + """ + ... + + @abstractmethod + def select( + self, + mask: AgentMask | None = None, + filter_func: Callable[[Self], AgentMask] | None = None, + n: int | None = None, + negate: bool = False, + inplace: bool = True, + ) -> Self: + """Select agents in the AbstractAgentSetRegistry based on the given criteria. + + Parameters + ---------- + mask : AgentMask | None, optional + The AgentMask of agents to be selected, by default None + filter_func : Callable[[Self], AgentMask] | None, optional + A function which takes as input the AbstractAgentSetRegistry and returns a AgentMask, by default None + n : int | None, optional + The maximum number of agents to be selected, by default None + negate : bool, optional + If the selection should be negated, by default False + inplace : bool, optional + If the operation should be performed on the same object, by default True + + Returns + ------- + Self + A new or updated AbstractAgentSetRegistry. + """ + ... + + @abstractmethod + @overload + def set( + self, + attr_names: dict[str, Any], + values: None, + mask: AgentMask | None = None, + inplace: bool = True, + ) -> Self: ... + + @abstractmethod + @overload + def set( + self, + attr_names: str | Collection[str], + values: Any, + mask: AgentMask | None = None, + inplace: bool = True, + ) -> Self: ... + + @abstractmethod + def set( + self, + attr_names: DataFrameInput | str | Collection[str], + values: Any | None = None, + mask: AgentMask | None = None, + inplace: bool = True, + ) -> Self: + """Set the value of a specified attribute or attributes for each agent in the mask in AbstractAgentSetRegistry. + + Parameters + ---------- + attr_names : DataFrameInput | str | Collection[str] + The key can be: + - A string: sets the specified column of the agents in the AbstractAgentSetRegistry. + - A collection of strings: sets the specified columns of the agents in the AbstractAgentSetRegistry. + - A dictionary: keys should be attributes and values should be the values to set. Value should be None. + values : Any | None + The value to set the attribute to. If None, attr_names must be a dictionary. + mask : AgentMask | None + The AgentMask of agents to set the attribute for. + inplace : bool + Whether to set the attribute in place. + + Returns + ------- + Self + The updated agent set. + """ + ... + + @abstractmethod + def shuffle(self, inplace: bool = False) -> Self: + """Shuffles the order of agents in the AbstractAgentSetRegistry. + + Parameters + ---------- + inplace : bool + Whether to shuffle the agents in place. + + Returns + ------- + Self + A new or updated AbstractAgentSetRegistry. + """ + + @abstractmethod + def sort( + self, + by: str | Sequence[str], + ascending: bool | Sequence[bool] = True, + inplace: bool = True, + **kwargs, + ) -> Self: + """ + Sorts the agents in the agent set based on the given criteria. + + Parameters + ---------- + by : str | Sequence[str] + The attribute(s) to sort by. + ascending : bool | Sequence[bool] + Whether to sort in ascending order. + inplace : bool + Whether to sort the agents in place. + **kwargs + Keyword arguments to pass to the sort + + Returns + ------- + Self + A new or updated AbstractAgentSetRegistry. + """ + + def __add__( + self, + other: DataFrame + | DataFrameInput + | mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet], + ) -> Self: + """Add agents to a new AbstractAgentSetRegistry through the + operator. + + Parameters + ---------- + other : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The agents to add. + + Returns + ------- + Self + A new AbstractAgentSetRegistry with the added agents. + """ + return self.add(agents=other, inplace=False) + + def __contains__( + self, agents: int | mesa_frames.abstract.agentset.AbstractAgentSet + ) -> bool: + """Check if an agent is in the AbstractAgentSetRegistry. + + Parameters + ---------- + agents : int | mesa_frames.abstract.agentset.AbstractAgentSet + The ID(s) or AbstractAgentSet to check for. + + Returns + ------- + bool + True if the agent is in the AbstractAgentSetRegistry, False otherwise. + """ + return self.contains(agents=agents) + + @overload + def __getitem__( + self, key: str | tuple[AgentMask, str] + ) -> Series | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series]: ... + + @overload + def __getitem__( + self, + key: AgentMask | Collection[str] | tuple[AgentMask, Collection[str]], + ) -> ( + DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] + ): ... + + def __getitem__( + self, + key: ( + str + | Collection[str] + | AgentMask + | tuple[AgentMask, str] + | tuple[AgentMask, Collection[str]] + | tuple[ + dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str + ] + | tuple[ + dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], + Collection[str], + ] + ), + ) -> ( + Series + | DataFrame + | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series] + | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] + ): + """Implement the [] operator for the AbstractAgentSetRegistry. + + The key can be: + - An attribute or collection of attributes (eg. AbstractAgentSetRegistry["str"], AbstractAgentSetRegistry[["str1", "str2"]]): returns the specified column(s) of the agents in the AbstractAgentSetRegistry. + - An AgentMask (eg. AbstractAgentSetRegistry[AgentMask]): returns the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. + - A tuple (eg. AbstractAgentSetRegistry[AgentMask, "str"]): returns the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. + - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, "str"]): returns the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. + - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, Collection[str]]): returns the specified columns of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. + + Parameters + ---------- + key : str | Collection[str] | AgentMask | tuple[AgentMask, str] | tuple[AgentMask, Collection[str]] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], Collection[str]] + The key to retrieve. + + Returns + ------- + Series | DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Series] | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] + The attribute values. + """ + # TODO: fix types + if isinstance(key, tuple): + return self.get(mask=key[0], attr_names=key[1]) + else: + if isinstance(key, str) or ( + isinstance(key, Collection) and all(isinstance(k, str) for k in key) + ): + return self.get(attr_names=key) + else: + return self.get(mask=key) + + def __iadd__( + self, + other: ( + DataFrame + | DataFrameInput + | mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + ), + ) -> Self: + """Add agents to the AbstractAgentSetRegistry through the += operator. + + Parameters + ---------- + other : DataFrame | DataFrameInput | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The agents to add. + + Returns + ------- + Self + The updated AbstractAgentSetRegistry. + """ + return self.add(agents=other, inplace=True) + + def __isub__( + self, + other: ( + IdsLike + | AgentMask + | mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + ), + ) -> Self: + """Remove agents from the AbstractAgentSetRegistry through the -= operator. + + Parameters + ---------- + other : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The agents to remove. + + Returns + ------- + Self + The updated AbstractAgentSetRegistry. + """ + return self.discard(other, inplace=True) + + def __sub__( + self, + other: ( + IdsLike + | AgentMask + | mesa_frames.abstract.agentset.AbstractAgentSet + | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + ), + ) -> Self: + """Remove agents from a new AbstractAgentSetRegistry through the - operator. + + Parameters + ---------- + other : IdsLike | AgentMask | mesa_frames.abstract.agentset.AbstractAgentSet | Collection[mesa_frames.abstract.agentset.AbstractAgentSet] + The agents to remove. + + Returns + ------- + Self + A new AbstractAgentSetRegistry with the removed agents. + """ + return self.discard(other, inplace=False) + + def __setitem__( + self, + key: ( + str + | Collection[str] + | AgentMask + | tuple[AgentMask, str | Collection[str]] + | tuple[ + dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str + ] + | tuple[ + dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], + Collection[str], + ] + ), + values: Any, + ) -> None: + """Implement the [] operator for setting values in the AbstractAgentSetRegistry. + + The key can be: + - A string (eg. AbstractAgentSetRegistry["str"]): sets the specified column of the agents in the AbstractAgentSetRegistry. + - A list of strings(eg. AbstractAgentSetRegistry[["str1", "str2"]]): sets the specified columns of the agents in the AbstractAgentSetRegistry. + - A tuple (eg. AbstractAgentSetRegistry[AgentMask, "str"]): sets the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. + - A AgentMask (eg. AbstractAgentSetRegistry[AgentMask]): sets the attributes of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask. + - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, "str"]): sets the specified column of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. + - A tuple with a dictionary (eg. AbstractAgentSetRegistry[{AbstractAgentSet: AgentMask}, Collection[str]]): sets the specified columns of the agents in the AbstractAgentSetRegistry that satisfy the AgentMask from the dictionary. + + Parameters + ---------- + key : str | Collection[str] | AgentMask | tuple[AgentMask, str | Collection[str]] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], str] | tuple[dict[mesa_frames.abstract.agentset.AbstractAgentSet, AgentMask], Collection[str]] + The key to set. + values : Any + The values to set for the specified key. + """ + # TODO: fix types as in __getitem__ + if isinstance(key, tuple): + self.set(mask=key[0], attr_names=key[1], values=values) + else: + if isinstance(key, str) or ( + isinstance(key, Collection) and all(isinstance(k, str) for k in key) + ): + try: + self.set(attr_names=key, values=values) + except KeyError: # key=AgentMask + self.set(attr_names=None, mask=key, values=values) + else: + self.set(attr_names=None, mask=key, values=values) + + @abstractmethod + def __getattr__(self, name: str) -> Any | dict[str, Any]: + """Fallback for retrieving attributes of the AbstractAgentSetRegistry. Retrieve an attribute of the underlying DataFrame(s). + + Parameters + ---------- + name : str + The name of the attribute to retrieve. + + Returns + ------- + Any | dict[str, Any] + The attribute value + """ + + @abstractmethod + def __iter__(self) -> Iterator[dict[str, Any]]: + """Iterate over the agents in the AbstractAgentSetRegistry. + + Returns + ------- + Iterator[dict[str, Any]] + An iterator over the agents. + """ + ... + + @abstractmethod + def __len__(self) -> int: + """Get the number of agents in the AbstractAgentSetRegistry. + + Returns + ------- + int + The number of agents in the AbstractAgentSetRegistry. + """ + ... + + @abstractmethod + def __repr__(self) -> str: + """Get a string representation of the DataFrame in the AbstractAgentSetRegistry. + + Returns + ------- + str + A string representation of the DataFrame in the AbstractAgentSetRegistry. + """ + pass + + @abstractmethod + def __reversed__(self) -> Iterator: + """Iterate over the agents in the AbstractAgentSetRegistry in reverse order. + + Returns + ------- + Iterator + An iterator over the agents in reverse order. + """ + ... + + @abstractmethod + def __str__(self) -> str: + """Get a string representation of the agents in the AbstractAgentSetRegistry. + + Returns + ------- + str + A string representation of the agents in the AbstractAgentSetRegistry. + """ + ... + + @property + def model(self) -> mesa_frames.concrete.model.Model: + """The model that the AbstractAgentSetRegistry belongs to. + + Returns + ------- + mesa_frames.concrete.model.Model + """ + return self._model + + @property + def random(self) -> Generator: + """The random number generator of the model. + + Returns + ------- + Generator + """ + return self.model.random + + @property + def space(self) -> mesa_frames.abstract.space.Space | None: + """The space of the model. + + Returns + ------- + mesa_frames.abstract.space.Space | None + """ + return self.model.space + + @property + @abstractmethod + def df(self) -> DataFrame | dict[str, DataFrame]: + """The agents in the AbstractAgentSetRegistry. + + Returns + ------- + DataFrame | dict[str, DataFrame] + """ + + @df.setter + @abstractmethod + def df( + self, agents: DataFrame | list[mesa_frames.abstract.agentset.AbstractAgentSet] + ) -> None: + """Set the agents in the AbstractAgentSetRegistry. + + Parameters + ---------- + agents : DataFrame | list[mesa_frames.abstract.agentset.AbstractAgentSet] + """ + + @property + @abstractmethod + def active_agents(self) -> DataFrame | dict[str, DataFrame]: + """The active agents in the AbstractAgentSetRegistry. + + Returns + ------- + DataFrame | dict[str, DataFrame] + """ + + @active_agents.setter + @abstractmethod + def active_agents( + self, + mask: AgentMask, + ) -> None: + """Set the active agents in the AbstractAgentSetRegistry. + + Parameters + ---------- + mask : AgentMask + The mask to apply. + """ + self.select(mask=mask, inplace=True) + + @property + @abstractmethod + def inactive_agents( + self, + ) -> DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]: + """The inactive agents in the AbstractAgentSetRegistry. + + Returns + ------- + DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] + """ + + @property + @abstractmethod + def index( + self, + ) -> Index | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Index]: + """The ids in the AbstractAgentSetRegistry. + + Returns + ------- + Index | dict[mesa_frames.abstract.agentset.AbstractAgentSet, Index] + """ + ... + + @property + @abstractmethod + def pos( + self, + ) -> DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame]: + """The position of the agents in the AbstractAgentSetRegistry. + + Returns + ------- + DataFrame | dict[mesa_frames.abstract.agentset.AbstractAgentSet, DataFrame] + """ + ... diff --git a/mesa_frames/abstract/datacollector.py b/mesa_frames/abstract/datacollector.py index d93f661d..edbfb11f 100644 --- a/mesa_frames/abstract/datacollector.py +++ b/mesa_frames/abstract/datacollector.py @@ -47,7 +47,7 @@ def flush(self): from abc import ABC, abstractmethod from typing import Any, Literal from collections.abc import Callable -from mesa_frames import ModelDF +from mesa_frames import Model import polars as pl import threading from concurrent.futures import ThreadPoolExecutor @@ -61,7 +61,7 @@ class AbstractDataCollector(ABC): Sub classes must implement logic for the methods """ - _model: ModelDF + _model: Model _model_reporters: dict[str, Callable] | None _agent_reporters: dict[str, str | Callable] | None _trigger: Callable[..., bool] | None @@ -71,7 +71,7 @@ class AbstractDataCollector(ABC): def __init__( self, - model: ModelDF, + model: Model, model_reporters: dict[str, Callable] | None, agent_reporters: dict[str, str | Callable] | None, trigger: Callable[[Any], bool] | None, @@ -86,7 +86,7 @@ def __init__( Parameters ---------- - model : ModelDF + model : Model The model object from which data is collected. model_reporters : dict[str, Callable] | None Functions to collect data at the model level. diff --git a/mesa_frames/abstract/mixin.py b/mesa_frames/abstract/mixin.py index 84b4ec7b..96904eba 100644 --- a/mesa_frames/abstract/mixin.py +++ b/mesa_frames/abstract/mixin.py @@ -81,8 +81,8 @@ def copy( Parameters ---------- deep : bool, optional - Flag indicating whether to perform a deep copy of the AgentContainer. - If True, all attributes of the AgentContainer will be recursively copied (except attributes in self._copy_reference_only). + Flag indicating whether to perform a deep copy of the AbstractAgentSetRegistry. + If True, all attributes of the AbstractAgentSetRegistry will be recursively copied (except attributes in self._copy_reference_only). If False, only the top-level attributes will be copied. Defaults to False. memo : dict | None, optional @@ -95,7 +95,7 @@ def copy( Returns ------- Self - A new instance of the AgentContainer class that is a copy of the original instance. + A new instance of the AbstractAgentSetRegistry class that is a copy of the original instance. """ cls = self.__class__ obj = cls.__new__(cls) @@ -155,17 +155,17 @@ def _get_obj(self, inplace: bool) -> Self: return deepcopy(self) def __copy__(self) -> Self: - """Create a shallow copy of the AgentContainer. + """Create a shallow copy of the AbstractAgentSetRegistry. Returns ------- Self - A shallow copy of the AgentContainer. + A shallow copy of the AbstractAgentSetRegistry. """ return self.copy(deep=False) def __deepcopy__(self, memo: dict) -> Self: - """Create a deep copy of the AgentContainer. + """Create a deep copy of the AbstractAgentSetRegistry. Parameters ---------- @@ -175,7 +175,7 @@ def __deepcopy__(self, memo: dict) -> Self: Returns ------- Self - A deep copy of the AgentContainer. + A deep copy of the AbstractAgentSetRegistry. """ return self.copy(deep=True, memo=memo) diff --git a/mesa_frames/abstract/space.py b/mesa_frames/abstract/space.py index dab5f7b0..74df16e8 100644 --- a/mesa_frames/abstract/space.py +++ b/mesa_frames/abstract/space.py @@ -12,13 +12,13 @@ classes in mesa-frames. It combines fast copying functionality with DataFrame operations. - DiscreteSpaceDF(SpaceDF): + AbstractDiscreteSpace(SpaceDF): An abstract base class for discrete space implementations, such as grids and networks. It extends SpaceDF with methods specific to discrete spaces. - GridDF(DiscreteSpaceDF): + AbstractGrid(AbstractDiscreteSpace): An abstract base class for grid-based spaces. It inherits from - DiscreteSpaceDF and adds grid-specific functionality. + AbstractDiscreteSpace and adds grid-specific functionality. These abstract classes are designed to be subclassed by concrete implementations that use Polars library as their backend. @@ -29,9 +29,9 @@ These classes should not be instantiated directly. Instead, they should be subclassed to create concrete implementations: - from mesa_frames.abstract.space import GridDF + from mesa_frames.abstract.space import AbstractGrid - class GridPolars(GridDF): + class Grid(AbstractGrid): def __init__(self, model, dimensions, torus, capacity, neighborhood_type): super().__init__(model, dimensions, torus, capacity, neighborhood_type) # Implementation using polars DataFrame @@ -59,9 +59,12 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type): import polars as pl from numpy.random import Generator -from mesa_frames.abstract.agents import AgentContainer, AgentSetDF +from mesa_frames.abstract.agentset import AbstractAgentSet +from mesa_frames.abstract.agentsetregistry import ( + AbstractAgentSetRegistry, +) from mesa_frames.abstract.mixin import CopyMixin, DataFrameMixin -from mesa_frames.concrete.agents import AgentsDF +from mesa_frames.concrete.agentsetregistry import AgentSetRegistry from mesa_frames.types_ import ( ArrayLike, BoolSeries, @@ -83,8 +86,8 @@ def __init__(self, model, dimensions, torus, capacity, neighborhood_type): ESPG = int -class SpaceDF(CopyMixin, DataFrameMixin): - """The SpaceDF class is an abstract class that defines the interface for all space classes in mesa_frames.""" +class Space(CopyMixin, DataFrameMixin): + """The Space class is an abstract class that defines the interface for all space classes in mesa_frames.""" _agents: DataFrame # | GeoDataFrame # Stores the agents placed in the space _center_col_names: list[ @@ -94,18 +97,20 @@ class SpaceDF(CopyMixin, DataFrameMixin): str ] # The column names of the positions in the _agents dataframe (eg. ['dim_0', 'dim_1', ...] in Grids, ['node_id', 'edge_id'] in Networks) - def __init__(self, model: mesa_frames.concrete.model.ModelDF) -> None: + def __init__(self, model: mesa_frames.concrete.model.Model) -> None: """Create a new SpaceDF. Parameters ---------- - model : mesa_frames.concrete.model.ModelDF + model : mesa_frames.concrete.model.Model """ self._model = model def move_agents( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], pos: SpaceCoordinate | SpaceCoordinates, inplace: bool = True, ) -> Self: @@ -115,7 +120,7 @@ def move_agents( Parameters ---------- - agents : IdsLike | AgentContainer | Collection[AgentContainer] + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The agents to move pos : SpaceCoordinate | SpaceCoordinates The coordinates for each agents. The length of the coordinates must match the number of agents. @@ -139,7 +144,9 @@ def move_agents( def place_agents( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], pos: SpaceCoordinate | SpaceCoordinates, inplace: bool = True, ) -> Self: @@ -147,7 +154,7 @@ def place_agents( Parameters ---------- - agents : IdsLike | AgentContainer | Collection[AgentContainer] + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The agents to place in the space pos : SpaceCoordinate | SpaceCoordinates The coordinates for each agents. The length of the coordinates must match the number of agents. @@ -190,8 +197,12 @@ def random_agents( def swap_agents( self, - agents0: IdsLike | AgentContainer | Collection[AgentContainer], - agents1: IdsLike | AgentContainer | Collection[AgentContainer], + agents0: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], + agents1: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: """Swap the positions of the agents in the space. @@ -200,9 +211,9 @@ def swap_agents( Parameters ---------- - agents0 : IdsLike | AgentContainer | Collection[AgentContainer] + agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The first set of agents to swap - agents1 : IdsLike | AgentContainer | Collection[AgentContainer] + agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The second set of agents to swap inplace : bool, optional Whether to perform the operation inplace, by default True @@ -245,8 +256,14 @@ def get_directions( self, pos0: SpaceCoordinate | SpaceCoordinates | None = None, pos1: SpaceCoordinate | SpaceCoordinates | None = None, - agents0: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, - agents1: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents0: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, + agents1: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, normalize: bool = False, ) -> DataFrame: """Return the directions from pos0 to pos1 or agents0 and agents1. @@ -261,9 +278,9 @@ def get_directions( The starting positions pos1 : SpaceCoordinate | SpaceCoordinates | None, optional The ending positions - agents0 : IdsLike | AgentContainer | Collection[AgentContainer] | None, optional + agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional The starting agents - agents1 : IdsLike | AgentContainer | Collection[AgentContainer] | None, optional + agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional The ending agents normalize : bool, optional Whether to normalize the vectors to unit norm. By default False @@ -280,8 +297,14 @@ def get_distances( self, pos0: SpaceCoordinate | SpaceCoordinates | None = None, pos1: SpaceCoordinate | SpaceCoordinates | None = None, - agents0: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, - agents1: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents0: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, + agents1: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, ) -> DataFrame: """Return the distances from pos0 to pos1 or agents0 and agents1. @@ -295,9 +318,9 @@ def get_distances( The starting positions pos1 : SpaceCoordinate | SpaceCoordinates | None, optional The ending positions - agents0 : IdsLike | AgentContainer | Collection[AgentContainer] | None, optional + agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional The starting agents - agents1 : IdsLike | AgentContainer | Collection[AgentContainer] | None, optional + agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional The ending agents Returns @@ -312,7 +335,10 @@ def get_neighbors( self, radius: int | float | Sequence[int] | Sequence[float] | ArrayLike, pos: SpaceCoordinate | SpaceCoordinates | None = None, - agents: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, include_center: bool = False, ) -> DataFrame: """Get the neighboring agents from given positions or agents according to the specified radiuses. @@ -325,7 +351,7 @@ def get_neighbors( The radius(es) of the neighborhood pos : SpaceCoordinate | SpaceCoordinates | None, optional The coordinates of the cell to get the neighborhood from, by default None - agents : IdsLike | AgentContainer | Collection[AgentContainer] | None, optional + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional The id of the agents to get the neighborhood from, by default None include_center : bool, optional If the center cells or agents should be included in the result, by default False @@ -346,14 +372,16 @@ def get_neighbors( @abstractmethod def move_to_empty( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: """Move agents to empty cells/positions in the space (cells/positions where there isn't any single agent). Parameters ---------- - agents : IdsLike | AgentContainer | Collection[AgentContainer] + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The agents to move to empty cells/positions inplace : bool, optional Whether to perform the operation inplace, by default True @@ -367,14 +395,16 @@ def move_to_empty( @abstractmethod def place_to_empty( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: """Place agents in empty cells/positions in the space (cells/positions where there isn't any single agent). Parameters ---------- - agents : IdsLike | AgentContainer | Collection[AgentContainer] + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The agents to place in empty cells/positions inplace : bool, optional Whether to perform the operation inplace, by default True @@ -407,7 +437,9 @@ def random_pos( @abstractmethod def remove_agents( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: """Remove agents from the space. @@ -416,7 +448,7 @@ def remove_agents( Parameters ---------- - agents : IdsLike | AgentContainer | Collection[AgentContainer] + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The agents to remove from the space inplace : bool, optional Whether to perform the operation inplace, by default True @@ -433,22 +465,27 @@ def remove_agents( return ... def _get_ids_srs( - self, agents: IdsLike | AgentContainer | Collection[AgentContainer] + self, + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], ) -> Series: if isinstance(agents, Sized) and len(agents) == 0: return self._srs_constructor([], name="agent_id", dtype="uint64") - if isinstance(agents, AgentSetDF): + if isinstance(agents, AbstractAgentSet): return self._srs_constructor( self._df_index(agents.df, "unique_id"), name="agent_id", dtype="uint64", ) - elif isinstance(agents, AgentsDF): + elif isinstance(agents, AgentSetRegistry): return self._srs_constructor(agents._ids, name="agent_id", dtype="uint64") - elif isinstance(agents, Collection) and (isinstance(agents[0], AgentContainer)): + elif isinstance(agents, Collection) and ( + isinstance(agents[0], AbstractAgentSetRegistry) + ): ids = [] for a in agents: - if isinstance(a, AgentSetDF): + if isinstance(a, AbstractAgentSet): ids.append( self._srs_constructor( self._df_index(a.df, "unique_id"), @@ -456,7 +493,7 @@ def _get_ids_srs( dtype="uint64", ) ) - elif isinstance(a, AgentsDF): + elif isinstance(a, AgentSetRegistry): ids.append( self._srs_constructor(a._ids, name="agent_id", dtype="uint64") ) @@ -469,7 +506,9 @@ def _get_ids_srs( @abstractmethod def _place_or_move_agents( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], pos: SpaceCoordinate | SpaceCoordinates, is_move: bool, ) -> Self: @@ -479,7 +518,7 @@ def _place_or_move_agents( Parameters ---------- - agents : IdsLike | AgentContainer | Collection[AgentContainer] + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The agents to move/place pos : SpaceCoordinate | SpaceCoordinates The position to move/place agents to @@ -493,7 +532,7 @@ def _place_or_move_agents( @abstractmethod def __repr__(self) -> str: - """Return a string representation of the SpaceDF. + """Return a string representation of the Space. Returns ------- @@ -503,7 +542,7 @@ def __repr__(self) -> str: @abstractmethod def __str__(self) -> str: - """Return a string representation of the SpaceDF. + """Return a string representation of the Space. Returns ------- @@ -522,12 +561,12 @@ def agents(self) -> DataFrame: # | GeoDataFrame: return self._agents @property - def model(self) -> mesa_frames.concrete.model.ModelDF: + def model(self) -> mesa_frames.concrete.model.Model: """The model to which the space belongs. Returns ------- - 'mesa_frames.concrete.model.ModelDF' + 'mesa_frames.concrete.model.Model' """ return self._model @@ -542,8 +581,8 @@ def random(self) -> Generator: return self.model.random -class DiscreteSpaceDF(SpaceDF): - """The DiscreteSpaceDF class is an abstract class that defines the interface for all discrete space classes (Grids and Networks) in mesa_frames.""" +class AbstractDiscreteSpace(Space): + """The AbstractDiscreteSpace class is an abstract class that defines the interface for all discrete space classes (Grids and Networks) in mesa_frames.""" _agents: DataFrame _capacity: int | None # The maximum capacity for cells (default is infinite) @@ -554,14 +593,14 @@ class DiscreteSpaceDF(SpaceDF): def __init__( self, - model: mesa_frames.concrete.model.ModelDF, + model: mesa_frames.concrete.model.Model, capacity: int | None = None, ): - """Create a new DiscreteSpaceDF. + """Create a new AbstractDiscreteSpace. Parameters ---------- - model : mesa_frames.concrete.model.ModelDF + model : mesa_frames.concrete.model.Model The model to which the space belongs capacity : int | None, optional The maximum capacity for cells (default is infinite), by default None @@ -616,7 +655,9 @@ def is_full(self, pos: DiscreteCoordinate | DiscreteCoordinates) -> DataFrame: def move_to_empty( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -626,14 +667,16 @@ def move_to_empty( def move_to_available( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: """Move agents to available cells/positions in the space (cells/positions where there is at least one spot available). Parameters ---------- - agents : IdsLike | AgentContainer | Collection[AgentContainer] + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] The agents to move to available cells/positions inplace : bool, optional Whether to perform the operation inplace, by default True @@ -649,7 +692,9 @@ def move_to_available( def place_to_empty( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -659,7 +704,9 @@ def place_to_empty( def place_to_available( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -775,7 +822,9 @@ def get_neighborhood( self, radius: int | float | Sequence[int] | Sequence[float] | ArrayLike, pos: DiscreteCoordinate | DiscreteCoordinates | None = None, - agents: IdsLike | AgentContainer | Collection[AgentContainer] = None, + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] = None, include_center: bool = False, ) -> DataFrame: """Get the neighborhood cells from the given positions (pos) or agents according to the specified radiuses. @@ -788,7 +837,7 @@ def get_neighborhood( The radius(es) of the neighborhoods pos : DiscreteCoordinate | DiscreteCoordinates | None, optional The coordinates of the cell(s) to get the neighborhood from - agents : IdsLike | AgentContainer | Collection[AgentContainer], optional + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry], optional The agent(s) to get the neighborhood from include_center : bool, optional If the cell in the center of the neighborhood should be included in the result, by default False @@ -883,7 +932,9 @@ def _check_cells( def _place_or_move_agents_to_cells( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], cell_type: Literal["any", "empty", "available"], is_move: bool, ) -> Self: @@ -892,7 +943,7 @@ def _place_or_move_agents_to_cells( if __debug__: # Check ids presence in model - b_contained = self.model.agents.contains(agents) + b_contained = self.model.sets.contains(agents) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -912,7 +963,10 @@ def _place_or_move_agents_to_cells( def _get_df_coords( self, pos: DiscreteCoordinate | DiscreteCoordinates | None = None, - agents: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, ) -> DataFrame: """Get the DataFrame of coordinates from the specified positions or agents. @@ -920,7 +974,7 @@ def _get_df_coords( ---------- pos : DiscreteCoordinate | DiscreteCoordinates | None, optional The positions to get the DataFrame from, by default None - agents : IdsLike | AgentContainer | Collection[AgentContainer] | None, optional + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional The agents to get the DataFrame from, by default None Returns @@ -1119,10 +1173,10 @@ def remaining_capacity(self) -> int | Infinity: ... -class GridDF(DiscreteSpaceDF): - """The GridDF class is an abstract class that defines the interface for all grid classes in mesa-frames. +class AbstractGrid(AbstractDiscreteSpace): + """The AbstractGrid class is an abstract class that defines the interface for all grid classes in mesa-frames. - Inherits from DiscreteSpaceDF. + Inherits from AbstractDiscreteSpace. Warning ------- @@ -1155,17 +1209,17 @@ class GridDF(DiscreteSpaceDF): def __init__( self, - model: mesa_frames.concrete.model.ModelDF, + model: mesa_frames.concrete.model.Model, dimensions: Sequence[int], torus: bool = False, capacity: int | None = None, neighborhood_type: str = "moore", ): - """Create a new GridDF. + """Create a new AbstractGrid. Parameters ---------- - model : mesa_frames.concrete.model.ModelDF + model : mesa_frames.concrete.model.Model The model to which the space belongs dimensions : Sequence[int] The dimensions of the grid @@ -1204,8 +1258,14 @@ def get_directions( self, pos0: GridCoordinate | GridCoordinates | None = None, pos1: GridCoordinate | GridCoordinates | None = None, - agents0: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, - agents1: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents0: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, + agents1: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, normalize: bool = False, ) -> DataFrame: result = self._calculate_differences(pos0, pos1, agents0, agents1) @@ -1217,8 +1277,14 @@ def get_distances( self, pos0: GridCoordinate | GridCoordinates | None = None, pos1: GridCoordinate | GridCoordinates | None = None, - agents0: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, - agents1: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents0: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, + agents1: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, ) -> DataFrame: result = self._calculate_differences(pos0, pos1, agents0, agents1) return self._df_norm(result, "distance", True) @@ -1227,7 +1293,10 @@ def get_neighbors( self, radius: int | Sequence[int], pos: GridCoordinate | GridCoordinates | None = None, - agents: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, include_center: bool = False, ) -> DataFrame: neighborhood_df = self.get_neighborhood( @@ -1243,7 +1312,10 @@ def get_neighborhood( self, radius: int | Sequence[int] | ArrayLike, pos: GridCoordinate | GridCoordinates | None = None, - agents: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, include_center: bool = False, ) -> DataFrame: pos_df = self._get_df_coords(pos, agents) @@ -1476,7 +1548,9 @@ def out_of_bounds(self, pos: GridCoordinate | GridCoordinates) -> DataFrame: def remove_agents( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -1485,7 +1559,7 @@ def remove_agents( if __debug__: # Check ids presence in model - b_contained = obj.model.agents.contains(agents) + b_contained = obj.model.sets.contains(agents) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -1519,8 +1593,14 @@ def _calculate_differences( self, pos0: GridCoordinate | GridCoordinates | None, pos1: GridCoordinate | GridCoordinates | None, - agents0: IdsLike | AgentContainer | Collection[AgentContainer] | None, - agents1: IdsLike | AgentContainer | Collection[AgentContainer] | None, + agents0: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None, + agents1: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None, ) -> DataFrame: """Calculate the differences between two positions or agents. @@ -1530,9 +1610,9 @@ def _calculate_differences( The starting positions pos1 : GridCoordinate | GridCoordinates | None The ending positions - agents0 : IdsLike | AgentContainer | Collection[AgentContainer] | None + agents0 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None The starting agents - agents1 : IdsLike | AgentContainer | Collection[AgentContainer] | None + agents1 : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None The ending agents Returns @@ -1613,7 +1693,10 @@ def _compute_offsets(self, neighborhood_type: str) -> DataFrame: def _get_df_coords( self, pos: GridCoordinate | GridCoordinates | None = None, - agents: IdsLike | AgentContainer | Collection[AgentContainer] | None = None, + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry] + | None = None, check_bounds: bool = True, ) -> DataFrame: """Get the DataFrame of coordinates from the specified positions or agents. @@ -1622,7 +1705,7 @@ def _get_df_coords( ---------- pos : GridCoordinate | GridCoordinates | None, optional The positions to get the DataFrame from, by default None - agents : IdsLike | AgentContainer | Collection[AgentContainer] | None, optional + agents : IdsLike | AbstractAgentSetRegistry | Collection[AbstractAgentSetRegistry] | None, optional The agents to get the DataFrame from, by default None check_bounds: bool, optional If the positions should be checked for out-of-bounds in non-toroidal grids, by default True @@ -1652,7 +1735,7 @@ def _get_df_coords( if agents is not None: agents = self._get_ids_srs(agents) # Check ids presence in model - b_contained = self.model.agents.contains(agents) + b_contained = self.model.sets.contains(agents) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): @@ -1712,7 +1795,9 @@ def _get_df_coords( def _place_or_move_agents( self, - agents: IdsLike | AgentContainer | Collection[AgentContainer], + agents: IdsLike + | AbstractAgentSetRegistry + | Collection[AbstractAgentSetRegistry], pos: GridCoordinate | GridCoordinates, is_move: bool, ) -> Self: @@ -1728,7 +1813,7 @@ def _place_or_move_agents( warn("Some agents are already present in the grid", RuntimeWarning) # Check if agents are present in the model - b_contained = self.model.agents.contains(agents) + b_contained = self.model.sets.contains(agents) if (isinstance(b_contained, Series) and not b_contained.all()) or ( isinstance(b_contained, bool) and not b_contained ): diff --git a/mesa_frames/concrete/__init__.py b/mesa_frames/concrete/__init__.py index ebccc9e8..069fcf4b 100644 --- a/mesa_frames/concrete/__init__.py +++ b/mesa_frames/concrete/__init__.py @@ -13,15 +13,15 @@ polars: Contains Polars-based implementations of agent sets, mixins, and spatial structures. Modules: - agents: Defines the AgentsDF class, a collection of AgentSetDFs. - model: Provides the ModelDF class, the base class for models in mesa-frames. - agentset: Defines the AgentSetPolars class, a Polars-based implementation of AgentSet. + agents: Defines the AgentSetRegistry class, a collection of AgentSets. + model: Provides the Model class, the base class for models in mesa-frames. + agentset: Defines the AgentSet class, a Polars-based implementation of AgentSet. mixin: Provides the PolarsMixin class, implementing DataFrame operations using Polars. - space: Contains the GridPolars class, a Polars-based implementation of Grid. + space: Contains the Grid class, a Polars-based implementation of Grid. Classes: from agentset: - AgentSetPolars(AgentSetDF, PolarsMixin): + AgentSet(AbstractAgentSet, PolarsMixin): A Polars-based implementation of the AgentSet, using Polars DataFrames for efficient agent storage and manipulation. @@ -30,43 +30,43 @@ A mixin class that implements DataFrame operations using Polars, providing methods for data manipulation and analysis. from space: - GridPolars(GridDF, PolarsMixin): + Grid(AbstractGrid, PolarsMixin): A Polars-based implementation of Grid, using Polars DataFrames for efficient spatial operations and agent positioning. From agents: - AgentsDF(AgentContainer): A collection of AgentSetDFs. All agents of the model are stored here. + AgentSetRegistry(AbstractAgentSetRegistry): A collection of AbstractAgentSets. All agents of the model are stored here. From model: - ModelDF: Base class for models in the mesa-frames library. + Model: Base class for models in the mesa-frames library. Usage: Users can import the concrete implementations directly from this package: - from mesa_frames.concrete import ModelDF, AgentsDF + from mesa_frames.concrete import Model, AgentSetRegistry # For Polars-based implementations - from mesa_frames.concrete import AgentSetPolars, GridPolars - from mesa_frames.concrete.model import ModelDF + from mesa_frames.concrete import AgentSet, Grid + from mesa_frames.concrete.model import Model - class MyModel(ModelDF): + class MyModel(Model): def __init__(self): super().__init__() - self.agents.add(AgentSetPolars(self)) - self.space = GridPolars(self, dimensions=[10, 10]) + self.sets.add(AgentSet(self)) + self.space = Grid(self, dimensions=[10, 10]) # ... other initialization code - from mesa_frames.concrete import AgentSetPolars, GridPolars + from mesa_frames.concrete import AgentSet, Grid - class MyAgents(AgentSetPolars): + class MyAgents(AgentSet): def __init__(self, model): super().__init__(model) # Initialize agents - class MyModel(ModelDF): + class MyModel(Model): def __init__(self, width, height): super().__init__() - self.agents = MyAgents(self) - self.grid = GridPolars(width, height, self) + self.sets = MyAgents(self) + self.grid = Grid(width, height, self) Features: - High-performance DataFrame operations using Polars - Efficient memory usage and fast computation diff --git a/mesa_frames/concrete/agentset.py b/mesa_frames/concrete/agentset.py index 81759b19..5c64aef6 100644 --- a/mesa_frames/concrete/agentset.py +++ b/mesa_frames/concrete/agentset.py @@ -2,29 +2,29 @@ Polars-based implementation of AgentSet for mesa-frames. This module provides a concrete implementation of the AgentSet class using Polars -as the backend for DataFrame operations. It defines the AgentSetPolars class, -which combines the abstract AgentSetDF functionality with Polars-specific +as the backend for DataFrame operations. It defines the AgentSet class, +which combines the abstract AbstractAgentSet functionality with Polars-specific operations for efficient agent management and manipulation. Classes: - AgentSetPolars(AgentSetDF, PolarsMixin): + AgentSet(AbstractAgentSet, PolarsMixin): A Polars-based implementation of the AgentSet. This class uses Polars DataFrames to store and manipulate agent data, providing high-performance operations for large numbers of agents. -The AgentSetPolars class is designed to be used within ModelDF instances or as -part of an AgentsDF collection. It leverages the power of Polars for fast and +The AgentSet class is designed to be used within Model instances or as +part of an AgentSetRegistry collection. It leverages the power of Polars for fast and efficient data operations on agent attributes and behaviors. Usage: - The AgentSetPolars class can be used directly in a model or as part of an - AgentsDF collection: + The AgentSet class can be used directly in a model or as part of an + AgentSetRegistry collection: - from mesa_frames.concrete.model import ModelDF - from mesa_frames.concrete.agentset import AgentSetPolars + from mesa_frames.concrete.model import Model + from mesa_frames.concrete.agentset import AgentSet import polars as pl - class MyAgents(AgentSetPolars): + class MyAgents(AgentSet): def __init__(self, model): super().__init__(model) # Initialize with some agents @@ -32,15 +32,15 @@ def __init__(self, model): def step(self): # Implement step behavior using Polars operations - self.agents = self.agents.with_columns(new_wealth = pl.col('wealth') + 1) + self.sets = self.sets.with_columns(new_wealth = pl.col('wealth') + 1) - class MyModel(ModelDF): + class MyModel(Model): def __init__(self): super().__init__() - self.agents += MyAgents(self) + self.sets += MyAgents(self) def step(self): - self.agents.step() + self.sets.step() Features: - Efficient storage and manipulation of large agent populations @@ -53,7 +53,7 @@ def step(self): is installed and imported. The performance characteristics of this class will depend on the Polars version and the specific operations used. -For more detailed information on the AgentSetPolars class and its methods, +For more detailed information on the AgentSet class and its methods, refer to the class docstring. """ @@ -65,16 +65,15 @@ def step(self): import numpy as np import polars as pl -from mesa_frames.concrete.agents import AgentSetDF +from mesa_frames.abstract.agentset import AbstractAgentSet from mesa_frames.concrete.mixin import PolarsMixin -from mesa_frames.concrete.model import ModelDF from mesa_frames.types_ import AgentPolarsMask, IntoExpr, PolarsIdsLike from mesa_frames.utils import copydoc -@copydoc(AgentSetDF) -class AgentSetPolars(AgentSetDF, PolarsMixin): - """Polars-based implementation of AgentSetDF.""" +@copydoc(AbstractAgentSet) +class AgentSet(AbstractAgentSet, PolarsMixin): + """Polars-based implementation of AgentSet.""" _df: pl.DataFrame _copy_with_method: dict[str, tuple[str, list[str]]] = { @@ -83,12 +82,12 @@ class AgentSetPolars(AgentSetDF, PolarsMixin): _copy_only_reference: list[str] = ["_model", "_mask"] _mask: pl.Expr | pl.Series - def __init__(self, model: mesa_frames.concrete.model.ModelDF) -> None: - """Initialize a new AgentSetPolars. + def __init__(self, model: mesa_frames.concrete.model.Model) -> None: + """Initialize a new AgentSet. Parameters ---------- - model : "mesa_frames.concrete.model.ModelDF" + model : "mesa_frames.concrete.model.Model" The model that the agent set belongs to. """ self._model = model @@ -101,7 +100,7 @@ def add( agents: pl.DataFrame | Sequence[Any] | dict[str, Any], inplace: bool = True, ) -> Self: - """Add agents to the AgentSetPolars. + """Add agents to the AgentSet. Parameters ---------- @@ -113,12 +112,12 @@ def add( Returns ------- Self - The updated AgentSetPolars. + The updated AgentSet. """ obj = self._get_obj(inplace) - if isinstance(agents, AgentSetDF): + if isinstance(agents, AbstractAgentSet): raise TypeError( - "AgentSetPolars.add() does not accept AgentSetDF objects. " + "AgentSet.add() does not accept AgentSet objects. " "Extract the DataFrame with agents.agents.drop('unique_id') first." ) elif isinstance(agents, pl.DataFrame): @@ -314,7 +313,7 @@ def _concatenate_agentsets( all_indices = pl.concat(indices_list) if all_indices.is_duplicated().any(): raise ValueError( - "Some ids are duplicated in the AgentSetDFs that are trying to be concatenated" + "Some ids are duplicated in the AgentSets that are trying to be concatenated" ) if duplicates_allowed & keep_first_only: # Find the original_index list (ie longest index list), to sort correctly the rows after concatenation diff --git a/mesa_frames/concrete/agents.py b/mesa_frames/concrete/agentsetregistry.py similarity index 68% rename from mesa_frames/concrete/agents.py rename to mesa_frames/concrete/agentsetregistry.py index 799a7b33..b9ed1563 100644 --- a/mesa_frames/concrete/agents.py +++ b/mesa_frames/concrete/agentsetregistry.py @@ -2,45 +2,45 @@ Concrete implementation of the agents collection for mesa-frames. This module provides the concrete implementation of the agents collection class -for the mesa-frames library. It defines the AgentsDF class, which serves as a +for the mesa-frames library. It defines the AgentSetRegistry class, which serves as a container for all agent sets in a model, leveraging DataFrame-based storage for improved performance. Classes: - AgentsDF(AgentContainer): - A collection of AgentSetDFs. This class acts as a container for all - agents in the model, organizing them into separate AgentSetDF instances + AgentSetRegistry(AbstractAgentSetRegistry): + A collection of AgentSets. This class acts as a container for all + agents in the model, organizing them into separate AgentSet instances based on their types. -The AgentsDF class is designed to be used within ModelDF instances to manage +The AgentSetRegistry class is designed to be used within Model instances to manage all agents in the simulation. It provides methods for adding, removing, and accessing agents and agent sets, while taking advantage of the performance benefits of DataFrame-based agent storage. Usage: - The AgentsDF class is typically instantiated and used within a ModelDF subclass: + The AgentSetRegistry class is typically instantiated and used within a Model subclass: - from mesa_frames.concrete.model import ModelDF - from mesa_frames.concrete.agents import AgentsDF - from mesa_frames.concrete import AgentSetPolars + from mesa_frames.concrete.model import Model + from mesa_frames.concrete.agents import AgentSetRegistry + from mesa_frames.concrete import AgentSet - class MyCustomModel(ModelDF): + class MyCustomModel(Model): def __init__(self): super().__init__() # Adding agent sets to the collection - self.agents += AgentSetPolars(self) - self.agents += AnotherAgentSetPolars(self) + self.sets += AgentSet(self) + self.sets += AnotherAgentSet(self) def step(self): # Step all agent sets - self.agents.do("step") + self.sets.do("step") Note: - This concrete implementation builds upon the abstract AgentContainer class + This concrete implementation builds upon the abstract AgentSetRegistry class defined in the mesa_frames.abstract package, providing a ready-to-use agents collection that integrates with the DataFrame-based agent storage system. -For more detailed information on the AgentsDF class and its methods, refer to +For more detailed information on the AgentSetRegistry class and its methods, refer to the class docstring. """ @@ -53,7 +53,10 @@ def step(self): import numpy as np import polars as pl -from mesa_frames.abstract.agents import AgentContainer, AgentSetDF +from mesa_frames.abstract.agentsetregistry import ( + AbstractAgentSetRegistry, +) +from mesa_frames.concrete.agentset import AgentSet from mesa_frames.types_ import ( AgentMask, AgnosticAgentMask, @@ -65,50 +68,54 @@ def step(self): ) -class AgentsDF(AgentContainer): - """A collection of AgentSetDFs. All agents of the model are stored here.""" +class AgentSetRegistry(AbstractAgentSetRegistry): + """A collection of AgentSets. All agents of the model are stored here.""" - _agentsets: list[AgentSetDF] + _agentsets: list[AgentSet] _ids: pl.Series - def __init__(self, model: mesa_frames.concrete.model.ModelDF) -> None: - """Initialize a new AgentsDF. + def __init__(self, model: mesa_frames.concrete.model.Model) -> None: + """Initialize a new AgentSetRegistry. Parameters ---------- - model : mesa_frames.concrete.model.ModelDF - The model associated with the AgentsDF. + model : mesa_frames.concrete.model.Model + The model associated with the AgentSetRegistry. """ self._model = model self._agentsets = [] self._ids = pl.Series(name="unique_id", dtype=pl.UInt64) def add( - self, agents: AgentSetDF | Iterable[AgentSetDF], inplace: bool = True + self, + agents: AgentSet | Iterable[AgentSet], + inplace: bool = True, ) -> Self: - """Add an AgentSetDF to the AgentsDF. + """Add an AgentSet to the AgentSetRegistry. Parameters ---------- - agents : AgentSetDF | Iterable[AgentSetDF] - The AgentSetDFs to add. + agents : AgentSet | Iterable[AgentSet] + The AgentSets to add. inplace : bool, optional - Whether to add the AgentSetDFs in place. Defaults to True. + Whether to add the AgentSets in place. Defaults to True. Returns ------- Self - The updated AgentsDF. + The updated AgentSetRegistry. Raises ------ ValueError - If any AgentSetDFs are already present or if IDs are not unique. + If any AgentSets are already present or if IDs are not unique. """ obj = self._get_obj(inplace) other_list = obj._return_agentsets_list(agents) if obj._check_agentsets_presence(other_list).any(): - raise ValueError("Some agentsets are already present in the AgentsDF.") + raise ValueError( + "Some agentsets are already present in the AgentSetRegistry." + ) new_ids = pl.concat( [obj._ids] + [pl.Series(agentset["unique_id"]) for agentset in other_list] ) @@ -119,23 +126,23 @@ def add( return obj @overload - def contains(self, agents: int | AgentSetDF) -> bool: ... + def contains(self, agents: int | AgentSet) -> bool: ... @overload - def contains(self, agents: IdsLike | Iterable[AgentSetDF]) -> pl.Series: ... + def contains(self, agents: IdsLike | Iterable[AgentSet]) -> pl.Series: ... def contains( - self, agents: IdsLike | AgentSetDF | Iterable[AgentSetDF] + self, agents: IdsLike | AgentSet | Iterable[AgentSet] ) -> bool | pl.Series: if isinstance(agents, int): return agents in self._ids - elif isinstance(agents, AgentSetDF): + elif isinstance(agents, AgentSet): return self._check_agentsets_presence([agents]).any() elif isinstance(agents, Iterable): if len(agents) == 0: return True - elif isinstance(next(iter(agents)), AgentSetDF): - agents = cast(Iterable[AgentSetDF], agents) + elif isinstance(next(iter(agents)), AgentSet): + agents = cast(Iterable[AgentSet], agents) return self._check_agentsets_presence(list(agents)) else: # IdsLike agents = cast(IdsLike, agents) @@ -147,7 +154,7 @@ def do( self, method_name: str, *args, - mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, return_results: Literal[False] = False, inplace: bool = True, **kwargs, @@ -158,17 +165,17 @@ def do( self, method_name: str, *args, - mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, return_results: Literal[True], inplace: bool = True, **kwargs, - ) -> dict[AgentSetDF, Any]: ... + ) -> dict[AgentSet, Any]: ... def do( self, method_name: str, *args, - mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, return_results: bool = False, inplace: bool = True, **kwargs, @@ -204,8 +211,8 @@ def do( def get( self, attr_names: str | Collection[str] | None = None, - mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, - ) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]: + mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, + ) -> dict[AgentSet, Series] | dict[AgentSet, DataFrame]: agentsets_masks = self._get_bool_masks(mask) result = {} @@ -232,16 +239,16 @@ def get( def remove( self, - agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike, + agents: AgentSet | Iterable[AgentSet] | IdsLike, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) if agents is None or (isinstance(agents, Iterable) and len(agents) == 0): return obj - if isinstance(agents, AgentSetDF): + if isinstance(agents, AgentSet): agents = [agents] - if isinstance(agents, Iterable) and isinstance(next(iter(agents)), AgentSetDF): - # We have to get the index of the original AgentSetDF because the copy made AgentSetDFs with different hash + if isinstance(agents, Iterable) and isinstance(next(iter(agents)), AgentSet): + # We have to get the index of the original AgentSet because the copy made AgentSets with different hash ids = [self._agentsets.index(agentset) for agentset in iter(agents)] ids.sort(reverse=True) removed_ids = pl.Series(dtype=pl.UInt64) @@ -281,8 +288,8 @@ def remove( def select( self, - mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, - filter_func: Callable[[AgentSetDF], AgentMask] | None = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, + filter_func: Callable[[AgentSet], AgentMask] | None = None, n: int | None = None, inplace: bool = True, negate: bool = False, @@ -301,9 +308,9 @@ def select( def set( self, - attr_names: str | dict[AgentSetDF, Any] | Collection[str], + attr_names: str | dict[AgentSet, Any] | Collection[str], values: Any | None = None, - mask: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] = None, + mask: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] = None, inplace: bool = True, ) -> Self: obj = self._get_obj(inplace) @@ -311,7 +318,7 @@ def set( if isinstance(attr_names, dict): for agentset, values in attr_names.items(): if not inplace: - # We have to get the index of the original AgentSetDF because the copy made AgentSetDFs with different hash + # We have to get the index of the original AgentSet because the copy made AgentSets with different hash id = self._agentsets.index(agentset) agentset = obj._agentsets[id] agentset.set( @@ -346,12 +353,12 @@ def sort( return obj def step(self, inplace: bool = True) -> Self: - """Advance the state of the agents in the AgentsDF by one step. + """Advance the state of the agents in the AgentSetRegistry by one step. Parameters ---------- inplace : bool, optional - Whether to update the AgentsDF in place, by default True + Whether to update the AgentSetRegistry in place, by default True Returns ------- @@ -362,13 +369,13 @@ def step(self, inplace: bool = True) -> Self: agentset.step() return obj - def _check_ids_presence(self, other: list[AgentSetDF]) -> pl.DataFrame: + def _check_ids_presence(self, other: list[AgentSet]) -> pl.DataFrame: """Check if the IDs of the agents to be added are unique. Parameters ---------- - other : list[AgentSetDF] - The AgentSetDFs to check. + other : list[AgentSet] + The AgentSets to check. Returns ------- @@ -395,13 +402,13 @@ def _check_ids_presence(self, other: list[AgentSetDF]) -> pl.DataFrame: presence_df = presence_df.slice(self._ids.len()) return presence_df - def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series: - """Check if the agent sets to be added are already present in the AgentsDF. + def _check_agentsets_presence(self, other: list[AgentSet]) -> pl.Series: + """Check if the agent sets to be added are already present in the AgentSetRegistry. Parameters ---------- - other : list[AgentSetDF] - The AgentSetDFs to check. + other : list[AgentSet] + The AgentSets to check. Returns ------- @@ -411,7 +418,7 @@ def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series: Raises ------ ValueError - If the agent sets are already present in the AgentsDF. + If the agent sets are already present in the AgentSetRegistry. """ other_set = set(other) return pl.Series( @@ -420,8 +427,8 @@ def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series: def _get_bool_masks( self, - mask: (AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask]) = None, - ) -> dict[AgentSetDF, BoolSeries]: + mask: (AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask]) = None, + ) -> dict[AgentSet, BoolSeries]: return_dictionary = {} if not isinstance(mask, dict): # No need to convert numpy integers - let polars handle them directly @@ -431,36 +438,36 @@ def _get_bool_masks( return return_dictionary def _return_agentsets_list( - self, agentsets: AgentSetDF | Iterable[AgentSetDF] - ) -> list[AgentSetDF]: - """Convert the agentsets to a list of AgentSetDF. + self, agentsets: AgentSet | Iterable[AgentSet] + ) -> list[AgentSet]: + """Convert the agentsets to a list of AgentSet. Parameters ---------- - agentsets : AgentSetDF | Iterable[AgentSetDF] + agentsets : AgentSet | Iterable[AgentSet] Returns ------- - list[AgentSetDF] + list[AgentSet] """ - return [agentsets] if isinstance(agentsets, AgentSetDF) else list(agentsets) + return [agentsets] if isinstance(agentsets, AgentSet) else list(agentsets) - def __add__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self: - """Add AgentSetDFs to a new AgentsDF through the + operator. + def __add__(self, other: AgentSet | Iterable[AgentSet]) -> Self: + """Add AgentSets to a new AgentSetRegistry through the + operator. Parameters ---------- - other : AgentSetDF | Iterable[AgentSetDF] - The AgentSetDFs to add. + other : AgentSet | Iterable[AgentSet] + The AgentSets to add. Returns ------- Self - A new AgentsDF with the added AgentSetDFs. + A new AgentSetRegistry with the added AgentSets. """ return super().__add__(other) - def __getattr__(self, name: str) -> dict[AgentSetDF, Any]: + def __getattr__(self, name: str) -> dict[AgentSet, Any]: # Avoids infinite recursion of private attributes if __debug__: # Only execute in non-optimized mode if name.startswith("_"): @@ -471,8 +478,8 @@ def __getattr__(self, name: str) -> dict[AgentSetDF, Any]: @overload def __getitem__( - self, key: str | tuple[dict[AgentSetDF, AgentMask], str] - ) -> dict[AgentSetDF, Series | pl.Expr]: ... + self, key: str | tuple[dict[AgentSet, AgentMask], str] + ) -> dict[AgentSet, Series | pl.Expr]: ... @overload def __getitem__( @@ -481,9 +488,9 @@ def __getitem__( Collection[str] | AgnosticAgentMask | IdsLike - | tuple[dict[AgentSetDF, AgentMask], Collection[str]] + | tuple[dict[AgentSet, AgentMask], Collection[str]] ), - ) -> dict[AgentSetDF, DataFrame]: ... + ) -> dict[AgentSet, DataFrame]: ... def __getitem__( self, @@ -492,42 +499,42 @@ def __getitem__( | Collection[str] | AgnosticAgentMask | IdsLike - | tuple[dict[AgentSetDF, AgentMask], str] - | tuple[dict[AgentSetDF, AgentMask], Collection[str]] + | tuple[dict[AgentSet, AgentMask], str] + | tuple[dict[AgentSet, AgentMask], Collection[str]] ), - ) -> dict[AgentSetDF, Series | pl.Expr] | dict[AgentSetDF, DataFrame]: + ) -> dict[AgentSet, Series | pl.Expr] | dict[AgentSet, DataFrame]: return super().__getitem__(key) - def __iadd__(self, agents: AgentSetDF | Iterable[AgentSetDF]) -> Self: - """Add AgentSetDFs to the AgentsDF through the += operator. + def __iadd__(self, agents: AgentSet | Iterable[AgentSet]) -> Self: + """Add AgentSets to the AgentSetRegistry through the += operator. Parameters ---------- - agents : AgentSetDF | Iterable[AgentSetDF] - The AgentSetDFs to add. + agents : AgentSet | Iterable[AgentSet] + The AgentSets to add. Returns ------- Self - The updated AgentsDF. + The updated AgentSetRegistry. """ return super().__iadd__(agents) def __iter__(self) -> Iterator[dict[str, Any]]: return (agent for agentset in self._agentsets for agent in iter(agentset)) - def __isub__(self, agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike) -> Self: - """Remove AgentSetDFs from the AgentsDF through the -= operator. + def __isub__(self, agents: AgentSet | Iterable[AgentSet] | IdsLike) -> Self: + """Remove AgentSets from the AgentSetRegistry through the -= operator. Parameters ---------- - agents : AgentSetDF | Iterable[AgentSetDF] | IdsLike - The AgentSetDFs or agent IDs to remove. + agents : AgentSet | Iterable[AgentSet] | IdsLike + The AgentSets or agent IDs to remove. Returns ------- Self - The updated AgentsDF. + The updated AgentSetRegistry. """ return super().__isub__(agents) @@ -551,8 +558,8 @@ def __setitem__( | Collection[str] | AgnosticAgentMask | IdsLike - | tuple[dict[AgentSetDF, AgentMask], str] - | tuple[dict[AgentSetDF, AgentMask], Collection[str]] + | tuple[dict[AgentSet, AgentMask], str] + | tuple[dict[AgentSet, AgentMask], Collection[str]] ), values: Any, ) -> None: @@ -561,54 +568,54 @@ def __setitem__( def __str__(self) -> str: return "\n".join([str(agentset) for agentset in self._agentsets]) - def __sub__(self, agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike) -> Self: - """Remove AgentSetDFs from a new AgentsDF through the - operator. + def __sub__(self, agents: AgentSet | Iterable[AgentSet] | IdsLike) -> Self: + """Remove AgentSets from a new AgentSetRegistry through the - operator. Parameters ---------- - agents : AgentSetDF | Iterable[AgentSetDF] | IdsLike - The AgentSetDFs or agent IDs to remove. Supports NumPy integer types. + agents : AgentSet | Iterable[AgentSet] | IdsLike + The AgentSets or agent IDs to remove. Supports NumPy integer types. Returns ------- Self - A new AgentsDF with the removed AgentSetDFs. + A new AgentSetRegistry with the removed AgentSets. """ return super().__sub__(agents) @property - def df(self) -> dict[AgentSetDF, DataFrame]: + def df(self) -> dict[AgentSet, DataFrame]: return {agentset: agentset.df for agentset in self._agentsets} @df.setter - def df(self, other: Iterable[AgentSetDF]) -> None: - """Set the agents in the AgentsDF. + def df(self, other: Iterable[AgentSet]) -> None: + """Set the agents in the AgentSetRegistry. Parameters ---------- - other : Iterable[AgentSetDF] - The AgentSetDFs to set. + other : Iterable[AgentSet] + The AgentSets to set. """ self._agentsets = list(other) @property - def active_agents(self) -> dict[AgentSetDF, DataFrame]: + def active_agents(self) -> dict[AgentSet, DataFrame]: return {agentset: agentset.active_agents for agentset in self._agentsets} @active_agents.setter def active_agents( - self, agents: AgnosticAgentMask | IdsLike | dict[AgentSetDF, AgentMask] + self, agents: AgnosticAgentMask | IdsLike | dict[AgentSet, AgentMask] ) -> None: self.select(agents, inplace=True) @property - def agentsets_by_type(self) -> dict[type[AgentSetDF], Self]: - """Get the agent sets in the AgentsDF grouped by type. + def agentsets_by_type(self) -> dict[type[AgentSet], Self]: + """Get the agent sets in the AgentSetRegistry grouped by type. Returns ------- - dict[type[AgentSetDF], Self] - A dictionary mapping agent set types to the corresponding AgentsDF. + dict[type[AgentSet], Self] + A dictionary mapping agent set types to the corresponding AgentSetRegistry. """ def copy_without_agentsets() -> Self: @@ -624,13 +631,13 @@ def copy_without_agentsets() -> Self: return dictionary @property - def inactive_agents(self) -> dict[AgentSetDF, DataFrame]: + def inactive_agents(self) -> dict[AgentSet, DataFrame]: return {agentset: agentset.inactive_agents for agentset in self._agentsets} @property - def index(self) -> dict[AgentSetDF, Index]: + def index(self) -> dict[AgentSet, Index]: return {agentset: agentset.index for agentset in self._agentsets} @property - def pos(self) -> dict[AgentSetDF, DataFrame]: + def pos(self) -> dict[AgentSet, DataFrame]: return {agentset: agentset.pos for agentset in self._agentsets} diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 02e40423..2b50c76d 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -26,18 +26,18 @@ If true, data is collected during `conditional_collect()`. Usage: - The `DataCollector` class is designed to be used within a `ModelDF` instance + The `DataCollector` class is designed to be used within a `Model` instance to collect model-level and/or agent-level data. Example: -------- - from mesa_frames.concrete.model import ModelDF + from mesa_frames.concrete.model import Model from mesa_frames.concrete.datacollector import DataCollector - class ExampleModel(ModelDF): - def __init__(self, agents: AgentsDF): + class ExampleModel(Model): + def __init__(self, agents: AgentSetRegistry): super().__init__() - self.agents = agents + self.sets = agents self.dc = DataCollector( model=self, # other required arguments @@ -62,14 +62,14 @@ def step(self): from mesa_frames.abstract.datacollector import AbstractDataCollector from typing import Any, Literal from collections.abc import Callable -from mesa_frames import ModelDF +from mesa_frames import Model from psycopg2.extensions import connection class DataCollector(AbstractDataCollector): def __init__( self, - model: ModelDF, + model: Model, model_reporters: dict[str, Callable] | None = None, agent_reporters: dict[str, str | Callable] | None = None, trigger: Callable[[Any], bool] | None = None, @@ -86,7 +86,7 @@ def __init__( Parameters ---------- - model : ModelDF + model : Model The model object from which data is collected. model_reporters : dict[str, Callable] | None Functions to collect data at the model level. @@ -180,7 +180,7 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int): agent_data_dict = {} for col_name, reporter in self._agent_reporters.items(): if isinstance(reporter, str): - for k, v in self._model.agents[reporter].items(): + for k, v in self._model.sets[reporter].items(): agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v else: agent_data_dict[col_name] = reporter(self._model) @@ -463,7 +463,7 @@ def _validate_reporter_table_columns( expected_columns = set() for col_name, required_column in reporter.items(): if isinstance(required_column, str): - for k, v in self._model.agents[required_column].items(): + for k, v in self._model.sets[required_column].items(): expected_columns.add( (col_name + "_" + str(k.__class__.__name__)).lower() ) diff --git a/mesa_frames/concrete/mixin.py b/mesa_frames/concrete/mixin.py index eba00ae6..4900536e 100644 --- a/mesa_frames/concrete/mixin.py +++ b/mesa_frames/concrete/mixin.py @@ -10,7 +10,7 @@ PolarsMixin(DataFrameMixin): A Polars-based implementation of DataFrame operations. This class provides methods for manipulating and analyzing data stored in Polars DataFrames, - tailored for use in mesa-frames components like AgentSetPolars and GridPolars. + tailored for use in mesa-frames components like AgentSet and Grid. The PolarsMixin class is designed to be used as a mixin with other mesa-frames classes, providing them with Polars-specific DataFrame functionality. It implements @@ -20,17 +20,17 @@ Usage: The PolarsMixin is typically used in combination with other base classes: - from mesa_frames.abstract import AgentSetDF + from mesa_frames.abstract import AbstractAgentSet from mesa_frames.concrete.mixin import PolarsMixin - class AgentSetPolars(AgentSetDF, PolarsMixin): + class AgentSet(AgentSet, PolarsMixin): def __init__(self, model): super().__init__(model) - self.agents = pl.DataFrame() # Initialize empty DataFrame + self.sets = pl.DataFrame() # Initialize empty DataFrame def some_method(self): # Use Polars operations provided by the mixin - result = self._df_groupby(self.agents, 'some_column') + result = self._df_groupby(self.sets, 'some_column') # ... further processing ... Features: diff --git a/mesa_frames/concrete/model.py b/mesa_frames/concrete/model.py index befc1812..a10ce240 100644 --- a/mesa_frames/concrete/model.py +++ b/mesa_frames/concrete/model.py @@ -2,31 +2,31 @@ Concrete implementation of the model class for mesa-frames. This module provides the concrete implementation of the base model class for -the mesa-frames library. It defines the ModelDF class, which serves as the +the mesa-frames library. It defines the Model class, which serves as the foundation for creating agent-based models using DataFrame-based agent storage. Classes: - ModelDF: + Model: The base class for models in the mesa-frames library. This class provides the core functionality for initializing and running agent-based simulations using DataFrame-backed agent sets. -The ModelDF class is designed to be subclassed by users to create specific +The Model class is designed to be subclassed by users to create specific model implementations. It provides the basic structure and methods necessary for setting up and running simulations, while leveraging the performance benefits of DataFrame-based agent storage. Usage: - To create a custom model, subclass ModelDF and implement the necessary + To create a custom model, subclass Model and implement the necessary methods: - from mesa_frames.concrete.model import ModelDF - from mesa_frames.concrete.agents import AgentSetPolars + from mesa_frames.concrete.model import Model + from mesa_frames.concrete.agentset import AgentSet - class MyCustomModel(ModelDF): + class MyCustomModel(Model): def __init__(self, num_agents): super().__init__() - self.agents += AgentSetPolars(self) + self.sets += AgentSet(self) # Initialize your model-specific attributes and agent sets def run_model(self): @@ -36,7 +36,7 @@ def run_model(self): # Add any other custom methods for your model -For more detailed information on the ModelDF class and its methods, refer to +For more detailed information on the Model class and its methods, refer to the class docstring. """ @@ -46,12 +46,12 @@ def run_model(self): import numpy as np -from mesa_frames.abstract.agents import AgentSetDF -from mesa_frames.abstract.space import SpaceDF -from mesa_frames.concrete.agents import AgentsDF +from mesa_frames.concrete.agentset import AgentSet +from mesa_frames.abstract.space import Space +from mesa_frames.concrete.agentsetregistry import AgentSetRegistry -class ModelDF: +class Model: """Base class for models in the mesa-frames library. This class serves as a foundational structure for creating agent-based models. @@ -63,8 +63,8 @@ class ModelDF: random: np.random.Generator running: bool _seed: int | Sequence[int] - _agents: AgentsDF # Where the agents are stored - _space: SpaceDF | None # This will be a MultiSpaceDF object + _sets: AgentSetRegistry # Where the agent sets are stored + _space: Space | None # This will be a MultiSpaceDF object def __init__(self, seed: int | Sequence[int] | None = None) -> None: """Create a new model. @@ -82,7 +82,7 @@ def __init__(self, seed: int | Sequence[int] | None = None) -> None: self.reset_randomizer(seed) self.running = True self.current_id = 0 - self._agents = AgentsDF(self) + self._sets = AgentSetRegistry(self) self._space = None self._steps = 0 @@ -99,23 +99,23 @@ def steps(self) -> int: """Get the current step count.""" return self._steps - def get_agents_of_type(self, agent_type: type) -> AgentSetDF: - """Retrieve the AgentSetDF of a specified type. + def get_sets_of_type(self, agent_type: type) -> AgentSet: + """Retrieve the AgentSet of a specified type. Parameters ---------- agent_type : type - The type of AgentSetDF to retrieve. + The type of AgentSet to retrieve. Returns ------- - AgentSetDF - The AgentSetDF of the specified type. + AgentSet + The AgentSet of the specified type. """ - for agentset in self._agents._agentsets: + for agentset in self._sets._agentsets: if isinstance(agentset, agent_type): return agentset - raise ValueError(f"No agents of type {agent_type} found in the model.") + raise ValueError(f"No agent sets of type {agent_type} found in the model.") def reset_randomizer(self, seed: int | Sequence[int] | None) -> None: """Reset the model random number generator. @@ -144,7 +144,7 @@ def step(self) -> None: The default method calls the step() method of all agents. Overload as needed. """ - self.agents.step() + self.sets.step() @property def steps(self) -> int: @@ -158,13 +158,13 @@ def steps(self) -> int: return self._steps @property - def agents(self) -> AgentsDF: - """Get the AgentsDF object containing all agents in the model. + def sets(self) -> AgentSetRegistry: + """Get the AgentSetRegistry object containing all agent sets in the model. Returns ------- - AgentsDF - The AgentsDF object containing all agents in the model. + AgentSetRegistry + The AgentSetRegistry object containing all agent sets in the model. Raises ------ @@ -172,39 +172,39 @@ def agents(self) -> AgentsDF: If the model has not been initialized properly with super().__init__(). """ try: - return self._agents + return self._sets except AttributeError: if __debug__: # Only execute in non-optimized mode raise RuntimeError( "You haven't called super().__init__() in your model. Make sure to call it in your __init__ method." ) - @agents.setter - def agents(self, agents: AgentsDF) -> None: + @sets.setter + def sets(self, sets: AgentSetRegistry) -> None: if __debug__: # Only execute in non-optimized mode - if not isinstance(agents, AgentsDF): - raise TypeError("agents must be an instance of AgentsDF") + if not isinstance(sets, AgentSetRegistry): + raise TypeError("sets must be an instance of AgentSetRegistry") - self._agents = agents + self._sets = sets @property - def agent_types(self) -> list[type]: - """Get a list of different agent types present in the model. + def set_types(self) -> list[type]: + """Get a list of different agent set types present in the model. Returns ------- list[type] - A list of the different agent types present in the model. + A list of the different agent set types present in the model. """ - return [agent.__class__ for agent in self._agents._agentsets] + return [agent.__class__ for agent in self._sets._agentsets] @property - def space(self) -> SpaceDF: + def space(self) -> Space: """Get the space object associated with the model. Returns ------- - SpaceDF + Space The space object associated with the model. Raises @@ -219,11 +219,11 @@ def space(self) -> SpaceDF: return self._space @space.setter - def space(self, space: SpaceDF) -> None: + def space(self, space: Space) -> None: """Set the space of the model. Parameters ---------- - space : SpaceDF + space : Space """ self._space = space diff --git a/mesa_frames/concrete/space.py b/mesa_frames/concrete/space.py index 55a00589..4f55a680 100644 --- a/mesa_frames/concrete/space.py +++ b/mesa_frames/concrete/space.py @@ -2,43 +2,43 @@ Polars-based implementation of spatial structures for mesa-frames. This module provides concrete implementations of spatial structures using Polars -as the backend for DataFrame operations. It defines the GridPolars class, which +as the backend for DataFrame operations. It defines the Grid class, which implements a 2D grid structure using Polars DataFrames for efficient spatial operations and agent positioning. Classes: - GridPolars(GridDF, PolarsMixin): + Grid(AbstractGrid, PolarsMixin): A Polars-based implementation of a 2D grid. This class uses Polars DataFrames to store and manipulate spatial data, providing high-performance operations for large-scale spatial simulations. -The GridPolars class is designed to be used within ModelDF instances to represent +The Grid class is designed to be used within Model instances to represent the spatial environment of the simulation. It leverages the power of Polars for fast and efficient data operations on spatial attributes and agent positions. Usage: - The GridPolars class can be used directly in a model to represent the + The Grid class can be used directly in a model to represent the spatial environment: - from mesa_frames.concrete.model import ModelDF - from mesa_frames.concrete.space import GridPolars - from mesa_frames.concrete.agentset import AgentSetPolars + from mesa_frames.concrete.model import Model + from mesa_frames.concrete.space import Grid + from mesa_frames.concrete.agentset import AgentSet - class MyAgents(AgentSetPolars): + class MyAgents(AgentSet): # ... agent implementation ... - class MyModel(ModelDF): + class MyModel(Model): def __init__(self, width, height): super().__init__() - self.space = GridPolars(self, [width, height]) - self.agents += MyAgents(self) + self.space = Grid(self, [width, height]) + self.sets += MyAgents(self) def step(self): # Move agents - self.space.move_agents(self.agents) + self.space.move_agents(self.sets) # ... other model logic ... -For more detailed information on the GridPolars class and its methods, +For more detailed information on the Grid class and its methods, refer to the class docstring. """ @@ -49,15 +49,15 @@ def step(self): import numpy as np import polars as pl -from mesa_frames.abstract.space import GridDF +from mesa_frames.abstract.space import AbstractGrid from mesa_frames.concrete.mixin import PolarsMixin from mesa_frames.types_ import Infinity from mesa_frames.utils import copydoc -@copydoc(GridDF) -class GridPolars(GridDF, PolarsMixin): - """Polars-based implementation of GridDF.""" +@copydoc(AbstractGrid) +class Grid(AbstractGrid, PolarsMixin): + """Polars-based implementation of AbstractGrid.""" _agents: pl.DataFrame _copy_with_method: dict[str, tuple[str, list[str]]] = { diff --git a/tests/test_agents.py b/tests/test_agents.py index 414bb632..f43d94f6 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -3,34 +3,34 @@ import polars as pl import pytest -from mesa_frames import AgentsDF, ModelDF -from mesa_frames.abstract.agents import AgentSetDF +from mesa_frames import AgentSetRegistry, Model +from mesa_frames import AgentSet from mesa_frames.types_ import AgentMask from tests.test_agentset import ( - ExampleAgentSetPolars, - ExampleAgentSetPolarsNoWealth, - fix1_AgentSetPolars_no_wealth, - fix1_AgentSetPolars, - fix2_AgentSetPolars, - fix3_AgentSetPolars, + ExampleAgentSet, + ExampleAgentSetNoWealth, + fix1_AgentSet_no_wealth, + fix1_AgentSet, + fix2_AgentSet, + fix3_AgentSet, ) @pytest.fixture -def fix_AgentsDF( - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, -) -> AgentsDF: - model = ModelDF() - agents = AgentsDF(model) - agents.add([fix1_AgentSetPolars, fix2_AgentSetPolars]) +def fix_AgentSetRegistry( + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, +) -> AgentSetRegistry: + model = Model() + agents = AgentSetRegistry(model) + agents.add([fix1_AgentSet, fix2_AgentSet]) return agents -class Test_AgentsDF: +class Test_AgentSetRegistry: def test___init__(self): - model = ModelDF() - agents = AgentsDF(model) + model = Model() + agents = AgentSetRegistry(model) assert agents.model == model assert isinstance(agents._agentsets, list) assert len(agents._agentsets) == 0 @@ -40,20 +40,20 @@ def test___init__(self): def test_add( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - model = ModelDF() - agents = AgentsDF(model) - agentset_polars1 = fix1_AgentSetPolars - agentset_polars2 = fix2_AgentSetPolars + model = Model() + agents = AgentSetRegistry(model) + agentset_polars1 = fix1_AgentSet + agentset_polars2 = fix2_AgentSet - # Test with a single AgentSetPolars + # Test with a single AgentSet result = agents.add(agentset_polars1, inplace=False) assert result._agentsets[0] is agentset_polars1 assert result._ids.to_list() == agentset_polars1._df["unique_id"].to_list() - # Test with a list of AgentSetDFs + # Test with a list of AgentSets result = agents.add([agentset_polars1, agentset_polars2], inplace=True) assert result._agentsets[0] is agentset_polars1 assert result._agentsets[1] is agentset_polars2 @@ -63,30 +63,30 @@ def test_add( + agentset_polars2._df["unique_id"].to_list() ) - # Test if adding the same AgentSetDF raises ValueError + # Test if adding the same AgentSet raises ValueError with pytest.raises(ValueError): agents.add(agentset_polars1, inplace=False) def test_contains( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, - fix3_AgentSetPolars: ExampleAgentSetPolars, - fix_AgentsDF: AgentsDF, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, + fix3_AgentSet: ExampleAgentSet, + fix_AgentSetRegistry: AgentSetRegistry, ): - agents = fix_AgentsDF + agents = fix_AgentSetRegistry agentset_polars1 = agents._agentsets[0] - # Test with an AgentSetDF + # Test with an AgentSet assert agents.contains(agentset_polars1) - assert agents.contains(fix1_AgentSetPolars) - assert agents.contains(fix2_AgentSetPolars) + assert agents.contains(fix1_AgentSet) + assert agents.contains(fix2_AgentSet) - # Test with an AgentSetDF not present - assert not agents.contains(fix3_AgentSetPolars) + # Test with an AgentSet not present + assert not agents.contains(fix3_AgentSet) - # Test with an iterable of AgentSetDFs - assert agents.contains([agentset_polars1, fix3_AgentSetPolars]).to_list() == [ + # Test with an iterable of AgentSets + assert agents.contains([agentset_polars1, fix3_AgentSet]).to_list() == [ True, False, ] @@ -100,8 +100,8 @@ def test_contains( False, ] - def test_copy(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_copy(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry agents.test_list = [[1, 2, 3]] # Test with deep=False @@ -113,7 +113,7 @@ def test_copy(self, fix_AgentsDF: AgentsDF): assert (agents._ids == agents2._ids).all() # Test with deep=True - agents2 = fix_AgentsDF.copy(deep=True) + agents2 = fix_AgentSetRegistry.copy(deep=True) agents2.test_list[0].append(4) assert agents.test_list[-1] != agents2.test_list[-1] assert agents.model == agents2.model @@ -121,16 +121,16 @@ def test_copy(self, fix_AgentsDF: AgentsDF): assert (agents._ids == agents2._ids).all() def test_discard( - self, fix_AgentsDF: AgentsDF, fix2_AgentSetPolars: ExampleAgentSetPolars + self, fix_AgentSetRegistry: AgentSetRegistry, fix2_AgentSet: ExampleAgentSet ): - agents = fix_AgentsDF - # Test with a single AgentSetDF + agents = fix_AgentSetRegistry + # Test with a single AgentSet agentset_polars2 = agents._agentsets[1] result = agents.discard(agents._agentsets[0], inplace=False) - assert isinstance(result._agentsets[0], ExampleAgentSetPolars) + assert isinstance(result._agentsets[0], ExampleAgentSet) assert len(result._agentsets) == 1 - # Test with a list of AgentSetDFs + # Test with a list of AgentSets result = agents.discard(agents._agentsets.copy(), inplace=False) assert len(result._agentsets) == 0 @@ -151,15 +151,15 @@ def test_discard( == agentset_polars2._df["unique_id"][1] ) - # Test if removing an AgentSetDF not present raises ValueError - result = agents.discard(fix2_AgentSetPolars, inplace=False) + # Test if removing an AgentSet not present raises ValueError + result = agents.discard(fix2_AgentSet, inplace=False) # Test if removing an ID not present raises KeyError assert 0 not in agents._ids result = agents.discard(0, inplace=False) - def test_do(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_do(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry expected_result_0 = agents._agentsets[0].df["wealth"] expected_result_0 += 1 @@ -212,88 +212,84 @@ def test_do(self, fix_AgentsDF: AgentsDF): def test_get( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, - fix1_AgentSetPolars_no_wealth: ExampleAgentSetPolarsNoWealth, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, + fix1_AgentSet_no_wealth: ExampleAgentSetNoWealth, ): - agents = fix_AgentsDF + agents = fix_AgentSetRegistry # Test with a single attribute assert ( - agents.get("wealth")[fix1_AgentSetPolars].to_list() - == fix1_AgentSetPolars._df["wealth"].to_list() + agents.get("wealth")[fix1_AgentSet].to_list() + == fix1_AgentSet._df["wealth"].to_list() ) assert ( - agents.get("wealth")[fix2_AgentSetPolars].to_list() - == fix2_AgentSetPolars._df["wealth"].to_list() + agents.get("wealth")[fix2_AgentSet].to_list() + == fix2_AgentSet._df["wealth"].to_list() ) # Test with a list of attributes result = agents.get(["wealth", "age"]) - assert result[fix1_AgentSetPolars].columns == ["wealth", "age"] + assert result[fix1_AgentSet].columns == ["wealth", "age"] assert ( - result[fix1_AgentSetPolars]["wealth"].to_list() - == fix1_AgentSetPolars._df["wealth"].to_list() + result[fix1_AgentSet]["wealth"].to_list() + == fix1_AgentSet._df["wealth"].to_list() ) assert ( - result[fix1_AgentSetPolars]["age"].to_list() - == fix1_AgentSetPolars._df["age"].to_list() + result[fix1_AgentSet]["age"].to_list() == fix1_AgentSet._df["age"].to_list() ) - assert result[fix2_AgentSetPolars].columns == ["wealth", "age"] + assert result[fix2_AgentSet].columns == ["wealth", "age"] assert ( - result[fix2_AgentSetPolars]["wealth"].to_list() - == fix2_AgentSetPolars._df["wealth"].to_list() + result[fix2_AgentSet]["wealth"].to_list() + == fix2_AgentSet._df["wealth"].to_list() ) assert ( - result[fix2_AgentSetPolars]["age"].to_list() - == fix2_AgentSetPolars._df["age"].to_list() + result[fix2_AgentSet]["age"].to_list() == fix2_AgentSet._df["age"].to_list() ) # Test with a single attribute and a mask - mask0 = fix1_AgentSetPolars._df["wealth"] > fix1_AgentSetPolars._df["wealth"][0] - mask1 = fix2_AgentSetPolars._df["wealth"] > fix2_AgentSetPolars._df["wealth"][0] - mask_dictionary = {fix1_AgentSetPolars: mask0, fix2_AgentSetPolars: mask1} + mask0 = fix1_AgentSet._df["wealth"] > fix1_AgentSet._df["wealth"][0] + mask1 = fix2_AgentSet._df["wealth"] > fix2_AgentSet._df["wealth"][0] + mask_dictionary = {fix1_AgentSet: mask0, fix2_AgentSet: mask1} result = agents.get("wealth", mask=mask_dictionary) assert ( - result[fix1_AgentSetPolars].to_list() - == fix1_AgentSetPolars._df["wealth"].to_list()[1:] + result[fix1_AgentSet].to_list() == fix1_AgentSet._df["wealth"].to_list()[1:] ) assert ( - result[fix2_AgentSetPolars].to_list() - == fix2_AgentSetPolars._df["wealth"].to_list()[1:] + result[fix2_AgentSet].to_list() == fix2_AgentSet._df["wealth"].to_list()[1:] ) # Test heterogeneous agent sets (different columns) # This tests the fix for the bug where agents_df["column"] would raise # ColumnNotFoundError when some agent sets didn't have that column. - # Create a new AgentsDF with heterogeneous agent sets - model = ModelDF() - hetero_agents = AgentsDF(model) - hetero_agents.add([fix1_AgentSetPolars, fix1_AgentSetPolars_no_wealth]) + # Create a new AgentSetRegistry with heterogeneous agent sets + model = Model() + hetero_agents = AgentSetRegistry(model) + hetero_agents.add([fix1_AgentSet, fix1_AgentSet_no_wealth]) # Test 1: Access column that exists in only one agent set result_wealth = hetero_agents.get("wealth") assert len(result_wealth) == 1, ( "Should only return agent sets that have 'wealth'" ) - assert fix1_AgentSetPolars in result_wealth, ( + assert fix1_AgentSet in result_wealth, ( "Should include the agent set with wealth" ) - assert fix1_AgentSetPolars_no_wealth not in result_wealth, ( + assert fix1_AgentSet_no_wealth not in result_wealth, ( "Should not include agent set without wealth" ) - assert result_wealth[fix1_AgentSetPolars].to_list() == [1, 2, 3, 4] + assert result_wealth[fix1_AgentSet].to_list() == [1, 2, 3, 4] # Test 2: Access column that exists in all agent sets result_age = hetero_agents.get("age") assert len(result_age) == 2, "Should return both agent sets that have 'age'" - assert fix1_AgentSetPolars in result_age - assert fix1_AgentSetPolars_no_wealth in result_age - assert result_age[fix1_AgentSetPolars].to_list() == [10, 20, 30, 40] - assert result_age[fix1_AgentSetPolars_no_wealth].to_list() == [1, 2, 3, 4] + assert fix1_AgentSet in result_age + assert fix1_AgentSet_no_wealth in result_age + assert result_age[fix1_AgentSet].to_list() == [10, 20, 30, 40] + assert result_age[fix1_AgentSet_no_wealth].to_list() == [1, 2, 3, 4] # Test 3: Access column that exists in no agent sets result_nonexistent = hetero_agents.get("nonexistent_column") @@ -306,41 +302,41 @@ def test_get( assert len(result_multi) == 1, ( "Should only include agent sets that have ALL requested columns" ) - assert fix1_AgentSetPolars in result_multi - assert fix1_AgentSetPolars_no_wealth not in result_multi - assert result_multi[fix1_AgentSetPolars].columns == ["wealth", "age"] + assert fix1_AgentSet in result_multi + assert fix1_AgentSet_no_wealth not in result_multi + assert result_multi[fix1_AgentSet].columns == ["wealth", "age"] # Test 5: Access multiple columns where some exist in different sets result_mixed = hetero_agents.get(["age", "income"]) assert len(result_mixed) == 1, ( "Should only include agent set that has both 'age' and 'income'" ) - assert fix1_AgentSetPolars_no_wealth in result_mixed - assert fix1_AgentSetPolars not in result_mixed + assert fix1_AgentSet_no_wealth in result_mixed + assert fix1_AgentSet not in result_mixed # Test 6: Test via __getitem__ syntax (the original bug report case) wealth_via_getitem = hetero_agents["wealth"] assert len(wealth_via_getitem) == 1 - assert fix1_AgentSetPolars in wealth_via_getitem - assert wealth_via_getitem[fix1_AgentSetPolars].to_list() == [1, 2, 3, 4] + assert fix1_AgentSet in wealth_via_getitem + assert wealth_via_getitem[fix1_AgentSet].to_list() == [1, 2, 3, 4] # Test 7: Test get(None) - should return all columns for all agent sets result_none = hetero_agents.get(None) assert len(result_none) == 2, ( "Should return both agent sets when attr_names=None" ) - assert fix1_AgentSetPolars in result_none - assert fix1_AgentSetPolars_no_wealth in result_none + assert fix1_AgentSet in result_none + assert fix1_AgentSet_no_wealth in result_none # Verify each agent set returns all its columns (excluding unique_id) - wealth_set_result = result_none[fix1_AgentSetPolars] + wealth_set_result = result_none[fix1_AgentSet] assert isinstance(wealth_set_result, pl.DataFrame), ( "Should return DataFrame when attr_names=None" ) expected_wealth_cols = {"wealth", "age"} # unique_id should be excluded assert set(wealth_set_result.columns) == expected_wealth_cols - no_wealth_set_result = result_none[fix1_AgentSetPolars_no_wealth] + no_wealth_set_result = result_none[fix1_AgentSet_no_wealth] assert isinstance(no_wealth_set_result, pl.DataFrame), ( "Should return DataFrame when attr_names=None" ) @@ -349,18 +345,18 @@ def test_get( def test_remove( self, - fix_AgentsDF: AgentsDF, - fix3_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix3_AgentSet: ExampleAgentSet, ): - agents = fix_AgentsDF + agents = fix_AgentSetRegistry - # Test with a single AgentSetDF + # Test with a single AgentSet agentset_polars = agents._agentsets[1] result = agents.remove(agents._agentsets[0], inplace=False) - assert isinstance(result._agentsets[0], ExampleAgentSetPolars) + assert isinstance(result._agentsets[0], ExampleAgentSet) assert len(result._agentsets) == 1 - # Test with a list of AgentSetDFs + # Test with a list of AgentSets result = agents.remove(agents._agentsets.copy(), inplace=False) assert len(result._agentsets) == 0 @@ -381,17 +377,17 @@ def test_remove( == agentset_polars2._df["unique_id"][1] ) - # Test if removing an AgentSetDF not present raises ValueError + # Test if removing an AgentSet not present raises ValueError with pytest.raises(ValueError): - result = agents.remove(fix3_AgentSetPolars, inplace=False) + result = agents.remove(fix3_AgentSet, inplace=False) # Test if removing an ID not present raises KeyError assert 0 not in agents._ids with pytest.raises(KeyError): result = agents.remove(0, inplace=False) - def test_select(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_select(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry # Test with default arguments. Should select all agents selected = agents.select(inplace=False) @@ -437,7 +433,7 @@ def test_select(self, fix_AgentsDF: AgentsDF): # Test with filter_func - def filter_func(agentset: AgentSetDF) -> pl.Series: + def filter_func(agentset: AgentSet) -> pl.Series: return agentset.df["wealth"] > agentset.df["wealth"].to_list()[0] selected = agents.select(filter_func=filter_func, inplace=False) @@ -472,8 +468,8 @@ def filter_func(agentset: AgentSetDF) -> pl.Series: ] ) - def test_set(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_set(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry # Test with a single attribute result = agents.set("wealth", 0, inplace=False) @@ -521,8 +517,8 @@ def test_set(self, fix_AgentsDF: AgentsDF): agents._agentsets[1] ) - def test_shuffle(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_shuffle(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry for _ in range(100): original_order_0 = agents._agentsets[0].df["unique_id"].to_list() original_order_1 = agents._agentsets[1].df["unique_id"].to_list() @@ -534,22 +530,22 @@ def test_shuffle(self, fix_AgentsDF: AgentsDF): return assert False - def test_sort(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_sort(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry agents.sort("wealth", ascending=False, inplace=True) assert pl.Series(agents._agentsets[0].df["wealth"]).is_sorted(descending=True) assert pl.Series(agents._agentsets[1].df["wealth"]).is_sorted(descending=True) def test_step( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, - fix_AgentsDF: AgentsDF, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, + fix_AgentSetRegistry: AgentSetRegistry, ): - previous_wealth_0 = fix1_AgentSetPolars._df["wealth"].clone() - previous_wealth_1 = fix2_AgentSetPolars._df["wealth"].clone() + previous_wealth_0 = fix1_AgentSet._df["wealth"].clone() + previous_wealth_1 = fix2_AgentSet._df["wealth"].clone() - agents = fix_AgentsDF + agents = fix_AgentSetRegistry agents.step() assert ( @@ -563,16 +559,16 @@ def test_step( def test__check_ids_presence( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - agents = fix_AgentsDF.remove(fix2_AgentSetPolars, inplace=False) - agents_different_index = deepcopy(fix2_AgentSetPolars) - result = agents._check_ids_presence([fix1_AgentSetPolars]) - assert result.filter( - pl.col("unique_id").is_in(fix1_AgentSetPolars._df["unique_id"]) - )["present"].all() + agents = fix_AgentSetRegistry.remove(fix2_AgentSet, inplace=False) + agents_different_index = deepcopy(fix2_AgentSet) + result = agents._check_ids_presence([fix1_AgentSet]) + assert result.filter(pl.col("unique_id").is_in(fix1_AgentSet._df["unique_id"]))[ + "present" + ].all() assert not result.filter( pl.col("unique_id").is_in(agents_different_index._df["unique_id"]) @@ -580,19 +576,17 @@ def test__check_ids_presence( def test__check_agentsets_presence( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix3_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix3_AgentSet: ExampleAgentSet, ): - agents = fix_AgentsDF - result = agents._check_agentsets_presence( - [fix1_AgentSetPolars, fix3_AgentSetPolars] - ) + agents = fix_AgentSetRegistry + result = agents._check_agentsets_presence([fix1_AgentSet, fix3_AgentSet]) assert result[0] assert not result[1] - def test__get_bool_masks(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test__get_bool_masks(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry # Test with mask = None result = agents._get_bool_masks(mask=None) truth_value = True @@ -637,51 +631,49 @@ def test__get_bool_masks(self, fix_AgentsDF: AgentsDF): len(agents._agentsets[1]) - 1 ) - # Test with mask = dict[AgentSetDF, AgentMask] + # Test with mask = dict[AgentSet, AgentMask] result = agents._get_bool_masks(mask=mask_dictionary) assert result[agents._agentsets[0]].to_list() == mask0.to_list() assert result[agents._agentsets[1]].to_list() == mask1.to_list() - def test__get_obj(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test__get_obj(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry assert agents._get_obj(inplace=True) is agents assert agents._get_obj(inplace=False) is not agents def test__return_agentsets_list( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - agents = fix_AgentsDF - result = agents._return_agentsets_list(fix1_AgentSetPolars) - assert result == [fix1_AgentSetPolars] - result = agents._return_agentsets_list( - [fix1_AgentSetPolars, fix2_AgentSetPolars] - ) - assert result == [fix1_AgentSetPolars, fix2_AgentSetPolars] + agents = fix_AgentSetRegistry + result = agents._return_agentsets_list(fix1_AgentSet) + assert result == [fix1_AgentSet] + result = agents._return_agentsets_list([fix1_AgentSet, fix2_AgentSet]) + assert result == [fix1_AgentSet, fix2_AgentSet] def test___add__( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - model = ModelDF() - agents = AgentsDF(model) - agentset_polars1 = fix1_AgentSetPolars - agentset_polars2 = fix2_AgentSetPolars + model = Model() + agents = AgentSetRegistry(model) + agentset_polars1 = fix1_AgentSet + agentset_polars2 = fix2_AgentSet - # Test with a single AgentSetPolars + # Test with a single AgentSet result = agents + agentset_polars1 assert result._agentsets[0] is agentset_polars1 assert result._ids.to_list() == agentset_polars1._df["unique_id"].to_list() - # Test with a single AgentSetPolars same as above + # Test with a single AgentSet same as above result = agents + agentset_polars2 assert result._agentsets[0] is agentset_polars2 assert result._ids.to_list() == agentset_polars2._df["unique_id"].to_list() - # Test with a list of AgentSetDFs + # Test with a list of AgentSets result = agents + [agentset_polars1, agentset_polars2] assert result._agentsets[0] is agentset_polars1 assert result._agentsets[1] is agentset_polars2 @@ -691,21 +683,21 @@ def test___add__( + agentset_polars2._df["unique_id"].to_list() ) - # Test if adding the same AgentSetDF raises ValueError + # Test if adding the same AgentSet raises ValueError with pytest.raises(ValueError): result + agentset_polars1 def test___contains__( - self, fix_AgentsDF: AgentsDF, fix3_AgentSetPolars: ExampleAgentSetPolars + self, fix_AgentSetRegistry: AgentSetRegistry, fix3_AgentSet: ExampleAgentSet ): # Test with a single value - agents = fix_AgentsDF + agents = fix_AgentSetRegistry agentset_polars1 = agents._agentsets[0] - # Test with an AgentSetDF + # Test with an AgentSet assert agentset_polars1 in agents - # Test with an AgentSetDF not present - assert fix3_AgentSetPolars not in agents + # Test with an AgentSet not present + assert fix3_AgentSet not in agents # Test with single id present assert agentset_polars1["unique_id"][0] in agents @@ -713,8 +705,8 @@ def test___contains__( # Test with single id not present assert 0 not in agents - def test___copy__(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test___copy__(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry agents.test_list = [[1, 2, 3]] # Test with deep=False @@ -725,8 +717,8 @@ def test___copy__(self, fix_AgentsDF: AgentsDF): assert agents._agentsets[0] == agents2._agentsets[0] assert (agents._ids == agents2._ids).all() - def test___deepcopy__(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test___deepcopy__(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry agents.test_list = [[1, 2, 3]] agents2 = deepcopy(agents) @@ -736,9 +728,9 @@ def test___deepcopy__(self, fix_AgentsDF: AgentsDF): assert agents._agentsets[0] != agents2._agentsets[0] assert (agents._ids == agents2._ids).all() - def test___getattr__(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF - assert isinstance(agents.model, ModelDF) + def test___getattr__(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry + assert isinstance(agents.model, Model) result = agents.wealth assert ( result[agents._agentsets[0]].to_list() @@ -751,77 +743,73 @@ def test___getattr__(self, fix_AgentsDF: AgentsDF): def test___getitem__( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - agents = fix_AgentsDF + agents = fix_AgentSetRegistry # Test with a single attribute assert ( - agents["wealth"][fix1_AgentSetPolars].to_list() - == fix1_AgentSetPolars._df["wealth"].to_list() + agents["wealth"][fix1_AgentSet].to_list() + == fix1_AgentSet._df["wealth"].to_list() ) assert ( - agents["wealth"][fix2_AgentSetPolars].to_list() - == fix2_AgentSetPolars._df["wealth"].to_list() + agents["wealth"][fix2_AgentSet].to_list() + == fix2_AgentSet._df["wealth"].to_list() ) # Test with a list of attributes result = agents[["wealth", "age"]] - assert result[fix1_AgentSetPolars].columns == ["wealth", "age"] + assert result[fix1_AgentSet].columns == ["wealth", "age"] assert ( - result[fix1_AgentSetPolars]["wealth"].to_list() - == fix1_AgentSetPolars._df["wealth"].to_list() + result[fix1_AgentSet]["wealth"].to_list() + == fix1_AgentSet._df["wealth"].to_list() ) assert ( - result[fix1_AgentSetPolars]["age"].to_list() - == fix1_AgentSetPolars._df["age"].to_list() + result[fix1_AgentSet]["age"].to_list() == fix1_AgentSet._df["age"].to_list() ) - assert result[fix2_AgentSetPolars].columns == ["wealth", "age"] + assert result[fix2_AgentSet].columns == ["wealth", "age"] assert ( - result[fix2_AgentSetPolars]["wealth"].to_list() - == fix2_AgentSetPolars._df["wealth"].to_list() + result[fix2_AgentSet]["wealth"].to_list() + == fix2_AgentSet._df["wealth"].to_list() ) assert ( - result[fix2_AgentSetPolars]["age"].to_list() - == fix2_AgentSetPolars._df["age"].to_list() + result[fix2_AgentSet]["age"].to_list() == fix2_AgentSet._df["age"].to_list() ) # Test with a single attribute and a mask - mask0 = fix1_AgentSetPolars._df["wealth"] > fix1_AgentSetPolars._df["wealth"][0] - mask1 = fix2_AgentSetPolars._df["wealth"] > fix2_AgentSetPolars._df["wealth"][0] - mask_dictionary: dict[AgentSetDF, AgentMask] = { - fix1_AgentSetPolars: mask0, - fix2_AgentSetPolars: mask1, + mask0 = fix1_AgentSet._df["wealth"] > fix1_AgentSet._df["wealth"][0] + mask1 = fix2_AgentSet._df["wealth"] > fix2_AgentSet._df["wealth"][0] + mask_dictionary: dict[AgentSet, AgentMask] = { + fix1_AgentSet: mask0, + fix2_AgentSet: mask1, } result = agents[mask_dictionary, "wealth"] assert ( - result[fix1_AgentSetPolars].to_list() - == fix1_AgentSetPolars.df["wealth"].to_list()[1:] + result[fix1_AgentSet].to_list() == fix1_AgentSet.df["wealth"].to_list()[1:] ) assert ( - result[fix2_AgentSetPolars].to_list() - == fix2_AgentSetPolars.df["wealth"].to_list()[1:] + result[fix2_AgentSet].to_list() == fix2_AgentSet.df["wealth"].to_list()[1:] ) def test___iadd__( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - model = ModelDF() - agents = AgentsDF(model) - agentset_polars1 = fix1_AgentSetPolars - agentset_polars = fix2_AgentSetPolars + model = Model() + agents = AgentSetRegistry(model) + agentset_polars1 = fix1_AgentSet + agentset_polars = fix2_AgentSet - # Test with a single AgentSetPolars + # Test with a single AgentSet agents_copy = deepcopy(agents) agents_copy += agentset_polars assert agents_copy._agentsets[0] is agentset_polars assert agents_copy._ids.to_list() == agentset_polars._df["unique_id"].to_list() - # Test with a list of AgentSetDFs + # Test with a list of AgentSets agents_copy = deepcopy(agents) agents_copy += [agentset_polars1, agentset_polars] assert agents_copy._agentsets[0] is agentset_polars1 @@ -832,12 +820,12 @@ def test___iadd__( + agentset_polars._df["unique_id"].to_list() ) - # Test if adding the same AgentSetDF raises ValueError + # Test if adding the same AgentSet raises ValueError with pytest.raises(ValueError): agents_copy += agentset_polars1 - def test___iter__(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test___iter__(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry len_agentset0 = len(agents._agentsets[0]) len_agentset1 = len(agents._agentsets[1]) for i, agent in enumerate(agents): @@ -853,36 +841,36 @@ def test___iter__(self, fix_AgentsDF: AgentsDF): def test___isub__( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - # Test with an AgentSetPolars and a DataFrame - agents = fix_AgentsDF - agents -= fix1_AgentSetPolars - assert agents._agentsets[0] == fix2_AgentSetPolars + # Test with an AgentSet and a DataFrame + agents = fix_AgentSetRegistry + agents -= fix1_AgentSet + assert agents._agentsets[0] == fix2_AgentSet assert len(agents._agentsets) == 1 def test___len__( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - assert len(fix_AgentsDF) == len(fix1_AgentSetPolars) + len(fix2_AgentSetPolars) + assert len(fix_AgentSetRegistry) == len(fix1_AgentSet) + len(fix2_AgentSet) - def test___repr__(self, fix_AgentsDF: AgentsDF): - repr(fix_AgentsDF) + def test___repr__(self, fix_AgentSetRegistry: AgentSetRegistry): + repr(fix_AgentSetRegistry) - def test___reversed__(self, fix2_AgentSetPolars: AgentsDF): - agents = fix2_AgentSetPolars + def test___reversed__(self, fix2_AgentSet: AgentSetRegistry): + agents = fix2_AgentSet reversed_wealth = [] for agent in reversed(list(agents)): reversed_wealth.append(agent["wealth"]) assert reversed_wealth == list(reversed(agents["wealth"])) - def test___setitem__(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test___setitem__(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry # Test with a single attribute agents["wealth"] = 0 @@ -918,38 +906,38 @@ def test___setitem__(self, fix_AgentsDF: AgentsDF): len(agents._agentsets[1]) - 1 ) - def test___str__(self, fix_AgentsDF: AgentsDF): - str(fix_AgentsDF) + def test___str__(self, fix_AgentSetRegistry: AgentSetRegistry): + str(fix_AgentSetRegistry) def test___sub__( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - # Test with an AgentSetPolars and a DataFrame - result = fix_AgentsDF - fix1_AgentSetPolars - assert isinstance(result._agentsets[0], ExampleAgentSetPolars) + # Test with an AgentSet and a DataFrame + result = fix_AgentSetRegistry - fix1_AgentSet + assert isinstance(result._agentsets[0], ExampleAgentSet) assert len(result._agentsets) == 1 def test_agents( self, - fix_AgentsDF: AgentsDF, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix_AgentSetRegistry: AgentSetRegistry, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - assert isinstance(fix_AgentsDF.df, dict) - assert len(fix_AgentsDF.df) == 2 - assert fix_AgentsDF.df[fix1_AgentSetPolars] is fix1_AgentSetPolars._df - assert fix_AgentsDF.df[fix2_AgentSetPolars] is fix2_AgentSetPolars._df + assert isinstance(fix_AgentSetRegistry.df, dict) + assert len(fix_AgentSetRegistry.df) == 2 + assert fix_AgentSetRegistry.df[fix1_AgentSet] is fix1_AgentSet._df + assert fix_AgentSetRegistry.df[fix2_AgentSet] is fix2_AgentSet._df # Test agents.setter - fix_AgentsDF.df = [fix1_AgentSetPolars, fix2_AgentSetPolars] - assert fix_AgentsDF._agentsets[0] == fix1_AgentSetPolars - assert fix_AgentsDF._agentsets[1] == fix2_AgentSetPolars + fix_AgentSetRegistry.df = [fix1_AgentSet, fix2_AgentSet] + assert fix_AgentSetRegistry._agentsets[0] == fix1_AgentSet + assert fix_AgentSetRegistry._agentsets[1] == fix2_AgentSet - def test_active_agents(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_active_agents(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry # Test with select mask0 = ( @@ -1002,20 +990,20 @@ def test_active_agents(self, fix_AgentsDF: AgentsDF): ) ) - def test_agentsets_by_type(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_agentsets_by_type(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry result = agents.agentsets_by_type assert isinstance(result, dict) - assert isinstance(result[ExampleAgentSetPolars], AgentsDF) + assert isinstance(result[ExampleAgentSet], AgentSetRegistry) assert ( - result[ExampleAgentSetPolars]._agentsets[0].df.rows() + result[ExampleAgentSet]._agentsets[0].df.rows() == agents._agentsets[1].df.rows() ) - def test_inactive_agents(self, fix_AgentsDF: AgentsDF): - agents = fix_AgentsDF + def test_inactive_agents(self, fix_AgentSetRegistry: AgentSetRegistry): + agents = fix_AgentSetRegistry # Test with select mask0 = ( diff --git a/tests/test_agentset.py b/tests/test_agentset.py index 0c849abe..d475a4fc 100644 --- a/tests/test_agentset.py +++ b/tests/test_agentset.py @@ -4,11 +4,11 @@ import pytest from numpy.random import Generator -from mesa_frames import AgentSetPolars, GridPolars, ModelDF +from mesa_frames import AgentSet, Grid, Model -class ExampleAgentSetPolars(AgentSetPolars): - def __init__(self, model: ModelDF): +class ExampleAgentSet(AgentSet): + def __init__(self, model: Model): super().__init__(model) self.starting_wealth = pl.Series("wealth", [1, 2, 3, 4]) @@ -19,8 +19,8 @@ def step(self) -> None: self.add_wealth(1) -class ExampleAgentSetPolarsNoWealth(AgentSetPolars): - def __init__(self, model: ModelDF): +class ExampleAgentSetNoWealth(AgentSet): + def __init__(self, model: Model): super().__init__(model) self.starting_income = pl.Series("income", [1000, 2000, 3000, 4000]) @@ -32,64 +32,62 @@ def step(self) -> None: @pytest.fixture -def fix1_AgentSetPolars() -> ExampleAgentSetPolars: - model = ModelDF() - agents = ExampleAgentSetPolars(model) +def fix1_AgentSet() -> ExampleAgentSet: + model = Model() + agents = ExampleAgentSet(model) agents["wealth"] = agents.starting_wealth agents["age"] = [10, 20, 30, 40] - model.agents.add(agents) + model.sets.add(agents) return agents @pytest.fixture -def fix2_AgentSetPolars() -> ExampleAgentSetPolars: - model = ModelDF() - agents = ExampleAgentSetPolars(model) +def fix2_AgentSet() -> ExampleAgentSet: + model = Model() + agents = ExampleAgentSet(model) agents["wealth"] = agents.starting_wealth + 10 agents["age"] = [100, 200, 300, 400] - model.agents.add(agents) - space = GridPolars(model, dimensions=[3, 3], capacity=2) + model.sets.add(agents) + space = Grid(model, dimensions=[3, 3], capacity=2) model.space = space space.place_agents(agents=agents["unique_id"][[0, 1]], pos=[[2, 1], [1, 2]]) return agents @pytest.fixture -def fix3_AgentSetPolars() -> ExampleAgentSetPolars: - model = ModelDF() - agents = ExampleAgentSetPolars(model) +def fix3_AgentSet() -> ExampleAgentSet: + model = Model() + agents = ExampleAgentSet(model) agents["wealth"] = agents.starting_wealth + 7 agents["age"] = [12, 13, 14, 116] return agents @pytest.fixture -def fix1_AgentSetPolars_with_pos( - fix1_AgentSetPolars: ExampleAgentSetPolars, -) -> ExampleAgentSetPolars: - space = GridPolars(fix1_AgentSetPolars.model, dimensions=[3, 3], capacity=2) - fix1_AgentSetPolars.model.space = space - space.place_agents( - agents=fix1_AgentSetPolars["unique_id"][[0, 1]], pos=[[0, 0], [1, 1]] - ) - return fix1_AgentSetPolars +def fix1_AgentSet_with_pos( + fix1_AgentSet: ExampleAgentSet, +) -> ExampleAgentSet: + space = Grid(fix1_AgentSet.model, dimensions=[3, 3], capacity=2) + fix1_AgentSet.model.space = space + space.place_agents(agents=fix1_AgentSet["unique_id"][[0, 1]], pos=[[0, 0], [1, 1]]) + return fix1_AgentSet @pytest.fixture -def fix1_AgentSetPolars_no_wealth() -> ExampleAgentSetPolarsNoWealth: - model = ModelDF() - agents = ExampleAgentSetPolarsNoWealth(model) +def fix1_AgentSet_no_wealth() -> ExampleAgentSetNoWealth: + model = Model() + agents = ExampleAgentSetNoWealth(model) agents["income"] = agents.starting_income agents["age"] = [1, 2, 3, 4] - model.agents.add(agents) + model.sets.add(agents) return agents -class Test_AgentSetPolars: +class Test_AgentSet: def test__init__(self): - model = ModelDF() - agents = ExampleAgentSetPolars(model) + model = Model() + agents = ExampleAgentSet(model) agents.add({"age": [0, 1, 2, 3]}) assert agents.model == model assert isinstance(agents.df, pl.DataFrame) @@ -100,9 +98,9 @@ def test__init__(self): def test_add( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, + fix1_AgentSet: ExampleAgentSet, ): - agents = fix1_AgentSetPolars + agents = fix1_AgentSet # Test with a pl.Dataframe result = agents.add( @@ -140,14 +138,14 @@ def test_add( agents.add([10, 20, 30]) # Three values but agents has 2 columns # Test adding sequence to empty AgentSet - should raise ValueError - empty_agents = ExampleAgentSetPolars(agents.model) + empty_agents = ExampleAgentSet(agents.model) with pytest.raises( ValueError, match="Cannot add a sequence to an empty AgentSet" ): empty_agents.add([1, 2]) # Should raise error for empty AgentSet - def test_contains(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_contains(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet # Test with a single value assert agents.contains(agents["unique_id"][0]) @@ -161,8 +159,8 @@ def test_contains(self, fix1_AgentSetPolars: ExampleAgentSetPolars): result = agents.contains(unique_ids[:2]) assert all(result == [True, True]) - def test_copy(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_copy(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet agents.test_list = [[1, 2, 3]] # Test with deep=False @@ -171,12 +169,12 @@ def test_copy(self, fix1_AgentSetPolars: ExampleAgentSetPolars): assert agents.test_list[0][-1] == agents2.test_list[0][-1] # Test with deep=True - agents2 = fix1_AgentSetPolars.copy(deep=True) + agents2 = fix1_AgentSet.copy(deep=True) agents2.test_list[0].append(4) assert agents.test_list[-1] != agents2.test_list[-1] - def test_discard(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars_with_pos + def test_discard(self, fix1_AgentSet_with_pos: ExampleAgentSet): + agents = fix1_AgentSet_with_pos # Test with a single value result = agents.discard(agents["unique_id"][0], inplace=False) @@ -214,8 +212,8 @@ def test_discard(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): result = agents.discard([], inplace=False) assert all(result.df["unique_id"] == agents["unique_id"]) - def test_do(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_do(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet # Test with no return_results, no mask agents.do("add_wealth", 1) @@ -229,8 +227,8 @@ def test_do(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents.do("add_wealth", 1, mask=agents["wealth"] > 3) assert agents.df["wealth"].to_list() == [3, 5, 6, 7] - def test_get(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_get(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet # Test with a single attribute assert agents.get("wealth").to_list() == [1, 2, 3, 4] @@ -245,16 +243,16 @@ def test_get(self, fix1_AgentSetPolars: ExampleAgentSetPolars): selected = agents.select(agents.df["wealth"] > 1, inplace=False) assert selected.get("wealth", mask="active").to_list() == [2, 3, 4] - def test_remove(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_remove(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet remaining_agents_id = agents["unique_id"][2, 3] agents.remove(agents["unique_id"][0, 1]) assert all(agents.df["unique_id"] == remaining_agents_id) with pytest.raises(KeyError): agents.remove([0]) - def test_select(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_select(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet # Test with default arguments. Should select all agents selected = agents.select(inplace=False) @@ -278,7 +276,7 @@ def test_select(self, fix1_AgentSetPolars: ExampleAgentSetPolars): assert all(selected.active_agents["unique_id"] == agents["unique_id"][0, 1]) # Test with filter_func - def filter_func(agentset: AgentSetPolars) -> pl.Series: + def filter_func(agentset: AgentSet) -> pl.Series: return agentset.df["wealth"] > 1 selected = agents.select(filter_func=filter_func, inplace=False) @@ -296,8 +294,8 @@ def filter_func(agentset: AgentSetPolars) -> pl.Series: for id in agents["unique_id"][2, 3] ) - def test_set(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_set(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet # Test with a single attribute result = agents.set("wealth", 0, inplace=False) @@ -322,8 +320,8 @@ def test_set(self, fix1_AgentSetPolars: ExampleAgentSetPolars): result = agents.set("wealth", [100, 200, 300, 400], inplace=False) assert result.df["wealth"].to_list() == [100, 200, 300, 400] - def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_shuffle(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet for _ in range(10): original_order = agents.df["unique_id"].to_list() agents.shuffle() @@ -331,41 +329,41 @@ def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSetPolars): return assert False - def test_sort(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_sort(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet agents.sort("wealth", ascending=False) assert agents.df["wealth"].to_list() == [4, 3, 2, 1] def test__add__( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, + fix1_AgentSet: ExampleAgentSet, ): - agents = fix1_AgentSetPolars + agents = fix1_AgentSet - # Test with an AgentSetPolars and a DataFrame + # Test with an AgentSet and a DataFrame agents3 = agents + pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}) assert agents3.df["wealth"].to_list() == [1, 2, 3, 4, 5, 6] assert agents3.df["age"].to_list() == [10, 20, 30, 40, 50, 60] - # Test with an AgentSetPolars and a list (Sequence[Any]) + # Test with an AgentSet and a list (Sequence[Any]) agents3 = agents + [5, 5] # unique_id, wealth, age assert all(agents3.df["unique_id"].to_list()[:-1] == agents["unique_id"]) assert len(agents3.df) == 5 assert agents3.df["wealth"].to_list() == [1, 2, 3, 4, 5] assert agents3.df["age"].to_list() == [10, 20, 30, 40, 5] - # Test with an AgentSetPolars and a dict + # Test with an AgentSet and a dict agents3 = agents + {"age": 10, "wealth": 5} assert agents3.df["wealth"].to_list() == [1, 2, 3, 4, 5] - def test__contains__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): + def test__contains__(self, fix1_AgentSet: ExampleAgentSet): # Test with a single value - agents = fix1_AgentSetPolars + agents = fix1_AgentSet assert agents["unique_id"][0] in agents assert 0 not in agents - def test__copy__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test__copy__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet agents.test_list = [[1, 2, 3]] # Test with deep=False @@ -373,21 +371,21 @@ def test__copy__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents2.test_list[0].append(4) assert agents.test_list[0][-1] == agents2.test_list[0][-1] - def test__deepcopy__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test__deepcopy__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet agents.test_list = [[1, 2, 3]] agents2 = deepcopy(agents) agents2.test_list[0].append(4) assert agents.test_list[-1] != agents2.test_list[-1] - def test__getattr__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars - assert isinstance(agents.model, ModelDF) + def test__getattr__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet + assert isinstance(agents.model, Model) assert agents.wealth.to_list() == [1, 2, 3, 4] - def test__getitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test__getitem__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet # Testing with a string assert agents["wealth"].to_list() == [1, 2, 3, 4] @@ -405,59 +403,58 @@ def test__getitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): def test__iadd__( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, + fix1_AgentSet: ExampleAgentSet, ): - # Test with an AgentSetPolars and a DataFrame - agents = deepcopy(fix1_AgentSetPolars) + # Test with an AgentSet and a DataFrame + agents = deepcopy(fix1_AgentSet) agents += pl.DataFrame({"wealth": [5, 6], "age": [50, 60]}) assert agents.df["wealth"].to_list() == [1, 2, 3, 4, 5, 6] assert agents.df["age"].to_list() == [10, 20, 30, 40, 50, 60] - # Test with an AgentSetPolars and a list - agents = deepcopy(fix1_AgentSetPolars) + # Test with an AgentSet and a list + agents = deepcopy(fix1_AgentSet) agents += [5, 5] # unique_id, wealth, age assert all( - agents["unique_id"].to_list()[:-1] - == fix1_AgentSetPolars["unique_id"][0, 1, 2, 3] + agents["unique_id"].to_list()[:-1] == fix1_AgentSet["unique_id"][0, 1, 2, 3] ) assert len(agents.df) == 5 assert agents.df["wealth"].to_list() == [1, 2, 3, 4, 5] assert agents.df["age"].to_list() == [10, 20, 30, 40, 5] - # Test with an AgentSetPolars and a dict - agents = deepcopy(fix1_AgentSetPolars) + # Test with an AgentSet and a dict + agents = deepcopy(fix1_AgentSet) agents += {"age": 10, "wealth": 5} assert agents.df["wealth"].to_list() == [1, 2, 3, 4, 5] - def test__iter__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test__iter__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet for i, agent in enumerate(agents): assert isinstance(agent, dict) assert agent["wealth"] == i + 1 - def test__isub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - # Test with an AgentSetPolars and a DataFrame - agents = deepcopy(fix1_AgentSetPolars) + def test__isub__(self, fix1_AgentSet: ExampleAgentSet): + # Test with an AgentSet and a DataFrame + agents = deepcopy(fix1_AgentSet) agents -= agents.df assert agents.df.is_empty() - def test__len__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test__len__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet assert len(agents) == 4 - def test__repr__(self, fix1_AgentSetPolars): - agents: ExampleAgentSetPolars = fix1_AgentSetPolars + def test__repr__(self, fix1_AgentSet): + agents: ExampleAgentSet = fix1_AgentSet repr(agents) - def test__reversed__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test__reversed__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet reversed_wealth = [] for i, agent in reversed(list(enumerate(agents))): reversed_wealth.append(agent["wealth"]) assert reversed_wealth == [4, 3, 2, 1] - def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test__setitem__(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet agents = deepcopy(agents) # To test passing through a df later @@ -479,36 +476,36 @@ def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): assert agents.df.item(0, "wealth") == 9 assert agents.df.item(0, "age") == 99 - def test__str__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents: ExampleAgentSetPolars = fix1_AgentSetPolars + def test__str__(self, fix1_AgentSet: ExampleAgentSet): + agents: ExampleAgentSet = fix1_AgentSet str(agents) - def test__sub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents: ExampleAgentSetPolars = fix1_AgentSetPolars - agents2: ExampleAgentSetPolars = agents - agents.df + def test__sub__(self, fix1_AgentSet: ExampleAgentSet): + agents: ExampleAgentSet = fix1_AgentSet + agents2: ExampleAgentSet = agents - agents.df assert agents2.df.is_empty() assert agents.df["wealth"].to_list() == [1, 2, 3, 4] - def test_get_obj(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_get_obj(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet assert agents._get_obj(inplace=True) is agents assert agents._get_obj(inplace=False) is not agents def test_agents( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): - agents = fix1_AgentSetPolars - agents2 = fix2_AgentSetPolars + agents = fix1_AgentSet + agents2 = fix2_AgentSet assert isinstance(agents.df, pl.DataFrame) # Test agents.setter agents.df = agents2.df assert all(agents["unique_id"] == agents2["unique_id"][0, 1, 2, 3]) - def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_active_agents(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet # Test with select agents.select(agents.df["wealth"] > 2, inplace=True) @@ -518,18 +515,16 @@ def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents.active_agents = agents.df["wealth"] > 2 assert all(agents.active_agents["unique_id"] == agents["unique_id"][2, 3]) - def test_inactive_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): - agents = fix1_AgentSetPolars + def test_inactive_agents(self, fix1_AgentSet: ExampleAgentSet): + agents = fix1_AgentSet agents.select(agents.df["wealth"] > 2, inplace=True) assert all(agents.inactive_agents["unique_id"] == agents["unique_id"][0, 1]) - def test_pos(self, fix1_AgentSetPolars_with_pos: ExampleAgentSetPolars): - pos = fix1_AgentSetPolars_with_pos.pos + def test_pos(self, fix1_AgentSet_with_pos: ExampleAgentSet): + pos = fix1_AgentSet_with_pos.pos assert isinstance(pos, pl.DataFrame) - assert all( - pos["unique_id"] == fix1_AgentSetPolars_with_pos["unique_id"][0, 1, 2, 3] - ) + assert all(pos["unique_id"] == fix1_AgentSet_with_pos["unique_id"][0, 1, 2, 3]) assert pos.columns == ["unique_id", "dim_0", "dim_1"] assert pos["dim_0"].to_list() == [0, 1, None, None] assert pos["dim_1"].to_list() == [0, 1, None, None] diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index beb96632..b7407711 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -1,5 +1,5 @@ from mesa_frames.concrete.datacollector import DataCollector -from mesa_frames import ModelDF, AgentSetPolars, AgentsDF +from mesa_frames import Model, AgentSet, AgentSetRegistry import pytest import polars as pl import beartype @@ -12,8 +12,8 @@ def custom_trigger(model): return model._steps % 2 == 0 -class ExampleAgentSet1(AgentSetPolars): - def __init__(self, model: ModelDF): +class ExampleAgentSet1(AgentSet): + def __init__(self, model: Model): super().__init__(model) self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) self["age"] = pl.Series("age", [10, 20, 30, 40]) @@ -25,8 +25,8 @@ def step(self) -> None: self.add_wealth(1) -class ExampleAgentSet2(AgentSetPolars): - def __init__(self, model: ModelDF): +class ExampleAgentSet2(AgentSet): + def __init__(self, model: Model): super().__init__(model) self["wealth"] = pl.Series("wealth", [10, 20, 30, 40]) self["age"] = pl.Series("age", [11, 22, 33, 44]) @@ -38,8 +38,8 @@ def step(self) -> None: self.add_wealth(2) -class ExampleAgentSet3(AgentSetPolars): - def __init__(self, model: ModelDF): +class ExampleAgentSet3(AgentSet): + def __init__(self, model: Model): super().__init__(model) self["age"] = pl.Series("age", [1, 2, 3, 4]) self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) @@ -51,13 +51,13 @@ def step(self) -> None: self.age_agents(1) -class ExampleModel(ModelDF): - def __init__(self, agents: AgentsDF): +class ExampleModel(Model): + def __init__(self, sets: AgentSetRegistry): super().__init__() - self.agents = agents + self.sets = sets def step(self): - self.agents.do("step") + self.sets.do("step") def run_model(self, n): for _ in range(n): @@ -74,14 +74,14 @@ def run_model_with_conditional_collect(self, n): self.dc.conditional_collect() -class ExampleModelWithMultipleCollects(ModelDF): - def __init__(self, agents: AgentsDF): +class ExampleModelWithMultipleCollects(Model): + def __init__(self, agents: AgentSetRegistry): super().__init__() - self.agents = agents + self.sets = agents def step(self): self.dc.conditional_collect() - self.agents.do("step") + self.sets.do("step") self.dc.conditional_collect() def run_model_with_conditional_collect_multiple_batch(self, n): @@ -95,40 +95,40 @@ def postgres_uri(): @pytest.fixture -def fix1_AgentSetPolars() -> ExampleAgentSet1: - return ExampleAgentSet1(ModelDF()) +def fix1_AgentSet() -> ExampleAgentSet1: + return ExampleAgentSet1(Model()) @pytest.fixture -def fix2_AgentSetPolars() -> ExampleAgentSet2: - return ExampleAgentSet2(ModelDF()) +def fix2_AgentSet() -> ExampleAgentSet2: + return ExampleAgentSet2(Model()) @pytest.fixture -def fix3_AgentSetPolars() -> ExampleAgentSet3: - return ExampleAgentSet3(ModelDF()) +def fix3_AgentSet() -> ExampleAgentSet3: + return ExampleAgentSet3(Model()) @pytest.fixture -def fix_AgentsDF( - fix1_AgentSetPolars: ExampleAgentSet1, - fix2_AgentSetPolars: ExampleAgentSet2, - fix3_AgentSetPolars: ExampleAgentSet3, -) -> AgentsDF: - model = ModelDF() - agents = AgentsDF(model) - agents.add([fix1_AgentSetPolars, fix2_AgentSetPolars, fix3_AgentSetPolars]) +def fix_AgentSetRegistry( + fix1_AgentSet: ExampleAgentSet1, + fix2_AgentSet: ExampleAgentSet2, + fix3_AgentSet: ExampleAgentSet3, +) -> AgentSetRegistry: + model = Model() + agents = AgentSetRegistry(model) + agents.add([fix1_AgentSet, fix2_AgentSet, fix3_AgentSet]) return agents @pytest.fixture -def fix1_model(fix_AgentsDF: AgentsDF) -> ExampleModel: - return ExampleModel(fix_AgentsDF) +def fix1_model(fix_AgentSetRegistry: AgentSetRegistry) -> ExampleModel: + return ExampleModel(fix_AgentSetRegistry) @pytest.fixture -def fix2_model(fix_AgentsDF: AgentsDF) -> ExampleModel: - return ExampleModelWithMultipleCollects(fix_AgentsDF) +def fix2_model(fix_AgentSetRegistry: AgentSetRegistry) -> ExampleModel: + return ExampleModelWithMultipleCollects(fix_AgentSetRegistry) class TestDataCollector: @@ -160,11 +160,11 @@ def test_collect(self, fix1_model): model=model, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], "age": "age", }, ) @@ -219,11 +219,11 @@ def test_collect_step(self, fix1_model): model=model, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], "age": "age", }, ) @@ -275,11 +275,11 @@ def test_conditional_collect(self, fix1_model): trigger=custom_trigger, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], "age": "age", }, ) @@ -357,11 +357,11 @@ def test_flush_local_csv(self, fix1_model): trigger=custom_trigger, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], "age": "age", }, storage="csv", @@ -433,11 +433,11 @@ def test_flush_local_parquet(self, fix1_model): trigger=custom_trigger, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], }, storage="parquet", storage_uri=tmpdir, @@ -509,11 +509,11 @@ def test_postgress(self, fix1_model, postgres_uri): trigger=custom_trigger, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], "age": "age", }, storage="postgresql", @@ -558,11 +558,11 @@ def test_batch_memory(self, fix2_model): trigger=custom_trigger, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], "age": "age", }, ) @@ -703,11 +703,11 @@ def test_batch_save(self, fix2_model): trigger=custom_trigger, model_reporters={ "total_agents": lambda model: sum( - len(agentset) for agentset in model.agents._agentsets + len(agentset) for agentset in model.sets._agentsets ) }, agent_reporters={ - "wealth": lambda model: model.agents._agentsets[0]["wealth"], + "wealth": lambda model: model.sets._agentsets[0]["wealth"], "age": "age", }, storage="csv", diff --git a/tests/test_grid.py b/tests/test_grid.py index 5d8cafa6..6d75f3cc 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -3,36 +3,36 @@ import pytest from polars.testing import assert_frame_equal -from mesa_frames import GridPolars, ModelDF +from mesa_frames import Grid, Model from tests.test_agentset import ( - ExampleAgentSetPolars, - fix1_AgentSetPolars, - fix2_AgentSetPolars, + ExampleAgentSet, + fix1_AgentSet, + fix2_AgentSet, ) -def get_unique_ids(model: ModelDF) -> pl.Series: - # return model.get_agents_of_type(model.agent_types[0])["unique_id"] +def get_unique_ids(model: Model) -> pl.Series: + # return model.get_sets_of_type(model.set_types[0])["unique_id"] series_list = [ - agent_set["unique_id"].cast(pl.UInt64) for agent_set in model.agents.df.values() + agent_set["unique_id"].cast(pl.UInt64) for agent_set in model.sets.df.values() ] return pl.concat(series_list) -class TestGridPolars: +class TestGrid: @pytest.fixture def model( self, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, - ) -> ModelDF: - model = ModelDF() - model.agents.add([fix1_AgentSetPolars, fix2_AgentSetPolars]) + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, + ) -> Model: + model = Model() + model.sets.add([fix1_AgentSet, fix2_AgentSet]) return model @pytest.fixture - def grid_moore(self, model: ModelDF) -> GridPolars: - space = GridPolars(model, dimensions=[3, 3], capacity=2) + def grid_moore(self, model: Model) -> Grid: + space = Grid(model, dimensions=[3, 3], capacity=2) unique_ids = get_unique_ids(model) space.place_agents(agents=unique_ids[[0, 1]], pos=[[0, 0], [1, 1]]) space.set_cells( @@ -41,8 +41,8 @@ def grid_moore(self, model: ModelDF) -> GridPolars: return space @pytest.fixture - def grid_moore_torus(self, model: ModelDF) -> GridPolars: - space = GridPolars(model, dimensions=[3, 3], capacity=2, torus=True) + def grid_moore_torus(self, model: Model) -> Grid: + space = Grid(model, dimensions=[3, 3], capacity=2, torus=True) unique_ids = get_unique_ids(model) space.place_agents(agents=unique_ids[[0, 1]], pos=[[0, 0], [1, 1]]) space.set_cells( @@ -51,23 +51,23 @@ def grid_moore_torus(self, model: ModelDF) -> GridPolars: return space @pytest.fixture - def grid_von_neumann(self, model: ModelDF) -> GridPolars: - space = GridPolars(model, dimensions=[3, 3], neighborhood_type="von_neumann") + def grid_von_neumann(self, model: Model) -> Grid: + space = Grid(model, dimensions=[3, 3], neighborhood_type="von_neumann") unique_ids = get_unique_ids(model) space.place_agents(agents=unique_ids[[0, 1]], pos=[[0, 0], [1, 1]]) return space @pytest.fixture - def grid_hexagonal(self, model: ModelDF) -> GridPolars: - space = GridPolars(model, dimensions=[10, 10], neighborhood_type="hexagonal") + def grid_hexagonal(self, model: Model) -> Grid: + space = Grid(model, dimensions=[10, 10], neighborhood_type="hexagonal") unique_ids = get_unique_ids(model) space.place_agents(agents=unique_ids[[0, 1]], pos=[[0, 0], [1, 1]]) return space - def test___init__(self, model: ModelDF): + def test___init__(self, model: Model): # Test with default parameters - grid1 = GridPolars(model, dimensions=[3, 3]) - assert isinstance(grid1, GridPolars) + grid1 = Grid(model, dimensions=[3, 3]) + assert isinstance(grid1, Grid) assert isinstance(grid1.agents, pl.DataFrame) assert grid1.agents.is_empty() assert isinstance(grid1.cells, pl.DataFrame) @@ -80,26 +80,26 @@ def test___init__(self, model: ModelDF): assert grid1.model == model # Test with capacity = 10 - grid2 = GridPolars(model, dimensions=[3, 3], capacity=10) + grid2 = Grid(model, dimensions=[3, 3], capacity=10) assert grid2.remaining_capacity == (10 * 3 * 3) # Test with torus = True - grid3 = GridPolars(model, dimensions=[3, 3], torus=True) + grid3 = Grid(model, dimensions=[3, 3], torus=True) assert grid3.torus # Test with neighborhood_type = "von_neumann" - grid4 = GridPolars(model, dimensions=[3, 3], neighborhood_type="von_neumann") + grid4 = Grid(model, dimensions=[3, 3], neighborhood_type="von_neumann") assert grid4.neighborhood_type == "von_neumann" # Test with neighborhood_type = "moore" - grid5 = GridPolars(model, dimensions=[3, 3], neighborhood_type="moore") + grid5 = Grid(model, dimensions=[3, 3], neighborhood_type="moore") assert grid5.neighborhood_type == "moore" # Test with neighborhood_type = "hexagonal" - grid6 = GridPolars(model, dimensions=[3, 3], neighborhood_type="hexagonal") + grid6 = Grid(model, dimensions=[3, 3], neighborhood_type="hexagonal") assert grid6.neighborhood_type == "hexagonal" - def test_get_cells(self, grid_moore: GridPolars): + def test_get_cells(self, grid_moore: Grid): # Test with None (all cells) result = grid_moore.get_cells() assert isinstance(result, pl.DataFrame) @@ -132,9 +132,9 @@ def test_get_cells(self, grid_moore: GridPolars): def test_get_directions( self, - grid_moore: GridPolars, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + grid_moore: Grid, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): unique_ids = get_unique_ids(grid_moore.model) # Test with GridCoordinate @@ -151,12 +151,10 @@ def test_get_directions( # Test with missing agents (raises ValueError) with pytest.raises(ValueError): - grid_moore.get_directions( - agents0=fix1_AgentSetPolars, agents1=fix2_AgentSetPolars - ) + grid_moore.get_directions(agents0=fix1_AgentSet, agents1=fix2_AgentSet) # Test with IdsLike - grid_moore.place_agents(fix2_AgentSetPolars, [[0, 1], [0, 2], [1, 0], [1, 2]]) + grid_moore.place_agents(fix2_AgentSet, [[0, 1], [0, 2], [1, 0], [1, 2]]) assert_frame_equal( grid_moore.agents, pl.DataFrame( @@ -177,18 +175,16 @@ def test_get_directions( assert dir.select(pl.col("dim_0")).to_series().to_list() == [0, -1] assert dir.select(pl.col("dim_1")).to_series().to_list() == [1, 1] - # Test with two AgentSetDFs + # Test with two AgentSets grid_moore.place_agents(unique_ids[[2, 3]], [[1, 1], [2, 2]]) - dir = grid_moore.get_directions( - agents0=fix1_AgentSetPolars, agents1=fix2_AgentSetPolars - ) + dir = grid_moore.get_directions(agents0=fix1_AgentSet, agents1=fix2_AgentSet) assert isinstance(dir, pl.DataFrame) assert dir.select(pl.col("dim_0")).to_series().to_list() == [0, -1, 0, -1] assert dir.select(pl.col("dim_1")).to_series().to_list() == [1, 1, -1, 0] - # Test with AgentsDF + # Test with AgentSetRegistry dir = grid_moore.get_directions( - agents0=grid_moore.model.agents, agents1=grid_moore.model.agents + agents0=grid_moore.model.sets, agents1=grid_moore.model.sets ) assert isinstance(dir, pl.DataFrame) assert grid_moore._df_all(dir == 0).all() @@ -215,9 +211,9 @@ def test_get_directions( def test_get_distances( self, - grid_moore: GridPolars, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + grid_moore: Grid, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): # Test with GridCoordinate dist = grid_moore.get_distances(pos0=[1, 1], pos1=[2, 2]) @@ -236,12 +232,10 @@ def test_get_distances( # Test with missing agents (raises ValueError) with pytest.raises(ValueError): - grid_moore.get_distances( - agents0=fix1_AgentSetPolars, agents1=fix2_AgentSetPolars - ) + grid_moore.get_distances(agents0=fix1_AgentSet, agents1=fix2_AgentSet) # Test with IdsLike - grid_moore.place_agents(fix2_AgentSetPolars, [[0, 1], [0, 2], [1, 0], [1, 2]]) + grid_moore.place_agents(fix2_AgentSet, [[0, 1], [0, 2], [1, 0], [1, 2]]) unique_ids = get_unique_ids(grid_moore.model) dist = grid_moore.get_distances( agents0=unique_ids[[0, 1]], agents1=unique_ids[[4, 5]] @@ -251,29 +245,27 @@ def test_get_distances( dist.select(pl.col("distance")).to_series().to_list(), [1.0, np.sqrt(2)] ) - # Test with two AgentSetDFs + # Test with two AgentSets grid_moore.place_agents(unique_ids[[2, 3]], [[1, 1], [2, 2]]) - dist = grid_moore.get_distances( - agents0=fix1_AgentSetPolars, agents1=fix2_AgentSetPolars - ) + dist = grid_moore.get_distances(agents0=fix1_AgentSet, agents1=fix2_AgentSet) assert isinstance(dist, pl.DataFrame) assert np.allclose( dist.select(pl.col("distance")).to_series().to_list(), [1.0, np.sqrt(2), 1.0, 1.0], ) - # Test with AgentsDF + # Test with AgentSetRegistry dist = grid_moore.get_distances( - agents0=grid_moore.model.agents, agents1=grid_moore.model.agents + agents0=grid_moore.model.sets, agents1=grid_moore.model.sets ) assert grid_moore._df_all(dist == 0).all() def test_get_neighborhood( self, - grid_moore: GridPolars, - grid_hexagonal: GridPolars, - grid_von_neumann: GridPolars, - grid_moore_torus: GridPolars, + grid_moore: Grid, + grid_hexagonal: Grid, + grid_von_neumann: Grid, + grid_moore_torus: Grid, ): # Test with radius = int, pos=GridCoordinate neighborhood = grid_moore.get_neighborhood(radius=1, pos=[1, 1]) @@ -621,11 +613,11 @@ def test_get_neighborhood( def test_get_neighbors( self, - fix2_AgentSetPolars: ExampleAgentSetPolars, - grid_moore: GridPolars, - grid_hexagonal: GridPolars, - grid_von_neumann: GridPolars, - grid_moore_torus: GridPolars, + fix2_AgentSet: ExampleAgentSet, + grid_moore: Grid, + grid_hexagonal: Grid, + grid_von_neumann: Grid, + grid_moore_torus: Grid, ): # Place agents in the grid unique_ids = get_unique_ids(grid_moore.model) @@ -759,7 +751,7 @@ def test_get_neighbors( check_column_order=False, ) - def test_is_available(self, grid_moore: GridPolars): + def test_is_available(self, grid_moore: Grid): # Test with GridCoordinate result = grid_moore.is_available([0, 0]) assert isinstance(result, pl.DataFrame) @@ -771,7 +763,7 @@ def test_is_available(self, grid_moore: GridPolars): result = grid_moore.is_available([[0, 0], [1, 1]]) assert result.select(pl.col("available")).to_series().to_list() == [False, True] - def test_is_empty(self, grid_moore: GridPolars): + def test_is_empty(self, grid_moore: Grid): # Test with GridCoordinate result = grid_moore.is_empty([0, 0]) assert isinstance(result, pl.DataFrame) @@ -783,7 +775,7 @@ def test_is_empty(self, grid_moore: GridPolars): result = grid_moore.is_empty([[0, 0], [1, 1]]) assert result.select(pl.col("empty")).to_series().to_list() == [False, False] - def test_is_full(self, grid_moore: GridPolars): + def test_is_full(self, grid_moore: Grid): # Test with GridCoordinate result = grid_moore.is_full([0, 0]) assert isinstance(result, pl.DataFrame) @@ -797,9 +789,9 @@ def test_is_full(self, grid_moore: GridPolars): def test_move_agents( self, - grid_moore: GridPolars, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + grid_moore: Grid, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): # Test with IdsLike unique_ids = get_unique_ids(grid_moore.model) @@ -813,10 +805,10 @@ def test_move_agents( check_row_order=False, ) - # Test with AgentSetDF + # Test with AgentSet with pytest.warns(RuntimeWarning): space = grid_moore.move_agents( - agents=fix2_AgentSetPolars, + agents=fix2_AgentSet, pos=[[0, 0], [1, 0], [2, 0], [0, 1]], inplace=False, ) @@ -833,10 +825,10 @@ def test_move_agents( check_row_order=False, ) - # Test with Collection[AgentSetDF] + # Test with Collection[AgentSet] with pytest.warns(RuntimeWarning): space = grid_moore.move_agents( - agents=[fix1_AgentSetPolars, fix2_AgentSetPolars], + agents=[fix1_AgentSet, fix2_AgentSet], pos=[[0, 2], [1, 2], [2, 2], [0, 1], [1, 1], [2, 1], [0, 0], [1, 0]], inplace=False, ) @@ -859,7 +851,7 @@ def test_move_agents( agents=unique_ids[[0, 1]], pos=[[0, 0], [1, 1], [2, 2]], inplace=False ) - # Test with AgentsDF, pos=DataFrame + # Test with AgentSetRegistry, pos=DataFrame pos = pl.DataFrame( { "dim_0": [0, 1, 2, 0, 1, 2, 0, 1], @@ -869,7 +861,7 @@ def test_move_agents( with pytest.warns(RuntimeWarning): space = grid_moore.move_agents( - agents=grid_moore.model.agents, + agents=grid_moore.model.sets, pos=pos, inplace=False, ) @@ -898,7 +890,7 @@ def test_move_agents( check_row_order=False, ) - def test_move_to_available(self, grid_moore: GridPolars): + def test_move_to_available(self, grid_moore: Grid): # Test with GridCoordinate unique_ids = get_unique_ids(grid_moore.model) last = None @@ -939,12 +931,12 @@ def test_move_to_available(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0", "dim_1")).to_numpy() assert different - # Test with AgentSetDF + # Test with AgentSet last = None different = False for _ in range(10): available_cells = grid_moore.available_cells - space = grid_moore.move_to_available(grid_moore.model.agents, inplace=False) + space = grid_moore.move_to_available(grid_moore.model.sets, inplace=False) if last is not None and not different: if (space.agents.select(pl.col("dim_0")).to_numpy() != last).any(): different = True @@ -958,7 +950,7 @@ def test_move_to_available(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0")).to_numpy() assert different - def test_move_to_empty(self, grid_moore: GridPolars): + def test_move_to_empty(self, grid_moore: Grid): # Test with GridCoordinate unique_ids = get_unique_ids(grid_moore.model) last = None @@ -999,12 +991,12 @@ def test_move_to_empty(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0", "dim_1")).to_numpy() assert different - # Test with AgentSetDF + # Test with AgentSet last = None different = False for _ in range(10): empty_cells = grid_moore.empty_cells - space = grid_moore.move_to_empty(grid_moore.model.agents, inplace=False) + space = grid_moore.move_to_empty(grid_moore.model.sets, inplace=False) if last is not None and not different: if (space.agents.select(pl.col("dim_0")).to_numpy() != last).any(): different = True @@ -1018,7 +1010,7 @@ def test_move_to_empty(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0")).to_numpy() assert different - def test_out_of_bounds(self, grid_moore: GridPolars): + def test_out_of_bounds(self, grid_moore: Grid): # Test with GridCoordinate out_of_bounds = grid_moore.out_of_bounds([11, 11]) assert isinstance(out_of_bounds, pl.DataFrame) @@ -1036,9 +1028,9 @@ def test_out_of_bounds(self, grid_moore: GridPolars): def test_place_agents( self, - grid_moore: GridPolars, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + grid_moore: Grid, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): # Test with IdsLike unique_ids = get_unique_ids(grid_moore.model) @@ -1069,9 +1061,9 @@ def test_place_agents( inplace=False, ) - # Test with AgentSetDF + # Test with AgentSet space = grid_moore.place_agents( - agents=fix2_AgentSetPolars, + agents=fix2_AgentSet, pos=[[0, 0], [1, 0], [2, 0], [0, 1]], inplace=False, ) @@ -1113,10 +1105,10 @@ def test_place_agents( check_row_order=False, ) - # Test with Collection[AgentSetDF] + # Test with Collection[AgentSet] with pytest.warns(RuntimeWarning): space = grid_moore.place_agents( - agents=[fix1_AgentSetPolars, fix2_AgentSetPolars], + agents=[fix1_AgentSet, fix2_AgentSet], pos=[[0, 2], [1, 2], [2, 2], [0, 1], [1, 1], [2, 1], [0, 0], [1, 0]], inplace=False, ) @@ -1163,7 +1155,7 @@ def test_place_agents( check_row_order=False, ) - # Test with AgentsDF, pos=DataFrame + # Test with AgentSetRegistry, pos=DataFrame pos = pl.DataFrame( { "dim_0": [0, 1, 2, 0, 1, 2, 0, 1], @@ -1172,7 +1164,7 @@ def test_place_agents( ) with pytest.warns(RuntimeWarning): space = grid_moore.place_agents( - agents=grid_moore.model.agents, + agents=grid_moore.model.sets, pos=pos, inplace=False, ) @@ -1233,7 +1225,7 @@ def test_place_agents( check_row_order=False, ) - def test_place_to_available(self, grid_moore: GridPolars): + def test_place_to_available(self, grid_moore: Grid): # Test with GridCoordinate unique_ids = get_unique_ids(grid_moore.model) last = None @@ -1274,14 +1266,12 @@ def test_place_to_available(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0", "dim_1")).to_numpy() assert different - # Test with AgentSetDF + # Test with AgentSet last = None different = False for _ in range(10): available_cells = grid_moore.available_cells - space = grid_moore.place_to_available( - grid_moore.model.agents, inplace=False - ) + space = grid_moore.place_to_available(grid_moore.model.sets, inplace=False) if last is not None and not different: if (space.agents.select(pl.col("dim_0")).to_numpy() != last).any(): different = True @@ -1295,7 +1285,7 @@ def test_place_to_available(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0")).to_numpy() assert different - def test_place_to_empty(self, grid_moore: GridPolars): + def test_place_to_empty(self, grid_moore: Grid): # Test with GridCoordinate unique_ids = get_unique_ids(grid_moore.model) last = None @@ -1336,12 +1326,12 @@ def test_place_to_empty(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0", "dim_1")).to_numpy() assert different - # Test with AgentSetDF + # Test with AgentSet last = None different = False for _ in range(10): empty_cells = grid_moore.empty_cells - space = grid_moore.place_to_empty(grid_moore.model.agents, inplace=False) + space = grid_moore.place_to_empty(grid_moore.model.sets, inplace=False) if last is not None and not different: if (space.agents.select(pl.col("dim_0")).to_numpy() != last).any(): different = True @@ -1355,7 +1345,7 @@ def test_place_to_empty(self, grid_moore: GridPolars): last = space.agents.select(pl.col("dim_0")).to_numpy() assert different - def test_random_agents(self, grid_moore: GridPolars): + def test_random_agents(self, grid_moore: Grid): different = False agents0 = grid_moore.random_agents(1) for _ in range(100): @@ -1365,7 +1355,7 @@ def test_random_agents(self, grid_moore: GridPolars): break assert different - def test_random_pos(self, grid_moore: GridPolars): + def test_random_pos(self, grid_moore: Grid): different = False last = None for _ in range(10): @@ -1388,9 +1378,9 @@ def test_random_pos(self, grid_moore: GridPolars): def test_remove_agents( self, - grid_moore: GridPolars, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + grid_moore: Grid, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): unique_ids = get_unique_ids(grid_moore.model) grid_moore.move_agents( @@ -1416,11 +1406,11 @@ def test_remove_agents( ].to_list() ) assert [ - x for id in space.model.agents.index.values() for x in id.to_list() + x for id in space.model.sets.index.values() for x in id.to_list() ] == unique_ids[:8].to_list() - # Test with AgentSetDF - space = grid_moore.remove_agents(fix1_AgentSetPolars, inplace=False) + # Test with AgentSet + space = grid_moore.remove_agents(fix1_AgentSet, inplace=False) assert space.agents.shape == (4, 3) assert space.remaining_capacity == capacity + 4 assert ( @@ -1435,27 +1425,25 @@ def test_remove_agents( ].to_list() ) assert [ - x for id in space.model.agents.index.values() for x in id.to_list() + x for id in space.model.sets.index.values() for x in id.to_list() ] == unique_ids[:8].to_list() - # Test with Collection[AgentSetDF] - space = grid_moore.remove_agents( - [fix1_AgentSetPolars, fix2_AgentSetPolars], inplace=False - ) + # Test with Collection[AgentSet] + space = grid_moore.remove_agents([fix1_AgentSet, fix2_AgentSet], inplace=False) assert [ - x for id in space.model.agents.index.values() for x in id.to_list() + x for id in space.model.sets.index.values() for x in id.to_list() ] == unique_ids[:8].to_list() assert space.agents.is_empty() assert space.remaining_capacity == capacity + 8 - # Test with AgentsDF - space = grid_moore.remove_agents(grid_moore.model.agents, inplace=False) + # Test with AgentSetRegistry + space = grid_moore.remove_agents(grid_moore.model.sets, inplace=False) assert space.remaining_capacity == capacity + 8 assert space.agents.is_empty() assert [ - x for id in space.model.agents.index.values() for x in id.to_list() + x for id in space.model.sets.index.values() for x in id.to_list() ] == unique_ids[:8].to_list() - def test_sample_cells(self, grid_moore: GridPolars): + def test_sample_cells(self, grid_moore: Grid): # Test with default parameters replacement = False same = True @@ -1532,9 +1520,9 @@ def test_sample_cells(self, grid_moore: GridPolars): with pytest.raises(AssertionError): grid_moore.sample_cells(3, cell_type="full", with_replacement=False) - def test_set_cells(self, model: ModelDF): - # Initialize GridPolars - grid_moore = GridPolars(model, dimensions=[3, 3], capacity=2) + def test_set_cells(self, model: Model): + # Initialize Grid + grid_moore = Grid(model, dimensions=[3, 3], capacity=2) # Test with GridCoordinate grid_moore.set_cells( @@ -1583,9 +1571,9 @@ def test_set_cells(self, model: ModelDF): def test_swap_agents( self, - grid_moore: GridPolars, - fix1_AgentSetPolars: ExampleAgentSetPolars, - fix2_AgentSetPolars: ExampleAgentSetPolars, + grid_moore: Grid, + fix1_AgentSet: ExampleAgentSet, + fix2_AgentSet: ExampleAgentSet, ): unique_ids = get_unique_ids(grid_moore.model) grid_moore.move_agents( @@ -1612,10 +1600,8 @@ def test_swap_agents( space.agents.filter(pl.col("agent_id") == unique_ids[3]).row(0)[1:] == grid_moore.agents.filter(pl.col("agent_id") == unique_ids[1]).row(0)[1:] ) - # Test with AgentSetDFs - space = grid_moore.swap_agents( - fix1_AgentSetPolars, fix2_AgentSetPolars, inplace=False - ) + # Test with AgentSets + space = grid_moore.swap_agents(fix1_AgentSet, fix2_AgentSet, inplace=False) assert ( space.agents.filter(pl.col("agent_id") == unique_ids[0]).row(0)[1:] == grid_moore.agents.filter(pl.col("agent_id") == unique_ids[4]).row(0)[1:] @@ -1633,7 +1619,7 @@ def test_swap_agents( == grid_moore.agents.filter(pl.col("agent_id") == unique_ids[7]).row(0)[1:] ) - def test_torus_adj(self, grid_moore: GridPolars, grid_moore_torus: GridPolars): + def test_torus_adj(self, grid_moore: Grid, grid_moore_torus: Grid): # Test with non-toroidal grid with pytest.raises(ValueError): grid_moore.torus_adj([10, 10]) @@ -1653,7 +1639,7 @@ def test_torus_adj(self, grid_moore: GridPolars, grid_moore_torus: GridPolars): assert adj_df.row(0) == (1, 2) assert adj_df.row(1) == (0, 2) - def test___getitem__(self, grid_moore: GridPolars): + def test___getitem__(self, grid_moore: Grid): # Test out of bounds with pytest.raises(ValueError): grid_moore[[5, 5]] @@ -1691,7 +1677,7 @@ def test___getitem__(self, grid_moore: GridPolars): check_dtypes=False, ) - def test___setitem__(self, grid_moore: GridPolars): + def test___setitem__(self, grid_moore: Grid): # Test with out-of-bounds with pytest.raises(ValueError): grid_moore[[5, 5]] = {"capacity": 10} @@ -1706,7 +1692,7 @@ def test___setitem__(self, grid_moore: GridPolars): ).to_series().to_list() == [20, 20] # Property tests - def test_agents(self, grid_moore: GridPolars): + def test_agents(self, grid_moore: Grid): unique_ids = get_unique_ids(grid_moore.model) assert_frame_equal( grid_moore.agents, @@ -1715,13 +1701,13 @@ def test_agents(self, grid_moore: GridPolars): ), ) - def test_available_cells(self, grid_moore: GridPolars): + def test_available_cells(self, grid_moore: Grid): result = grid_moore.available_cells assert len(result) == 8 assert isinstance(result, pl.DataFrame) assert result.columns == ["dim_0", "dim_1"] - def test_cells(self, grid_moore: GridPolars): + def test_cells(self, grid_moore: Grid): result = grid_moore.cells unique_ids = get_unique_ids(grid_moore.model) assert_frame_equal( @@ -1738,17 +1724,17 @@ def test_cells(self, grid_moore: GridPolars): check_dtypes=False, ) - def test_dimensions(self, grid_moore: GridPolars): + def test_dimensions(self, grid_moore: Grid): assert isinstance(grid_moore.dimensions, list) assert len(grid_moore.dimensions) == 2 - def test_empty_cells(self, grid_moore: GridPolars): + def test_empty_cells(self, grid_moore: Grid): result = grid_moore.empty_cells assert len(result) == 7 assert isinstance(result, pl.DataFrame) assert result.columns == ["dim_0", "dim_1"] - def test_full_cells(self, grid_moore: GridPolars): + def test_full_cells(self, grid_moore: Grid): grid_moore.set_cells([[0, 0], [1, 1]], {"capacity": 1}) result = grid_moore.full_cells assert len(result) == 2 @@ -1765,27 +1751,27 @@ def test_full_cells(self, grid_moore: GridPolars): ) ).all() - def test_model(self, grid_moore: GridPolars, model: ModelDF): + def test_model(self, grid_moore: Grid, model: Model): assert grid_moore.model == model def test_neighborhood_type( self, - grid_moore: GridPolars, - grid_von_neumann: GridPolars, - grid_hexagonal: GridPolars, + grid_moore: Grid, + grid_von_neumann: Grid, + grid_hexagonal: Grid, ): assert grid_moore.neighborhood_type == "moore" assert grid_von_neumann.neighborhood_type == "von_neumann" assert grid_hexagonal.neighborhood_type == "hexagonal" - def test_random(self, grid_moore: GridPolars): + def test_random(self, grid_moore: Grid): assert grid_moore.random == grid_moore.model.random - def test_remaining_capacity(self, grid_moore: GridPolars): + def test_remaining_capacity(self, grid_moore: Grid): assert grid_moore.remaining_capacity == (3 * 3 * 2 - 2) - def test_torus(self, model: ModelDF, grid_moore: GridPolars): + def test_torus(self, model: Model, grid_moore: Grid): assert not grid_moore.torus - grid_2 = GridPolars(model, [3, 3], torus=True) + grid_2 = Grid(model, [3, 3], torus=True) assert grid_2.torus diff --git a/tests/test_modeldf.py b/tests/test_model.py similarity index 86% rename from tests/test_modeldf.py rename to tests/test_model.py index afc45405..34a7862b 100644 --- a/tests/test_modeldf.py +++ b/tests/test_model.py @@ -1,7 +1,7 @@ -from mesa_frames import ModelDF +from mesa_frames import Model -class CustomModel(ModelDF): +class CustomModel(Model): def __init__(self): super().__init__() self.custom_step_count = 0 @@ -10,9 +10,9 @@ def step(self): self.custom_step_count += 2 -class Test_ModelDF: +class Test_Model: def test_steps(self): - model = ModelDF() + model = Model() assert model.steps == 0 diff --git a/uv.lock b/uv.lock index f09db044..a72164c0 100644 --- a/uv.lock +++ b/uv.lock @@ -1258,6 +1258,7 @@ dev = [ docs = [ { name = "autodocsumm" }, { name = "beartype" }, + { name = "mesa" }, { name = "mkdocs-git-revision-date-localized-plugin" }, { name = "mkdocs-include-markdown-plugin" }, { name = "mkdocs-jupyter" }, @@ -1319,6 +1320,7 @@ dev = [ docs = [ { name = "autodocsumm", specifier = ">=0.2.14" }, { name = "beartype", specifier = ">=0.21.0" }, + { name = "mesa", specifier = ">=3.2.0" }, { name = "mkdocs-git-revision-date-localized-plugin", specifier = ">=1.4.7" }, { name = "mkdocs-include-markdown-plugin", specifier = ">=7.1.5" }, { name = "mkdocs-jupyter", specifier = ">=0.25.1" },