From 04e0573802163bd58879a58922e8fe49d8fce74c Mon Sep 17 00:00:00 2001 From: Jonathan Bastnagel Date: Thu, 14 Aug 2025 11:58:11 -0500 Subject: [PATCH 1/4] Replace Python isidentifier() with MariaDB-compliant identifier validation --- src/server.py | 106 ++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 85 insertions(+), 21 deletions(-) diff --git a/src/server.py b/src/server.py index 54c0ef9..dcb920d 100644 --- a/src/server.py +++ b/src/server.py @@ -160,10 +160,74 @@ async def _execute_query(self, sql: str, params: Optional[tuple] = None, databas conn_state = f"Connection: {'acquired' if conn else 'not acquired'}" logger.error(f"Unexpected error during query execution ({conn_state}): {e}", exc_info=True) raise RuntimeError(f"An unexpected error occurred: {e}") from e + + def _has_properly_escaped_backticks(self, identifier: str) -> bool: + """ + Check that any backticks inside quoted identifier are properly escaped (doubled). + Returns True if all backticks are properly escaped, False otherwise. + + - identifier (str): quoted identifier. + """ + i = 0 + while i < len(identifier): + if identifier[i] == '`': + if i + 1 < len(identifier) and identifier[i + 1] == '`': + i += 2 # Skip the escaped pair + else: + return False # Found unescaped backtick + else: + i += 1 + return True + + def _is_valid_identifier(self, identifier: str) -> bool: + """ + Validates MariaDB identifier (database, table, column names, etc.). + Handles both quoted and unquoted identifiers. + + Parameters: + - identifier (str): quoted or unquoted identifier. + """ + if not identifier: + return False + + import re + + if identifier.startswith('`') and identifier.endswith('`'): + # Quoted identifier rules + if len(identifier) <= 2: + return False + + actual_name = identifier[1:-1] + + # Check that any backticks inside are properly escaped (doubled) + if not self._has_properly_escaped_backticks(actual_name): + return False + + # Handle escaped backticks + escaped_name = actual_name.replace('``', '`') + + if len(escaped_name) > 64: + return False + + if escaped_name.endswith(' '): + return False + + # Allow full Unicode BMP except U+0000 + return bool(re.match(r'^[\u0001-\uFFFF]*$', escaped_name)) + else: + # Unquoted identifier rules + if len(identifier) > 64: + return False + if identifier.endswith(' '): + return False + if identifier.isdigit(): + return False + # MariaDB unquoted identifiers: ASCII [0-9,a-z,A-Z$_] + Extended U+0080..U+FFFF + return bool(re.match(r'^[0-9a-zA-Z$_\u0080-\uFFFF]+$', identifier)) async def _database_exists(self, database_name: str) -> bool: """Checks if a database exists.""" - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.warning(f"_database_exists called with invalid database_name: {database_name}") return False @@ -177,8 +241,8 @@ async def _database_exists(self, database_name: str) -> bool: async def _table_exists(self, database_name: str, table_name: str) -> bool: """Checks if a table exists in the given database.""" - if not database_name or not database_name.isidentifier() or \ - not table_name or not table_name.isidentifier(): + if not self._is_valid_identifier(database_name) or \ + not self._is_valid_identifier(table_name): logger.warning(f"_table_exists called with invalid names: db='{database_name}', table='{table_name}'") return False @@ -205,8 +269,8 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool: """ logger.debug(f"Checking if '{database_name}.{table_name}' is a vector store.") - if not database_name or not database_name.isidentifier() or \ - not table_name or not table_name.isidentifier(): + if not self._is_valid_identifier(database_name) or \ + not self._is_valid_identifier(table_name) logger.warning(f"_is_vector_store called with invalid names: db='{database_name}', table='{table_name}'") return False @@ -254,7 +318,7 @@ async def list_databases(self) -> List[str]: async def list_tables(self, database_name: str) -> List[str]: """Lists all tables within the specified database.""" logger.info(f"TOOL START: list_tables called. database_name={database_name}") - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.warning(f"TOOL WARNING: list_tables called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") sql = "SHOW TABLES" @@ -273,10 +337,10 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st for a specific table in a database. """ logger.info(f"TOOL START: get_table_schema called. database_name={database_name}, table_name={table_name}") - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.warning(f"TOOL WARNING: get_table_schema called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") - if not table_name or not table_name.isidentifier(): + if not self._is_valid_identifier(table_name): logger.warning(f"TOOL WARNING: get_table_schema called with invalid table_name: {table_name}") raise ValueError(f"Invalid table name provided: {table_name}") @@ -318,10 +382,10 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: Includes all basic schema info plus foreign key relationships and referenced tables. """ logger.info(f"TOOL START: get_table_schema_with_relations called. database_name={database_name}, table_name={table_name}") - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") - if not table_name or not table_name.isidentifier(): + if not self._is_valid_identifier(table_name): logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid table_name: {table_name}") raise ValueError(f"Invalid table name provided: {table_name}") @@ -389,7 +453,7 @@ async def execute_sql(self, sql_query: str, database_name: str, parameters: Opti Example `parameters`: ["value1", 123] corresponding to %s placeholders in `sql_query`. """ logger.info(f"TOOL START: execute_sql called. database_name={database_name}, sql_query={sql_query[:100]}, parameters={parameters}") - if database_name and not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.warning(f"TOOL WARNING: execute_sql called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") param_tuple = tuple(parameters) if parameters is not None else None @@ -406,7 +470,7 @@ async def create_database(self, database_name: str) -> Dict[str, Any]: Creates a new database if it doesn't exist. """ logger.info(f"TOOL START: create_database called for database: '{database_name}'") - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.error(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") @@ -455,10 +519,10 @@ async def create_vector_store_tool(self, logger.info(f"TOOL START: create_vector_store called. DB: '{database_name}', Store: '{vector_store_name}', Model: '{model_name}', Embedding_Length: {embedding_length}, Distance_Requested: '{distance_function}'") # --- Input Validation --- - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not vector_store_name or not vector_store_name.isidentifier(): + if not self._is_valid_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") @@ -550,7 +614,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: logger.info(f"TOOL START: list_vector_stores called for database: '{database_name}'") # --- Input Validation --- - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") @@ -614,10 +678,10 @@ async def delete_vector_store(self, logger.info(f"TOOL START: delete_vector_store called for: '{database_name}.{vector_store_name}'") # --- Input Validation for names --- - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not vector_store_name or not vector_store_name.isidentifier(): + if not self._is_valid_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") @@ -670,10 +734,10 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: If metadata is not provided, an empty dict will be used for each document. """ import json - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'") raise ValueError(f"Invalid database_name: '{database_name}'") - if not vector_store_name or not vector_store_name.isidentifier(): + if not self._is_valid_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") if not isinstance(documents, list) or not documents or not all(isinstance(doc, str) and doc for doc in documents): @@ -723,10 +787,10 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ if not user_query or not isinstance(user_query, str): logger.error("user_query must be a non-empty string.") raise ValueError("user_query must be a non-empty string.") - if not database_name or not database_name.isidentifier(): + if not self._is_valid_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'") raise ValueError(f"Invalid database_name: '{database_name}'") - if not vector_store_name or not vector_store_name.isidentifier(): + if not self._is_valid_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") if not isinstance(k, int) or k <= 0: From aa0131c1669e944384c9707320cce39eb0bd01f7 Mon Sep 17 00:00:00 2001 From: Jonathan Bastnagel Date: Thu, 14 Aug 2025 12:02:32 -0500 Subject: [PATCH 2/4] Fix syntax error --- src/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server.py b/src/server.py index dcb920d..31b3238 100644 --- a/src/server.py +++ b/src/server.py @@ -270,7 +270,7 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool: logger.debug(f"Checking if '{database_name}.{table_name}' is a vector store.") if not self._is_valid_identifier(database_name) or \ - not self._is_valid_identifier(table_name) + not self._is_valid_identifier(table_name): logger.warning(f"_is_vector_store called with invalid names: db='{database_name}', table='{table_name}'") return False From 3a11ed9c61ae1234689f55f4f3c5c6c80fbac6c3 Mon Sep 17 00:00:00 2001 From: Jonathan Bastnagel Date: Thu, 14 Aug 2025 14:35:21 -0500 Subject: [PATCH 3/4] Add _quote_identifier to quote identifiers, if they're already quoted use them as-is. --- src/server.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/server.py b/src/server.py index 31b3238..3e131cd 100644 --- a/src/server.py +++ b/src/server.py @@ -179,6 +179,25 @@ def _has_properly_escaped_backticks(self, identifier: str) -> bool: i += 1 return True + def _quote_identifier(self, identifier: str) -> str: + """ + quote's an identifier + If already quoted, returns as-is. If unquoted, wraps in backticks. + + Parameters: + - identifier (str): The identifier to quote + + Returns: + - str: quoted identifier + """ + if identifier.startswith('`') and identifier.endswith('`'): + # Already quoted, return as-is + return identifier + else: + # Unquoted, wrap in backticks and escape any existing backticks + escaped_content = identifier.replace('`', '``') + return f'`{escaped_content}`' + def _is_valid_identifier(self, identifier: str) -> bool: """ Validates MariaDB identifier (database, table, column names, etc.). @@ -344,7 +363,7 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st logger.warning(f"TOOL WARNING: get_table_schema called with invalid table_name: {table_name}") raise ValueError(f"Invalid table name provided: {table_name}") - sql = f"DESCRIBE `{database_name}`.`{table_name}`" + sql = f"DESCRIBE {self._quote_identifier(database_name)}.{self._quote_identifier(table_name)}" try: schema_results = await self._execute_query(sql) schema_info = {} @@ -480,7 +499,7 @@ async def create_database(self, database_name: str) -> Dict[str, Any]: logger.info(f"TOOL END: create_database. {message}") return {"status": "exists", "message": message, "database_name": database_name} - sql = f"CREATE DATABASE IF NOT EXISTS `{database_name}`;" + sql = f"CREATE DATABASE IF NOT EXISTS {self._quote_identifier(database_name)};" try: await self._execute_query(sql, database=None) @@ -568,7 +587,7 @@ async def create_vector_store_tool(self, # --- SQL Query for Vector Store Table Creation --- schema_query = f""" - CREATE TABLE IF NOT EXISTS `{vector_store_name}` ( + CREATE TABLE IF NOT EXISTS {self._quote_identifier(vector_store_name)} ( id VARCHAR(36) NOT NULL DEFAULT UUID_v7() PRIMARY KEY, document TEXT NOT NULL, embedding VECTOR({embedding_length}) NOT NULL, @@ -704,7 +723,7 @@ async def delete_vector_store(self, return {"status": "not_vector_store", "message": message} # --- SQL Query for Deletion --- - drop_query = f"DROP TABLE IF EXISTS `{vector_store_name}`;" + drop_query = f"DROP TABLE IF EXISTS {self._quote_identifier(vector_store_name)};" try: await self._execute_query(drop_query, database=database_name) @@ -754,7 +773,7 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: # Prepare metadata JSON metadata_json = [json.dumps(m) for m in metadata] # Prepare values for batch insert - insert_query = f"INSERT INTO `{database_name}`.`{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)" + insert_query = f"INSERT INTO {self._quote_identifier(database_name)}.{self._quote_identifier(vector_store_name)} (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)" inserted = 0 errors = [] for doc, emb, meta in zip(documents, embeddings, metadata_json): @@ -805,7 +824,7 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ document, metadata, VEC_DISTANCE_COSINE(embedding, VEC_FromText(%s)) AS distance - FROM `{database_name}`.`{vector_store_name}` + FROM {self._quote_identifier(database_name)}.{self._quote_identifier(vector_store_name)} ORDER BY distance ASC LIMIT %s """ From 65187f98410f96fede7eaf083d3da6799951fe63 Mon Sep 17 00:00:00 2001 From: Jonathan Bastnagel Date: Wed, 20 Aug 2025 18:55:34 -0500 Subject: [PATCH 4/4] Simplify implementation --- src/server.py | 183 +++++++++++++++++++++++--------------------------- 1 file changed, 85 insertions(+), 98 deletions(-) diff --git a/src/server.py b/src/server.py index 3e131cd..bd0ee98 100644 --- a/src/server.py +++ b/src/server.py @@ -181,15 +181,18 @@ def _has_properly_escaped_backticks(self, identifier: str) -> bool: def _quote_identifier(self, identifier: str) -> str: """ - quote's an identifier + Quote an identifier If already quoted, returns as-is. If unquoted, wraps in backticks. Parameters: - identifier (str): The identifier to quote Returns: - - str: quoted identifier + - str: Quoted identifier """ + if identifier is None: + raise ValueError("Identifier cannot be None") + if identifier.startswith('`') and identifier.endswith('`'): # Already quoted, return as-is return identifier @@ -198,58 +201,78 @@ def _quote_identifier(self, identifier: str) -> str: escaped_content = identifier.replace('`', '``') return f'`{escaped_content}`' + def _is_valid_identifier(self, identifier: str) -> bool: """ - Validates MariaDB identifier (database, table, column names, etc.). - Handles both quoted and unquoted identifiers. + Validates MariaDB identifier that will be quoted when used in SQL. + Accepts both quoted and unquoted identifiers since all identifiers + are treated as "will be quoted when needed". Parameters: - - identifier (str): quoted or unquoted identifier. + - identifier (str): identifier (quoted or unquoted). """ if not identifier: return False - import re + # If unquoted, quote it first, then validate as quoted identifier + if not (identifier.startswith('`') and identifier.endswith('`')): + identifier = self._quote_identifier(identifier) - if identifier.startswith('`') and identifier.endswith('`'): - # Quoted identifier rules - if len(identifier) <= 2: - return False - - actual_name = identifier[1:-1] - - # Check that any backticks inside are properly escaped (doubled) - if not self._has_properly_escaped_backticks(actual_name): - return False - - # Handle escaped backticks - escaped_name = actual_name.replace('``', '`') - - if len(escaped_name) > 64: - return False + # Now validate as quoted identifier + if len(identifier) <= 2: + return False - if escaped_name.endswith(' '): - return False + actual_name = identifier[1:-1] + + # Check that any backticks inside are properly escaped (doubled) + if not self._has_properly_escaped_backticks(actual_name): + return False + + # Handle escaped backticks to get the real length + escaped_name = actual_name.replace('``', '`') + + # Basic length check + if len(escaped_name) > 64: + return False + + # No trailing spaces allowed + if escaped_name.endswith(' '): + return False + + # No null characters + if '\x00' in escaped_name: + return False - # Allow full Unicode BMP except U+0000 - return bool(re.match(r'^[\u0001-\uFFFF]*$', escaped_name)) + return True + + def _normalize_identifier(self, identifier: str, method_name: str) -> str: + """ + Normalizes and validates an identifier for MCP methods. + Validates the identifier and returns the unquoted version. + + Parameters: + - identifier (str): identifier (quoted or unquoted) + - method_name (str): name of the calling method for error messages + + Returns: + - str: unquoted identifier + + Raises: + - ValueError: if identifier is invalid + """ + if not self._is_valid_identifier(identifier): + error_msg = f"Invalid identifier '{identifier}' in {method_name}" + logger.error(error_msg) + raise ValueError(error_msg) + + # Strip quotes if present + if identifier.startswith('`') and identifier.endswith('`'): + return identifier[1:-1].replace('``', '`') else: - # Unquoted identifier rules - if len(identifier) > 64: - return False - if identifier.endswith(' '): - return False - if identifier.isdigit(): - return False - # MariaDB unquoted identifiers: ASCII [0-9,a-z,A-Z$_] + Extended U+0080..U+FFFF - return bool(re.match(r'^[0-9a-zA-Z$_\u0080-\uFFFF]+$', identifier)) + return identifier async def _database_exists(self, database_name: str) -> bool: - """Checks if a database exists.""" - if not self._is_valid_identifier(database_name): - logger.warning(f"_database_exists called with invalid database_name: {database_name}") - return False - + """Checks if a database exists. Expects normalized (unquoted) database name.""" sql = "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = %s" try: results = await self._execute_query(sql, params=(database_name,), database='information_schema') @@ -259,12 +282,7 @@ async def _database_exists(self, database_name: str) -> bool: return False async def _table_exists(self, database_name: str, table_name: str) -> bool: - """Checks if a table exists in the given database.""" - if not self._is_valid_identifier(database_name) or \ - not self._is_valid_identifier(table_name): - logger.warning(f"_table_exists called with invalid names: db='{database_name}', table='{table_name}'") - return False - + """Checks if a table exists in the given database. Expects normalized (unquoted) names.""" sql = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s" try: results = await self._execute_query(sql, params=(database_name, table_name), database='information_schema') @@ -288,10 +306,7 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool: """ logger.debug(f"Checking if '{database_name}.{table_name}' is a vector store.") - if not self._is_valid_identifier(database_name) or \ - not self._is_valid_identifier(table_name): - logger.warning(f"_is_vector_store called with invalid names: db='{database_name}', table='{table_name}'") - return False + # Expects normalized (unquoted) names # SQL query to verify vector store criteria sql_query = """ @@ -337,9 +352,8 @@ async def list_databases(self) -> List[str]: async def list_tables(self, database_name: str) -> List[str]: """Lists all tables within the specified database.""" logger.info(f"TOOL START: list_tables called. database_name={database_name}") - if not self._is_valid_identifier(database_name): - logger.warning(f"TOOL WARNING: list_tables called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") + database_name = self._normalize_identifier(database_name, "list_tables") + sql = "SHOW TABLES" try: results = await self._execute_query(sql, database=database_name) @@ -356,18 +370,15 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st for a specific table in a database. """ logger.info(f"TOOL START: get_table_schema called. database_name={database_name}, table_name={table_name}") - if not self._is_valid_identifier(database_name): - logger.warning(f"TOOL WARNING: get_table_schema called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") - if not self._is_valid_identifier(table_name): - logger.warning(f"TOOL WARNING: get_table_schema called with invalid table_name: {table_name}") - raise ValueError(f"Invalid table name provided: {table_name}") + database_name = self._normalize_identifier(database_name, "get_table_schema") + table_name = self._normalize_identifier(table_name, "get_table_schema") sql = f"DESCRIBE {self._quote_identifier(database_name)}.{self._quote_identifier(table_name)}" try: schema_results = await self._execute_query(sql) schema_info = {} if not schema_results: + # Use normalized names for information_schema query exists_sql = "SELECT COUNT(*) as count FROM information_schema.tables WHERE table_schema = %s AND table_name = %s" exists_result = await self._execute_query(exists_sql, params=(database_name, table_name)) if not exists_result or exists_result[0]['count'] == 0: @@ -401,12 +412,8 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: Includes all basic schema info plus foreign key relationships and referenced tables. """ logger.info(f"TOOL START: get_table_schema_with_relations called. database_name={database_name}, table_name={table_name}") - if not self._is_valid_identifier(database_name): - logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") - if not self._is_valid_identifier(table_name): - logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid table_name: {table_name}") - raise ValueError(f"Invalid table name provided: {table_name}") + database_name = self._normalize_identifier(database_name, "get_table_schema_with_relations") + table_name = self._normalize_identifier(table_name, "get_table_schema_with_relations") try: # 1. Get basic schema information @@ -431,6 +438,7 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: ORDER BY kcu.CONSTRAINT_NAME, kcu.ORDINAL_POSITION """ + # Use normalized names for information_schema query fk_results = await self._execute_query(fk_sql, params=(database_name, table_name)) # 3. Add foreign key information to the basic schema @@ -472,9 +480,7 @@ async def execute_sql(self, sql_query: str, database_name: str, parameters: Opti Example `parameters`: ["value1", 123] corresponding to %s placeholders in `sql_query`. """ logger.info(f"TOOL START: execute_sql called. database_name={database_name}, sql_query={sql_query[:100]}, parameters={parameters}") - if not self._is_valid_identifier(database_name): - logger.warning(f"TOOL WARNING: execute_sql called with invalid database_name: {database_name}") - raise ValueError(f"Invalid database name provided: {database_name}") + database_name = self._normalize_identifier(database_name, "execute_sql") param_tuple = tuple(parameters) if parameters is not None else None try: results = await self._execute_query(sql_query, params=param_tuple, database=database_name) @@ -489,9 +495,7 @@ async def create_database(self, database_name: str) -> Dict[str, Any]: Creates a new database if it doesn't exist. """ logger.info(f"TOOL START: create_database called for database: '{database_name}'") - if not self._is_valid_identifier(database_name): - logger.error(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "create_database") # Check existence first to provide a clear message, though CREATE DATABASE IF NOT EXISTS is idempotent if await self._database_exists(database_name): @@ -538,12 +542,8 @@ async def create_vector_store_tool(self, logger.info(f"TOOL START: create_vector_store called. DB: '{database_name}', Store: '{vector_store_name}', Model: '{model_name}', Embedding_Length: {embedding_length}, Distance_Requested: '{distance_function}'") # --- Input Validation --- - if not self._is_valid_identifier(database_name): - logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not self._is_valid_identifier(vector_store_name): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "create_vector_store_tool") + vector_store_name = self._normalize_identifier(vector_store_name, "create_vector_store_tool") if not isinstance(embedding_length, int) or embedding_length <= 0: logger.error(f"Invalid embedding_length: {embedding_length}. Must be a positive integer.") @@ -633,9 +633,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: logger.info(f"TOOL START: list_vector_stores called for database: '{database_name}'") # --- Input Validation --- - if not self._is_valid_identifier(database_name): - logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "list_vector_stores") if not await self._database_exists(database_name): logger.warning(f"Database '{database_name}' does not exist. Cannot list vector stores.") @@ -660,6 +658,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: """ try: + # Use normalized name for information_schema query results = await self._execute_query(sql_query, params=(database_name,), database='information_schema') store_list = [row['TABLE_NAME'] for row in results if 'TABLE_NAME' in row] @@ -697,12 +696,8 @@ async def delete_vector_store(self, logger.info(f"TOOL START: delete_vector_store called for: '{database_name}.{vector_store_name}'") # --- Input Validation for names --- - if not self._is_valid_identifier(database_name): - logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not self._is_valid_identifier(vector_store_name): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") + database_name = self._normalize_identifier(database_name, "delete_vector_store") + vector_store_name = self._normalize_identifier(vector_store_name, "delete_vector_store") # --- Database Existence Check --- if not await self._database_exists(database_name): @@ -753,12 +748,8 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: If metadata is not provided, an empty dict will be used for each document. """ import json - if not self._is_valid_identifier(database_name): - logger.error(f"Invalid database_name: '{database_name}'") - raise ValueError(f"Invalid database_name: '{database_name}'") - if not self._is_valid_identifier(vector_store_name): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") + database_name = self._normalize_identifier(database_name, "insert_docs_vector_store") + vector_store_name = self._normalize_identifier(vector_store_name, "insert_docs_vector_store") if not isinstance(documents, list) or not documents or not all(isinstance(doc, str) and doc for doc in documents): logger.error("'documents' must be a non-empty list of non-empty strings.") raise ValueError("'documents' must be a non-empty list of non-empty strings.") @@ -806,12 +797,8 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ if not user_query or not isinstance(user_query, str): logger.error("user_query must be a non-empty string.") raise ValueError("user_query must be a non-empty string.") - if not self._is_valid_identifier(database_name): - logger.error(f"Invalid database_name: '{database_name}'") - raise ValueError(f"Invalid database_name: '{database_name}'") - if not self._is_valid_identifier(vector_store_name): - logger.error(f"Invalid vector_store_name: '{vector_store_name}'") - raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") + database_name = self._normalize_identifier(database_name, "search_vector_store") + vector_store_name = self._normalize_identifier(vector_store_name, "search_vector_store") if not isinstance(k, int) or k <= 0: logger.error("k must be a positive integer.") raise ValueError("k must be a positive integer.")