Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from Muscat.Containers.MeshFieldOperations import GetFieldTransferOp
from Muscat.FE.Fields.FEField import FEField
from Muscat.Bridges.CGNSBridge import MeshToCGNS,CGNSToMesh
import Muscat.Containers.ElementsDescription as ED
from Muscat.Containers.ConstantRectilinearMeshTools import CreateConstantRectilinearMesh
from Muscat.Containers.MeshTetrahedrization import Tetrahedrization
from Muscat.Containers.MeshModificationTools import ComputeSkin
Expand Down
1 change: 0 additions & 1 deletion benchmarks/FNO/Rotor37/prepare_rotor37.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from Muscat.FE.Fields.FEField import FEField
from Muscat.Bridges.CGNSBridge import MeshToCGNS,CGNSToMesh
from Muscat.Containers.ConstantRectilinearMeshTools import CreateConstantRectilinearMesh
from Muscat.Containers.MeshTetrahedrization import Tetrahedrization
from Muscat.Containers.MeshModificationTools import ComputeSkin
from Muscat.FE.FETools import PrepareFEComputation
from Muscat.FE.FETools import ComputeNormalsAtPoints
Expand Down
3 changes: 0 additions & 3 deletions benchmarks/FNO/VKI-LS59/prepare_vki.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from plaid.problem_definition import ProblemDefinition
from plaid.containers.sample import Sample
import numpy as np
from Muscat.Bridges.CGNSBridge import MeshToCGNS
import Muscat.Containers.ElementsDescription as ED
from Muscat.Containers.ConstantRectilinearMeshTools import CreateConstantRectilinearMesh
from Muscat.Containers.MeshTetrahedrization import Tetrahedrization
from Muscat.Containers import MeshCreationTools as MCT
import os, time, shutil
from tqdm import tqdm

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/FNO/VKI-LS59/train_and_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@


import torch
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import Dataset

class GridDataset(Dataset):
def __init__(self, inputs, outputs):
Expand Down
87 changes: 87 additions & 0 deletions src/plaid/containers/collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Module for implementing collections of features within a Sample."""

import logging
from typing import Optional, Union

from plaid.types import Scalar

logger = logging.getLogger(__name__)
logging.basicConfig(
format="[%(asctime)s:%(levelname)s:%(filename)s:%(funcName)s(%(lineno)d)]:%(message)s",
level=logging.INFO,
)


def _check_names(names: Union[str, list[str]]):
"""Check that names do not contain invalid character ``/``.

Args:
names (Union[str, list[str]]): The names to check.

Raises:
ValueError: If any name contains the invalid character ``/``.
"""
if isinstance(names, str):
names = [names]
for name in names:
if (name is not None) and ("/" in name):
raise ValueError(
f"feature_names containing `/` are not allowed, but {name=}, you should first replace any occurence of `/` with something else, for example: `name.replace('/','__')`"
)


class SampleScalars:
"""A container for scalar features within a Sample.

Provides dict-like operations for adding, retrieving, and removing scalars.
Names must be unique and may not contain the character ``/``.
"""

def __init__(self, scalars: Optional[dict[str, Scalar]]) -> None:
self.data: dict[str, Scalar] = scalars if scalars is not None else {}

def add(self, name: str, value: Scalar) -> None:
"""Add a scalar value to a dictionary.

Args:
name (str): The name of the scalar value.
value (Scalar): The scalar value to add or update in the dictionary.
"""
_check_names(name)
self.data[name] = value

def remove(self, name: str) -> Scalar:
"""Delete a scalar value from the dictionary.

Args:
name (str): The name of the scalar value to be deleted.

Raises:
KeyError: Raised when there is no scalar / there is no scalar with the provided name.

Returns:
Scalar: The value of the deleted scalar.
"""
if name not in self.data:
raise KeyError(f"There is no scalar value with name {name}.")

return self.data.pop(name)

def get(self, name: str) -> Optional[Scalar]:
"""Retrieve a scalar value associated with the given name.

Args:
name (str): The name of the scalar value to retrieve.

Returns:
Scalar or None: The scalar value associated with the given name, or None if the name is not found.
"""
return self.data.get(name)

def get_names(self) -> list[str]:
"""Get a set of scalar names available in the object.

Returns:
list[str]: A set containing the names of the available scalars.
"""
return sorted(self.data.keys())
Loading
Loading