From 04a586b1b1304704562062bd2c47a41b26d56905 Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Mon, 10 Mar 2025 13:46:21 +0800 Subject: [PATCH 1/6] feature(nyz): add basic math reward model interfaces --- ding/reward_model/__init__.py | 3 + ding/reward_model/math_reward_model.py | 45 +++++++ ding/reward_model/math_rule_reward_model.py | 122 ++++++++++++++++++ .../tests/test_math_rule_reward_model.py | 20 +++ 4 files changed, 190 insertions(+) create mode 100644 ding/reward_model/math_reward_model.py create mode 100644 ding/reward_model/math_rule_reward_model.py create mode 100644 ding/reward_model/tests/test_math_rule_reward_model.py diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py index 4538102861..fa3f344ec2 100644 --- a/ding/reward_model/__init__.py +++ b/ding/reward_model/__init__.py @@ -13,3 +13,6 @@ from .guided_cost_reward_model import GuidedCostRewardModel from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel from .icm_reward_model import ICMRewardModel +# LLM/VLM reward model and verifier +from .math_reward_model import MathRewardModel +from .math_rule_reward_model import MathRuleRewardModel \ No newline at end of file diff --git a/ding/reward_model/math_reward_model.py b/ding/reward_model/math_reward_model.py new file mode 100644 index 0000000000..284ca83d58 --- /dev/null +++ b/ding/reward_model/math_reward_model.py @@ -0,0 +1,45 @@ +from typing import Tuple, Optional, List, Dict +from easydict import EasyDict +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer +import re + +from ding.utils import REWARD_MODEL_REGISTRY +from .base_reward_model import BaseRewardModel + + +@REWARD_MODEL_REGISTRY.register('math') +class MathRewardModel(BaseRewardModel): + config = dict( + # (str) The type of the reward model. + type='math', + # (str) The name of the tokenizer, usually the huggingface tokenizer name. + tokenizer_name='Qwen/Qwen2.5-Math-PRM-7B', + ) + + def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa + self.cfg = config + self.device = device + self.logger = logger + self.tb_logger = tb_logger + + def estimate(self, data: List[str]) -> List[Dict]: + """ + Arguments: + - data (:obj:`List[str]`): The list of data queries used for estimation, each query is a string. + of the \ + form "1 + 1 = ?" + Returns: + - reward (:obj:`List[Dict]`): The estimated reward. + """ + pass + + # rule-based reward model does not need training, thus the following methods are empty + def train(self): + pass + + def collect_data(self, data: list) -> None: + pass + + def clear_data(self) -> None: + pass \ No newline at end of file diff --git a/ding/reward_model/math_rule_reward_model.py b/ding/reward_model/math_rule_reward_model.py new file mode 100644 index 0000000000..1b38f0d8c5 --- /dev/null +++ b/ding/reward_model/math_rule_reward_model.py @@ -0,0 +1,122 @@ +from typing import Tuple, Optional, List, Dict +from easydict import EasyDict +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer +import re + +from ding.utils import REWARD_MODEL_REGISTRY +from .base_reward_model import BaseRewardModel + + +@REWARD_MODEL_REGISTRY.register('math_rule') +class MathRuleRewardModel(BaseRewardModel): + config = dict( + # (str) The type of the reward model. + type='math_rule', + # (str) The name of the dataset, usually the huggingface dataset name. + dataset_name='', + # (str) The name of the tokenizer, usually the huggingface tokenizer name. + tokenizer_name='', + # (float) The score of format error. + format_error_reward=-2, + # (float) The score of answer error. + answer_error_reward=-1, + # (float) The score of correct. + correct_reward=1, + ) + + def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa + self.cfg = config + self.device = device + self.logger = logger + self.tb_logger = tb_logger + + def estimate(self, data: List[str]) -> List[Dict]: + """ + Arguments: + - data (:obj:`List[str]`): The list of data queries used for estimation, each query is a string of the \ + form "1 + 1 = ?" + Returns: + - reward (:obj:`List[Dict]`): The estimated reward. + """ + # 1. parse the query to get question and predicted answer + # 2. get the ground truth answer according to the question + # 3. calculate the reward based on the predicted answer and the ground truth answer (format error -2, answer error -1, correct 1) + pass + + # rule-based reward model does not need training, thus the following methods are empty + def train(self): + pass + + def collect_data(self, data: list) -> None: + pass + + def clear_data(self) -> None: + pass + + +def strip_sequence(text: str, pad_token: str, eos_token: str) -> str: + """ + Overview: + Remove leading and trailing sequences of padding/eos tokens from a text. + + .. note:: + This function uses regular expressions to strip all consecutive occurrences + of the specified padding and end-of-sequence tokens from both the beginning + and end of the input text. Tokens in the middle of the text are preserved. + + Arguments: + - text (str): The input text to be processed. + - pad_token (str): The padding token to be stripped (e.g., ""). + - eos_token (str): The end-of-sequence token to be stripped (e.g., ""). + + Returns: + - cleaned_text (str): The cleaned text with leading/trailing padding/eos tokens removed. + + Examples: + >>> strip_sequence("Hello", "", "") + 'Hello' + + >>> strip_sequence("TestMiddleKeep", "", "") + 'TestMiddleKeep' + + >>> strip_sequence("Full removal", "", "") + 'Full removal' + + >>> strip_sequence("No tokens here", "", "") + 'No tokens here' + + >>> strip_sequence("", "", "") + '' + """ + pad_token_escaped = re.escape(pad_token) + eos_token_escaped = re.escape(eos_token) + + # Remove leading tokens + pattern = f"^({eos_token_escaped}|{pad_token_escaped})+" + text = re.sub(pattern, "", text) + + # Remove trailing tokens + pattern = f"({eos_token_escaped}|{pad_token_escaped})+$" + text = re.sub(pattern, "", text) + return text + + +def normalize_text(text: str) -> str: + """ + Overview: + This function is designed to standardize text by: + - Converting all text to lowercase + - Replacing various punctuation marks and special characters with spaces + - Removing import statements + - Normalizing whitespace by replacing multiple spaces with a single space + - Stripping leading and trailing whitespace + Arguments: + - text (str): The input text to be processed. + Returns: + - normalized_text (str): The normalized text. + """ + text = re.sub("[,.:\"'\[\]\-=\+\\|!@#$%^&*();<>?/!¥…()—\{\}:”“《》?]", " ", text.lower()) + text = re.sub("import\s[a-zA-Z\.]+(\sas\s[a-zA-Z\.]+)\n", " ", text) + text = re.sub("\s+", " ", text) + return text.strip() diff --git a/ding/reward_model/tests/test_math_rule_reward_model.py b/ding/reward_model/tests/test_math_rule_reward_model.py new file mode 100644 index 0000000000..b79b05725e --- /dev/null +++ b/ding/reward_model/tests/test_math_rule_reward_model.py @@ -0,0 +1,20 @@ +import pytest +from easydict import EasyDict + +from ding.reward_model import MathRuleRewardModel + + +@pytest.mark.envtest +def test_math_rule_reward_model(): + reward_model = MathRuleRewardModel( + config=EasyDict( + dataset_name='RUC-AIBOX/STILL-3-Preview-RL-Data', + tokenizer_name='unsloth/Meta-Llama-3.1-8B', + ) + ) + + data = [ + "The school now introduces a new color, silver, for the flag design. Crestview's school colors are now purple, gold, and silver. The students are designing a flag using three solid-colored horizontal stripes. Using one, two, or all three of the school colors, how many different flags are possible if adjacent stripes may be the same color?", # noqa + ] + rewards = reward_model.estimate(data) + assert len(rewards) == len(data) From ab5f6e791a63a9e327cc6f0b492b82728242b82e Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Mon, 10 Mar 2025 13:53:11 +0800 Subject: [PATCH 2/6] style(nyz): polish flake8 style --- ding/reward_model/__init__.py | 2 +- ding/reward_model/math_reward_model.py | 2 +- ding/reward_model/math_rule_reward_model.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py index fa3f344ec2..9b06a7b109 100644 --- a/ding/reward_model/__init__.py +++ b/ding/reward_model/__init__.py @@ -15,4 +15,4 @@ from .icm_reward_model import ICMRewardModel # LLM/VLM reward model and verifier from .math_reward_model import MathRewardModel -from .math_rule_reward_model import MathRuleRewardModel \ No newline at end of file +from .math_rule_reward_model import MathRuleRewardModel diff --git a/ding/reward_model/math_reward_model.py b/ding/reward_model/math_reward_model.py index 284ca83d58..95fa16f2bf 100644 --- a/ding/reward_model/math_reward_model.py +++ b/ding/reward_model/math_reward_model.py @@ -42,4 +42,4 @@ def collect_data(self, data: list) -> None: pass def clear_data(self) -> None: - pass \ No newline at end of file + pass diff --git a/ding/reward_model/math_rule_reward_model.py b/ding/reward_model/math_rule_reward_model.py index 1b38f0d8c5..92b32f1279 100644 --- a/ding/reward_model/math_rule_reward_model.py +++ b/ding/reward_model/math_rule_reward_model.py @@ -41,7 +41,8 @@ def estimate(self, data: List[str]) -> List[Dict]: """ # 1. parse the query to get question and predicted answer # 2. get the ground truth answer according to the question - # 3. calculate the reward based on the predicted answer and the ground truth answer (format error -2, answer error -1, correct 1) + # 3. calculate the reward based on the predicted answer and the ground truth answer + # (format error -2, answer error -1, correct 1) pass # rule-based reward model does not need training, thus the following methods are empty From 60d88f9fc01b00ec8ce30812124a9db9760aa75a Mon Sep 17 00:00:00 2001 From: Berit-chengyi <2826895005@qq.com> Date: Wed, 12 Mar 2025 15:07:54 +0800 Subject: [PATCH 3/6] (dcy) add math_reward_model and its test file --- ding/reward_model/math_reward_model.py | 126 ++++++++++++++++-- .../tests/test_math_reward_model.py | 87 ++++++++++++ 2 files changed, 203 insertions(+), 10 deletions(-) create mode 100644 ding/reward_model/tests/test_math_reward_model.py diff --git a/ding/reward_model/math_reward_model.py b/ding/reward_model/math_reward_model.py index 95fa16f2bf..ccd4aacf32 100644 --- a/ding/reward_model/math_reward_model.py +++ b/ding/reward_model/math_reward_model.py @@ -1,7 +1,9 @@ from typing import Tuple, Optional, List, Dict from easydict import EasyDict from torch.utils.tensorboard import SummaryWriter -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModel +import torch +import torch.nn.functional as F import re from ding.utils import REWARD_MODEL_REGISTRY @@ -13,8 +15,8 @@ class MathRewardModel(BaseRewardModel): config = dict( # (str) The type of the reward model. type='math', - # (str) The name of the tokenizer, usually the huggingface tokenizer name. - tokenizer_name='Qwen/Qwen2.5-Math-PRM-7B', + # (str) The name of the tokenizer and model + model_name='Qwen/Qwen2.5-Math-PRM-7B', ) def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa @@ -23,23 +25,127 @@ def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWri self.logger = logger self.tb_logger = tb_logger - def estimate(self, data: List[str]) -> List[Dict]: + # 初始化tokenizer和model + self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name, trust_remote_code=True) + self.model = AutoModel.from_pretrained( + self.cfg.model_name, device_map=self.device, torch_dtype=torch.bfloat16, trust_remote_code=True + ) + self.model.eval() + + def make_step_rewards(self, logits: torch.Tensor, token_masks: torch.Tensor) -> List[List[float]]: + """Calculate step-wise rewards from model outputs""" + probabilities = F.softmax(logits, dim=-1) + probabilities = probabilities * token_masks.unsqueeze(-1) # bs, seq_len, num_labels + + all_scores_res = [] + for i in range(probabilities.size(0)): + sample = probabilities[i] # seq_len, num_labels + positive_probs = sample[sample != 0].view(-1, 2)[:, 1] # valid_tokens, num_labels + non_zero_elements_list = positive_probs.cpu().tolist() + all_scores_res.append(non_zero_elements_list) + return all_scores_res + + def estimate(self, data: List[Dict]) -> List[Dict]: """ + Overview: + Estimate rewards for mathematical reasoning steps using Qwen2.5-Math-PRM-7B model. Arguments: - - data (:obj:`List[str]`): The list of data queries used for estimation, each query is a string. - of the \ - form "1 + 1 = ?" + - data (:obj:`List[Dict]`): List of dictionaries containing: + - system (:obj:`str`): System prompt for the model + - query (:obj:`str`): The mathematical query to be evaluated + - response (:obj:`List[str]`): List of reasoning steps Returns: - - reward (:obj:`List[Dict]`): The estimated reward. + - reward (:obj:`List[Dict]`): List of dictionaries containing: + - reward (:obj:`float`): Final reward (last step reward) + - metadata (:obj:`Dict`): Additional information including: + - query (:obj:`str`): Original query + - step_rewards (:obj:`List[float]`): Rewards for each reasoning step + - num_steps (:obj:`int`): Number of reasoning steps + Shapes: + - input_ids (:obj:`torch.LongTensor`): :math:`(B, L)`, where B is batch size and L is sequence length + - outputs (:obj:`torch.FloatTensor`): :math:`(B, L, H)`, where H is hidden size + - token_masks (:obj:`torch.BoolTensor`): :math:`(B, L)` + - step_rewards (:obj:`List[List[float]]`): List of length B, each containing S rewards where S is num steps + Examples: + >>> data = [{ + >>> "system": "Please reason step by step...", + >>> "query": "What is 1 + 1?", + >>> "response": ["First, we have 1", "Then add 1", "Therefore, 1 + 1 = 2"] + >>> }] + >>> results = model.estimate(data) + >>> print(results[0]["reward"]) # 1.0 + >>> print(results[0]["metadata"]["step_rewards"]) # [0.8, 0.9, 1.0] """ - pass + # 批量处理所有样本 + all_messages = [] + for item in data: + messages = [ + { + "role": "system", + "content": item['system'] + }, + { + "role": "user", + "content": item['query'] + }, + { + "role": "assistant", + "content": "".join(item['response']) + "" + }, + ] + all_messages.append(messages) + + # 批量转换为模型输入格式 + conversation_strs = [ + self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + for messages in all_messages + ] + + # 批量编码输入 + input_ids = self.tokenizer( + conversation_strs, return_tensors="pt", padding=True, truncation=True + )["input_ids"].to(self.model.device) + + # 批量获取模型输出 + with torch.no_grad(): + outputs = self.model(input_ids=input_ids) + + # 计算每个样本的步骤奖励 + step_sep_id = self.tokenizer.encode("")[0] + token_masks = (input_ids == step_sep_id) + batch_rewards = self.make_step_rewards(outputs[0], token_masks) + + # 构建详细的结果字典 + results = [] + for item, step_rewards in zip(data, batch_rewards): + results.append( + { + "reward": step_rewards[-1] if step_rewards else 0.0, # 最后一步的奖励作为总体奖励 + "metadata": { + "query": item['query'], + "step_rewards": step_rewards, # 每个步骤的奖励 + "num_steps": len(item['response']), + } + } + ) + + return results - # rule-based reward model does not need training, thus the following methods are empty def train(self): + """ + Training is not implemented for this reward model as it uses a pre-trained model + """ + self.logger.warning("Training is not implemented for this reward model") pass def collect_data(self, data: list) -> None: + """ + Data collection is not needed for this reward model + """ pass def clear_data(self) -> None: + """ + Data clearing is not needed for this reward model + """ pass diff --git a/ding/reward_model/tests/test_math_reward_model.py b/ding/reward_model/tests/test_math_reward_model.py new file mode 100644 index 0000000000..31c59ac85b --- /dev/null +++ b/ding/reward_model/tests/test_math_reward_model.py @@ -0,0 +1,87 @@ +import pytest +from easydict import EasyDict +import torch +from unittest.mock import MagicMock + +from ding.reward_model import MathRewardModel + + +@pytest.mark.envtest +def test_math_reward_model(): + # Create configuration + cfg = EasyDict(dict( + type='math', + model_name='Qwen/Qwen2.5-Math-PRM-7B', + )) + + # Create mock logger and tb_logger + logger = MagicMock() + tb_logger = MagicMock() + + # Initialize reward model + model = MathRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger) + + # Test case 1: Simple math problem + data_simple = [ + { + "system": "Please reason step by step...", + "query": "What is 1 + 1?", + "response": ["First, we have 1", "Then add 1", "Therefore, 1 + 1 = 2"] + } + ] + + # Test case 2: Complex word problem + data_complex = [ + { + "system": "Please reason step by step, and put your final answer within \\boxed{}.", + "query": "Sue lives in a fun neighborhood...", + "response": [ + "To find out how many more pink plastic flamingos...", + "On Saturday, they take back one third of the flamingos...", + "On Sunday, the neighbors add another 18 pink plastic flamingos...", + "To find the difference, subtract the number of white flamingos..." + ] + } + ] + + # Test simple case + results_simple = model.estimate(data_simple) + + # Verify simple case results + assert len(results_simple) == 1, "Should return one result" + assert "reward" in results_simple[0], "Result should contain reward" + assert "metadata" in results_simple[0], "Result should contain metadata" + assert "step_rewards" in results_simple[0]["metadata"], "Metadata should contain step_rewards" + assert len(results_simple[0]["metadata"]["step_rewards"]) == 3, "Should have 3 step rewards" + assert results_simple[0]["metadata"]["num_steps"] == 3, "Should have 3 steps" + + # Test complex case + results_complex = model.estimate(data_complex) + + # Verify complex case results + assert len(results_complex) == 1, "Should return one result" + assert "reward" in results_complex[0], "Result should contain reward" + assert "metadata" in results_complex[0], "Result should contain metadata" + assert "step_rewards" in results_complex[0]["metadata"], "Metadata should contain step_rewards" + assert len(results_complex[0]["metadata"]["step_rewards"]) == 4, "Should have 4 step rewards" + assert results_complex[0]["metadata"]["num_steps"] == 4, "Should have 4 steps" + + # Verify reward value ranges + for result in results_simple + results_complex: + assert 0 <= result["reward"] <= 1, "Reward should be between 0 and 1" + for step_reward in result["metadata"]["step_rewards"]: + assert 0 <= step_reward <= 1, "Step rewards should be between 0 and 1" + + # Test batch processing functionality + batch_data = data_simple + data_complex + batch_results = model.estimate(batch_data) + assert len(batch_results) == 2, "Should return two results for batch processing" + + # Print detailed information for debugging + print("\nSimple problem results:") + print(f"Final reward: {results_simple[0]['reward']}") + print(f"Step rewards: {results_simple[0]['metadata']['step_rewards']}") + + print("\nComplex problem results:") + print(f"Final reward: {results_complex[0]['reward']}") + print(f"Step rewards: {results_complex[0]['metadata']['step_rewards']}") From 68db31ffdc234a3b030f04211c7bc51b3cc56a49 Mon Sep 17 00:00:00 2001 From: Berit-chengyi <2826895005@qq.com> Date: Sun, 16 Mar 2025 12:11:17 +0000 Subject: [PATCH 4/6] (dcy) add math_rule_reward_model and its test file --- ding/reward_model/math_rule_reward_model.py | 687 ++++++++++++++++-- .../tests/test_math_rule_reward_model.py | 111 ++- 2 files changed, 736 insertions(+), 62 deletions(-) diff --git a/ding/reward_model/math_rule_reward_model.py b/ding/reward_model/math_rule_reward_model.py index 92b32f1279..33aafbeb6b 100644 --- a/ding/reward_model/math_rule_reward_model.py +++ b/ding/reward_model/math_rule_reward_model.py @@ -1,8 +1,10 @@ -from typing import Tuple, Optional, List, Dict +from typing import Tuple, Optional, List, Dict, Union, Any from easydict import EasyDict from torch.utils.tensorboard import SummaryWriter from transformers import AutoTokenizer import re +import math +import json from ding.utils import REWARD_MODEL_REGISTRY from .base_reward_model import BaseRewardModel @@ -10,6 +12,10 @@ @REWARD_MODEL_REGISTRY.register('math_rule') class MathRuleRewardModel(BaseRewardModel): + """ + Math rule-based reward model for evaluating mathematical answers. + Supports various mathematical expression formats including LaTeX, fractions, percentages, etc. + """ config = dict( # (str) The type of the reward model. type='math_rule', @@ -23,72 +29,646 @@ class MathRuleRewardModel(BaseRewardModel): answer_error_reward=-1, # (float) The score of correct. correct_reward=1, + # (float) Relative tolerance for numerical comparison + rel_tol=1e-5, + # (float) Absolute tolerance for numerical comparison + abs_tol=1e-8, ) - def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: # noqa + def __init__( + self, + config: EasyDict, + device: str = 'cpu', + logger=None, + tb_logger: 'SummaryWriter' = None + ) -> None: # noqa + """Initialize the math rule reward model""" self.cfg = config self.device = device self.logger = logger self.tb_logger = tb_logger - def estimate(self, data: List[str]) -> List[Dict]: + # Initialize tokenizer + if hasattr(config, 'tokenizer_name') and config.tokenizer_name: + self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + self.pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]" + self.eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[EOS]" + else: + self.tokenizer = None + self.pad_token = "[PAD]" + self.eos_token = "[EOS]" + + def _process_target_answer(self, text: str) -> Optional[float]: + """Process target answer text and convert to numerical value""" + if text is None or not text.strip(): + return None + + # Clean and normalize text + if self.tokenizer: + text = strip_sequence(text, self.pad_token, self.eos_token) + text = normalize_text(text) + + # Try to process the mathematical expression + try: + result = self._process_math_expression(text) + if result is not None: + return result + except Exception as e: + if self.logger: + self.logger.warning(f"Error processing target answer: {e}") + return None + + def _process_response_answer(self, response: str) -> Tuple[Optional[float], Optional[str]]: + """Process response text, extract and convert to numerical value""" + if response is None or not response.strip(): + return None, None + + # Clean text + if self.tokenizer: + response = strip_sequence(response, self.pad_token, self.eos_token) + + # First try to extract the final answer + final_answer = self._extract_final_answer(response) + + # If a final answer is extracted, try to process it + if final_answer: + try: + value = self._process_math_expression(final_answer) + if value is not None: + return value, final_answer + except Exception as e: + if self.logger: + self.logger.debug(f"Error processing final answer: {e}") + + # If unable to get a valid value from the final answer, try to extract all possible expressions + expressions = self._extract_all_expressions(response) + + # Try to process each expression until a valid answer is found + for expr in expressions: + try: + value = self._process_math_expression(expr) + if value is not None: + return value, expr + except Exception as e: + if self.logger: + self.logger.debug(f"Error processing expression '{expr}': {e}") + + # If all attempts fail, return None + return None, None + + def _check_answer_match(self, pred: Optional[float], target: Optional[float]) -> bool: + """Check if two answers match within tolerance""" + if pred is None or target is None: + return False + try: + return math.isclose( + pred, target, rel_tol=self.cfg.get('rel_tol', 1e-5), abs_tol=self.cfg.get('abs_tol', 1e-8) + ) + except Exception as e: + if self.logger: + self.logger.warning(f"Error comparing answers: {e}") + return False + + def _extract_final_answer(self, text: str) -> Optional[str]: + """ + Extract the final answer from text. + Supports various formats: + 1. "The answer is X" + 2. "Therefore, X is the answer" + 3. "X" (if only one number) + 4. "\\boxed{X}" + 5. "= X" (expression after equals sign) + 6. Last LaTeX expression like \\frac{a}{b}, \\sqrt{x}, etc. + """ + # Try to extract boxed content + boxed_match = re.search(r'\\boxed\{([^}]+)\}', text) + if boxed_match: + return boxed_match.group(0) + + # Try to extract "the answer is X" format + answer_match = re.search(r'(?:the\s+)?answer\s+is\s+([^\.]+)', text, re.IGNORECASE) + if answer_match: + answer_text = answer_match.group(1).strip() + # Check if the extracted answer contains a LaTeX expression + latex_match = re.search(r'(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})', answer_text) + if latex_match: + return latex_match.group(0) + return answer_text + + # Try to extract "therefore, X is the answer" format + therefore_match = re.search(r'therefore,?\s+([^\.]+)\s+is\s+the\s+answer', text, re.IGNORECASE) + if therefore_match: + therefore_text = therefore_match.group(1).strip() + # Check if the extracted answer contains a LaTeX expression + latex_match = re.search(r'(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})', therefore_text) + if latex_match: + return latex_match.group(0) + return therefore_text + + # Try to extract expression after equals sign + equals_matches = re.findall(r'=\s*([^\.=]+?)(?:\.|$|=)', text) + if equals_matches: + last_eq = equals_matches[-1].strip() + # Check if there's a LaTeX expression after the equals sign + latex_match = re.search(r'(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})', last_eq) + if latex_match: + return latex_match.group(0) + return last_eq + + # Try to directly extract LaTeX fraction expression + frac_matches = re.findall(r'(\\frac\{[^}]+\}\{[^}]+\})', text) + if frac_matches: + return frac_matches[-1] + + # Try to directly extract LaTeX square root expression + sqrt_matches = re.findall(r'(\\sqrt\{[^}]+\})', text) + if sqrt_matches: + return sqrt_matches[-1] + + # Try to extract pi-related expressions + pi_expr = self._extract_pi_expressions(text) + if pi_expr: + return pi_expr + + # If there's only one number, return it directly + numbers = re.findall(r'-?\d*\.?\d+', text) + if len(numbers) == 1: + return numbers[0] + + # Try to extract the last number (as a fallback) + if numbers: + return numbers[-1] + + return None + + def _extract_pi_expressions(self, text: str) -> Optional[str]: + """Extract pi-related expressions from text""" + # Try to extract expressions like \frac{a\pi}{b} + pi_frac_matches = re.findall(r'(\\frac\{[^}]*\\pi[^}]*\}\{[^}]+\})', text) + if pi_frac_matches: + return pi_frac_matches[-1] + + # Try to extract expressions like \frac{a}{b}\pi + frac_pi_matches = re.findall(r'(\\frac\{[^}]+\}\{[^}]+\}\\pi)', text) + if frac_pi_matches: + return frac_pi_matches[-1] + + # Try to extract expressions like 11π/6 + text_with_pi = text.replace("\\pi", "π") + pi_div_matches = re.findall(r'(\d+π/\d+)', text_with_pi) + if pi_div_matches: + return pi_div_matches[-1] + + # Try to extract expressions like π/2 + pi_simple_div_matches = re.findall(r'(π/\d+)', text_with_pi) + if pi_simple_div_matches: + return pi_simple_div_matches[-1] + + # Try to extract expressions like 2π + pi_mult_matches = re.findall(r'(\d+π)', text_with_pi) + if pi_mult_matches: + return pi_mult_matches[-1] + + # Check for standalone π + if "π" in text_with_pi or "\\pi" in text: + pi_standalone = re.search(r'(^|[^a-zA-Z0-9])π($|[^a-zA-Z0-9])', text_with_pi) + if pi_standalone: + return "π" + + return None + + def _process_pi_expressions(self, text: str) -> Optional[float]: + """Process pi-related expressions and convert to numerical value""" + # Standardize pi notation + text = text.replace("\\pi", "π") + + # Process expressions like 11π/6 + pi_match = re.search(r'(\d+)π/(\d+)', text) + if pi_match: + num, denom = map(int, pi_match.groups()) + return (num * math.pi) / denom + + # Process expressions like π/2 + pi_div_match = re.search(r'π/(\d+)', text) + if pi_div_match: + denom = int(pi_div_match.group(1)) + return math.pi / denom + + # Process expressions like 2π + pi_mult_match = re.search(r'(\d+)π', text) + if pi_mult_match: + num = int(pi_mult_match.group(1)) + return num * math.pi + + # If just π + if text == "π": + return math.pi + + return None + + def _process_math_expression(self, text: str) -> Optional[float]: + """ + Process special mathematical expressions, such as: + 1. Fractions: 1/2, \\frac{1}{2} + 2. Percentages: 50% + 3. Scientific notation: 1.2e-3 + 4. Mixed expressions: 1 + 2/3 + 5. Square roots: \\sqrt{2} + 6. Mixed fractions: 1\\frac{1}{2} + 7. Max/min functions: \\max(1,2,3), \\min(1,2,3) + 8. Pi-related expressions: 11π/6, \\frac{11\\pi}{6} + """ + if text is None or not text.strip(): + return None + + try: + # Remove all spaces and unnecessary LaTeX commands + text = text.replace(" ", "") + text = text.replace("\\left", "").replace("\\right", "") + + # Process pi-related expressions + if "π" in text or "\\pi" in text: + pi_value = self._process_pi_expressions(text) + if pi_value is not None: + return pi_value + + # Process percentages + if "%" in text: + return float(text.replace("%", "")) / 100 + + # Process LaTeX square roots \sqrt{...} + sqrt_match = re.search(r'\\sqrt\{([^}]+)\}', text) + if sqrt_match: + inner_expr = sqrt_match.group(1) + inner_value = self._process_math_expression(inner_expr) + if inner_value is not None: + return math.sqrt(inner_value) + + # Process LaTeX fractions \frac{...}{...} + frac_match = re.search(r'\\frac\{([^}]+)\}\{([^}]+)\}', text) + if frac_match: + num = frac_match.group(1) + denom = frac_match.group(2) + + # Recursively process numerator and denominator + num_value = self._process_math_expression(num) + denom_value = self._process_math_expression(denom) + + if num_value is not None and denom_value is not None and denom_value != 0: + return num_value / denom_value + + # Process mixed fractions 1\frac{1}{2} + mixed_frac_match = re.search(r'(\d+)\\frac\{([^}]+)\}\{([^}]+)\}', text) + if mixed_frac_match: + whole = int(mixed_frac_match.group(1)) + num = mixed_frac_match.group(2) + denom = mixed_frac_match.group(3) + + # Recursively process numerator and denominator + num_value = self._process_math_expression(num) + denom_value = self._process_math_expression(denom) + + if num_value is not None and denom_value is not None and denom_value != 0: + return whole + (num_value / denom_value) + + # Process max function \max(a,b,c) + max_match = re.search(r'\\max\(([^)]+)\)', text) + if max_match: + values_str = max_match.group(1) + values = values_str.split(',') + processed_values = [] + for val in values: + processed_val = self._process_math_expression(val) + if processed_val is not None: + processed_values.append(processed_val) + if processed_values: + return max(processed_values) + + # Process min function \min(a,b,c) + min_match = re.search(r'\\min\(([^)]+)\)', text) + if min_match: + values_str = min_match.group(1) + values = values_str.split(',') + processed_values = [] + for val in values: + processed_val = self._process_math_expression(val) + if processed_val is not None: + processed_values.append(processed_val) + if processed_values: + return min(processed_values) + + # Process simple arithmetic operations + if any(op in text for op in ['+', '-', '*', '/']): + # Safe eval, only allowing basic operations + safe_dict = {"__builtins__": None} + return float(eval(text, safe_dict)) + + # Process scientific notation + if 'e' in text.lower() and re.match(r'-?\d+\.?\d*e[+-]?\d+', text.lower()): + return float(text) + + # Process regular numbers + return float(text) + except Exception as e: + # Log exception information for debugging + if self.logger: + self.logger.debug(f"Error processing math expression '{text}': {str(e)}") + return None + + def estimate(self, data: List[Dict]) -> List[Dict]: """ + Overview: + Estimate rewards for mathematical answers based on rule-based comparison. Arguments: - - data (:obj:`List[str]`): The list of data queries used for estimation, each query is a string of the \ - form "1 + 1 = ?" + - data (:obj:`List[Dict]`): The list of data queries used for estimation. + Format: [{"question": "...", "answer": "...", "response": "..."}, ...] + Each dictionary may contain: + - question: The mathematical question + - answer: The ground truth answer + - response: The model's response to evaluate + - system: Optional system prompt + - query: Optional alternative to question Returns: - - reward (:obj:`List[Dict]`): The estimated reward. + - rewards (:obj:`List[Dict]`): The estimated rewards. + Each dictionary contains: + - reward: The numerical reward value + - metadata: Additional information about the evaluation + Examples: + >>> data = [{ + >>> "question": "What is 2+2?", + >>> "answer": "4", + >>> "response": "The answer is 4." + >>> }] + >>> results = model.estimate(data) + >>> print(results[0]["reward"]) # 1.0 (correct) + >>> print(results[0]["metadata"]["reason"]) # "correct_answer" """ - # 1. parse the query to get question and predicted answer - # 2. get the ground truth answer according to the question - # 3. calculate the reward based on the predicted answer and the ground truth answer - # (format error -2, answer error -1, correct 1) - pass + rewards = [] + + for item in data: + result = { + 'reward': self.cfg.format_error_reward, + 'metadata': { + 'reason': 'format_error', + 'response_value': None, + 'target_value': None, + 'match_result': False, + 'extracted_code': None, + 'final_answer': None, + 'extracted_expressions': [] + } + } + + try: + # Extract question, answer and response from data item + item_data = self._extract_item_data(item) + if item_data is None: + rewards.append(result) + continue + + question, gt_answer, response = item_data + + # Process target answer + target_value = self._process_target_answer(gt_answer) + result['metadata']['target_value'] = target_value + + # Process response answer + response_value, final_answer = self._process_response_answer(response) + result['metadata']['response_value'] = response_value + result['metadata']['final_answer'] = final_answer + + # Extract Python code (if any) + extracted_code = self._extract_python_code(response) + result['metadata']['extracted_code'] = extracted_code + + # Extract all possible expressions (for debugging) + expressions = self._extract_all_expressions(response) + result['metadata']['extracted_expressions'] = expressions + + # Determine reward based on answer comparison + result = self._determine_reward(result, target_value, response_value) + + except Exception as e: + result['metadata']['reason'] = f'error: {str(e)}' + if self.logger: + self.logger.error(f"Error evaluating data: {str(e)}") + + rewards.append(result) + + return rewards + + def _extract_item_data(self, item) -> Optional[Tuple[str, str, str]]: + """Extract question, answer and response from data item""" + if isinstance(item, dict): + question = item.get('question', '') + gt_answer = item.get('answer', '') + response = item.get('response', '') + system = item.get('system', '') + query = item.get('query', '') + elif isinstance(item, str): + # If input is a string, try to parse as JSON + try: + item_dict = json.loads(item) + question = item_dict.get('question', '') + gt_answer = item_dict.get('answer', '') + response = item_dict.get('response', '') + system = item_dict.get('system', '') + query = item_dict.get('query', '') + except: + # If parsing fails, assume the entire string is the response + question = '' + gt_answer = '' + response = item + system = '' + query = '' + else: + # Unsupported input type + return None + + # If no question but query exists, use query as question + if not question and query: + question = query + + return question, gt_answer, response + + def _determine_reward(self, result: Dict, target_value: Optional[float], response_value: Optional[float]) -> Dict: + """Determine reward based on answer comparison""" + if target_value is None: + result['reward'] = self.cfg.format_error_reward + result['metadata']['reason'] = 'invalid_target_format' + elif response_value is None: + result['reward'] = self.cfg.format_error_reward + result['metadata']['reason'] = 'invalid_response_format' + else: + # Compare answers + is_match = self._check_answer_match(response_value, target_value) + result['metadata']['match_result'] = is_match + + if is_match: + result['reward'] = self.cfg.correct_reward + result['metadata']['reason'] = 'correct_answer' + else: + result['reward'] = self.cfg.answer_error_reward + result['metadata']['reason'] = 'wrong_answer' + + return result + + def _extract_all_expressions(self, text: str) -> List[str]: + """Extract all possible mathematical expressions from text, sorted by priority""" + if text is None or not text.strip(): + return [] + + expressions = [] + + # Extract expressions from LaTeX math environments + self._extract_latex_environments(text, expressions) + + # Extract boxed content (highest priority) + self._extract_boxed_content(text, expressions) + + # Extract expressions after equals sign + self._extract_equals_expressions(text, expressions) + + # Extract expressions from answer phrases + self._extract_answer_phrases(text, expressions) + + # Extract LaTeX expressions + self._extract_latex_expressions(text, expressions) + + # Extract pi-related expressions + self._extract_pi_expressions_for_list(text, expressions) + + # Extract all numbers (lowest priority) + self._extract_numbers(text, expressions) + + # Remove duplicates while preserving order + unique_expressions = [] + for expr in expressions: + if expr not in unique_expressions: + unique_expressions.append(expr) + + return unique_expressions + + def _extract_latex_environments(self, text: str, expressions: List[str]) -> None: + """Extract expressions from LaTeX math environments""" + # Match \(...\) or $...$ format LaTeX expressions + latex_envs = re.findall(r'\\\\?\((.+?)\\\\?\)', text) + re.findall(r'\$(.+?)\$', text) + for latex_env in latex_envs: + expressions.append(latex_env.strip()) + + def _extract_boxed_content(self, text: str, expressions: List[str]) -> None: + """Extract boxed content""" + boxed_matches = re.findall(r'\\boxed\{([^}]+)\}', text) + for match in boxed_matches: + expressions.append(match.strip()) + + def _extract_equals_expressions(self, text: str, expressions: List[str]) -> None: + """Extract expressions after equals sign""" + equals_matches = re.findall(r'=\s*([^\.=]+?)(?:\.|$|=)', text) + for match in equals_matches: + expressions.append(match.strip()) + + def _extract_answer_phrases(self, text: str, expressions: List[str]) -> None: + """Extract expressions from answer phrases""" + # Extract "the answer is X" format + answer_match = re.search(r'(?:the\s+)?answer\s+is\s+([^\.]+)', text, re.IGNORECASE) + if answer_match: + expressions.append(answer_match.group(1).strip()) + + # Extract "therefore, X is the answer" format + therefore_match = re.search(r'therefore,?\s+([^\.]+)\s+is\s+the\s+answer', text, re.IGNORECASE) + if therefore_match: + expressions.append(therefore_match.group(1).strip()) + + def _extract_latex_expressions(self, text: str, expressions: List[str]) -> None: + """Extract LaTeX expressions""" + # Extract LaTeX fraction expressions + frac_matches = re.findall(r'\\frac\{([^}]+)\}\{([^}]+)\}', text) + for num, denom in frac_matches: + expressions.append(f"\\frac{{{num}}}{{{denom}}}") + + # Extract LaTeX square root expressions + sqrt_matches = re.findall(r'\\sqrt\{([^}]+)\}', text) + for inner in sqrt_matches: + expressions.append(f"\\sqrt{{{inner}}}") + + # Extract all LaTeX expressions + latex_expressions = re.findall(r'\\[a-zA-Z]+(?:\{[^}]*\})+', text) + for expr in latex_expressions: + if expr not in expressions: + expressions.append(expr) + + def _extract_pi_expressions_for_list(self, text: str, expressions: List[str]) -> None: + """Extract pi-related expressions for the expressions list""" + # Replace \pi with π for unified processing + text_with_pi = text.replace("\\pi", "π") + + # Extract expressions like 11π/6 + pi_div_matches = re.findall(r'(\d+)π/(\d+)', text_with_pi) + for num, denom in pi_div_matches: + expressions.append(f"{num}π/{denom}") + + # Extract expressions like π/2 + pi_simple_div_matches = re.findall(r'π/(\d+)', text_with_pi) + for denom in pi_simple_div_matches: + expressions.append(f"π/{denom}") + + # Extract expressions like 2π + pi_mult_matches = re.findall(r'(\d+)π', text_with_pi) + for num in pi_mult_matches: + expressions.append(f"{num}π") + + # Extract standalone π + if "π" in text_with_pi: + expressions.append("π") + + def _extract_numbers(self, text: str, expressions: List[str]) -> None: + """Extract all numbers""" + numbers = re.findall(r'-?\d*\.?\d+', text) + expressions.extend(numbers) # rule-based reward model does not need training, thus the following methods are empty def train(self): + """Training method (not needed for rule-based reward model)""" pass def collect_data(self, data: list) -> None: + """Data collection method (not needed for rule-based reward model)""" pass def clear_data(self) -> None: + """Data clearing method (not needed for rule-based reward model)""" pass + def _extract_python_code(self, text: str) -> Optional[str]: + """Extract Python code blocks from text""" + if text is None or not text.strip(): + return None -def strip_sequence(text: str, pad_token: str, eos_token: str) -> str: - """ - Overview: - Remove leading and trailing sequences of padding/eos tokens from a text. - - .. note:: - This function uses regular expressions to strip all consecutive occurrences - of the specified padding and end-of-sequence tokens from both the beginning - and end of the input text. Tokens in the middle of the text are preserved. + # Match code between ```python and ``` + code_blocks = re.findall(r'```python\s*(.*?)\s*```', text, re.DOTALL) + if code_blocks: + return code_blocks[-1].strip() # Return the last code block - Arguments: - - text (str): The input text to be processed. - - pad_token (str): The padding token to be stripped (e.g., ""). - - eos_token (str): The end-of-sequence token to be stripped (e.g., ""). + # Match code between ``` and ``` (without specified language) + code_blocks = re.findall(r'```\s*(.*?)\s*```', text, re.DOTALL) + if code_blocks: + return code_blocks[-1].strip() - Returns: - - cleaned_text (str): The cleaned text with leading/trailing padding/eos tokens removed. + return None - Examples: - >>> strip_sequence("Hello", "", "") - 'Hello' - >>> strip_sequence("TestMiddleKeep", "", "") - 'TestMiddleKeep' - - >>> strip_sequence("Full removal", "", "") - 'Full removal' - - >>> strip_sequence("No tokens here", "", "") - 'No tokens here' - - >>> strip_sequence("", "", "") - '' +def strip_sequence(text: str, pad_token: str, eos_token: str) -> str: + """ + Remove leading and trailing sequences of padding/eos tokens from text. + + Args: + text: Input text + pad_token: Padding token + eos_token: End-of-sequence token + + Returns: + Cleaned text """ pad_token_escaped = re.escape(pad_token) eos_token_escaped = re.escape(eos_token) @@ -105,19 +685,20 @@ def strip_sequence(text: str, pad_token: str, eos_token: str) -> str: def normalize_text(text: str) -> str: """ - Overview: - This function is designed to standardize text by: - - Converting all text to lowercase - - Replacing various punctuation marks and special characters with spaces - - Removing import statements - - Normalizing whitespace by replacing multiple spaces with a single space - - Stripping leading and trailing whitespace - Arguments: - - text (str): The input text to be processed. + Standardize text: + - Convert to lowercase + - Replace punctuation and special characters with spaces + - Remove import statements + - Normalize whitespace + - Strip leading and trailing whitespace + + Args: + text: Input text + Returns: - - normalized_text (str): The normalized text. + Normalized text """ - text = re.sub("[,.:\"'\[\]\-=\+\\|!@#$%^&*();<>?/!¥…()—\{\}:”“《》?]", " ", text.lower()) - text = re.sub("import\s[a-zA-Z\.]+(\sas\s[a-zA-Z\.]+)\n", " ", text) - text = re.sub("\s+", " ", text) + # text = re.sub(r"[,.:\"\'\[\]\-=\+\\|!@#$%^&*();<>?/!¥…()—\{\}:""《》?]", " ", text.lower()) + text = re.sub(r"import\s[a-zA-Z\.]+(\sas\s[a-zA-Z\.]+)\n", " ", text) + text = re.sub(r"\s+", " ", text) return text.strip() diff --git a/ding/reward_model/tests/test_math_rule_reward_model.py b/ding/reward_model/tests/test_math_rule_reward_model.py index b79b05725e..69b7a35d4c 100644 --- a/ding/reward_model/tests/test_math_rule_reward_model.py +++ b/ding/reward_model/tests/test_math_rule_reward_model.py @@ -1,20 +1,113 @@ import pytest from easydict import EasyDict +import sys +import os -from ding.reward_model import MathRuleRewardModel +# Add project root directory to Python path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../'))) +# Use absolute import +from ding.reward_model.math_rule_reward_model import MathRuleRewardModel -@pytest.mark.envtest -def test_math_rule_reward_model(): - reward_model = MathRuleRewardModel( + +@pytest.fixture +def reward_model(): + return MathRuleRewardModel( config=EasyDict( - dataset_name='RUC-AIBOX/STILL-3-Preview-RL-Data', tokenizer_name='unsloth/Meta-Llama-3.1-8B', + type='math_rule', + format_error_reward=-2, + answer_error_reward=-1, + correct_reward=1, ) ) - data = [ - "The school now introduces a new color, silver, for the flag design. Crestview's school colors are now purple, gold, and silver. The students are designing a flag using three solid-colored horizontal stripes. Using one, two, or all three of the school colors, how many different flags are possible if adjacent stripes may be the same color?", # noqa + +@pytest.mark.envtest +def test_math_rule_reward_model_correct_answer(reward_model): + data_correct = [ + { + "system": "Please answer this math problem...", + "query": "The school now introduces a new color, silver, for the flag design. Crestview's school colors are now purple, gold, and silver. The students are designing a flag using three solid-colored horizontal stripes. Using one, two, or all three of the school colors, how many different flags are possible if adjacent stripes may be the same color?", + "response": r"Crestview's school colors—purple, gold, and silver—can be used to design a flag with three horizontal stripes, where each stripe can be any of the three colors and adjacent stripes may be the same. Since each of the three stripes has three independent color choices, the total number of possible flag designs is 27", + "answer": r"27" + } + ] + + # Test the case with correct answer + rewards = reward_model.estimate(data_correct) + assert len(rewards) == len(data_correct) + assert rewards[0]['reward'] == reward_model.cfg.correct_reward + assert rewards[0]['metadata']['reason'] == 'correct_answer' + assert rewards[0]['metadata']['match_result'] == True + + +@pytest.mark.envtest +def test_math_rule_reward_model_wrong_answer(reward_model): + data_wrong = [ + { + "system": "Please answer this math problem...", + "query": "The school now introduces a new color, silver, for the flag design. Crestview's school colors are now purple, gold, and silver. The students are designing a flag using three solid-colored horizontal stripes. Using one, two, or all three of the school colors, how many different flags are possible if adjacent stripes may be the same color?", + "response": r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on the unit circle, meaning its coordinates correspond to \((\cos \alpha, \sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and \(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the **fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). Therefore, the smallest positive value of \(\alpha\) is \(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\).", + "answer": r"\frac{11\pi}{6}" + } ] - rewards = reward_model.estimate(data) - assert len(rewards) == len(data) + + # Test the case with wrong answer + rewards = reward_model.estimate(data_wrong) + assert len(rewards) == len(data_wrong) + assert rewards[0]['reward'] == reward_model.cfg.answer_error_reward + assert rewards[0]['metadata']['reason'] == 'wrong_answer' + assert rewards[0]['metadata']['match_result'] == False + + +@pytest.mark.envtest +def test_math_rule_reward_model_format_error(reward_model): + data_format_error = [ + { + "system": "Please answer this math problem...", + "query": "What is 2+2?", + "response": "The answer is four.", + "answer": r"4" + } + ] + + rewards_format = reward_model.estimate(data_format_error) + assert len(rewards_format) == len(data_format_error) + # This should be a format error because "four" cannot be processed as a numerical value + assert rewards_format[0]['reward'] == reward_model.cfg.format_error_reward + assert 'format' in rewards_format[0]['metadata']['reason'] + + +@pytest.mark.envtest +def test_math_rule_reward_model_special_expressions(reward_model): + data_edge_cases = [ + { + "query": "What is 1/2?", + "response": r"The answer is \frac{1}{2}.", + "answer": r"0.5" + }, { + "query": "What is 50%?", + "response": "The answer is 50%.", + "answer": r"0.5" + }, { + "query": "What is sqrt(4)?", + "response": r"The answer is \sqrt{4} = 2.", + "answer": r"2" + } + ] + + rewards_edge = reward_model.estimate(data_edge_cases) + assert len(rewards_edge) == len(data_edge_cases) + + # Check fraction processing + assert rewards_edge[0]['metadata']['match_result'] == True + assert rewards_edge[0]['reward'] == reward_model.cfg.correct_reward + + # Check percentage processing + assert rewards_edge[1]['metadata']['match_result'] == True + assert rewards_edge[1]['reward'] == reward_model.cfg.correct_reward + + # Check square root processing + assert rewards_edge[2]['metadata']['match_result'] == True + assert rewards_edge[2]['reward'] == reward_model.cfg.correct_reward From 7314bff3a5053ad5b86842e8053926155ece37f0 Mon Sep 17 00:00:00 2001 From: Berit-chengyi <2826895005@qq.com> Date: Mon, 17 Mar 2025 03:59:10 +0000 Subject: [PATCH 5/6] polish flake8 --- ding/reward_model/math_reward_model.py | 4 +- ding/reward_model/math_rule_reward_model.py | 254 +++++++++--------- .../tests/test_math_rule_reward_model.py | 99 ++++--- 3 files changed, 190 insertions(+), 167 deletions(-) diff --git a/ding/reward_model/math_reward_model.py b/ding/reward_model/math_reward_model.py index ccd4aacf32..4ff4e9de07 100644 --- a/ding/reward_model/math_reward_model.py +++ b/ding/reward_model/math_reward_model.py @@ -1,11 +1,9 @@ -from typing import Tuple, Optional, List, Dict +from typing import List, Dict from easydict import EasyDict from torch.utils.tensorboard import SummaryWriter from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F -import re - from ding.utils import REWARD_MODEL_REGISTRY from .base_reward_model import BaseRewardModel diff --git a/ding/reward_model/math_rule_reward_model.py b/ding/reward_model/math_rule_reward_model.py index 33aafbeb6b..0a43194196 100644 --- a/ding/reward_model/math_rule_reward_model.py +++ b/ding/reward_model/math_rule_reward_model.py @@ -1,28 +1,28 @@ -from typing import Tuple, Optional, List, Dict, Union, Any +from typing import Tuple, Optional, List, Dict from easydict import EasyDict from torch.utils.tensorboard import SummaryWriter from transformers import AutoTokenizer import re import math import json - from ding.utils import REWARD_MODEL_REGISTRY from .base_reward_model import BaseRewardModel -@REWARD_MODEL_REGISTRY.register('math_rule') +@REWARD_MODEL_REGISTRY.register("math_rule") class MathRuleRewardModel(BaseRewardModel): """ Math rule-based reward model for evaluating mathematical answers. Supports various mathematical expression formats including LaTeX, fractions, percentages, etc. """ + config = dict( # (str) The type of the reward model. - type='math_rule', + type="math_rule", # (str) The name of the dataset, usually the huggingface dataset name. - dataset_name='', + dataset_name="", # (str) The name of the tokenizer, usually the huggingface tokenizer name. - tokenizer_name='', + tokenizer_name="", # (float) The score of format error. format_error_reward=-2, # (float) The score of answer error. @@ -38,9 +38,9 @@ class MathRuleRewardModel(BaseRewardModel): def __init__( self, config: EasyDict, - device: str = 'cpu', + device: str = "cpu", logger=None, - tb_logger: 'SummaryWriter' = None + tb_logger: "SummaryWriter" = None, ) -> None: # noqa """Initialize the math rule reward model""" self.cfg = config @@ -49,10 +49,10 @@ def __init__( self.tb_logger = tb_logger # Initialize tokenizer - if hasattr(config, 'tokenizer_name') and config.tokenizer_name: + if hasattr(config, "tokenizer_name") and config.tokenizer_name: self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) - self.pad_token = self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]" - self.eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[EOS]" + self.pad_token = (self.tokenizer.pad_token if self.tokenizer.pad_token else "[PAD]") + self.eos_token = (self.tokenizer.eos_token if self.tokenizer.eos_token else "[EOS]") else: self.tokenizer = None self.pad_token = "[PAD]" @@ -62,7 +62,6 @@ def _process_target_answer(self, text: str) -> Optional[float]: """Process target answer text and convert to numerical value""" if text is None or not text.strip(): return None - # Clean and normalize text if self.tokenizer: text = strip_sequence(text, self.pad_token, self.eos_token) @@ -122,7 +121,10 @@ def _check_answer_match(self, pred: Optional[float], target: Optional[float]) -> return False try: return math.isclose( - pred, target, rel_tol=self.cfg.get('rel_tol', 1e-5), abs_tol=self.cfg.get('abs_tol', 1e-8) + pred, + target, + rel_tol=self.cfg.get("rel_tol", 1e-5), + abs_tol=self.cfg.get("abs_tol", 1e-8), ) except Exception as e: if self.logger: @@ -141,47 +143,47 @@ def _extract_final_answer(self, text: str) -> Optional[str]: 6. Last LaTeX expression like \\frac{a}{b}, \\sqrt{x}, etc. """ # Try to extract boxed content - boxed_match = re.search(r'\\boxed\{([^}]+)\}', text) + boxed_match = re.search(r"\\boxed\{([^}]+)\}", text) if boxed_match: return boxed_match.group(0) # Try to extract "the answer is X" format - answer_match = re.search(r'(?:the\s+)?answer\s+is\s+([^\.]+)', text, re.IGNORECASE) + answer_match = re.search(r"(?:the\s+)?answer\s+is\s+([^\.]+)", text, re.IGNORECASE) if answer_match: answer_text = answer_match.group(1).strip() # Check if the extracted answer contains a LaTeX expression - latex_match = re.search(r'(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})', answer_text) + latex_match = re.search(r"(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})", answer_text) if latex_match: return latex_match.group(0) return answer_text # Try to extract "therefore, X is the answer" format - therefore_match = re.search(r'therefore,?\s+([^\.]+)\s+is\s+the\s+answer', text, re.IGNORECASE) + therefore_match = re.search(r"therefore,?\s+([^\.]+)\s+is\s+the\s+answer", text, re.IGNORECASE) if therefore_match: therefore_text = therefore_match.group(1).strip() # Check if the extracted answer contains a LaTeX expression - latex_match = re.search(r'(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})', therefore_text) + latex_match = re.search(r"(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})", therefore_text) if latex_match: return latex_match.group(0) return therefore_text # Try to extract expression after equals sign - equals_matches = re.findall(r'=\s*([^\.=]+?)(?:\.|$|=)', text) + equals_matches = re.findall(r"=\s*([^\.=]+?)(?:\.|$|=)", text) if equals_matches: last_eq = equals_matches[-1].strip() # Check if there's a LaTeX expression after the equals sign - latex_match = re.search(r'(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})', last_eq) + latex_match = re.search(r"(\\frac\{[^}]+\}\{[^}]+\}|\\sqrt\{[^}]+\})", last_eq) if latex_match: return latex_match.group(0) return last_eq # Try to directly extract LaTeX fraction expression - frac_matches = re.findall(r'(\\frac\{[^}]+\}\{[^}]+\})', text) + frac_matches = re.findall(r"(\\frac\{[^}]+\}\{[^}]+\})", text) if frac_matches: return frac_matches[-1] # Try to directly extract LaTeX square root expression - sqrt_matches = re.findall(r'(\\sqrt\{[^}]+\})', text) + sqrt_matches = re.findall(r"(\\sqrt\{[^}]+\})", text) if sqrt_matches: return sqrt_matches[-1] @@ -191,7 +193,7 @@ def _extract_final_answer(self, text: str) -> Optional[str]: return pi_expr # If there's only one number, return it directly - numbers = re.findall(r'-?\d*\.?\d+', text) + numbers = re.findall(r"-?\d*\.?\d+", text) if len(numbers) == 1: return numbers[0] @@ -204,34 +206,34 @@ def _extract_final_answer(self, text: str) -> Optional[str]: def _extract_pi_expressions(self, text: str) -> Optional[str]: """Extract pi-related expressions from text""" # Try to extract expressions like \frac{a\pi}{b} - pi_frac_matches = re.findall(r'(\\frac\{[^}]*\\pi[^}]*\}\{[^}]+\})', text) + pi_frac_matches = re.findall(r"(\\frac\{[^}]*\\pi[^}]*\}\{[^}]+\})", text) if pi_frac_matches: return pi_frac_matches[-1] # Try to extract expressions like \frac{a}{b}\pi - frac_pi_matches = re.findall(r'(\\frac\{[^}]+\}\{[^}]+\}\\pi)', text) + frac_pi_matches = re.findall(r"(\\frac\{[^}]+\}\{[^}]+\}\\pi)", text) if frac_pi_matches: return frac_pi_matches[-1] # Try to extract expressions like 11π/6 text_with_pi = text.replace("\\pi", "π") - pi_div_matches = re.findall(r'(\d+π/\d+)', text_with_pi) + pi_div_matches = re.findall(r"(\d+π/\d+)", text_with_pi) if pi_div_matches: return pi_div_matches[-1] # Try to extract expressions like π/2 - pi_simple_div_matches = re.findall(r'(π/\d+)', text_with_pi) + pi_simple_div_matches = re.findall(r"(π/\d+)", text_with_pi) if pi_simple_div_matches: return pi_simple_div_matches[-1] # Try to extract expressions like 2π - pi_mult_matches = re.findall(r'(\d+π)', text_with_pi) + pi_mult_matches = re.findall(r"(\d+π)", text_with_pi) if pi_mult_matches: return pi_mult_matches[-1] # Check for standalone π if "π" in text_with_pi or "\\pi" in text: - pi_standalone = re.search(r'(^|[^a-zA-Z0-9])π($|[^a-zA-Z0-9])', text_with_pi) + pi_standalone = re.search(r"(^|[^a-zA-Z0-9])π($|[^a-zA-Z0-9])", text_with_pi) if pi_standalone: return "π" @@ -243,19 +245,19 @@ def _process_pi_expressions(self, text: str) -> Optional[float]: text = text.replace("\\pi", "π") # Process expressions like 11π/6 - pi_match = re.search(r'(\d+)π/(\d+)', text) + pi_match = re.search(r"(\d+)π/(\d+)", text) if pi_match: num, denom = map(int, pi_match.groups()) return (num * math.pi) / denom # Process expressions like π/2 - pi_div_match = re.search(r'π/(\d+)', text) + pi_div_match = re.search(r"π/(\d+)", text) if pi_div_match: denom = int(pi_div_match.group(1)) return math.pi / denom # Process expressions like 2π - pi_mult_match = re.search(r'(\d+)π', text) + pi_mult_match = re.search(r"(\d+)π", text) if pi_mult_match: num = int(pi_mult_match.group(1)) return num * math.pi @@ -297,7 +299,7 @@ def _process_math_expression(self, text: str) -> Optional[float]: return float(text.replace("%", "")) / 100 # Process LaTeX square roots \sqrt{...} - sqrt_match = re.search(r'\\sqrt\{([^}]+)\}', text) + sqrt_match = re.search(r"\\sqrt\{([^}]+)\}", text) if sqrt_match: inner_expr = sqrt_match.group(1) inner_value = self._process_math_expression(inner_expr) @@ -305,7 +307,7 @@ def _process_math_expression(self, text: str) -> Optional[float]: return math.sqrt(inner_value) # Process LaTeX fractions \frac{...}{...} - frac_match = re.search(r'\\frac\{([^}]+)\}\{([^}]+)\}', text) + frac_match = re.search(r"\\frac\{([^}]+)\}\{([^}]+)\}", text) if frac_match: num = frac_match.group(1) denom = frac_match.group(2) @@ -314,11 +316,11 @@ def _process_math_expression(self, text: str) -> Optional[float]: num_value = self._process_math_expression(num) denom_value = self._process_math_expression(denom) - if num_value is not None and denom_value is not None and denom_value != 0: + if (num_value is not None and denom_value is not None and denom_value != 0): return num_value / denom_value # Process mixed fractions 1\frac{1}{2} - mixed_frac_match = re.search(r'(\d+)\\frac\{([^}]+)\}\{([^}]+)\}', text) + mixed_frac_match = re.search(r"(\d+)\\frac\{([^}]+)\}\{([^}]+)\}", text) if mixed_frac_match: whole = int(mixed_frac_match.group(1)) num = mixed_frac_match.group(2) @@ -328,14 +330,14 @@ def _process_math_expression(self, text: str) -> Optional[float]: num_value = self._process_math_expression(num) denom_value = self._process_math_expression(denom) - if num_value is not None and denom_value is not None and denom_value != 0: + if (num_value is not None and denom_value is not None and denom_value != 0): return whole + (num_value / denom_value) # Process max function \max(a,b,c) - max_match = re.search(r'\\max\(([^)]+)\)', text) + max_match = re.search(r"\\max\(([^)]+)\)", text) if max_match: values_str = max_match.group(1) - values = values_str.split(',') + values = values_str.split(",") processed_values = [] for val in values: processed_val = self._process_math_expression(val) @@ -345,10 +347,10 @@ def _process_math_expression(self, text: str) -> Optional[float]: return max(processed_values) # Process min function \min(a,b,c) - min_match = re.search(r'\\min\(([^)]+)\)', text) + min_match = re.search(r"\\min\(([^)]+)\)", text) if min_match: values_str = min_match.group(1) - values = values_str.split(',') + values = values_str.split(",") processed_values = [] for val in values: processed_val = self._process_math_expression(val) @@ -358,13 +360,13 @@ def _process_math_expression(self, text: str) -> Optional[float]: return min(processed_values) # Process simple arithmetic operations - if any(op in text for op in ['+', '-', '*', '/']): + if any(op in text for op in ["+", "-", "*", "/"]): # Safe eval, only allowing basic operations safe_dict = {"__builtins__": None} return float(eval(text, safe_dict)) # Process scientific notation - if 'e' in text.lower() and re.match(r'-?\d+\.?\d*e[+-]?\d+', text.lower()): + if "e" in text.lower() and re.match(r"-?\d+\.?\d*e[+-]?\d+", text.lower()): return float(text) # Process regular numbers @@ -407,16 +409,16 @@ def estimate(self, data: List[Dict]) -> List[Dict]: for item in data: result = { - 'reward': self.cfg.format_error_reward, - 'metadata': { - 'reason': 'format_error', - 'response_value': None, - 'target_value': None, - 'match_result': False, - 'extracted_code': None, - 'final_answer': None, - 'extracted_expressions': [] - } + "reward": self.cfg.format_error_reward, + "metadata": { + "reason": "format_error", + "response_value": None, + "target_value": None, + "match_result": False, + "extracted_code": None, + "final_answer": None, + "extracted_expressions": [], + }, } try: @@ -430,26 +432,26 @@ def estimate(self, data: List[Dict]) -> List[Dict]: # Process target answer target_value = self._process_target_answer(gt_answer) - result['metadata']['target_value'] = target_value + result["metadata"]["target_value"] = target_value # Process response answer response_value, final_answer = self._process_response_answer(response) - result['metadata']['response_value'] = response_value - result['metadata']['final_answer'] = final_answer + result["metadata"]["response_value"] = response_value + result["metadata"]["final_answer"] = final_answer # Extract Python code (if any) extracted_code = self._extract_python_code(response) - result['metadata']['extracted_code'] = extracted_code + result["metadata"]["extracted_code"] = extracted_code # Extract all possible expressions (for debugging) expressions = self._extract_all_expressions(response) - result['metadata']['extracted_expressions'] = expressions + result["metadata"]["extracted_expressions"] = expressions # Determine reward based on answer comparison result = self._determine_reward(result, target_value, response_value) except Exception as e: - result['metadata']['reason'] = f'error: {str(e)}' + result["metadata"]["reason"] = f"error: {str(e)}" if self.logger: self.logger.error(f"Error evaluating data: {str(e)}") @@ -460,27 +462,24 @@ def estimate(self, data: List[Dict]) -> List[Dict]: def _extract_item_data(self, item) -> Optional[Tuple[str, str, str]]: """Extract question, answer and response from data item""" if isinstance(item, dict): - question = item.get('question', '') - gt_answer = item.get('answer', '') - response = item.get('response', '') - system = item.get('system', '') - query = item.get('query', '') + question = item.get("question", "") + gt_answer = item.get("answer", "") + response = item.get("response", "") + query = item.get("query", "") elif isinstance(item, str): # If input is a string, try to parse as JSON try: item_dict = json.loads(item) - question = item_dict.get('question', '') - gt_answer = item_dict.get('answer', '') - response = item_dict.get('response', '') - system = item_dict.get('system', '') - query = item_dict.get('query', '') + question = item_dict.get("question", "") + gt_answer = item_dict.get("answer", "") + response = item_dict.get("response", "") + query = item_dict.get("query", "") except: # If parsing fails, assume the entire string is the response - question = '' - gt_answer = '' + question = "" + gt_answer = "" response = item - system = '' - query = '' + query = "" else: # Unsupported input type return None @@ -491,25 +490,30 @@ def _extract_item_data(self, item) -> Optional[Tuple[str, str, str]]: return question, gt_answer, response - def _determine_reward(self, result: Dict, target_value: Optional[float], response_value: Optional[float]) -> Dict: + def _determine_reward( + self, + result: Dict, + target_value: Optional[float], + response_value: Optional[float], + ) -> Dict: """Determine reward based on answer comparison""" if target_value is None: - result['reward'] = self.cfg.format_error_reward - result['metadata']['reason'] = 'invalid_target_format' + result["reward"] = self.cfg.format_error_reward + result["metadata"]["reason"] = "invalid_target_format" elif response_value is None: - result['reward'] = self.cfg.format_error_reward - result['metadata']['reason'] = 'invalid_response_format' + result["reward"] = self.cfg.format_error_reward + result["metadata"]["reason"] = "invalid_response_format" else: # Compare answers is_match = self._check_answer_match(response_value, target_value) - result['metadata']['match_result'] = is_match + result["metadata"]["match_result"] = is_match if is_match: - result['reward'] = self.cfg.correct_reward - result['metadata']['reason'] = 'correct_answer' + result["reward"] = self.cfg.correct_reward + result["metadata"]["reason"] = "correct_answer" else: - result['reward'] = self.cfg.answer_error_reward - result['metadata']['reason'] = 'wrong_answer' + result["reward"] = self.cfg.answer_error_reward + result["metadata"]["reason"] = "wrong_answer" return result @@ -552,48 +556,48 @@ def _extract_all_expressions(self, text: str) -> List[str]: def _extract_latex_environments(self, text: str, expressions: List[str]) -> None: """Extract expressions from LaTeX math environments""" # Match \(...\) or $...$ format LaTeX expressions - latex_envs = re.findall(r'\\\\?\((.+?)\\\\?\)', text) + re.findall(r'\$(.+?)\$', text) + latex_envs = re.findall(r"\\\\?\((.+?)\\\\?\)", text) + re.findall(r"\$(.+?)\$", text) for latex_env in latex_envs: expressions.append(latex_env.strip()) def _extract_boxed_content(self, text: str, expressions: List[str]) -> None: """Extract boxed content""" - boxed_matches = re.findall(r'\\boxed\{([^}]+)\}', text) + boxed_matches = re.findall(r"\\boxed\{([^}]+)\}", text) for match in boxed_matches: expressions.append(match.strip()) def _extract_equals_expressions(self, text: str, expressions: List[str]) -> None: """Extract expressions after equals sign""" - equals_matches = re.findall(r'=\s*([^\.=]+?)(?:\.|$|=)', text) + equals_matches = re.findall(r"=\s*([^\.=]+?)(?:\.|$|=)", text) for match in equals_matches: expressions.append(match.strip()) def _extract_answer_phrases(self, text: str, expressions: List[str]) -> None: """Extract expressions from answer phrases""" # Extract "the answer is X" format - answer_match = re.search(r'(?:the\s+)?answer\s+is\s+([^\.]+)', text, re.IGNORECASE) + answer_match = re.search(r"(?:the\s+)?answer\s+is\s+([^\.]+)", text, re.IGNORECASE) if answer_match: expressions.append(answer_match.group(1).strip()) # Extract "therefore, X is the answer" format - therefore_match = re.search(r'therefore,?\s+([^\.]+)\s+is\s+the\s+answer', text, re.IGNORECASE) + therefore_match = re.search(r"therefore,?\s+([^\.]+)\s+is\s+the\s+answer", text, re.IGNORECASE) if therefore_match: expressions.append(therefore_match.group(1).strip()) def _extract_latex_expressions(self, text: str, expressions: List[str]) -> None: """Extract LaTeX expressions""" # Extract LaTeX fraction expressions - frac_matches = re.findall(r'\\frac\{([^}]+)\}\{([^}]+)\}', text) + frac_matches = re.findall(r"\\frac\{([^}]+)\}\{([^}]+)\}", text) for num, denom in frac_matches: expressions.append(f"\\frac{{{num}}}{{{denom}}}") # Extract LaTeX square root expressions - sqrt_matches = re.findall(r'\\sqrt\{([^}]+)\}', text) + sqrt_matches = re.findall(r"\\sqrt\{([^}]+)\}", text) for inner in sqrt_matches: expressions.append(f"\\sqrt{{{inner}}}") # Extract all LaTeX expressions - latex_expressions = re.findall(r'\\[a-zA-Z]+(?:\{[^}]*\})+', text) + latex_expressions = re.findall(r"\\[a-zA-Z]+(?:\{[^}]*\})+", text) for expr in latex_expressions: if expr not in expressions: expressions.append(expr) @@ -604,17 +608,17 @@ def _extract_pi_expressions_for_list(self, text: str, expressions: List[str]) -> text_with_pi = text.replace("\\pi", "π") # Extract expressions like 11π/6 - pi_div_matches = re.findall(r'(\d+)π/(\d+)', text_with_pi) + pi_div_matches = re.findall(r"(\d+)π/(\d+)", text_with_pi) for num, denom in pi_div_matches: expressions.append(f"{num}π/{denom}") # Extract expressions like π/2 - pi_simple_div_matches = re.findall(r'π/(\d+)', text_with_pi) + pi_simple_div_matches = re.findall(r"π/(\d+)", text_with_pi) for denom in pi_simple_div_matches: expressions.append(f"π/{denom}") # Extract expressions like 2π - pi_mult_matches = re.findall(r'(\d+)π', text_with_pi) + pi_mult_matches = re.findall(r"(\d+)π", text_with_pi) for num in pi_mult_matches: expressions.append(f"{num}π") @@ -624,7 +628,7 @@ def _extract_pi_expressions_for_list(self, text: str, expressions: List[str]) -> def _extract_numbers(self, text: str, expressions: List[str]) -> None: """Extract all numbers""" - numbers = re.findall(r'-?\d*\.?\d+', text) + numbers = re.findall(r"-?\d*\.?\d+", text) expressions.extend(numbers) # rule-based reward model does not need training, thus the following methods are empty @@ -644,14 +648,12 @@ def _extract_python_code(self, text: str) -> Optional[str]: """Extract Python code blocks from text""" if text is None or not text.strip(): return None - # Match code between ```python and ``` - code_blocks = re.findall(r'```python\s*(.*?)\s*```', text, re.DOTALL) + code_blocks = re.findall(r"```python\s*(.*?)\s*```", text, re.DOTALL) if code_blocks: - return code_blocks[-1].strip() # Return the last code block - + return code_blocks[-1].strip() # Match code between ``` and ``` (without specified language) - code_blocks = re.findall(r'```\s*(.*?)\s*```', text, re.DOTALL) + code_blocks = re.findall(r"```\s*(.*?)\s*```", text, re.DOTALL) if code_blocks: return code_blocks[-1].strip() @@ -660,15 +662,29 @@ def _extract_python_code(self, text: str) -> Optional[str]: def strip_sequence(text: str, pad_token: str, eos_token: str) -> str: """ - Remove leading and trailing sequences of padding/eos tokens from text. - - Args: - text: Input text - pad_token: Padding token - eos_token: End-of-sequence token - + Overview: + Remove leading and trailing sequences of padding/eos tokens from a text. + .. note:: + This function uses regular expressions to strip all consecutive occurrences + of the specified padding and end-of-sequence tokens from both the beginning + and end of the input text. Tokens in the middle of the text are preserved. + Arguments: + - text (str): The input text to be processed. + - pad_token (str): The padding token to be stripped (e.g., ""). + - eos_token (str): The end-of-sequence token to be stripped (e.g., ""). Returns: - Cleaned text + - cleaned_text (str): The cleaned text with leading/trailing padding/eos tokens removed. + Examples: + >>> strip_sequence("Hello", "", "") + 'Hello' + >>> strip_sequence("TestMiddleKeep", "", "") + 'TestMiddleKeep' + >>> strip_sequence("Full removal", "", "") + 'Full removal' + >>> strip_sequence("No tokens here", "", "") + 'No tokens here' + >>> strip_sequence("", "", "") + '' """ pad_token_escaped = re.escape(pad_token) eos_token_escaped = re.escape(eos_token) @@ -685,20 +701,18 @@ def strip_sequence(text: str, pad_token: str, eos_token: str) -> str: def normalize_text(text: str) -> str: """ - Standardize text: - - Convert to lowercase - - Replace punctuation and special characters with spaces - - Remove import statements - - Normalize whitespace - - Strip leading and trailing whitespace - - Args: - text: Input text - + Overview: + This function is designed to standardize text by: + - Converting all text to lowercase + - Replacing various punctuation marks and special characters with spaces + - Removing import statements + - Normalizing whitespace by replacing multiple spaces with a single space + - Stripping leading and trailing whitespace + Arguments: + - text (str): The input text to be processed. Returns: - Normalized text + - normalized_text (str): The normalized text. """ - # text = re.sub(r"[,.:\"\'\[\]\-=\+\\|!@#$%^&*();<>?/!¥…()—\{\}:""《》?]", " ", text.lower()) text = re.sub(r"import\s[a-zA-Z\.]+(\sas\s[a-zA-Z\.]+)\n", " ", text) text = re.sub(r"\s+", " ", text) return text.strip() diff --git a/ding/reward_model/tests/test_math_rule_reward_model.py b/ding/reward_model/tests/test_math_rule_reward_model.py index 69b7a35d4c..85e50c8023 100644 --- a/ding/reward_model/tests/test_math_rule_reward_model.py +++ b/ding/reward_model/tests/test_math_rule_reward_model.py @@ -1,15 +1,9 @@ +import os +import sys import pytest from easydict import EasyDict -import sys -import os - -# Add project root directory to Python path -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../'))) - -# Use absolute import from ding.reward_model.math_rule_reward_model import MathRuleRewardModel - @pytest.fixture def reward_model(): return MathRuleRewardModel( @@ -25,53 +19,72 @@ def reward_model(): @pytest.mark.envtest def test_math_rule_reward_model_correct_answer(reward_model): - data_correct = [ - { - "system": "Please answer this math problem...", - "query": "The school now introduces a new color, silver, for the flag design. Crestview's school colors are now purple, gold, and silver. The students are designing a flag using three solid-colored horizontal stripes. Using one, two, or all three of the school colors, how many different flags are possible if adjacent stripes may be the same color?", - "response": r"Crestview's school colors—purple, gold, and silver—can be used to design a flag with three horizontal stripes, where each stripe can be any of the three colors and adjacent stripes may be the same. Since each of the three stripes has three independent color choices, the total number of possible flag designs is 27", - "answer": r"27" - } - ] + data_correct = [{ + "system": "Please answer this math problem...", + "query": ( + "The school now introduces a new color, silver, for the flag design. " + "Crestview's school colors are now purple, gold, and silver. " + "The students are designing a flag using three solid-colored horizontal stripes. " + "Using one, two, or all three of the school colors, how many different flags " + "are possible if adjacent stripes may be the same color?" + ), + "response": ( + "Crestview's school colors—purple, gold, and silver—can be used to design " + "a flag with three horizontal stripes, where each stripe can be any of the " + "three colors and adjacent stripes may be the same. Since each of the three " + "stripes has three independent color choices, the total number of possible " + "flag designs is 27" + ), + "answer": r"27" +}] # Test the case with correct answer rewards = reward_model.estimate(data_correct) assert len(rewards) == len(data_correct) assert rewards[0]['reward'] == reward_model.cfg.correct_reward assert rewards[0]['metadata']['reason'] == 'correct_answer' - assert rewards[0]['metadata']['match_result'] == True + assert rewards[0]['metadata']['match_result'] @pytest.mark.envtest def test_math_rule_reward_model_wrong_answer(reward_model): - data_wrong = [ - { - "system": "Please answer this math problem...", - "query": "The school now introduces a new color, silver, for the flag design. Crestview's school colors are now purple, gold, and silver. The students are designing a flag using three solid-colored horizontal stripes. Using one, two, or all three of the school colors, how many different flags are possible if adjacent stripes may be the same color?", - "response": r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on the unit circle, meaning its coordinates correspond to \((\cos \alpha, \sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and \(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the **fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). Therefore, the smallest positive value of \(\alpha\) is \(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\).", - "answer": r"\frac{11\pi}{6}" - } - ] + data_wrong = [{ + "system": "Please answer this math problem...", + "query": ( + "The school now introduces a new color, silver, for the flag design. " + "Crestview's school colors are now purple, gold, and silver. " + "The students are designing a flag using three solid-colored horizontal stripes. " + "Using one, two, or all three of the school colors, how many different flags " + "are possible if adjacent stripes may be the same color?" + ), + "response": ( + r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on " + r"the unit circle, meaning its coordinates correspond to \((\cos \alpha, " + r"\sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and " + r"\(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the " + r"**fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). " + r"Therefore, the smallest positive value of \(\alpha\) is " + r"\(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\)." + ), + "answer": r"\frac{11\pi}{6}" +}] # Test the case with wrong answer rewards = reward_model.estimate(data_wrong) assert len(rewards) == len(data_wrong) assert rewards[0]['reward'] == reward_model.cfg.answer_error_reward assert rewards[0]['metadata']['reason'] == 'wrong_answer' - assert rewards[0]['metadata']['match_result'] == False + assert rewards[0]['metadata']['match_result'] is False @pytest.mark.envtest def test_math_rule_reward_model_format_error(reward_model): - data_format_error = [ - { - "system": "Please answer this math problem...", - "query": "What is 2+2?", - "response": "The answer is four.", - "answer": r"4" - } - ] - + data_format_error = [{ + "system": "Please answer this math problem...", + "query": "What is 2+2?", + "response": "The answer is four.", + "answer": r"4" + }] rewards_format = reward_model.estimate(data_format_error) assert len(rewards_format) == len(data_format_error) # This should be a format error because "four" cannot be processed as a numerical value @@ -86,28 +99,26 @@ def test_math_rule_reward_model_special_expressions(reward_model): "query": "What is 1/2?", "response": r"The answer is \frac{1}{2}.", "answer": r"0.5" - }, { + }, + { "query": "What is 50%?", "response": "The answer is 50%.", "answer": r"0.5" - }, { + }, + { "query": "What is sqrt(4)?", "response": r"The answer is \sqrt{4} = 2.", "answer": r"2" } ] - rewards_edge = reward_model.estimate(data_edge_cases) assert len(rewards_edge) == len(data_edge_cases) - # Check fraction processing - assert rewards_edge[0]['metadata']['match_result'] == True + assert rewards_edge[0]['metadata']['match_result'] assert rewards_edge[0]['reward'] == reward_model.cfg.correct_reward - # Check percentage processing - assert rewards_edge[1]['metadata']['match_result'] == True + assert rewards_edge[1]['metadata']['match_result'] assert rewards_edge[1]['reward'] == reward_model.cfg.correct_reward - # Check square root processing - assert rewards_edge[2]['metadata']['match_result'] == True + assert rewards_edge[2]['metadata']['match_result'] assert rewards_edge[2]['reward'] == reward_model.cfg.correct_reward From 1cd74e5460963ed7a25c111fcdd6fbedd68eb983 Mon Sep 17 00:00:00 2001 From: Berit-chengyi <2826895005@qq.com> Date: Wed, 2 Apr 2025 09:48:18 +0000 Subject: [PATCH 6/6] (dcy)polish flake8 add multimodal_rewardmodel and test --- ding/reward_model/__init__.py | 3 +- ding/reward_model/math_reward_model.py | 31 ++-- ding/reward_model/math_rule_reward_model.py | 10 +- ding/reward_model/multi_modal_reward_model.py | 163 ++++++++++++++++++ .../tests/test_math_reward_model.py | 4 +- .../tests/test_math_rule_reward_model.py | 102 +++++------ .../tests/test_multi_modal_reward_model.py | 121 +++++++++++++ 7 files changed, 359 insertions(+), 75 deletions(-) create mode 100644 ding/reward_model/multi_modal_reward_model.py create mode 100644 ding/reward_model/tests/test_multi_modal_reward_model.py diff --git a/ding/reward_model/__init__.py b/ding/reward_model/__init__.py index 9b06a7b109..5202588d9a 100644 --- a/ding/reward_model/__init__.py +++ b/ding/reward_model/__init__.py @@ -13,6 +13,7 @@ from .guided_cost_reward_model import GuidedCostRewardModel from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel from .icm_reward_model import ICMRewardModel -# LLM/VLM reward model and verifier +# LLM/VLM reward models and verifiers from .math_reward_model import MathRewardModel from .math_rule_reward_model import MathRuleRewardModel +from .multi_modal_reward_model import MultiModalRewardModel diff --git a/ding/reward_model/math_reward_model.py b/ding/reward_model/math_reward_model.py index 4ff4e9de07..90a5e3a0d3 100644 --- a/ding/reward_model/math_reward_model.py +++ b/ding/reward_model/math_reward_model.py @@ -49,21 +49,21 @@ def estimate(self, data: List[Dict]) -> List[Dict]: Estimate rewards for mathematical reasoning steps using Qwen2.5-Math-PRM-7B model. Arguments: - data (:obj:`List[Dict]`): List of dictionaries containing: - - system (:obj:`str`): System prompt for the model - - query (:obj:`str`): The mathematical query to be evaluated - - response (:obj:`List[str]`): List of reasoning steps + - system (:obj:`str`): System prompt for the model. + - query (:obj:`str`): The mathematical query to be evaluated. + - response (:obj:`List[str]`): List of reasoning steps. Returns: - reward (:obj:`List[Dict]`): List of dictionaries containing: - - reward (:obj:`float`): Final reward (last step reward) + - reward (:obj:`float`): Final reward (last step reward). - metadata (:obj:`Dict`): Additional information including: - - query (:obj:`str`): Original query - - step_rewards (:obj:`List[float]`): Rewards for each reasoning step - - num_steps (:obj:`int`): Number of reasoning steps + - query (:obj:`str`): Original query. + - step_rewards (:obj:`List[float]`): Rewards for each reasoning step. + - num_steps (:obj:`int`): Number of reasoning steps. Shapes: - - input_ids (:obj:`torch.LongTensor`): :math:`(B, L)`, where B is batch size and L is sequence length - - outputs (:obj:`torch.FloatTensor`): :math:`(B, L, H)`, where H is hidden size - - token_masks (:obj:`torch.BoolTensor`): :math:`(B, L)` - - step_rewards (:obj:`List[List[float]]`): List of length B, each containing S rewards where S is num steps + - input_ids (:obj:`torch.LongTensor`): :math:`(B, L)`, where B is batch size and L is sequence length. + - outputs (:obj:`torch.FloatTensor`): :math:`(B, L, H)`, where H is hidden size. + - token_masks (:obj:`torch.BoolTensor`): :math:`(B, L)`. + - step_rewards (:obj:`List[List[float]]`): List of length B, each containing S rewards where S is num steps. Examples: >>> data = [{ >>> "system": "Please reason step by step...", @@ -74,7 +74,6 @@ def estimate(self, data: List[Dict]) -> List[Dict]: >>> print(results[0]["reward"]) # 1.0 >>> print(results[0]["metadata"]["step_rewards"]) # [0.8, 0.9, 1.0] """ - # 批量处理所有样本 all_messages = [] for item in data: messages = [ @@ -93,7 +92,6 @@ def estimate(self, data: List[Dict]) -> List[Dict]: ] all_messages.append(messages) - # 批量转换为模型输入格式 conversation_strs = [ self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) for messages in all_messages @@ -104,24 +102,21 @@ def estimate(self, data: List[Dict]) -> List[Dict]: conversation_strs, return_tensors="pt", padding=True, truncation=True )["input_ids"].to(self.model.device) - # 批量获取模型输出 with torch.no_grad(): outputs = self.model(input_ids=input_ids) - # 计算每个样本的步骤奖励 step_sep_id = self.tokenizer.encode("")[0] token_masks = (input_ids == step_sep_id) batch_rewards = self.make_step_rewards(outputs[0], token_masks) - # 构建详细的结果字典 results = [] for item, step_rewards in zip(data, batch_rewards): results.append( { - "reward": step_rewards[-1] if step_rewards else 0.0, # 最后一步的奖励作为总体奖励 + "reward": step_rewards[-1] if step_rewards else 0.0, "metadata": { "query": item['query'], - "step_rewards": step_rewards, # 每个步骤的奖励 + "step_rewards": step_rewards, "num_steps": len(item['response']), } } diff --git a/ding/reward_model/math_rule_reward_model.py b/ding/reward_model/math_rule_reward_model.py index 0a43194196..ffaf720c54 100644 --- a/ding/reward_model/math_rule_reward_model.py +++ b/ding/reward_model/math_rule_reward_model.py @@ -135,11 +135,11 @@ def _extract_final_answer(self, text: str) -> Optional[str]: """ Extract the final answer from text. Supports various formats: - 1. "The answer is X" - 2. "Therefore, X is the answer" - 3. "X" (if only one number) - 4. "\\boxed{X}" - 5. "= X" (expression after equals sign) + 1. "The answer is X". + 2. "Therefore, X is the answer". + 3. "X" (if only one number). + 4. "\\boxed{X}". + 5. "= X" (expression after equals sign). 6. Last LaTeX expression like \\frac{a}{b}, \\sqrt{x}, etc. """ # Try to extract boxed content diff --git a/ding/reward_model/multi_modal_reward_model.py b/ding/reward_model/multi_modal_reward_model.py new file mode 100644 index 0000000000..9a12eb1cd4 --- /dev/null +++ b/ding/reward_model/multi_modal_reward_model.py @@ -0,0 +1,163 @@ +from typing import List, Dict +from easydict import EasyDict +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +from ding.utils import REWARD_MODEL_REGISTRY +from .base_reward_model import BaseRewardModel + + +@REWARD_MODEL_REGISTRY.register('multi_modal') +class MultiModalRewardModel(BaseRewardModel): + config = dict( + type='multi_modal', + model_name='internlm/internlm-xcomposer2d5-7b-reward', + hd_num=9, # Number of high-definition patches for image processing + ) + + def __init__(self, config: EasyDict, device: str, logger, tb_logger: 'SummaryWriter') -> None: + self.cfg = config + self.device = device + self.logger = logger + self.tb_logger = tb_logger + + self.tokenizer = AutoTokenizer.from_pretrained( + self.cfg.model_name, trust_remote_code=True, local_files_only=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.cfg.model_name, torch_dtype=torch.float16, trust_remote_code=True + ) + + self.model.tokenizer = self.tokenizer + self.model.cuda().eval() + + def estimate(self, data: List[Dict], image: List[str], output_mode: str = 'score') -> List[Dict]: + """ + Estimate rewards for multi-modal inputs using internlm-xcomposer model. + + Arguments: + data (List[Dict]): List of chat dictionaries, each containing: + - chat (List[Dict]): List of messages, each message is a dict with: + - role (str): Either "user" or "assistant" + - content (str): The message content + image (List[str]): List of image paths. If fewer images than chats, last image will be reused + output_mode (str, optional): Evaluation mode. Defaults to 'score'. + - 'score': Return reward scores for each chat + - 'rank': Return ranking indices (0 is best) for all chats + - 'compare': Compare first two chats (returns 1.0 for better, 0.0 for worse) + + Returns: + List[Dict]: Results depending on output_mode: + - For 'score' mode: + [{ + 'reward': float, # Reward score + 'metadata': { + 'mode': 'score', + 'chat_idx': int, # Index of the chat + 'image_path': str # Path of the image used + } + }, ...] + - For 'rank' mode: + [{ + 'rank': int, # Ranking position (0 is best) + 'metadata': { + 'mode': 'rank', + 'chat_idx': int, + 'image_path': str + } + }, ...] + - For 'compare' mode: + [{ + 'reward': float, # 1.0 for better, 0.0 for worse + 'metadata': { + 'mode': 'compare', + 'chat_idx': int, + 'image_path': str, + 'compared_with': int # Index of the compared chat + } + }, ...] + """ + # Get chat data + chats = [item['chat'] for item in data] + + with torch.autocast(device_type='cuda', dtype=torch.float16): + if output_mode == 'score': + # Ensure each chat has a corresponding image, use the last image if not enough + if len(image) < len(chats): + image = image + [image[-1]] * (len(chats) - len(image)) + + # Get scores for each chat + scores = [] + for chat, img in zip(chats, image): + score = self.model.get_score(chat, [img], hd_num=self.cfg.hd_num) + scores.append(score) + + return [ + { + 'reward': float(score), + 'metadata': { + 'mode': 'score', + 'chat_idx': idx, + 'image_path': img + } + } for idx, (score, img) in enumerate(zip(scores, image)) + ] + + elif output_mode == 'rank': + # Use the first image for ranking + img = image[0] + ranks = self.model.rank(chats, [[img]] * len(chats), hd_num=self.cfg.hd_num) + + return [ + { + 'rank': int(rank), + 'metadata': { + 'mode': 'rank', + 'chat_idx': idx, + 'image_path': img + } + } for idx, rank in enumerate(ranks) + ] + + elif output_mode == 'compare': + if len(data) < 2: + raise ValueError("Compare mode requires at least 2 samples") + + # Use the first image for comparison + img = image[0] + is_better = self.model.compare(chats[0], [img], chats[1], [img], hd_num=self.cfg.hd_num) + + return [ + { + 'reward': 1.0 if is_better else 0.0, + 'metadata': { + 'mode': 'compare', + 'chat_idx': 0, + 'image_path': img, + 'compared_with': 1 + } + }, { + 'reward': 0.0 if is_better else 1.0, + 'metadata': { + 'mode': 'compare', + 'chat_idx': 1, + 'image_path': img, + 'compared_with': 0 + } + } + ] + else: + raise ValueError(f"Invalid output mode: {output_mode}") + + def train(self): + """Training is not implemented for this reward model""" + self.logger.warning("Training is not implemented for this reward model") + pass + + def collect_data(self, data: list) -> None: + """Data collection is not needed for this reward model""" + pass + + def clear_data(self) -> None: + """Data clearing is not needed for this reward model""" + pass diff --git a/ding/reward_model/tests/test_math_reward_model.py b/ding/reward_model/tests/test_math_reward_model.py index 31c59ac85b..70a7ba4637 100644 --- a/ding/reward_model/tests/test_math_reward_model.py +++ b/ding/reward_model/tests/test_math_reward_model.py @@ -21,7 +21,7 @@ def test_math_reward_model(): # Initialize reward model model = MathRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger) - # Test case 1: Simple math problem + # Simple math problem data_simple = [ { "system": "Please reason step by step...", @@ -30,7 +30,7 @@ def test_math_reward_model(): } ] - # Test case 2: Complex word problem + # Complex word problem data_complex = [ { "system": "Please reason step by step, and put your final answer within \\boxed{}.", diff --git a/ding/reward_model/tests/test_math_rule_reward_model.py b/ding/reward_model/tests/test_math_rule_reward_model.py index 85e50c8023..d4a7600d91 100644 --- a/ding/reward_model/tests/test_math_rule_reward_model.py +++ b/ding/reward_model/tests/test_math_rule_reward_model.py @@ -4,6 +4,7 @@ from easydict import EasyDict from ding.reward_model.math_rule_reward_model import MathRuleRewardModel + @pytest.fixture def reward_model(): return MathRuleRewardModel( @@ -19,24 +20,26 @@ def reward_model(): @pytest.mark.envtest def test_math_rule_reward_model_correct_answer(reward_model): - data_correct = [{ - "system": "Please answer this math problem...", - "query": ( - "The school now introduces a new color, silver, for the flag design. " - "Crestview's school colors are now purple, gold, and silver. " - "The students are designing a flag using three solid-colored horizontal stripes. " - "Using one, two, or all three of the school colors, how many different flags " - "are possible if adjacent stripes may be the same color?" - ), - "response": ( - "Crestview's school colors—purple, gold, and silver—can be used to design " - "a flag with three horizontal stripes, where each stripe can be any of the " - "three colors and adjacent stripes may be the same. Since each of the three " - "stripes has three independent color choices, the total number of possible " - "flag designs is 27" - ), - "answer": r"27" -}] + data_correct = [ + { + "system": "Please answer this math problem...", + "query": ( + "The school now introduces a new color, silver, for the flag design. " + "Crestview's school colors are now purple, gold, and silver. " + "The students are designing a flag using three solid-colored horizontal stripes. " + "Using one, two, or all three of the school colors, how many different flags " + "are possible if adjacent stripes may be the same color?" + ), + "response": ( + "Crestview's school colors—purple, gold, and silver—can be used to design " + "a flag with three horizontal stripes, where each stripe can be any of the " + "three colors and adjacent stripes may be the same. Since each of the three " + "stripes has three independent color choices, the total number of possible " + "flag designs is 27" + ), + "answer": r"27" + } + ] # Test the case with correct answer rewards = reward_model.estimate(data_correct) @@ -48,28 +51,29 @@ def test_math_rule_reward_model_correct_answer(reward_model): @pytest.mark.envtest def test_math_rule_reward_model_wrong_answer(reward_model): - data_wrong = [{ - "system": "Please answer this math problem...", - "query": ( - "The school now introduces a new color, silver, for the flag design. " - "Crestview's school colors are now purple, gold, and silver. " - "The students are designing a flag using three solid-colored horizontal stripes. " - "Using one, two, or all three of the school colors, how many different flags " - "are possible if adjacent stripes may be the same color?" - ), - "response": ( - r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on " - r"the unit circle, meaning its coordinates correspond to \((\cos \alpha, " - r"\sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and " - r"\(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the " - r"**fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). " - r"Therefore, the smallest positive value of \(\alpha\) is " - r"\(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\)." - ), - "answer": r"\frac{11\pi}{6}" -}] + data_wrong = [ + { + "system": "Please answer this math problem...", + "query": ( + "The school now introduces a new color, silver, for the flag design. " + "Crestview's school colors are now purple, gold, and silver. " + "The students are designing a flag using three solid-colored horizontal stripes. " + "Using one, two, or all three of the school colors, how many different flags " + "are possible if adjacent stripes may be the same color?" + ), + "response": ( + r"The given point \(\left(\frac{\sqrt{3}}{2}, -\frac{1}{2}\right)\) lies on " + r"the unit circle, meaning its coordinates correspond to \((\cos \alpha, " + r"\sin \alpha)\). Since \(\cos \alpha = \frac{\sqrt{3}}{2}\) and " + r"\(\sin \alpha = -\frac{1}{2}\), the angle \(\alpha\) is in the " + r"**fourth quadrant**, where the reference angle is \(\frac{\pi}{6}\). " + r"Therefore, the smallest positive value of \(\alpha\) is " + r"\(2\pi - \frac{\pi}{6} = \frac{17\pi}{6}\)." + ), + "answer": r"\frac{11\pi}{6}" + } + ] - # Test the case with wrong answer rewards = reward_model.estimate(data_wrong) assert len(rewards) == len(data_wrong) assert rewards[0]['reward'] == reward_model.cfg.answer_error_reward @@ -79,12 +83,14 @@ def test_math_rule_reward_model_wrong_answer(reward_model): @pytest.mark.envtest def test_math_rule_reward_model_format_error(reward_model): - data_format_error = [{ - "system": "Please answer this math problem...", - "query": "What is 2+2?", - "response": "The answer is four.", - "answer": r"4" - }] + data_format_error = [ + { + "system": "Please answer this math problem...", + "query": "What is 2+2?", + "response": "The answer is four.", + "answer": r"4" + } + ] rewards_format = reward_model.estimate(data_format_error) assert len(rewards_format) == len(data_format_error) # This should be a format error because "four" cannot be processed as a numerical value @@ -99,13 +105,11 @@ def test_math_rule_reward_model_special_expressions(reward_model): "query": "What is 1/2?", "response": r"The answer is \frac{1}{2}.", "answer": r"0.5" - }, - { + }, { "query": "What is 50%?", "response": "The answer is 50%.", "answer": r"0.5" - }, - { + }, { "query": "What is sqrt(4)?", "response": r"The answer is \sqrt{4} = 2.", "answer": r"2" diff --git a/ding/reward_model/tests/test_multi_modal_reward_model.py b/ding/reward_model/tests/test_multi_modal_reward_model.py new file mode 100644 index 0000000000..7e87aa6f8a --- /dev/null +++ b/ding/reward_model/tests/test_multi_modal_reward_model.py @@ -0,0 +1,121 @@ +import pytest +from easydict import EasyDict +import torch +from ding.reward_model import MultiModalRewardModel +from unittest.mock import MagicMock +import os + + +@pytest.fixture +def reward_model(): + # Create configuration + cfg = EasyDict(dict( + type='multi_modal', + model_name='internlm/internlm-xcomposer2d5-7b-reward', + hd_num=9, + )) + + # Create mock logger and tb_logger + logger = MagicMock() + tb_logger = MagicMock() + + # Initialize reward model + model = MultiModalRewardModel(cfg, "cuda" if torch.cuda.is_available() else "cpu", logger, tb_logger) + return model + + +@pytest.fixture +def test_data(): + # Shared test data + chats = [ + [ # chat_1 + {"role": "user", "content": 'I want to buy a car from the input image, ' + 'analyze the advantages and weaknesses.'}, + {"role": "assistant", "content": "The car in the image is a Mercedes-Benz G-Class..."} + ], + [ # chat_2 + {"role": "user", "content": 'I want to buy a car from the input image, ' + 'analyze the advantages and weaknesses.'}, + {"role": "assistant", "content": "Based on the image, it appears to be a Ferrari F8 Tributo..."} + ] + ] + + images = ['./examples/cars1.jpg'] + + return {'chats': chats, 'images': images, 'hd_num': 9} + + +@pytest.mark.envtest +def test_single_score(reward_model, test_data): + """Test single chat scoring""" + data = [{'chat': test_data['chats'][0]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='score') + print(f"Single score results: {results}") + + assert len(results) == 1 + assert 'reward' in results[0] + assert isinstance(results[0]['reward'], float) + assert results[0]['metadata']['mode'] == 'score' + assert results[0]['metadata']['chat_idx'] == 0 + + +@pytest.mark.envtest +def test_multiple_scores(reward_model, test_data): + """Test multiple chats scoring""" + data = [{'chat': test_data['chats'][0]}, {'chat': test_data['chats'][1]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='score') + print(f"Multiple scores results: {results}") + + assert len(results) == 2 + assert all('reward' in r for r in results) + assert all(isinstance(r['reward'], float) for r in results) + assert all(r['metadata']['mode'] == 'score' for r in results) + + +@pytest.mark.envtest +def test_rank(reward_model, test_data): + """Test ranking functionality""" + data = [{'chat': test_data['chats'][0]}, {'chat': test_data['chats'][1]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='rank') + print(f"Rank results: {results}") + + assert len(results) == 2 + assert all('rank' in r for r in results) + assert set(r['rank'] for r in results) == {0, 1} + + +@pytest.mark.envtest +def test_compare(reward_model, test_data): + """Test comparison functionality""" + data = [{'chat': test_data['chats'][0]}, {'chat': test_data['chats'][1]}] + + results = reward_model.estimate(data, test_data['images'], output_mode='compare') + print(f"Compare results: {results}") + + assert len(results) == 2 + assert sum(r['reward'] for r in results) == 1.0 + assert all(r['metadata']['mode'] == 'compare' for r in results) + + +@pytest.mark.envtest +def test_default_parameters(reward_model, test_data): + """Test default parameters""" + data = [{'chat': test_data['chats'][0]}] + + # Test without specifying optional parameters + results = reward_model.estimate(data, test_data['images']) + + assert len(results) == 1 + assert 'reward' in results[0] + assert results[0]['metadata']['mode'] == 'score' + + +@pytest.mark.envtest +def test_error_handling(reward_model, test_data): + """Test error handling""" + with pytest.raises(Exception): + # Test invalid input format + reward_model.model.get_score(None, test_data['image'], hd_num=test_data['hd_num'])